mirror of
				https://github.com/python/cpython.git
				synced 2025-10-26 11:14:33 +00:00 
			
		
		
		
	 6ca78affc8
			
		
	
	
		6ca78affc8
		
			
		
	
	
	
	
		
			
			To keep checks even if run tests with optimized Python. Either use special assertion methods like assertEqual() or raise an AssertionError explicitly.
		
			
				
	
	
		
			774 lines
		
	
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			774 lines
		
	
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Tests for asyncio/sslproto.py."""
 | |
| 
 | |
| import logging
 | |
| import socket
 | |
| from test import support
 | |
| import unittest
 | |
| import weakref
 | |
| from unittest import mock
 | |
| try:
 | |
|     import ssl
 | |
| except ImportError:
 | |
|     ssl = None
 | |
| 
 | |
| import asyncio
 | |
| from asyncio import log
 | |
| from asyncio import protocols
 | |
| from asyncio import sslproto
 | |
| from test import support
 | |
| from test.test_asyncio import utils as test_utils
 | |
| from test.test_asyncio import functional as func_tests
 | |
| 
 | |
| 
 | |
| def tearDownModule():
 | |
|     asyncio.set_event_loop_policy(None)
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(ssl is None, 'No ssl module')
 | |
| class SslProtoHandshakeTests(test_utils.TestCase):
 | |
| 
 | |
|     def setUp(self):
 | |
|         super().setUp()
 | |
|         self.loop = asyncio.new_event_loop()
 | |
|         self.set_event_loop(self.loop)
 | |
| 
 | |
|     def ssl_protocol(self, *, waiter=None, proto=None):
 | |
|         sslcontext = test_utils.dummy_ssl_context()
 | |
|         if proto is None:  # app protocol
 | |
|             proto = asyncio.Protocol()
 | |
|         ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
 | |
|                                          ssl_handshake_timeout=0.1)
 | |
|         self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
 | |
|         self.addCleanup(ssl_proto._app_transport.close)
 | |
|         return ssl_proto
 | |
| 
 | |
|     def connection_made(self, ssl_proto, *, do_handshake=None):
 | |
|         transport = mock.Mock()
 | |
|         sslpipe = mock.Mock()
 | |
|         sslpipe.shutdown.return_value = b''
 | |
|         if do_handshake:
 | |
|             sslpipe.do_handshake.side_effect = do_handshake
 | |
|         else:
 | |
|             def mock_handshake(callback):
 | |
|                 return []
 | |
|             sslpipe.do_handshake.side_effect = mock_handshake
 | |
|         with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
 | |
|             ssl_proto.connection_made(transport)
 | |
|         return transport
 | |
| 
 | |
|     def test_handshake_timeout_zero(self):
 | |
|         sslcontext = test_utils.dummy_ssl_context()
 | |
|         app_proto = mock.Mock()
 | |
|         waiter = mock.Mock()
 | |
|         with self.assertRaisesRegex(ValueError, 'a positive number'):
 | |
|             sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
 | |
|                                  ssl_handshake_timeout=0)
 | |
| 
 | |
|     def test_handshake_timeout_negative(self):
 | |
|         sslcontext = test_utils.dummy_ssl_context()
 | |
|         app_proto = mock.Mock()
 | |
|         waiter = mock.Mock()
 | |
|         with self.assertRaisesRegex(ValueError, 'a positive number'):
 | |
|             sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
 | |
|                                  ssl_handshake_timeout=-10)
 | |
| 
 | |
|     def test_eof_received_waiter(self):
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         self.connection_made(ssl_proto)
 | |
|         ssl_proto.eof_received()
 | |
|         test_utils.run_briefly(self.loop)
 | |
|         self.assertIsInstance(waiter.exception(), ConnectionResetError)
 | |
| 
 | |
|     def test_fatal_error_no_name_error(self):
 | |
|         # From issue #363.
 | |
|         # _fatal_error() generates a NameError if sslproto.py
 | |
|         # does not import base_events.
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         # Temporarily turn off error logging so as not to spoil test output.
 | |
|         log_level = log.logger.getEffectiveLevel()
 | |
|         log.logger.setLevel(logging.FATAL)
 | |
|         try:
 | |
|             ssl_proto._fatal_error(None)
 | |
|         finally:
 | |
|             # Restore error logging.
 | |
|             log.logger.setLevel(log_level)
 | |
| 
 | |
|     def test_connection_lost(self):
 | |
|         # From issue #472.
 | |
|         # yield from waiter hang if lost_connection was called.
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         self.connection_made(ssl_proto)
 | |
|         ssl_proto.connection_lost(ConnectionAbortedError)
 | |
|         test_utils.run_briefly(self.loop)
 | |
|         self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
 | |
| 
 | |
|     def test_close_during_handshake(self):
 | |
|         # bpo-29743 Closing transport during handshake process leaks socket
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
| 
 | |
|         transport = self.connection_made(ssl_proto)
 | |
|         test_utils.run_briefly(self.loop)
 | |
| 
 | |
|         ssl_proto._app_transport.close()
 | |
|         self.assertTrue(transport.abort.called)
 | |
| 
 | |
|     def test_get_extra_info_on_closed_connection(self):
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         self.assertIsNone(ssl_proto._get_extra_info('socket'))
 | |
|         default = object()
 | |
|         self.assertIs(ssl_proto._get_extra_info('socket', default), default)
 | |
|         self.connection_made(ssl_proto)
 | |
|         self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
 | |
|         ssl_proto.connection_lost(None)
 | |
|         self.assertIsNone(ssl_proto._get_extra_info('socket'))
 | |
| 
 | |
|     def test_set_new_app_protocol(self):
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         new_app_proto = asyncio.Protocol()
 | |
|         ssl_proto._app_transport.set_protocol(new_app_proto)
 | |
|         self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
 | |
|         self.assertIs(ssl_proto._app_protocol, new_app_proto)
 | |
| 
 | |
|     def test_data_received_after_closing(self):
 | |
|         ssl_proto = self.ssl_protocol()
 | |
|         self.connection_made(ssl_proto)
 | |
|         transp = ssl_proto._app_transport
 | |
| 
 | |
|         transp.close()
 | |
| 
 | |
|         # should not raise
 | |
|         self.assertIsNone(ssl_proto.data_received(b'data'))
 | |
| 
 | |
|     def test_write_after_closing(self):
 | |
|         ssl_proto = self.ssl_protocol()
 | |
|         self.connection_made(ssl_proto)
 | |
|         transp = ssl_proto._app_transport
 | |
|         transp.close()
 | |
| 
 | |
|         # should not raise
 | |
|         self.assertIsNone(transp.write(b'data'))
 | |
| 
 | |
| 
 | |
| ##############################################################################
 | |
| # Start TLS Tests
 | |
| ##############################################################################
 | |
| 
 | |
| 
 | |
| class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
 | |
| 
 | |
|     PAYLOAD_SIZE = 1024 * 100
 | |
|     TIMEOUT = support.LONG_TIMEOUT
 | |
| 
 | |
|     def new_loop(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def test_buf_feed_data(self):
 | |
| 
 | |
|         class Proto(asyncio.BufferedProtocol):
 | |
| 
 | |
|             def __init__(self, bufsize, usemv):
 | |
|                 self.buf = bytearray(bufsize)
 | |
|                 self.mv = memoryview(self.buf)
 | |
|                 self.data = b''
 | |
|                 self.usemv = usemv
 | |
| 
 | |
|             def get_buffer(self, sizehint):
 | |
|                 if self.usemv:
 | |
|                     return self.mv
 | |
|                 else:
 | |
|                     return self.buf
 | |
| 
 | |
|             def buffer_updated(self, nsize):
 | |
|                 if self.usemv:
 | |
|                     self.data += self.mv[:nsize]
 | |
|                 else:
 | |
|                     self.data += self.buf[:nsize]
 | |
| 
 | |
|         for usemv in [False, True]:
 | |
|             proto = Proto(1, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'12345')
 | |
|             self.assertEqual(proto.data, b'12345')
 | |
| 
 | |
|             proto = Proto(2, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'12345')
 | |
|             self.assertEqual(proto.data, b'12345')
 | |
| 
 | |
|             proto = Proto(2, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'1234')
 | |
|             self.assertEqual(proto.data, b'1234')
 | |
| 
 | |
|             proto = Proto(4, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'1234')
 | |
|             self.assertEqual(proto.data, b'1234')
 | |
| 
 | |
|             proto = Proto(100, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'12345')
 | |
|             self.assertEqual(proto.data, b'12345')
 | |
| 
 | |
|             proto = Proto(0, usemv)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
 | |
|                 protocols._feed_data_to_buffered_proto(proto, b'12345')
 | |
| 
 | |
|     def test_start_tls_client_reg_proto_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.start_tls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.shutdown(socket.SHUT_RDWR)
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
|             new_tr = await self.loop.start_tls(tr, proto, client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data, b'O')
 | |
|             new_tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             new_tr.close()
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|         # No garbage is left if SSL is closed uncleanly
 | |
|         client_context = weakref.ref(client_context)
 | |
|         support.gc_collect()
 | |
|         self.assertIsNone(client_context())
 | |
| 
 | |
|     def test_create_connection_memory_leak(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             sock.start_tls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.shutdown(socket.SHUT_RDWR)
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 # XXX: We assume user stores the transport in protocol
 | |
|                 proto.tr = tr
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr,
 | |
|                 ssl=client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data, b'O')
 | |
|             tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             tr.close()
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|         # No garbage is left for SSL client from loop.create_connection, even
 | |
|         # if user stores the SSLTransport in corresponding protocol instance
 | |
|         client_context = weakref.ref(client_context)
 | |
|         support.gc_collect()
 | |
|         self.assertIsNone(client_context())
 | |
| 
 | |
|     def test_start_tls_client_buf_proto_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
|         client_con_made_calls = 0
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.start_tls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.sendall(b'2')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.shutdown(socket.SHUT_RDWR)
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProtoFirst(asyncio.BufferedProtocol):
 | |
|             def __init__(self, on_data):
 | |
|                 self.on_data = on_data
 | |
|                 self.buf = bytearray(1)
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 nonlocal client_con_made_calls
 | |
|                 client_con_made_calls += 1
 | |
| 
 | |
|             def get_buffer(self, sizehint):
 | |
|                 return self.buf
 | |
| 
 | |
|             def buffer_updated(slf, nsize):
 | |
|                 self.assertEqual(nsize, 1)
 | |
|                 slf.on_data.set_result(bytes(slf.buf[:nsize]))
 | |
| 
 | |
|         class ClientProtoSecond(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 nonlocal client_con_made_calls
 | |
|                 client_con_made_calls += 1
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data1 = self.loop.create_future()
 | |
|             on_data2 = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProtoFirst(on_data1), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
|             new_tr = await self.loop.start_tls(tr, proto, client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data1, b'O')
 | |
|             new_tr.write(HELLO_MSG)
 | |
| 
 | |
|             new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
 | |
|             self.assertEqual(await on_data2, b'2')
 | |
|             new_tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             new_tr.close()
 | |
| 
 | |
|             # connection_made() should be called only once -- when
 | |
|             # we establish connection for the first time. Start TLS
 | |
|             # doesn't call connection_made() on application protocols.
 | |
|             self.assertEqual(client_con_made_calls, 1)
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=self.TIMEOUT))
 | |
| 
 | |
|     def test_start_tls_slow_client_cancel(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
|         server_waits_on_handshake = self.loop.create_future()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             try:
 | |
|                 self.loop.call_soon_threadsafe(
 | |
|                     server_waits_on_handshake.set_result, None)
 | |
|                 data = sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
| 
 | |
|             await server_waits_on_handshake
 | |
| 
 | |
|             with self.assertRaises(asyncio.TimeoutError):
 | |
|                 await asyncio.wait_for(
 | |
|                     self.loop.start_tls(tr, proto, client_context),
 | |
|                     0.5)
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|     def test_start_tls_server_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
|         ANSWER = b'answer'
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
|         answer = None
 | |
| 
 | |
|         def client(sock, addr):
 | |
|             nonlocal answer
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             sock.connect(addr)
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.start_tls(client_context)
 | |
|             sock.sendall(HELLO_MSG)
 | |
|             answer = sock.recv_all(len(ANSWER))
 | |
|             sock.close()
 | |
| 
 | |
|         class ServerProto(asyncio.Protocol):
 | |
|             def __init__(self, on_con, on_con_lost, on_got_hello):
 | |
|                 self.on_con = on_con
 | |
|                 self.on_con_lost = on_con_lost
 | |
|                 self.on_got_hello = on_got_hello
 | |
|                 self.data = b''
 | |
|                 self.transport = None
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 self.transport = tr
 | |
|                 self.on_con.set_result(tr)
 | |
| 
 | |
|             def replace_transport(self, tr):
 | |
|                 self.transport = tr
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.data += data
 | |
|                 if len(self.data) >= len(HELLO_MSG):
 | |
|                     self.on_got_hello.set_result(None)
 | |
| 
 | |
|             def connection_lost(self, exc):
 | |
|                 self.transport = None
 | |
|                 if exc is None:
 | |
|                     self.on_con_lost.set_result(None)
 | |
|                 else:
 | |
|                     self.on_con_lost.set_exception(exc)
 | |
| 
 | |
|         async def main(proto, on_con, on_con_lost, on_got_hello):
 | |
|             tr = await on_con
 | |
|             tr.write(HELLO_MSG)
 | |
| 
 | |
|             self.assertEqual(proto.data, b'')
 | |
| 
 | |
|             new_tr = await self.loop.start_tls(
 | |
|                 tr, proto, server_context,
 | |
|                 server_side=True,
 | |
|                 ssl_handshake_timeout=self.TIMEOUT)
 | |
|             proto.replace_transport(new_tr)
 | |
| 
 | |
|             await on_got_hello
 | |
|             new_tr.write(ANSWER)
 | |
| 
 | |
|             await on_con_lost
 | |
|             self.assertEqual(proto.data, HELLO_MSG)
 | |
|             new_tr.close()
 | |
| 
 | |
|         async def run_main():
 | |
|             on_con = self.loop.create_future()
 | |
|             on_con_lost = self.loop.create_future()
 | |
|             on_got_hello = self.loop.create_future()
 | |
|             proto = ServerProto(on_con, on_con_lost, on_got_hello)
 | |
| 
 | |
|             server = await self.loop.create_server(
 | |
|                 lambda: proto, '127.0.0.1', 0)
 | |
|             addr = server.sockets[0].getsockname()
 | |
| 
 | |
|             with self.tcp_client(lambda sock: client(sock, addr),
 | |
|                                  timeout=self.TIMEOUT):
 | |
|                 await asyncio.wait_for(
 | |
|                     main(proto, on_con, on_con_lost, on_got_hello),
 | |
|                     timeout=self.TIMEOUT)
 | |
| 
 | |
|             server.close()
 | |
|             await server.wait_closed()
 | |
|             self.assertEqual(answer, ANSWER)
 | |
| 
 | |
|         self.loop.run_until_complete(run_main())
 | |
| 
 | |
|     def test_start_tls_wrong_args(self):
 | |
|         async def main():
 | |
|             with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
 | |
|                 await self.loop.start_tls(None, None, None)
 | |
| 
 | |
|             sslctx = test_utils.simple_server_sslcontext()
 | |
|             with self.assertRaisesRegex(TypeError, 'is not supported'):
 | |
|                 await self.loop.start_tls(None, None, sslctx)
 | |
| 
 | |
|         self.loop.run_until_complete(main())
 | |
| 
 | |
|     def test_handshake_timeout(self):
 | |
|         # bpo-29970: Check that a connection is aborted if handshake is not
 | |
|         # completed in timeout period, instead of remaining open indefinitely
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         messages = []
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 | |
| 
 | |
|         server_side_aborted = False
 | |
| 
 | |
|         def server(sock):
 | |
|             nonlocal server_side_aborted
 | |
|             try:
 | |
|                 sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 server_side_aborted = True
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.wait_for(
 | |
|                 self.loop.create_connection(
 | |
|                     asyncio.Protocol,
 | |
|                     *addr,
 | |
|                     ssl=client_sslctx,
 | |
|                     server_hostname='',
 | |
|                     ssl_handshake_timeout=support.SHORT_TIMEOUT),
 | |
|                 0.5)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaises(asyncio.TimeoutError):
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         self.assertTrue(server_side_aborted)
 | |
| 
 | |
|         # Python issue #23197: cancelling a handshake must not raise an
 | |
|         # exception or log an error, even if the handshake failed
 | |
|         self.assertEqual(messages, [])
 | |
| 
 | |
|         # The 10s handshake timeout should be cancelled to free related
 | |
|         # objects without really waiting for 10s
 | |
|         client_sslctx = weakref.ref(client_sslctx)
 | |
|         support.gc_collect()
 | |
|         self.assertIsNone(client_sslctx())
 | |
| 
 | |
|     def test_create_connection_ssl_slow_handshake(self):
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         messages = []
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 | |
| 
 | |
|         def server(sock):
 | |
|             try:
 | |
|                 sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='',
 | |
|                 ssl_handshake_timeout=1.0)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaisesRegex(
 | |
|                     ConnectionAbortedError,
 | |
|                     r'SSL handshake.*is taking longer'):
 | |
| 
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         self.assertEqual(messages, [])
 | |
| 
 | |
|     def test_create_connection_ssl_failed_certificate(self):
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: None)
 | |
| 
 | |
|         sslctx = test_utils.simple_server_sslcontext()
 | |
|         client_sslctx = test_utils.simple_client_sslcontext(
 | |
|             disable_verify=False)
 | |
| 
 | |
|         def server(sock):
 | |
|             try:
 | |
|                 sock.start_tls(
 | |
|                     sslctx,
 | |
|                     server_side=True)
 | |
|             except ssl.SSLError:
 | |
|                 pass
 | |
|             except OSError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='',
 | |
|                 ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaises(ssl.SSLCertVerificationError):
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|     def test_start_tls_client_corrupted_ssl(self):
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: None)
 | |
| 
 | |
|         sslctx = test_utils.simple_server_sslcontext()
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def server(sock):
 | |
|             orig_sock = sock.dup()
 | |
|             try:
 | |
|                 sock.start_tls(
 | |
|                     sslctx,
 | |
|                     server_side=True)
 | |
|                 sock.sendall(b'A\n')
 | |
|                 sock.recv_all(1)
 | |
|                 orig_sock.send(b'please corrupt the SSL connection')
 | |
|             except ssl.SSLError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 orig_sock.close()
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='')
 | |
| 
 | |
|             self.assertEqual(await reader.readline(), b'A\n')
 | |
|             writer.write(b'B')
 | |
|             with self.assertRaises(ssl.SSLError):
 | |
|                 await reader.readline()
 | |
| 
 | |
|             writer.close()
 | |
|             return 'OK'
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             res = self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         self.assertEqual(res, 'OK')
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(ssl is None, 'No ssl module')
 | |
| class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
 | |
| 
 | |
|     def new_loop(self):
 | |
|         return asyncio.SelectorEventLoop()
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(ssl is None, 'No ssl module')
 | |
| @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
 | |
| class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
 | |
| 
 | |
|     def new_loop(self):
 | |
|         return asyncio.ProactorEventLoop()
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     unittest.main()
 |