mirror of
				https://github.com/python/cpython.git
				synced 2025-10-28 12:15:13 +00:00 
			
		
		
		
	 dce30c9cbc
			
		
	
	
		dce30c9cbc
		
			
		
	
	
	
	
		
			
			Some of the asyncio SSL changes in GH-31275 [1] were taken from v0.16.0 of the uvloop project [2]. In order to comply with the MIT license, we need to just need to document the copyright information. [1]: https://github.com/python/cpython/pull/31275 [2]: https://github.com/MagicStack/uvloop/tree/v0.16.0
		
			
				
	
	
		
			1742 lines
		
	
	
	
		
			53 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1742 lines
		
	
	
	
		
			53 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Contains code from https://github.com/MagicStack/uvloop/tree/v0.16.0
 | |
| # SPDX-License-Identifier: PSF-2.0 AND (MIT OR Apache-2.0)
 | |
| # SPDX-FileCopyrightText: Copyright (c) 2015-2021 MagicStack Inc.  http://magic.io
 | |
| 
 | |
| import asyncio
 | |
| import contextlib
 | |
| import gc
 | |
| import logging
 | |
| import select
 | |
| import socket
 | |
| import sys
 | |
| import tempfile
 | |
| import threading
 | |
| import time
 | |
| import weakref
 | |
| import unittest
 | |
| 
 | |
| try:
 | |
|     import ssl
 | |
| except ImportError:
 | |
|     ssl = None
 | |
| 
 | |
| from test import support
 | |
| from test.test_asyncio import utils as test_utils
 | |
| 
 | |
| 
 | |
| MACOS = (sys.platform == 'darwin')
 | |
| BUF_MULTIPLIER = 1024 if not MACOS else 64
 | |
| 
 | |
| 
 | |
| def tearDownModule():
 | |
|     asyncio.set_event_loop_policy(None)
 | |
| 
 | |
| 
 | |
| class MyBaseProto(asyncio.Protocol):
 | |
|     connected = None
 | |
|     done = None
 | |
| 
 | |
|     def __init__(self, loop=None):
 | |
|         self.transport = None
 | |
|         self.state = 'INITIAL'
 | |
|         self.nbytes = 0
 | |
|         if loop is not None:
 | |
|             self.connected = asyncio.Future(loop=loop)
 | |
|             self.done = asyncio.Future(loop=loop)
 | |
| 
 | |
|     def connection_made(self, transport):
 | |
|         self.transport = transport
 | |
|         assert self.state == 'INITIAL', self.state
 | |
|         self.state = 'CONNECTED'
 | |
|         if self.connected:
 | |
|             self.connected.set_result(None)
 | |
| 
 | |
|     def data_received(self, data):
 | |
|         assert self.state == 'CONNECTED', self.state
 | |
|         self.nbytes += len(data)
 | |
| 
 | |
|     def eof_received(self):
 | |
|         assert self.state == 'CONNECTED', self.state
 | |
|         self.state = 'EOF'
 | |
| 
 | |
|     def connection_lost(self, exc):
 | |
|         assert self.state in ('CONNECTED', 'EOF'), self.state
 | |
|         self.state = 'CLOSED'
 | |
|         if self.done:
 | |
|             self.done.set_result(None)
 | |
| 
 | |
| 
 | |
| class MessageOutFilter(logging.Filter):
 | |
|     def __init__(self, msg):
 | |
|         self.msg = msg
 | |
| 
 | |
|     def filter(self, record):
 | |
|         if self.msg in record.msg:
 | |
|             return False
 | |
|         return True
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(ssl is None, 'No ssl module')
 | |
| class TestSSL(test_utils.TestCase):
 | |
| 
 | |
|     PAYLOAD_SIZE = 1024 * 100
 | |
|     TIMEOUT = support.LONG_TIMEOUT
 | |
| 
 | |
|     def setUp(self):
 | |
|         super().setUp()
 | |
|         self.loop = asyncio.new_event_loop()
 | |
|         self.set_event_loop(self.loop)
 | |
|         self.addCleanup(self.loop.close)
 | |
| 
 | |
|     def tearDown(self):
 | |
|         # just in case if we have transport close callbacks
 | |
|         if not self.loop.is_closed():
 | |
|             test_utils.run_briefly(self.loop)
 | |
| 
 | |
|         self.doCleanups()
 | |
|         support.gc_collect()
 | |
|         super().tearDown()
 | |
| 
 | |
|     def tcp_server(self, server_prog, *,
 | |
|                    family=socket.AF_INET,
 | |
|                    addr=None,
 | |
|                    timeout=support.SHORT_TIMEOUT,
 | |
|                    backlog=1,
 | |
|                    max_clients=10):
 | |
| 
 | |
|         if addr is None:
 | |
|             if family == getattr(socket, "AF_UNIX", None):
 | |
|                 with tempfile.NamedTemporaryFile() as tmp:
 | |
|                     addr = tmp.name
 | |
|             else:
 | |
|                 addr = ('127.0.0.1', 0)
 | |
| 
 | |
|         sock = socket.socket(family, socket.SOCK_STREAM)
 | |
| 
 | |
|         if timeout is None:
 | |
|             raise RuntimeError('timeout is required')
 | |
|         if timeout <= 0:
 | |
|             raise RuntimeError('only blocking sockets are supported')
 | |
|         sock.settimeout(timeout)
 | |
| 
 | |
|         try:
 | |
|             sock.bind(addr)
 | |
|             sock.listen(backlog)
 | |
|         except OSError as ex:
 | |
|             sock.close()
 | |
|             raise ex
 | |
| 
 | |
|         return TestThreadedServer(
 | |
|             self, sock, server_prog, timeout, max_clients)
 | |
| 
 | |
|     def tcp_client(self, client_prog,
 | |
|                    family=socket.AF_INET,
 | |
|                    timeout=support.SHORT_TIMEOUT):
 | |
| 
 | |
|         sock = socket.socket(family, socket.SOCK_STREAM)
 | |
| 
 | |
|         if timeout is None:
 | |
|             raise RuntimeError('timeout is required')
 | |
|         if timeout <= 0:
 | |
|             raise RuntimeError('only blocking sockets are supported')
 | |
|         sock.settimeout(timeout)
 | |
| 
 | |
|         return TestThreadedClient(
 | |
|             self, sock, client_prog, timeout)
 | |
| 
 | |
|     def unix_server(self, *args, **kwargs):
 | |
|         return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
 | |
| 
 | |
|     def unix_client(self, *args, **kwargs):
 | |
|         return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
 | |
| 
 | |
|     def _create_server_ssl_context(self, certfile, keyfile=None):
 | |
|         sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
 | |
|         sslcontext.options |= ssl.OP_NO_SSLv2
 | |
|         sslcontext.load_cert_chain(certfile, keyfile)
 | |
|         return sslcontext
 | |
| 
 | |
|     def _create_client_ssl_context(self, *, disable_verify=True):
 | |
|         sslcontext = ssl.create_default_context()
 | |
|         sslcontext.check_hostname = False
 | |
|         if disable_verify:
 | |
|             sslcontext.verify_mode = ssl.CERT_NONE
 | |
|         return sslcontext
 | |
| 
 | |
|     @contextlib.contextmanager
 | |
|     def _silence_eof_received_warning(self):
 | |
|         # TODO This warning has to be fixed in asyncio.
 | |
|         logger = logging.getLogger('asyncio')
 | |
|         filter = MessageOutFilter('has no effect when using ssl')
 | |
|         logger.addFilter(filter)
 | |
|         try:
 | |
|             yield
 | |
|         finally:
 | |
|             logger.removeFilter(filter)
 | |
| 
 | |
|     def _abort_socket_test(self, ex):
 | |
|         try:
 | |
|             self.loop.stop()
 | |
|         finally:
 | |
|             self.fail(ex)
 | |
| 
 | |
|     def new_loop(self):
 | |
|         return asyncio.new_event_loop()
 | |
| 
 | |
|     def new_policy(self):
 | |
|         return asyncio.DefaultEventLoopPolicy()
 | |
| 
 | |
|     async def wait_closed(self, obj):
 | |
|         if not isinstance(obj, asyncio.StreamWriter):
 | |
|             return
 | |
|         try:
 | |
|             await obj.wait_closed()
 | |
|         except (BrokenPipeError, ConnectionError):
 | |
|             pass
 | |
| 
 | |
|     def test_create_server_ssl_1(self):
 | |
|         CNT = 0           # number of clients that were successful
 | |
|         TOTAL_CNT = 25    # total number of clients that test will create
 | |
|         TIMEOUT = support.LONG_TIMEOUT  # timeout for this test
 | |
| 
 | |
|         A_DATA = b'A' * 1024 * BUF_MULTIPLIER
 | |
|         B_DATA = b'B' * 1024 * BUF_MULTIPLIER
 | |
| 
 | |
|         sslctx = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT, test_utils.ONLYKEY
 | |
|         )
 | |
|         client_sslctx = self._create_client_ssl_context()
 | |
| 
 | |
|         clients = []
 | |
| 
 | |
|         async def handle_client(reader, writer):
 | |
|             nonlocal CNT
 | |
| 
 | |
|             data = await reader.readexactly(len(A_DATA))
 | |
|             self.assertEqual(data, A_DATA)
 | |
|             writer.write(b'OK')
 | |
| 
 | |
|             data = await reader.readexactly(len(B_DATA))
 | |
|             self.assertEqual(data, B_DATA)
 | |
|             writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
 | |
| 
 | |
|             await writer.drain()
 | |
|             writer.close()
 | |
| 
 | |
|             CNT += 1
 | |
| 
 | |
|         async def test_client(addr):
 | |
|             fut = asyncio.Future()
 | |
| 
 | |
|             def prog(sock):
 | |
|                 try:
 | |
|                     sock.starttls(client_sslctx)
 | |
|                     sock.connect(addr)
 | |
|                     sock.send(A_DATA)
 | |
| 
 | |
|                     data = sock.recv_all(2)
 | |
|                     self.assertEqual(data, b'OK')
 | |
| 
 | |
|                     sock.send(B_DATA)
 | |
|                     data = sock.recv_all(4)
 | |
|                     self.assertEqual(data, b'SPAM')
 | |
| 
 | |
|                     sock.close()
 | |
| 
 | |
|                 except Exception as ex:
 | |
|                     self.loop.call_soon_threadsafe(fut.set_exception, ex)
 | |
|                 else:
 | |
|                     self.loop.call_soon_threadsafe(fut.set_result, None)
 | |
| 
 | |
|             client = self.tcp_client(prog)
 | |
|             client.start()
 | |
|             clients.append(client)
 | |
| 
 | |
|             await fut
 | |
| 
 | |
|         async def start_server():
 | |
|             extras = {}
 | |
|             extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
 | |
| 
 | |
|             srv = await asyncio.start_server(
 | |
|                 handle_client,
 | |
|                 '127.0.0.1', 0,
 | |
|                 family=socket.AF_INET,
 | |
|                 ssl=sslctx,
 | |
|                 **extras)
 | |
| 
 | |
|             try:
 | |
|                 srv_socks = srv.sockets
 | |
|                 self.assertTrue(srv_socks)
 | |
| 
 | |
|                 addr = srv_socks[0].getsockname()
 | |
| 
 | |
|                 tasks = []
 | |
|                 for _ in range(TOTAL_CNT):
 | |
|                     tasks.append(test_client(addr))
 | |
| 
 | |
|                 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
 | |
| 
 | |
|             finally:
 | |
|                 self.loop.call_soon(srv.close)
 | |
|                 await srv.wait_closed()
 | |
| 
 | |
|         with self._silence_eof_received_warning():
 | |
|             self.loop.run_until_complete(start_server())
 | |
| 
 | |
|         self.assertEqual(CNT, TOTAL_CNT)
 | |
| 
 | |
|         for client in clients:
 | |
|             client.stop()
 | |
| 
 | |
|     def test_create_connection_ssl_1(self):
 | |
|         self.loop.set_exception_handler(None)
 | |
| 
 | |
|         CNT = 0
 | |
|         TOTAL_CNT = 25
 | |
| 
 | |
|         A_DATA = b'A' * 1024 * BUF_MULTIPLIER
 | |
|         B_DATA = b'B' * 1024 * BUF_MULTIPLIER
 | |
| 
 | |
|         sslctx = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT,
 | |
|             test_utils.ONLYKEY
 | |
|         )
 | |
|         client_sslctx = self._create_client_ssl_context()
 | |
| 
 | |
|         def server(sock):
 | |
|             sock.starttls(
 | |
|                 sslctx,
 | |
|                 server_side=True)
 | |
| 
 | |
|             data = sock.recv_all(len(A_DATA))
 | |
|             self.assertEqual(data, A_DATA)
 | |
|             sock.send(b'OK')
 | |
| 
 | |
|             data = sock.recv_all(len(B_DATA))
 | |
|             self.assertEqual(data, B_DATA)
 | |
|             sock.send(b'SPAM')
 | |
| 
 | |
|             sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             extras = {}
 | |
|             extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
 | |
| 
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='',
 | |
|                 **extras)
 | |
| 
 | |
|             writer.write(A_DATA)
 | |
|             self.assertEqual(await reader.readexactly(2), b'OK')
 | |
| 
 | |
|             writer.write(B_DATA)
 | |
|             self.assertEqual(await reader.readexactly(4), b'SPAM')
 | |
| 
 | |
|             nonlocal CNT
 | |
|             CNT += 1
 | |
| 
 | |
|             writer.close()
 | |
|             await self.wait_closed(writer)
 | |
| 
 | |
|         async def client_sock(addr):
 | |
|             sock = socket.socket()
 | |
|             sock.connect(addr)
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 sock=sock,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='')
 | |
| 
 | |
|             writer.write(A_DATA)
 | |
|             self.assertEqual(await reader.readexactly(2), b'OK')
 | |
| 
 | |
|             writer.write(B_DATA)
 | |
|             self.assertEqual(await reader.readexactly(4), b'SPAM')
 | |
| 
 | |
|             nonlocal CNT
 | |
|             CNT += 1
 | |
| 
 | |
|             writer.close()
 | |
|             await self.wait_closed(writer)
 | |
|             sock.close()
 | |
| 
 | |
|         def run(coro):
 | |
|             nonlocal CNT
 | |
|             CNT = 0
 | |
| 
 | |
|             async def _gather(*tasks):
 | |
|                 # trampoline
 | |
|                 return await asyncio.gather(*tasks)
 | |
| 
 | |
|             with self.tcp_server(server,
 | |
|                                  max_clients=TOTAL_CNT,
 | |
|                                  backlog=TOTAL_CNT) as srv:
 | |
|                 tasks = []
 | |
|                 for _ in range(TOTAL_CNT):
 | |
|                     tasks.append(coro(srv.addr))
 | |
| 
 | |
|                 self.loop.run_until_complete(_gather(*tasks))
 | |
| 
 | |
|             self.assertEqual(CNT, TOTAL_CNT)
 | |
| 
 | |
|         with self._silence_eof_received_warning():
 | |
|             run(client)
 | |
| 
 | |
|         with self._silence_eof_received_warning():
 | |
|             run(client_sock)
 | |
| 
 | |
|     def test_create_connection_ssl_slow_handshake(self):
 | |
|         client_sslctx = self._create_client_ssl_context()
 | |
| 
 | |
|         # silence error logger
 | |
|         self.loop.set_exception_handler(lambda *args: None)
 | |
| 
 | |
|         def server(sock):
 | |
|             try:
 | |
|                 sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='',
 | |
|                 ssl_handshake_timeout=1.0)
 | |
|             writer.close()
 | |
|             await self.wait_closed(writer)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaisesRegex(
 | |
|                     ConnectionAbortedError,
 | |
|                     r'SSL handshake.*is taking longer'):
 | |
| 
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|     def test_create_connection_ssl_failed_certificate(self):
 | |
|         # silence error logger
 | |
|         self.loop.set_exception_handler(lambda *args: None)
 | |
| 
 | |
|         sslctx = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT,
 | |
|             test_utils.ONLYKEY
 | |
|         )
 | |
|         client_sslctx = self._create_client_ssl_context(disable_verify=False)
 | |
| 
 | |
|         def server(sock):
 | |
|             try:
 | |
|                 sock.starttls(
 | |
|                     sslctx,
 | |
|                     server_side=True)
 | |
|                 sock.connect()
 | |
|             except (ssl.SSLError, OSError):
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='',
 | |
|                 ssl_handshake_timeout=support.SHORT_TIMEOUT)
 | |
|             writer.close()
 | |
|             await self.wait_closed(writer)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaises(ssl.SSLCertVerificationError):
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|     def test_ssl_handshake_timeout(self):
 | |
|         # bpo-29970: Check that a connection is aborted if handshake is not
 | |
|         # completed in timeout period, instead of remaining open indefinitely
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         # silence error logger
 | |
|         messages = []
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 | |
| 
 | |
|         server_side_aborted = False
 | |
| 
 | |
|         def server(sock):
 | |
|             nonlocal server_side_aborted
 | |
|             try:
 | |
|                 sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 server_side_aborted = True
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.wait_for(
 | |
|                 self.loop.create_connection(
 | |
|                     asyncio.Protocol,
 | |
|                     *addr,
 | |
|                     ssl=client_sslctx,
 | |
|                     server_hostname='',
 | |
|                     ssl_handshake_timeout=10.0),
 | |
|                 0.5)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaises(asyncio.TimeoutError):
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         self.assertTrue(server_side_aborted)
 | |
| 
 | |
|         # Python issue #23197: cancelling a handshake must not raise an
 | |
|         # exception or log an error, even if the handshake failed
 | |
|         self.assertEqual(messages, [])
 | |
| 
 | |
|     def test_ssl_handshake_connection_lost(self):
 | |
|         # #246: make sure that no connection_lost() is called before
 | |
|         # connection_made() is called first
 | |
| 
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         # silence error logger
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: None)
 | |
| 
 | |
|         connection_made_called = False
 | |
|         connection_lost_called = False
 | |
| 
 | |
|         def server(sock):
 | |
|             sock.recv(1024)
 | |
|             # break the connection during handshake
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def connection_made(self, transport):
 | |
|                 nonlocal connection_made_called
 | |
|                 connection_made_called = True
 | |
| 
 | |
|             def connection_lost(self, exc):
 | |
|                 nonlocal connection_lost_called
 | |
|                 connection_lost_called = True
 | |
| 
 | |
|         async def client(addr):
 | |
|             await self.loop.create_connection(
 | |
|                 ClientProto,
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname=''),
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaises(ConnectionResetError):
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         if connection_lost_called:
 | |
|             if connection_made_called:
 | |
|                 self.fail("unexpected call to connection_lost()")
 | |
|             else:
 | |
|                 self.fail("unexpected call to connection_lost() without"
 | |
|                           "calling connection_made()")
 | |
|         elif connection_made_called:
 | |
|             self.fail("unexpected call to connection_made()")
 | |
| 
 | |
|     def test_ssl_connect_accepted_socket(self):
 | |
|         proto = ssl.PROTOCOL_TLS_SERVER
 | |
|         server_context = ssl.SSLContext(proto)
 | |
|         server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
 | |
|         if hasattr(server_context, 'check_hostname'):
 | |
|             server_context.check_hostname = False
 | |
|         server_context.verify_mode = ssl.CERT_NONE
 | |
| 
 | |
|         client_context = ssl.SSLContext(proto)
 | |
|         if hasattr(server_context, 'check_hostname'):
 | |
|             client_context.check_hostname = False
 | |
|         client_context.verify_mode = ssl.CERT_NONE
 | |
| 
 | |
|     def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
 | |
|         loop = self.loop
 | |
| 
 | |
|         class MyProto(MyBaseProto):
 | |
| 
 | |
|             def connection_lost(self, exc):
 | |
|                 super().connection_lost(exc)
 | |
|                 loop.call_soon(loop.stop)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 super().data_received(data)
 | |
|                 self.transport.write(expected_response)
 | |
| 
 | |
|         lsock = socket.socket(socket.AF_INET)
 | |
|         lsock.bind(('127.0.0.1', 0))
 | |
|         lsock.listen(1)
 | |
|         addr = lsock.getsockname()
 | |
| 
 | |
|         message = b'test data'
 | |
|         response = None
 | |
|         expected_response = b'roger'
 | |
| 
 | |
|         def client():
 | |
|             nonlocal response
 | |
|             try:
 | |
|                 csock = socket.socket(socket.AF_INET)
 | |
|                 if client_ssl is not None:
 | |
|                     csock = client_ssl.wrap_socket(csock)
 | |
|                 csock.connect(addr)
 | |
|                 csock.sendall(message)
 | |
|                 response = csock.recv(99)
 | |
|                 csock.close()
 | |
|             except Exception as exc:
 | |
|                 print(
 | |
|                     "Failure in client thread in test_connect_accepted_socket",
 | |
|                     exc)
 | |
| 
 | |
|         thread = threading.Thread(target=client, daemon=True)
 | |
|         thread.start()
 | |
| 
 | |
|         conn, _ = lsock.accept()
 | |
|         proto = MyProto(loop=loop)
 | |
|         proto.loop = loop
 | |
| 
 | |
|         extras = {}
 | |
|         if server_ssl:
 | |
|             extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
 | |
| 
 | |
|         f = loop.create_task(
 | |
|             loop.connect_accepted_socket(
 | |
|                 (lambda: proto), conn, ssl=server_ssl,
 | |
|                 **extras))
 | |
|         loop.run_forever()
 | |
|         conn.close()
 | |
|         lsock.close()
 | |
| 
 | |
|         thread.join(1)
 | |
|         self.assertFalse(thread.is_alive())
 | |
|         self.assertEqual(proto.state, 'CLOSED')
 | |
|         self.assertEqual(proto.nbytes, len(message))
 | |
|         self.assertEqual(response, expected_response)
 | |
|         tr, _ = f.result()
 | |
| 
 | |
|         if server_ssl:
 | |
|             self.assertIn('SSL', tr.__class__.__name__)
 | |
| 
 | |
|         tr.close()
 | |
|         # let it close
 | |
|         self.loop.run_until_complete(asyncio.sleep(0.1))
 | |
| 
 | |
|     def test_start_tls_client_corrupted_ssl(self):
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: None)
 | |
| 
 | |
|         sslctx = test_utils.simple_server_sslcontext()
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def server(sock):
 | |
|             orig_sock = sock.dup()
 | |
|             try:
 | |
|                 sock.starttls(
 | |
|                     sslctx,
 | |
|                     server_side=True)
 | |
|                 sock.sendall(b'A\n')
 | |
|                 sock.recv_all(1)
 | |
|                 orig_sock.send(b'please corrupt the SSL connection')
 | |
|             except ssl.SSLError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
|                 orig_sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='')
 | |
| 
 | |
|             self.assertEqual(await reader.readline(), b'A\n')
 | |
|             writer.write(b'B')
 | |
|             with self.assertRaises(ssl.SSLError):
 | |
|                 await reader.readline()
 | |
|             writer.close()
 | |
|             try:
 | |
|                 await self.wait_closed(writer)
 | |
|             except ssl.SSLError:
 | |
|                 pass
 | |
|             return 'OK'
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             res = self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         self.assertEqual(res, 'OK')
 | |
| 
 | |
|     def test_start_tls_client_reg_proto_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.starttls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.unwrap()
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
|             new_tr = await self.loop.start_tls(tr, proto, client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data, b'O')
 | |
|             new_tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             new_tr.close()
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|     def test_create_connection_memory_leak(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT, test_utils.ONLYKEY)
 | |
|         client_context = self._create_client_ssl_context()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             sock.starttls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.unwrap()
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 # XXX: We assume user stores the transport in protocol
 | |
|                 proto.tr = tr
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr,
 | |
|                 ssl=client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data, b'O')
 | |
|             tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             tr.close()
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|         # No garbage is left for SSL client from loop.create_connection, even
 | |
|         # if user stores the SSLTransport in corresponding protocol instance
 | |
|         client_context = weakref.ref(client_context)
 | |
|         self.assertIsNone(client_context())
 | |
| 
 | |
|     def test_start_tls_client_buf_proto_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         client_con_made_calls = 0
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.starttls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.sendall(b'2')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.unwrap()
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProtoFirst(asyncio.BufferedProtocol):
 | |
|             def __init__(self, on_data):
 | |
|                 self.on_data = on_data
 | |
|                 self.buf = bytearray(1)
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 nonlocal client_con_made_calls
 | |
|                 client_con_made_calls += 1
 | |
| 
 | |
|             def get_buffer(self, sizehint):
 | |
|                 return self.buf
 | |
| 
 | |
|             def buffer_updated(self, nsize):
 | |
|                 assert nsize == 1
 | |
|                 self.on_data.set_result(bytes(self.buf[:nsize]))
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 pass
 | |
| 
 | |
|         class ClientProtoSecond(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 nonlocal client_con_made_calls
 | |
|                 client_con_made_calls += 1
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data1 = self.loop.create_future()
 | |
|             on_data2 = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProtoFirst(on_data1), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
|             new_tr = await self.loop.start_tls(tr, proto, client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data1, b'O')
 | |
|             new_tr.write(HELLO_MSG)
 | |
| 
 | |
|             new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
 | |
|             self.assertEqual(await on_data2, b'2')
 | |
|             new_tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             new_tr.close()
 | |
| 
 | |
|             # connection_made() should be called only once -- when
 | |
|             # we establish connection for the first time. Start TLS
 | |
|             # doesn't call connection_made() on application protocols.
 | |
|             self.assertEqual(client_con_made_calls, 1)
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=self.TIMEOUT))
 | |
| 
 | |
|     def test_start_tls_slow_client_cancel(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
|         server_waits_on_handshake = self.loop.create_future()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             try:
 | |
|                 self.loop.call_soon_threadsafe(
 | |
|                     server_waits_on_handshake.set_result, None)
 | |
|                 data = sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
| 
 | |
|             await server_waits_on_handshake
 | |
| 
 | |
|             with self.assertRaises(asyncio.TimeoutError):
 | |
|                 await asyncio.wait_for(
 | |
|                     self.loop.start_tls(tr, proto, client_context),
 | |
|                     0.5)
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|     def test_start_tls_server_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def client(sock, addr):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             sock.connect(addr)
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.starttls(client_context)
 | |
|             sock.sendall(HELLO_MSG)
 | |
| 
 | |
|             sock.unwrap()
 | |
|             sock.close()
 | |
| 
 | |
|         class ServerProto(asyncio.Protocol):
 | |
|             def __init__(self, on_con, on_eof, on_con_lost):
 | |
|                 self.on_con = on_con
 | |
|                 self.on_eof = on_eof
 | |
|                 self.on_con_lost = on_con_lost
 | |
|                 self.data = b''
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 self.on_con.set_result(tr)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.data += data
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(1)
 | |
| 
 | |
|             def connection_lost(self, exc):
 | |
|                 if exc is None:
 | |
|                     self.on_con_lost.set_result(None)
 | |
|                 else:
 | |
|                     self.on_con_lost.set_exception(exc)
 | |
| 
 | |
|         async def main(proto, on_con, on_eof, on_con_lost):
 | |
|             tr = await on_con
 | |
|             tr.write(HELLO_MSG)
 | |
| 
 | |
|             self.assertEqual(proto.data, b'')
 | |
| 
 | |
|             new_tr = await self.loop.start_tls(
 | |
|                 tr, proto, server_context,
 | |
|                 server_side=True,
 | |
|                 ssl_handshake_timeout=self.TIMEOUT)
 | |
| 
 | |
|             await on_eof
 | |
|             await on_con_lost
 | |
|             self.assertEqual(proto.data, HELLO_MSG)
 | |
|             new_tr.close()
 | |
| 
 | |
|         async def run_main():
 | |
|             on_con = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
|             on_con_lost = self.loop.create_future()
 | |
|             proto = ServerProto(on_con, on_eof, on_con_lost)
 | |
| 
 | |
|             server = await self.loop.create_server(
 | |
|                 lambda: proto, '127.0.0.1', 0)
 | |
|             addr = server.sockets[0].getsockname()
 | |
| 
 | |
|             with self.tcp_client(lambda sock: client(sock, addr),
 | |
|                                  timeout=self.TIMEOUT):
 | |
|                 await asyncio.wait_for(
 | |
|                     main(proto, on_con, on_eof, on_con_lost),
 | |
|                     timeout=self.TIMEOUT)
 | |
| 
 | |
|             server.close()
 | |
|             await server.wait_closed()
 | |
| 
 | |
|         self.loop.run_until_complete(run_main())
 | |
| 
 | |
|     def test_create_server_ssl_over_ssl(self):
 | |
|         CNT = 0           # number of clients that were successful
 | |
|         TOTAL_CNT = 25    # total number of clients that test will create
 | |
|         TIMEOUT = support.LONG_TIMEOUT  # timeout for this test
 | |
| 
 | |
|         A_DATA = b'A' * 1024 * BUF_MULTIPLIER
 | |
|         B_DATA = b'B' * 1024 * BUF_MULTIPLIER
 | |
| 
 | |
|         sslctx_1 = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT, test_utils.ONLYKEY)
 | |
|         client_sslctx_1 = self._create_client_ssl_context()
 | |
|         sslctx_2 = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT, test_utils.ONLYKEY)
 | |
|         client_sslctx_2 = self._create_client_ssl_context()
 | |
| 
 | |
|         clients = []
 | |
| 
 | |
|         async def handle_client(reader, writer):
 | |
|             nonlocal CNT
 | |
| 
 | |
|             data = await reader.readexactly(len(A_DATA))
 | |
|             self.assertEqual(data, A_DATA)
 | |
|             writer.write(b'OK')
 | |
| 
 | |
|             data = await reader.readexactly(len(B_DATA))
 | |
|             self.assertEqual(data, B_DATA)
 | |
|             writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
 | |
| 
 | |
|             await writer.drain()
 | |
|             writer.close()
 | |
| 
 | |
|             CNT += 1
 | |
| 
 | |
|         class ServerProtocol(asyncio.StreamReaderProtocol):
 | |
|             def connection_made(self, transport):
 | |
|                 super_ = super()
 | |
|                 transport.pause_reading()
 | |
|                 fut = self._loop.create_task(self._loop.start_tls(
 | |
|                     transport, self, sslctx_2, server_side=True))
 | |
| 
 | |
|                 def cb(_):
 | |
|                     try:
 | |
|                         tr = fut.result()
 | |
|                     except Exception as ex:
 | |
|                         super_.connection_lost(ex)
 | |
|                     else:
 | |
|                         super_.connection_made(tr)
 | |
|                 fut.add_done_callback(cb)
 | |
| 
 | |
|         def server_protocol_factory():
 | |
|             reader = asyncio.StreamReader()
 | |
|             protocol = ServerProtocol(reader, handle_client)
 | |
|             return protocol
 | |
| 
 | |
|         async def test_client(addr):
 | |
|             fut = asyncio.Future()
 | |
| 
 | |
|             def prog(sock):
 | |
|                 try:
 | |
|                     sock.connect(addr)
 | |
|                     sock.starttls(client_sslctx_1)
 | |
| 
 | |
|                     # because wrap_socket() doesn't work correctly on
 | |
|                     # SSLSocket, we have to do the 2nd level SSL manually
 | |
|                     incoming = ssl.MemoryBIO()
 | |
|                     outgoing = ssl.MemoryBIO()
 | |
|                     sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
 | |
| 
 | |
|                     def do(func, *args):
 | |
|                         while True:
 | |
|                             try:
 | |
|                                 rv = func(*args)
 | |
|                                 break
 | |
|                             except ssl.SSLWantReadError:
 | |
|                                 if outgoing.pending:
 | |
|                                     sock.send(outgoing.read())
 | |
|                                 incoming.write(sock.recv(65536))
 | |
|                         if outgoing.pending:
 | |
|                             sock.send(outgoing.read())
 | |
|                         return rv
 | |
| 
 | |
|                     do(sslobj.do_handshake)
 | |
| 
 | |
|                     do(sslobj.write, A_DATA)
 | |
|                     data = do(sslobj.read, 2)
 | |
|                     self.assertEqual(data, b'OK')
 | |
| 
 | |
|                     do(sslobj.write, B_DATA)
 | |
|                     data = b''
 | |
|                     while True:
 | |
|                         chunk = do(sslobj.read, 4)
 | |
|                         if not chunk:
 | |
|                             break
 | |
|                         data += chunk
 | |
|                     self.assertEqual(data, b'SPAM')
 | |
| 
 | |
|                     do(sslobj.unwrap)
 | |
|                     sock.close()
 | |
| 
 | |
|                 except Exception as ex:
 | |
|                     self.loop.call_soon_threadsafe(fut.set_exception, ex)
 | |
|                     sock.close()
 | |
|                 else:
 | |
|                     self.loop.call_soon_threadsafe(fut.set_result, None)
 | |
| 
 | |
|             client = self.tcp_client(prog)
 | |
|             client.start()
 | |
|             clients.append(client)
 | |
| 
 | |
|             await fut
 | |
| 
 | |
|         async def start_server():
 | |
|             extras = {}
 | |
| 
 | |
|             srv = await self.loop.create_server(
 | |
|                 server_protocol_factory,
 | |
|                 '127.0.0.1', 0,
 | |
|                 family=socket.AF_INET,
 | |
|                 ssl=sslctx_1,
 | |
|                 **extras)
 | |
| 
 | |
|             try:
 | |
|                 srv_socks = srv.sockets
 | |
|                 self.assertTrue(srv_socks)
 | |
| 
 | |
|                 addr = srv_socks[0].getsockname()
 | |
| 
 | |
|                 tasks = []
 | |
|                 for _ in range(TOTAL_CNT):
 | |
|                     tasks.append(test_client(addr))
 | |
| 
 | |
|                 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
 | |
| 
 | |
|             finally:
 | |
|                 self.loop.call_soon(srv.close)
 | |
|                 await srv.wait_closed()
 | |
| 
 | |
|         with self._silence_eof_received_warning():
 | |
|             self.loop.run_until_complete(start_server())
 | |
| 
 | |
|         self.assertEqual(CNT, TOTAL_CNT)
 | |
| 
 | |
|         for client in clients:
 | |
|             client.stop()
 | |
| 
 | |
|     def test_shutdown_cleanly(self):
 | |
|         CNT = 0
 | |
|         TOTAL_CNT = 25
 | |
| 
 | |
|         A_DATA = b'A' * 1024 * BUF_MULTIPLIER
 | |
| 
 | |
|         sslctx = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT, test_utils.ONLYKEY)
 | |
|         client_sslctx = self._create_client_ssl_context()
 | |
| 
 | |
|         def server(sock):
 | |
|             sock.starttls(
 | |
|                 sslctx,
 | |
|                 server_side=True)
 | |
| 
 | |
|             data = sock.recv_all(len(A_DATA))
 | |
|             self.assertEqual(data, A_DATA)
 | |
|             sock.send(b'OK')
 | |
| 
 | |
|             sock.unwrap()
 | |
| 
 | |
|             sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             extras = {}
 | |
|             extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
 | |
| 
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='',
 | |
|                 **extras)
 | |
| 
 | |
|             writer.write(A_DATA)
 | |
|             self.assertEqual(await reader.readexactly(2), b'OK')
 | |
| 
 | |
|             self.assertEqual(await reader.read(), b'')
 | |
| 
 | |
|             nonlocal CNT
 | |
|             CNT += 1
 | |
| 
 | |
|             writer.close()
 | |
|             await self.wait_closed(writer)
 | |
| 
 | |
|         def run(coro):
 | |
|             nonlocal CNT
 | |
|             CNT = 0
 | |
| 
 | |
|             async def _gather(*tasks):
 | |
|                 return await asyncio.gather(*tasks)
 | |
| 
 | |
|             with self.tcp_server(server,
 | |
|                                  max_clients=TOTAL_CNT,
 | |
|                                  backlog=TOTAL_CNT) as srv:
 | |
|                 tasks = []
 | |
|                 for _ in range(TOTAL_CNT):
 | |
|                     tasks.append(coro(srv.addr))
 | |
| 
 | |
|                 self.loop.run_until_complete(
 | |
|                     _gather(*tasks))
 | |
| 
 | |
|             self.assertEqual(CNT, TOTAL_CNT)
 | |
| 
 | |
|         with self._silence_eof_received_warning():
 | |
|             run(client)
 | |
| 
 | |
|     def test_flush_before_shutdown(self):
 | |
|         CHUNK = 1024 * 128
 | |
|         SIZE = 32
 | |
| 
 | |
|         sslctx = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT, test_utils.ONLYKEY)
 | |
|         client_sslctx = self._create_client_ssl_context()
 | |
| 
 | |
|         future = None
 | |
| 
 | |
|         def server(sock):
 | |
|             sock.starttls(sslctx, server_side=True)
 | |
|             self.assertEqual(sock.recv_all(4), b'ping')
 | |
|             sock.send(b'pong')
 | |
|             time.sleep(0.5)  # hopefully stuck the TCP buffer
 | |
|             data = sock.recv_all(CHUNK * SIZE)
 | |
|             self.assertEqual(len(data), CHUNK * SIZE)
 | |
|             sock.close()
 | |
| 
 | |
|         def run(meth):
 | |
|             def wrapper(sock):
 | |
|                 try:
 | |
|                     meth(sock)
 | |
|                 except Exception as ex:
 | |
|                     self.loop.call_soon_threadsafe(future.set_exception, ex)
 | |
|                 else:
 | |
|                     self.loop.call_soon_threadsafe(future.set_result, None)
 | |
|             return wrapper
 | |
| 
 | |
|         async def client(addr):
 | |
|             nonlocal future
 | |
|             future = self.loop.create_future()
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='')
 | |
|             sslprotocol = writer.transport._ssl_protocol
 | |
|             writer.write(b'ping')
 | |
|             data = await reader.readexactly(4)
 | |
|             self.assertEqual(data, b'pong')
 | |
| 
 | |
|             sslprotocol.pause_writing()
 | |
|             for _ in range(SIZE):
 | |
|                 writer.write(b'x' * CHUNK)
 | |
| 
 | |
|             writer.close()
 | |
|             sslprotocol.resume_writing()
 | |
| 
 | |
|             await self.wait_closed(writer)
 | |
|             try:
 | |
|                 data = await reader.read()
 | |
|                 self.assertEqual(data, b'')
 | |
|             except ConnectionResetError:
 | |
|                 pass
 | |
|             await future
 | |
| 
 | |
|         with self.tcp_server(run(server)) as srv:
 | |
|             self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|     def test_remote_shutdown_receives_trailing_data(self):
 | |
|         CHUNK = 1024 * 128
 | |
|         SIZE = 32
 | |
| 
 | |
|         sslctx = self._create_server_ssl_context(
 | |
|             test_utils.ONLYCERT,
 | |
|             test_utils.ONLYKEY
 | |
|         )
 | |
|         client_sslctx = self._create_client_ssl_context()
 | |
|         future = None
 | |
| 
 | |
|         def server(sock):
 | |
|             incoming = ssl.MemoryBIO()
 | |
|             outgoing = ssl.MemoryBIO()
 | |
|             sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
 | |
| 
 | |
|             while True:
 | |
|                 try:
 | |
|                     sslobj.do_handshake()
 | |
|                 except ssl.SSLWantReadError:
 | |
|                     if outgoing.pending:
 | |
|                         sock.send(outgoing.read())
 | |
|                     incoming.write(sock.recv(16384))
 | |
|                 else:
 | |
|                     if outgoing.pending:
 | |
|                         sock.send(outgoing.read())
 | |
|                     break
 | |
| 
 | |
|             while True:
 | |
|                 try:
 | |
|                     data = sslobj.read(4)
 | |
|                 except ssl.SSLWantReadError:
 | |
|                     incoming.write(sock.recv(16384))
 | |
|                 else:
 | |
|                     break
 | |
| 
 | |
|             self.assertEqual(data, b'ping')
 | |
|             sslobj.write(b'pong')
 | |
|             sock.send(outgoing.read())
 | |
| 
 | |
|             time.sleep(0.2)  # wait for the peer to fill its backlog
 | |
| 
 | |
|             # send close_notify but don't wait for response
 | |
|             with self.assertRaises(ssl.SSLWantReadError):
 | |
|                 sslobj.unwrap()
 | |
|             sock.send(outgoing.read())
 | |
| 
 | |
|             # should receive all data
 | |
|             data_len = 0
 | |
|             while True:
 | |
|                 try:
 | |
|                     chunk = len(sslobj.read(16384))
 | |
|                     data_len += chunk
 | |
|                 except ssl.SSLWantReadError:
 | |
|                     incoming.write(sock.recv(16384))
 | |
|                 except ssl.SSLZeroReturnError:
 | |
|                     break
 | |
| 
 | |
|             self.assertEqual(data_len, CHUNK * SIZE)
 | |
| 
 | |
|             # verify that close_notify is received
 | |
|             sslobj.unwrap()
 | |
| 
 | |
|             sock.close()
 | |
| 
 | |
|         def eof_server(sock):
 | |
|             sock.starttls(sslctx, server_side=True)
 | |
|             self.assertEqual(sock.recv_all(4), b'ping')
 | |
|             sock.send(b'pong')
 | |
| 
 | |
|             time.sleep(0.2)  # wait for the peer to fill its backlog
 | |
| 
 | |
|             # send EOF
 | |
|             sock.shutdown(socket.SHUT_WR)
 | |
| 
 | |
|             # should receive all data
 | |
|             data = sock.recv_all(CHUNK * SIZE)
 | |
|             self.assertEqual(len(data), CHUNK * SIZE)
 | |
| 
 | |
|             sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             nonlocal future
 | |
|             future = self.loop.create_future()
 | |
| 
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='')
 | |
|             writer.write(b'ping')
 | |
|             data = await reader.readexactly(4)
 | |
|             self.assertEqual(data, b'pong')
 | |
| 
 | |
|             # fill write backlog in a hacky way - renegotiation won't help
 | |
|             for _ in range(SIZE):
 | |
|                 writer.transport._test__append_write_backlog(b'x' * CHUNK)
 | |
| 
 | |
|             try:
 | |
|                 data = await reader.read()
 | |
|                 self.assertEqual(data, b'')
 | |
|             except (BrokenPipeError, ConnectionResetError):
 | |
|                 pass
 | |
| 
 | |
|             await future
 | |
| 
 | |
|             writer.close()
 | |
|             await self.wait_closed(writer)
 | |
| 
 | |
|         def run(meth):
 | |
|             def wrapper(sock):
 | |
|                 try:
 | |
|                     meth(sock)
 | |
|                 except Exception as ex:
 | |
|                     self.loop.call_soon_threadsafe(future.set_exception, ex)
 | |
|                 else:
 | |
|                     self.loop.call_soon_threadsafe(future.set_result, None)
 | |
|             return wrapper
 | |
| 
 | |
|         with self.tcp_server(run(server)) as srv:
 | |
|             self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         with self.tcp_server(run(eof_server)) as srv:
 | |
|             self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|     def test_connect_timeout_warning(self):
 | |
|         s = socket.socket(socket.AF_INET)
 | |
|         s.bind(('127.0.0.1', 0))
 | |
|         addr = s.getsockname()
 | |
| 
 | |
|         async def test():
 | |
|             try:
 | |
|                 await asyncio.wait_for(
 | |
|                     self.loop.create_connection(asyncio.Protocol,
 | |
|                                                 *addr, ssl=True),
 | |
|                     0.1)
 | |
|             except (ConnectionRefusedError, asyncio.TimeoutError):
 | |
|                 pass
 | |
|             else:
 | |
|                 self.fail('TimeoutError is not raised')
 | |
| 
 | |
|         with s:
 | |
|             try:
 | |
|                 with self.assertWarns(ResourceWarning) as cm:
 | |
|                     self.loop.run_until_complete(test())
 | |
|                     gc.collect()
 | |
|                     gc.collect()
 | |
|                     gc.collect()
 | |
|             except AssertionError as e:
 | |
|                 self.assertEqual(str(e), 'ResourceWarning not triggered')
 | |
|             else:
 | |
|                 self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
 | |
| 
 | |
|     def test_handshake_timeout_handler_leak(self):
 | |
|         s = socket.socket(socket.AF_INET)
 | |
|         s.bind(('127.0.0.1', 0))
 | |
|         s.listen(1)
 | |
|         addr = s.getsockname()
 | |
| 
 | |
|         async def test(ctx):
 | |
|             try:
 | |
|                 await asyncio.wait_for(
 | |
|                     self.loop.create_connection(asyncio.Protocol, *addr,
 | |
|                                                 ssl=ctx),
 | |
|                     0.1)
 | |
|             except (ConnectionRefusedError, asyncio.TimeoutError):
 | |
|                 pass
 | |
|             else:
 | |
|                 self.fail('TimeoutError is not raised')
 | |
| 
 | |
|         with s:
 | |
|             ctx = ssl.create_default_context()
 | |
|             self.loop.run_until_complete(test(ctx))
 | |
|             ctx = weakref.ref(ctx)
 | |
| 
 | |
|         # SSLProtocol should be DECREF to 0
 | |
|         self.assertIsNone(ctx())
 | |
| 
 | |
|     def test_shutdown_timeout_handler_leak(self):
 | |
|         loop = self.loop
 | |
| 
 | |
|         def server(sock):
 | |
|             sslctx = self._create_server_ssl_context(
 | |
|                 test_utils.ONLYCERT,
 | |
|                 test_utils.ONLYKEY
 | |
|             )
 | |
|             sock = sslctx.wrap_socket(sock, server_side=True)
 | |
|             sock.recv(32)
 | |
|             sock.close()
 | |
| 
 | |
|         class Protocol(asyncio.Protocol):
 | |
|             def __init__(self):
 | |
|                 self.fut = asyncio.Future(loop=loop)
 | |
| 
 | |
|             def connection_lost(self, exc):
 | |
|                 self.fut.set_result(None)
 | |
| 
 | |
|         async def client(addr, ctx):
 | |
|             tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
 | |
|             tr.close()
 | |
|             await pr.fut
 | |
| 
 | |
|         with self.tcp_server(server) as srv:
 | |
|             ctx = self._create_client_ssl_context()
 | |
|             loop.run_until_complete(client(srv.addr, ctx))
 | |
|             ctx = weakref.ref(ctx)
 | |
| 
 | |
|         # asyncio has no shutdown timeout, but it ends up with a circular
 | |
|         # reference loop - not ideal (introduces gc glitches), but at least
 | |
|         # not leaking
 | |
|         gc.collect()
 | |
|         gc.collect()
 | |
|         gc.collect()
 | |
| 
 | |
|         # SSLProtocol should be DECREF to 0
 | |
|         self.assertIsNone(ctx())
 | |
| 
 | |
|     def test_shutdown_timeout_handler_not_set(self):
 | |
|         loop = self.loop
 | |
|         eof = asyncio.Event()
 | |
|         extra = None
 | |
| 
 | |
|         def server(sock):
 | |
|             sslctx = self._create_server_ssl_context(
 | |
|                 test_utils.ONLYCERT,
 | |
|                 test_utils.ONLYKEY
 | |
|             )
 | |
|             sock = sslctx.wrap_socket(sock, server_side=True)
 | |
|             sock.send(b'hello')
 | |
|             assert sock.recv(1024) == b'world'
 | |
|             sock.send(b'extra bytes')
 | |
|             # sending EOF here
 | |
|             sock.shutdown(socket.SHUT_WR)
 | |
|             loop.call_soon_threadsafe(eof.set)
 | |
|             # make sure we have enough time to reproduce the issue
 | |
|             assert sock.recv(1024) == b''
 | |
|             sock.close()
 | |
| 
 | |
|         class Protocol(asyncio.Protocol):
 | |
|             def __init__(self):
 | |
|                 self.fut = asyncio.Future(loop=loop)
 | |
|                 self.transport = None
 | |
| 
 | |
|             def connection_made(self, transport):
 | |
|                 self.transport = transport
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 if data == b'hello':
 | |
|                     self.transport.write(b'world')
 | |
|                     # pause reading would make incoming data stay in the sslobj
 | |
|                     self.transport.pause_reading()
 | |
|                 else:
 | |
|                     nonlocal extra
 | |
|                     extra = data
 | |
| 
 | |
|             def connection_lost(self, exc):
 | |
|                 if exc is None:
 | |
|                     self.fut.set_result(None)
 | |
|                 else:
 | |
|                     self.fut.set_exception(exc)
 | |
| 
 | |
|         async def client(addr):
 | |
|             ctx = self._create_client_ssl_context()
 | |
|             tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
 | |
|             await eof.wait()
 | |
|             tr.resume_reading()
 | |
|             await pr.fut
 | |
|             tr.close()
 | |
|             assert extra == b'extra bytes'
 | |
| 
 | |
|         with self.tcp_server(server) as srv:
 | |
|             loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
| 
 | |
| ###############################################################################
 | |
| # Socket Testing Utilities
 | |
| ###############################################################################
 | |
| 
 | |
| 
 | |
| class TestSocketWrapper:
 | |
| 
 | |
|     def __init__(self, sock):
 | |
|         self.__sock = sock
 | |
| 
 | |
|     def recv_all(self, n):
 | |
|         buf = b''
 | |
|         while len(buf) < n:
 | |
|             data = self.recv(n - len(buf))
 | |
|             if data == b'':
 | |
|                 raise ConnectionAbortedError
 | |
|             buf += data
 | |
|         return buf
 | |
| 
 | |
|     def starttls(self, ssl_context, *,
 | |
|                  server_side=False,
 | |
|                  server_hostname=None,
 | |
|                  do_handshake_on_connect=True):
 | |
| 
 | |
|         assert isinstance(ssl_context, ssl.SSLContext)
 | |
| 
 | |
|         ssl_sock = ssl_context.wrap_socket(
 | |
|             self.__sock, server_side=server_side,
 | |
|             server_hostname=server_hostname,
 | |
|             do_handshake_on_connect=do_handshake_on_connect)
 | |
| 
 | |
|         if server_side:
 | |
|             ssl_sock.do_handshake()
 | |
| 
 | |
|         self.__sock.close()
 | |
|         self.__sock = ssl_sock
 | |
| 
 | |
|     def __getattr__(self, name):
 | |
|         return getattr(self.__sock, name)
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return '<{} {!r}>'.format(type(self).__name__, self.__sock)
 | |
| 
 | |
| 
 | |
| class SocketThread(threading.Thread):
 | |
| 
 | |
|     def stop(self):
 | |
|         self._active = False
 | |
|         self.join()
 | |
| 
 | |
|     def __enter__(self):
 | |
|         self.start()
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, *exc):
 | |
|         self.stop()
 | |
| 
 | |
| 
 | |
| class TestThreadedClient(SocketThread):
 | |
| 
 | |
|     def __init__(self, test, sock, prog, timeout):
 | |
|         threading.Thread.__init__(self, None, None, 'test-client')
 | |
|         self.daemon = True
 | |
| 
 | |
|         self._timeout = timeout
 | |
|         self._sock = sock
 | |
|         self._active = True
 | |
|         self._prog = prog
 | |
|         self._test = test
 | |
| 
 | |
|     def run(self):
 | |
|         try:
 | |
|             self._prog(TestSocketWrapper(self._sock))
 | |
|         except (KeyboardInterrupt, SystemExit):
 | |
|             raise
 | |
|         except BaseException as ex:
 | |
|             self._test._abort_socket_test(ex)
 | |
| 
 | |
| 
 | |
| class TestThreadedServer(SocketThread):
 | |
| 
 | |
|     def __init__(self, test, sock, prog, timeout, max_clients):
 | |
|         threading.Thread.__init__(self, None, None, 'test-server')
 | |
|         self.daemon = True
 | |
| 
 | |
|         self._clients = 0
 | |
|         self._finished_clients = 0
 | |
|         self._max_clients = max_clients
 | |
|         self._timeout = timeout
 | |
|         self._sock = sock
 | |
|         self._active = True
 | |
| 
 | |
|         self._prog = prog
 | |
| 
 | |
|         self._s1, self._s2 = socket.socketpair()
 | |
|         self._s1.setblocking(False)
 | |
| 
 | |
|         self._test = test
 | |
| 
 | |
|     def stop(self):
 | |
|         try:
 | |
|             if self._s2 and self._s2.fileno() != -1:
 | |
|                 try:
 | |
|                     self._s2.send(b'stop')
 | |
|                 except OSError:
 | |
|                     pass
 | |
|         finally:
 | |
|             super().stop()
 | |
| 
 | |
|     def run(self):
 | |
|         try:
 | |
|             with self._sock:
 | |
|                 self._sock.setblocking(False)
 | |
|                 self._run()
 | |
|         finally:
 | |
|             self._s1.close()
 | |
|             self._s2.close()
 | |
| 
 | |
|     def _run(self):
 | |
|         while self._active:
 | |
|             if self._clients >= self._max_clients:
 | |
|                 return
 | |
| 
 | |
|             r, w, x = select.select(
 | |
|                 [self._sock, self._s1], [], [], self._timeout)
 | |
| 
 | |
|             if self._s1 in r:
 | |
|                 return
 | |
| 
 | |
|             if self._sock in r:
 | |
|                 try:
 | |
|                     conn, addr = self._sock.accept()
 | |
|                 except BlockingIOError:
 | |
|                     continue
 | |
|                 except socket.timeout:
 | |
|                     if not self._active:
 | |
|                         return
 | |
|                     else:
 | |
|                         raise
 | |
|                 else:
 | |
|                     self._clients += 1
 | |
|                     conn.settimeout(self._timeout)
 | |
|                     try:
 | |
|                         with conn:
 | |
|                             self._handle_client(conn)
 | |
|                     except (KeyboardInterrupt, SystemExit):
 | |
|                         raise
 | |
|                     except BaseException as ex:
 | |
|                         self._active = False
 | |
|                         try:
 | |
|                             raise
 | |
|                         finally:
 | |
|                             self._test._abort_socket_test(ex)
 | |
| 
 | |
|     def _handle_client(self, sock):
 | |
|         self._prog(TestSocketWrapper(sock))
 | |
| 
 | |
|     @property
 | |
|     def addr(self):
 | |
|         return self._sock.getsockname()
 |