mirror of
				https://github.com/python/cpython.git
				synced 2025-10-25 18:54:53 +00:00 
			
		
		
		
	Misc asyncio improvements from upstream (merge 3.5->3.6)
This commit is contained in:
		
						commit
						0035be3fee
					
				
					 8 changed files with 164 additions and 42 deletions
				
			
		|  | @ -115,24 +115,16 @@ def _ipaddr_info(host, port, family, type, proto): | ||||||
| 
 | 
 | ||||||
|     if port is None: |     if port is None: | ||||||
|         port = 0 |         port = 0 | ||||||
|     elif isinstance(port, bytes): |     elif isinstance(port, bytes) and port == b'': | ||||||
|         if port == b'': |         port = 0 | ||||||
|  |     elif isinstance(port, str) and port == '': | ||||||
|         port = 0 |         port = 0 | ||||||
|     else: |     else: | ||||||
|  |         # If port's a service name like "http", don't skip getaddrinfo. | ||||||
|         try: |         try: | ||||||
|             port = int(port) |             port = int(port) | ||||||
|             except ValueError: |         except (TypeError, ValueError): | ||||||
|                 # Might be a service name like b"http". |             return None | ||||||
|                 port = socket.getservbyname(port.decode('ascii')) |  | ||||||
|     elif isinstance(port, str): |  | ||||||
|         if port == '': |  | ||||||
|             port = 0 |  | ||||||
|         else: |  | ||||||
|             try: |  | ||||||
|                 port = int(port) |  | ||||||
|             except ValueError: |  | ||||||
|                 # Might be a service name like "http". |  | ||||||
|                 port = socket.getservbyname(port) |  | ||||||
| 
 | 
 | ||||||
|     if family == socket.AF_UNSPEC: |     if family == socket.AF_UNSPEC: | ||||||
|         afs = [socket.AF_INET, socket.AF_INET6] |         afs = [socket.AF_INET, socket.AF_INET6] | ||||||
|  |  | ||||||
|  | @ -3,7 +3,6 @@ | ||||||
| import warnings | import warnings | ||||||
| 
 | 
 | ||||||
| from . import compat | from . import compat | ||||||
| from . import futures |  | ||||||
| from . import protocols | from . import protocols | ||||||
| from . import transports | from . import transports | ||||||
| from .coroutines import coroutine | from .coroutines import coroutine | ||||||
|  |  | ||||||
|  | @ -120,8 +120,8 @@ def send(self, *value): | ||||||
|         def send(self, value): |         def send(self, value): | ||||||
|             return self.gen.send(value) |             return self.gen.send(value) | ||||||
| 
 | 
 | ||||||
|     def throw(self, exc): |     def throw(self, type, value=None, traceback=None): | ||||||
|         return self.gen.throw(exc) |         return self.gen.throw(type, value, traceback) | ||||||
| 
 | 
 | ||||||
|     def close(self): |     def close(self): | ||||||
|         return self.gen.close() |         return self.gen.close() | ||||||
|  |  | ||||||
|  | @ -7,7 +7,6 @@ | ||||||
| 
 | 
 | ||||||
| from . import compat | from . import compat | ||||||
| from . import events | from . import events | ||||||
| from . import futures |  | ||||||
| from . import locks | from . import locks | ||||||
| from .coroutines import coroutine | from .coroutines import coroutine | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -594,6 +594,10 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): | ||||||
|     """Return a future aggregating results from the given coroutines |     """Return a future aggregating results from the given coroutines | ||||||
|     or futures. |     or futures. | ||||||
| 
 | 
 | ||||||
|  |     Coroutines will be wrapped in a future and scheduled in the event | ||||||
|  |     loop. They will not necessarily be scheduled in the same order as | ||||||
|  |     passed in. | ||||||
|  | 
 | ||||||
|     All futures must share the same event loop.  If all the tasks are |     All futures must share the same event loop.  If all the tasks are | ||||||
|     done successfully, the returned future's result is the list of |     done successfully, the returned future's result is the list of | ||||||
|     results (in the order of the original sequence, not necessarily |     results (in the order of the original sequence, not necessarily | ||||||
|  |  | ||||||
|  | @ -142,26 +142,6 @@ def test_port_parameter_types(self): | ||||||
|             (INET, STREAM, TCP, '', ('1.2.3.4', 1)), |             (INET, STREAM, TCP, '', ('1.2.3.4', 1)), | ||||||
|             base_events._ipaddr_info('1.2.3.4', b'1', INET, STREAM, TCP)) |             base_events._ipaddr_info('1.2.3.4', b'1', INET, STREAM, TCP)) | ||||||
| 
 | 
 | ||||||
|     def test_getaddrinfo_servname(self): |  | ||||||
|         INET = socket.AF_INET |  | ||||||
|         STREAM = socket.SOCK_STREAM |  | ||||||
|         TCP = socket.IPPROTO_TCP |  | ||||||
| 
 |  | ||||||
|         self.assertEqual( |  | ||||||
|             (INET, STREAM, TCP, '', ('1.2.3.4', 80)), |  | ||||||
|             base_events._ipaddr_info('1.2.3.4', 'http', INET, STREAM, TCP)) |  | ||||||
| 
 |  | ||||||
|         self.assertEqual( |  | ||||||
|             (INET, STREAM, TCP, '', ('1.2.3.4', 80)), |  | ||||||
|             base_events._ipaddr_info('1.2.3.4', b'http', INET, STREAM, TCP)) |  | ||||||
| 
 |  | ||||||
|         # Raises "service/proto not found". |  | ||||||
|         with self.assertRaises(OSError): |  | ||||||
|             base_events._ipaddr_info('1.2.3.4', 'nonsense', INET, STREAM, TCP) |  | ||||||
| 
 |  | ||||||
|         with self.assertRaises(OSError): |  | ||||||
|             base_events._ipaddr_info('1.2.3.4', 'nonsense', INET, STREAM, TCP) |  | ||||||
| 
 |  | ||||||
|     @patch_socket |     @patch_socket | ||||||
|     def test_ipaddr_info_no_inet_pton(self, m_socket): |     def test_ipaddr_info_no_inet_pton(self, m_socket): | ||||||
|         del m_socket.inet_pton |         del m_socket.inet_pton | ||||||
|  | @ -1209,6 +1189,37 @@ def test_create_connection_ip_addr(self, m_socket): | ||||||
|     def test_create_connection_no_inet_pton(self, m_socket): |     def test_create_connection_no_inet_pton(self, m_socket): | ||||||
|         self._test_create_connection_ip_addr(m_socket, False) |         self._test_create_connection_ip_addr(m_socket, False) | ||||||
| 
 | 
 | ||||||
|  |     @patch_socket | ||||||
|  |     def test_create_connection_service_name(self, m_socket): | ||||||
|  |         m_socket.getaddrinfo = socket.getaddrinfo | ||||||
|  |         sock = m_socket.socket.return_value | ||||||
|  | 
 | ||||||
|  |         self.loop.add_reader = mock.Mock() | ||||||
|  |         self.loop.add_reader._is_coroutine = False | ||||||
|  |         self.loop.add_writer = mock.Mock() | ||||||
|  |         self.loop.add_writer._is_coroutine = False | ||||||
|  | 
 | ||||||
|  |         for service, port in ('http', 80), (b'http', 80): | ||||||
|  |             coro = self.loop.create_connection(asyncio.Protocol, | ||||||
|  |                                                '127.0.0.1', service) | ||||||
|  | 
 | ||||||
|  |             t, p = self.loop.run_until_complete(coro) | ||||||
|  |             try: | ||||||
|  |                 sock.connect.assert_called_with(('127.0.0.1', port)) | ||||||
|  |                 _, kwargs = m_socket.socket.call_args | ||||||
|  |                 self.assertEqual(kwargs['family'], m_socket.AF_INET) | ||||||
|  |                 self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM) | ||||||
|  |             finally: | ||||||
|  |                 t.close() | ||||||
|  |                 test_utils.run_briefly(self.loop)  # allow transport to close | ||||||
|  | 
 | ||||||
|  |         for service in 'nonsense', b'nonsense': | ||||||
|  |             coro = self.loop.create_connection(asyncio.Protocol, | ||||||
|  |                                                '127.0.0.1', service) | ||||||
|  | 
 | ||||||
|  |             with self.assertRaises(OSError): | ||||||
|  |                 self.loop.run_until_complete(coro) | ||||||
|  | 
 | ||||||
|     def test_create_connection_no_local_addr(self): |     def test_create_connection_no_local_addr(self): | ||||||
|         @asyncio.coroutine |         @asyncio.coroutine | ||||||
|         def getaddrinfo(host, *args, **kw): |         def getaddrinfo(host, *args, **kw): | ||||||
|  |  | ||||||
|  | @ -2,6 +2,8 @@ | ||||||
| 
 | 
 | ||||||
| import errno | import errno | ||||||
| import socket | import socket | ||||||
|  | import threading | ||||||
|  | import time | ||||||
| import unittest | import unittest | ||||||
| from unittest import mock | from unittest import mock | ||||||
| try: | try: | ||||||
|  | @ -1784,5 +1786,89 @@ def test_fatal_error_connected(self, m_exc): | ||||||
|                 'Fatal error on transport\nprotocol:.*\ntransport:.*'), |                 'Fatal error on transport\nprotocol:.*\ntransport:.*'), | ||||||
|             exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) |             exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | class SelectorLoopFunctionalTests(unittest.TestCase): | ||||||
|  | 
 | ||||||
|  |     def setUp(self): | ||||||
|  |         self.loop = asyncio.new_event_loop() | ||||||
|  |         asyncio.set_event_loop(None) | ||||||
|  | 
 | ||||||
|  |     def tearDown(self): | ||||||
|  |         self.loop.close() | ||||||
|  | 
 | ||||||
|  |     @asyncio.coroutine | ||||||
|  |     def recv_all(self, sock, nbytes): | ||||||
|  |         buf = b'' | ||||||
|  |         while len(buf) < nbytes: | ||||||
|  |             buf += yield from self.loop.sock_recv(sock, nbytes - len(buf)) | ||||||
|  |         return buf | ||||||
|  | 
 | ||||||
|  |     def test_sock_connect_sock_write_race(self): | ||||||
|  |         TIMEOUT = 3.0 | ||||||
|  |         PAYLOAD = b'DATA' * 1024 * 1024 | ||||||
|  | 
 | ||||||
|  |         class Server(threading.Thread): | ||||||
|  |             def __init__(self, *args, srv_sock, **kwargs): | ||||||
|  |                 super().__init__(*args, **kwargs) | ||||||
|  |                 self.srv_sock = srv_sock | ||||||
|  | 
 | ||||||
|  |             def run(self): | ||||||
|  |                 with self.srv_sock: | ||||||
|  |                     srv_sock.listen(100) | ||||||
|  | 
 | ||||||
|  |                     sock, addr = self.srv_sock.accept() | ||||||
|  |                     sock.settimeout(TIMEOUT) | ||||||
|  | 
 | ||||||
|  |                     with sock: | ||||||
|  |                         sock.sendall(b'helo') | ||||||
|  | 
 | ||||||
|  |                         buf = bytearray() | ||||||
|  |                         while len(buf) < len(PAYLOAD): | ||||||
|  |                             pack = sock.recv(1024 * 65) | ||||||
|  |                             if not pack: | ||||||
|  |                                 break | ||||||
|  |                             buf.extend(pack) | ||||||
|  | 
 | ||||||
|  |         @asyncio.coroutine | ||||||
|  |         def client(addr): | ||||||
|  |             sock = socket.socket() | ||||||
|  |             with sock: | ||||||
|  |                 sock.setblocking(False) | ||||||
|  | 
 | ||||||
|  |                 started = time.monotonic() | ||||||
|  |                 while True: | ||||||
|  |                     if time.monotonic() - started > TIMEOUT: | ||||||
|  |                         self.fail('unable to connect to the socket') | ||||||
|  |                         return | ||||||
|  |                     try: | ||||||
|  |                         yield from self.loop.sock_connect(sock, addr) | ||||||
|  |                     except OSError: | ||||||
|  |                         yield from asyncio.sleep(0.05, loop=self.loop) | ||||||
|  |                     else: | ||||||
|  |                         break | ||||||
|  | 
 | ||||||
|  |                 # Give 'Server' thread a chance to accept and send b'helo' | ||||||
|  |                 time.sleep(0.1) | ||||||
|  | 
 | ||||||
|  |                 data = yield from self.recv_all(sock, 4) | ||||||
|  |                 self.assertEqual(data, b'helo') | ||||||
|  |                 yield from self.loop.sock_sendall(sock, PAYLOAD) | ||||||
|  | 
 | ||||||
|  |         srv_sock = socket.socket() | ||||||
|  |         srv_sock.settimeout(TIMEOUT) | ||||||
|  |         srv_sock.bind(('127.0.0.1', 0)) | ||||||
|  |         srv_addr = srv_sock.getsockname() | ||||||
|  | 
 | ||||||
|  |         srv = Server(srv_sock=srv_sock, daemon=True) | ||||||
|  |         srv.start() | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             self.loop.run_until_complete( | ||||||
|  |                 asyncio.wait_for(client(srv_addr), loop=self.loop, | ||||||
|  |                                  timeout=TIMEOUT)) | ||||||
|  |         finally: | ||||||
|  |             srv.join() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|  |  | ||||||
|  | @ -1723,6 +1723,37 @@ def foo(): yield from [] | ||||||
|         wd['cw'] = cw  # Would fail without __weakref__ slot. |         wd['cw'] = cw  # Would fail without __weakref__ slot. | ||||||
|         cw.gen = None  # Suppress warning from __del__. |         cw.gen = None  # Suppress warning from __del__. | ||||||
| 
 | 
 | ||||||
|  |     def test_corowrapper_throw(self): | ||||||
|  |         # Issue 429: CoroWrapper.throw must be compatible with gen.throw | ||||||
|  |         def foo(): | ||||||
|  |             value = None | ||||||
|  |             while True: | ||||||
|  |                 try: | ||||||
|  |                     value = yield value | ||||||
|  |                 except Exception as e: | ||||||
|  |                     value = e | ||||||
|  | 
 | ||||||
|  |         exception = Exception("foo") | ||||||
|  |         cw = asyncio.coroutines.CoroWrapper(foo()) | ||||||
|  |         cw.send(None) | ||||||
|  |         self.assertIs(exception, cw.throw(exception)) | ||||||
|  | 
 | ||||||
|  |         cw = asyncio.coroutines.CoroWrapper(foo()) | ||||||
|  |         cw.send(None) | ||||||
|  |         self.assertIs(exception, cw.throw(Exception, exception)) | ||||||
|  | 
 | ||||||
|  |         cw = asyncio.coroutines.CoroWrapper(foo()) | ||||||
|  |         cw.send(None) | ||||||
|  |         exception = cw.throw(Exception, "foo") | ||||||
|  |         self.assertIsInstance(exception, Exception) | ||||||
|  |         self.assertEqual(exception.args, ("foo", )) | ||||||
|  | 
 | ||||||
|  |         cw = asyncio.coroutines.CoroWrapper(foo()) | ||||||
|  |         cw.send(None) | ||||||
|  |         exception = cw.throw(Exception, "foo", None) | ||||||
|  |         self.assertIsInstance(exception, Exception) | ||||||
|  |         self.assertEqual(exception.args, ("foo", )) | ||||||
|  | 
 | ||||||
|     @unittest.skipUnless(PY34, |     @unittest.skipUnless(PY34, | ||||||
|                          'need python 3.4 or later') |                          'need python 3.4 or later') | ||||||
|     def test_log_destroyed_pending_task(self): |     def test_log_destroyed_pending_task(self): | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Guido van Rossum
						Guido van Rossum