asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(),

call_at() and run_in_executor() now raise a TypeError if the callback is a
coroutine function.
This commit is contained in:
Victor Stinner 2014-02-11 11:34:30 +01:00
parent 1db2ba3a92
commit a125497ea3
6 changed files with 39 additions and 13 deletions

View file

@ -227,6 +227,8 @@ def call_later(self, delay, callback, *args):
def call_at(self, when, callback, *args): def call_at(self, when, callback, *args):
"""Like call_later(), but uses an absolute time.""" """Like call_later(), but uses an absolute time."""
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with call_at()")
timer = events.TimerHandle(when, callback, args) timer = events.TimerHandle(when, callback, args)
heapq.heappush(self._scheduled, timer) heapq.heappush(self._scheduled, timer)
return timer return timer
@ -241,6 +243,8 @@ def call_soon(self, callback, *args):
Any positional arguments after the callback will be passed to Any positional arguments after the callback will be passed to
the callback when it is called. the callback when it is called.
""" """
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with call_soon()")
handle = events.Handle(callback, args) handle = events.Handle(callback, args)
self._ready.append(handle) self._ready.append(handle)
return handle return handle
@ -252,6 +256,8 @@ def call_soon_threadsafe(self, callback, *args):
return handle return handle
def run_in_executor(self, executor, callback, *args): def run_in_executor(self, executor, callback, *args):
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with run_in_executor()")
if isinstance(callback, events.Handle): if isinstance(callback, events.Handle):
assert not args assert not args
assert not isinstance(callback, events.TimerHandle) assert not isinstance(callback, events.TimerHandle)

View file

@ -135,7 +135,7 @@ def make_test_protocol(base):
if name.startswith('__') and name.endswith('__'): if name.startswith('__') and name.endswith('__'):
# skip magic names # skip magic names
continue continue
dct[name] = unittest.mock.Mock(return_value=None) dct[name] = MockCallback(return_value=None)
return type('TestProtocol', (base,) + base.__bases__, dct)() return type('TestProtocol', (base,) + base.__bases__, dct)()
@ -274,3 +274,6 @@ def _process_events(self, event_list):
def _write_to_self(self): def _write_to_self(self):
pass pass
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)

View file

@ -567,6 +567,7 @@ class Err(OSError):
m_socket.getaddrinfo.return_value = [ m_socket.getaddrinfo.return_value = [
(2, 1, 6, '', ('127.0.0.1', 10100))] (2, 1, 6, '', ('127.0.0.1', 10100))]
m_socket.getaddrinfo._is_coroutine = False
m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock = m_socket.socket.return_value = unittest.mock.Mock()
m_sock.bind.side_effect = Err m_sock.bind.side_effect = Err
@ -577,6 +578,7 @@ class Err(OSError):
@unittest.mock.patch('asyncio.base_events.socket') @unittest.mock.patch('asyncio.base_events.socket')
def test_create_datagram_endpoint_no_addrinfo(self, m_socket): def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
m_socket.getaddrinfo.return_value = [] m_socket.getaddrinfo.return_value = []
m_socket.getaddrinfo._is_coroutine = False
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
MyDatagramProto, local_addr=('localhost', 0)) MyDatagramProto, local_addr=('localhost', 0))
@ -681,6 +683,22 @@ def test_accept_connection_exception(self, m_log):
unittest.mock.ANY, unittest.mock.ANY,
MyProto, sock, None, None) MyProto, sock, None, None)
def test_call_coroutine(self):
@asyncio.coroutine
def coroutine_function():
pass
with self.assertRaises(TypeError):
self.loop.call_soon(coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_soon_threadsafe(coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_later(60, coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_at(self.loop.time() + 60, coroutine_function)
with self.assertRaises(TypeError):
self.loop.run_in_executor(None, coroutine_function)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -402,7 +402,7 @@ def test_socketpair(self):
NotImplementedError, BaseProactorEventLoop, self.proactor) NotImplementedError, BaseProactorEventLoop, self.proactor)
def test_make_socket_transport(self): def test_make_socket_transport(self):
tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
self.assertIsInstance(tr, _ProactorSocketTransport) self.assertIsInstance(tr, _ProactorSocketTransport)
def test_loop_self_reading(self): def test_loop_self_reading(self):

View file

@ -44,8 +44,8 @@ def setUp(self):
def test_make_socket_transport(self): def test_make_socket_transport(self):
m = unittest.mock.Mock() m = unittest.mock.Mock()
self.loop.add_reader = unittest.mock.Mock() self.loop.add_reader = unittest.mock.Mock()
self.assertIsInstance( transport = self.loop._make_socket_transport(m, asyncio.Protocol())
self.loop._make_socket_transport(m, m), _SelectorSocketTransport) self.assertIsInstance(transport, _SelectorSocketTransport)
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_make_ssl_transport(self): def test_make_ssl_transport(self):
@ -54,8 +54,9 @@ def test_make_ssl_transport(self):
self.loop.add_writer = unittest.mock.Mock() self.loop.add_writer = unittest.mock.Mock()
self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock()
self.assertIsInstance( waiter = asyncio.Future(loop=self.loop)
self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None) @unittest.mock.patch('asyncio.selector_events.ssl', None)
def test_make_ssl_transport_without_ssl_error(self): def test_make_ssl_transport_without_ssl_error(self):

View file

@ -2,8 +2,6 @@
import gc import gc
import unittest import unittest
import unittest.mock
from unittest.mock import Mock
import asyncio import asyncio
from asyncio import test_utils from asyncio import test_utils
@ -1358,7 +1356,7 @@ def _run_loop(self, loop):
def _check_success(self, **kwargs): def _check_success(self, **kwargs):
a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
b.set_result(1) b.set_result(1)
a.set_result(2) a.set_result(2)
@ -1380,7 +1378,7 @@ def test_result_exception_success(self):
def test_one_exception(self): def test_one_exception(self):
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
exc = ZeroDivisionError() exc = ZeroDivisionError()
a.set_result(1) a.set_result(1)
@ -1399,7 +1397,7 @@ def test_return_exceptions(self):
a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
fut = asyncio.gather(*self.wrap_futures(a, b, c, d), fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
return_exceptions=True) return_exceptions=True)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
exc = ZeroDivisionError() exc = ZeroDivisionError()
exc2 = RuntimeError() exc2 = RuntimeError()
@ -1460,7 +1458,7 @@ def test_constructor_homogenous_futures(self):
def test_one_cancellation(self): def test_one_cancellation(self):
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = asyncio.gather(a, b, c, d, e) fut = asyncio.gather(a, b, c, d, e)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
a.set_result(1) a.set_result(1)
b.cancel() b.cancel()
@ -1479,7 +1477,7 @@ def test_result_exception_one_cancellation(self):
a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
for i in range(6)] for i in range(6)]
fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
a.set_result(1) a.set_result(1)
zde = ZeroDivisionError() zde = ZeroDivisionError()