mirror of
				https://github.com/python/cpython.git
				synced 2025-10-26 11:14:33 +00:00 
			
		
		
		
	asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(),
call_at() and run_in_executor() now raise a TypeError if the callback is a coroutine function.
This commit is contained in:
		
							parent
							
								
									1db2ba3a92
								
							
						
					
					
						commit
						a125497ea3
					
				
					 6 changed files with 39 additions and 13 deletions
				
			
		|  | @ -227,6 +227,8 @@ def call_later(self, delay, callback, *args): | ||||||
| 
 | 
 | ||||||
|     def call_at(self, when, callback, *args): |     def call_at(self, when, callback, *args): | ||||||
|         """Like call_later(), but uses an absolute time.""" |         """Like call_later(), but uses an absolute time.""" | ||||||
|  |         if tasks.iscoroutinefunction(callback): | ||||||
|  |             raise TypeError("coroutines cannot be used with call_at()") | ||||||
|         timer = events.TimerHandle(when, callback, args) |         timer = events.TimerHandle(when, callback, args) | ||||||
|         heapq.heappush(self._scheduled, timer) |         heapq.heappush(self._scheduled, timer) | ||||||
|         return timer |         return timer | ||||||
|  | @ -241,6 +243,8 @@ def call_soon(self, callback, *args): | ||||||
|         Any positional arguments after the callback will be passed to |         Any positional arguments after the callback will be passed to | ||||||
|         the callback when it is called. |         the callback when it is called. | ||||||
|         """ |         """ | ||||||
|  |         if tasks.iscoroutinefunction(callback): | ||||||
|  |             raise TypeError("coroutines cannot be used with call_soon()") | ||||||
|         handle = events.Handle(callback, args) |         handle = events.Handle(callback, args) | ||||||
|         self._ready.append(handle) |         self._ready.append(handle) | ||||||
|         return handle |         return handle | ||||||
|  | @ -252,6 +256,8 @@ def call_soon_threadsafe(self, callback, *args): | ||||||
|         return handle |         return handle | ||||||
| 
 | 
 | ||||||
|     def run_in_executor(self, executor, callback, *args): |     def run_in_executor(self, executor, callback, *args): | ||||||
|  |         if tasks.iscoroutinefunction(callback): | ||||||
|  |             raise TypeError("coroutines cannot be used with run_in_executor()") | ||||||
|         if isinstance(callback, events.Handle): |         if isinstance(callback, events.Handle): | ||||||
|             assert not args |             assert not args | ||||||
|             assert not isinstance(callback, events.TimerHandle) |             assert not isinstance(callback, events.TimerHandle) | ||||||
|  |  | ||||||
|  | @ -135,7 +135,7 @@ def make_test_protocol(base): | ||||||
|         if name.startswith('__') and name.endswith('__'): |         if name.startswith('__') and name.endswith('__'): | ||||||
|             # skip magic names |             # skip magic names | ||||||
|             continue |             continue | ||||||
|         dct[name] = unittest.mock.Mock(return_value=None) |         dct[name] = MockCallback(return_value=None) | ||||||
|     return type('TestProtocol', (base,) + base.__bases__, dct)() |     return type('TestProtocol', (base,) + base.__bases__, dct)() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -274,3 +274,6 @@ def _process_events(self, event_list): | ||||||
| 
 | 
 | ||||||
|     def _write_to_self(self): |     def _write_to_self(self): | ||||||
|         pass |         pass | ||||||
|  | 
 | ||||||
|  | def MockCallback(**kwargs): | ||||||
|  |     return unittest.mock.Mock(spec=['__call__'], **kwargs) | ||||||
|  |  | ||||||
|  | @ -567,6 +567,7 @@ class Err(OSError): | ||||||
| 
 | 
 | ||||||
|         m_socket.getaddrinfo.return_value = [ |         m_socket.getaddrinfo.return_value = [ | ||||||
|             (2, 1, 6, '', ('127.0.0.1', 10100))] |             (2, 1, 6, '', ('127.0.0.1', 10100))] | ||||||
|  |         m_socket.getaddrinfo._is_coroutine = False | ||||||
|         m_sock = m_socket.socket.return_value = unittest.mock.Mock() |         m_sock = m_socket.socket.return_value = unittest.mock.Mock() | ||||||
|         m_sock.bind.side_effect = Err |         m_sock.bind.side_effect = Err | ||||||
| 
 | 
 | ||||||
|  | @ -577,6 +578,7 @@ class Err(OSError): | ||||||
|     @unittest.mock.patch('asyncio.base_events.socket') |     @unittest.mock.patch('asyncio.base_events.socket') | ||||||
|     def test_create_datagram_endpoint_no_addrinfo(self, m_socket): |     def test_create_datagram_endpoint_no_addrinfo(self, m_socket): | ||||||
|         m_socket.getaddrinfo.return_value = [] |         m_socket.getaddrinfo.return_value = [] | ||||||
|  |         m_socket.getaddrinfo._is_coroutine = False | ||||||
| 
 | 
 | ||||||
|         coro = self.loop.create_datagram_endpoint( |         coro = self.loop.create_datagram_endpoint( | ||||||
|             MyDatagramProto, local_addr=('localhost', 0)) |             MyDatagramProto, local_addr=('localhost', 0)) | ||||||
|  | @ -681,6 +683,22 @@ def test_accept_connection_exception(self, m_log): | ||||||
|                                                 unittest.mock.ANY, |                                                 unittest.mock.ANY, | ||||||
|                                                 MyProto, sock, None, None) |                                                 MyProto, sock, None, None) | ||||||
| 
 | 
 | ||||||
|  |     def test_call_coroutine(self): | ||||||
|  |         @asyncio.coroutine | ||||||
|  |         def coroutine_function(): | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |         with self.assertRaises(TypeError): | ||||||
|  |             self.loop.call_soon(coroutine_function) | ||||||
|  |         with self.assertRaises(TypeError): | ||||||
|  |             self.loop.call_soon_threadsafe(coroutine_function) | ||||||
|  |         with self.assertRaises(TypeError): | ||||||
|  |             self.loop.call_later(60, coroutine_function) | ||||||
|  |         with self.assertRaises(TypeError): | ||||||
|  |             self.loop.call_at(self.loop.time() + 60, coroutine_function) | ||||||
|  |         with self.assertRaises(TypeError): | ||||||
|  |             self.loop.run_in_executor(None, coroutine_function) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|  |  | ||||||
|  | @ -402,7 +402,7 @@ def test_socketpair(self): | ||||||
|             NotImplementedError, BaseProactorEventLoop, self.proactor) |             NotImplementedError, BaseProactorEventLoop, self.proactor) | ||||||
| 
 | 
 | ||||||
|     def test_make_socket_transport(self): |     def test_make_socket_transport(self): | ||||||
|         tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) |         tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) | ||||||
|         self.assertIsInstance(tr, _ProactorSocketTransport) |         self.assertIsInstance(tr, _ProactorSocketTransport) | ||||||
| 
 | 
 | ||||||
|     def test_loop_self_reading(self): |     def test_loop_self_reading(self): | ||||||
|  |  | ||||||
|  | @ -44,8 +44,8 @@ def setUp(self): | ||||||
|     def test_make_socket_transport(self): |     def test_make_socket_transport(self): | ||||||
|         m = unittest.mock.Mock() |         m = unittest.mock.Mock() | ||||||
|         self.loop.add_reader = unittest.mock.Mock() |         self.loop.add_reader = unittest.mock.Mock() | ||||||
|         self.assertIsInstance( |         transport = self.loop._make_socket_transport(m, asyncio.Protocol()) | ||||||
|             self.loop._make_socket_transport(m, m), _SelectorSocketTransport) |         self.assertIsInstance(transport, _SelectorSocketTransport) | ||||||
| 
 | 
 | ||||||
|     @unittest.skipIf(ssl is None, 'No ssl module') |     @unittest.skipIf(ssl is None, 'No ssl module') | ||||||
|     def test_make_ssl_transport(self): |     def test_make_ssl_transport(self): | ||||||
|  | @ -54,8 +54,9 @@ def test_make_ssl_transport(self): | ||||||
|         self.loop.add_writer = unittest.mock.Mock() |         self.loop.add_writer = unittest.mock.Mock() | ||||||
|         self.loop.remove_reader = unittest.mock.Mock() |         self.loop.remove_reader = unittest.mock.Mock() | ||||||
|         self.loop.remove_writer = unittest.mock.Mock() |         self.loop.remove_writer = unittest.mock.Mock() | ||||||
|         self.assertIsInstance( |         waiter = asyncio.Future(loop=self.loop) | ||||||
|             self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) |         transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter) | ||||||
|  |         self.assertIsInstance(transport, _SelectorSslTransport) | ||||||
| 
 | 
 | ||||||
|     @unittest.mock.patch('asyncio.selector_events.ssl', None) |     @unittest.mock.patch('asyncio.selector_events.ssl', None) | ||||||
|     def test_make_ssl_transport_without_ssl_error(self): |     def test_make_ssl_transport_without_ssl_error(self): | ||||||
|  |  | ||||||
|  | @ -2,8 +2,6 @@ | ||||||
| 
 | 
 | ||||||
| import gc | import gc | ||||||
| import unittest | import unittest | ||||||
| import unittest.mock |  | ||||||
| from unittest.mock import Mock |  | ||||||
| 
 | 
 | ||||||
| import asyncio | import asyncio | ||||||
| from asyncio import test_utils | from asyncio import test_utils | ||||||
|  | @ -1358,7 +1356,7 @@ def _run_loop(self, loop): | ||||||
|     def _check_success(self, **kwargs): |     def _check_success(self, **kwargs): | ||||||
|         a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] |         a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] | ||||||
|         fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) |         fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) | ||||||
|         cb = Mock() |         cb = test_utils.MockCallback() | ||||||
|         fut.add_done_callback(cb) |         fut.add_done_callback(cb) | ||||||
|         b.set_result(1) |         b.set_result(1) | ||||||
|         a.set_result(2) |         a.set_result(2) | ||||||
|  | @ -1380,7 +1378,7 @@ def test_result_exception_success(self): | ||||||
|     def test_one_exception(self): |     def test_one_exception(self): | ||||||
|         a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] |         a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] | ||||||
|         fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) |         fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) | ||||||
|         cb = Mock() |         cb = test_utils.MockCallback() | ||||||
|         fut.add_done_callback(cb) |         fut.add_done_callback(cb) | ||||||
|         exc = ZeroDivisionError() |         exc = ZeroDivisionError() | ||||||
|         a.set_result(1) |         a.set_result(1) | ||||||
|  | @ -1399,7 +1397,7 @@ def test_return_exceptions(self): | ||||||
|         a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] |         a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] | ||||||
|         fut = asyncio.gather(*self.wrap_futures(a, b, c, d), |         fut = asyncio.gather(*self.wrap_futures(a, b, c, d), | ||||||
|                              return_exceptions=True) |                              return_exceptions=True) | ||||||
|         cb = Mock() |         cb = test_utils.MockCallback() | ||||||
|         fut.add_done_callback(cb) |         fut.add_done_callback(cb) | ||||||
|         exc = ZeroDivisionError() |         exc = ZeroDivisionError() | ||||||
|         exc2 = RuntimeError() |         exc2 = RuntimeError() | ||||||
|  | @ -1460,7 +1458,7 @@ def test_constructor_homogenous_futures(self): | ||||||
|     def test_one_cancellation(self): |     def test_one_cancellation(self): | ||||||
|         a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] |         a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] | ||||||
|         fut = asyncio.gather(a, b, c, d, e) |         fut = asyncio.gather(a, b, c, d, e) | ||||||
|         cb = Mock() |         cb = test_utils.MockCallback() | ||||||
|         fut.add_done_callback(cb) |         fut.add_done_callback(cb) | ||||||
|         a.set_result(1) |         a.set_result(1) | ||||||
|         b.cancel() |         b.cancel() | ||||||
|  | @ -1479,7 +1477,7 @@ def test_result_exception_one_cancellation(self): | ||||||
|         a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) |         a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) | ||||||
|                             for i in range(6)] |                             for i in range(6)] | ||||||
|         fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) |         fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) | ||||||
|         cb = Mock() |         cb = test_utils.MockCallback() | ||||||
|         fut.add_done_callback(cb) |         fut.add_done_callback(cb) | ||||||
|         a.set_result(1) |         a.set_result(1) | ||||||
|         zde = ZeroDivisionError() |         zde = ZeroDivisionError() | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Victor Stinner
						Victor Stinner