mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	 95bcbcb7c8
			
		
	
	
		95bcbcb7c8
		
			
		
	
	
	
	
		
			
			gh-118950: Fix SSLProtocol.connection_lost not being called when OSError is thrown (GH-118960)
(cherry picked from commit 3f24bde0b6)
Co-authored-by: Javad Shafique <javadshafique@hotmail.com>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
		
	
			
		
			
				
	
	
		
			843 lines
		
	
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			843 lines
		
	
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Tests for asyncio/sslproto.py."""
 | |
| 
 | |
| import logging
 | |
| import socket
 | |
| import unittest
 | |
| import weakref
 | |
| from test import support
 | |
| from test.support import socket_helper
 | |
| from unittest import mock
 | |
| try:
 | |
|     import ssl
 | |
| except ImportError:
 | |
|     ssl = None
 | |
| 
 | |
| import asyncio
 | |
| from asyncio import log
 | |
| from asyncio import protocols
 | |
| from asyncio import sslproto
 | |
| from test.test_asyncio import utils as test_utils
 | |
| from test.test_asyncio import functional as func_tests
 | |
| 
 | |
| 
 | |
| def tearDownModule():
 | |
|     asyncio.set_event_loop_policy(None)
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(ssl is None, 'No ssl module')
 | |
| class SslProtoHandshakeTests(test_utils.TestCase):
 | |
| 
 | |
|     def setUp(self):
 | |
|         super().setUp()
 | |
|         self.loop = asyncio.new_event_loop()
 | |
|         self.set_event_loop(self.loop)
 | |
| 
 | |
|     def ssl_protocol(self, *, waiter=None, proto=None):
 | |
|         sslcontext = test_utils.dummy_ssl_context()
 | |
|         if proto is None:  # app protocol
 | |
|             proto = asyncio.Protocol()
 | |
|         ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
 | |
|                                          ssl_handshake_timeout=0.1)
 | |
|         self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
 | |
|         self.addCleanup(ssl_proto._app_transport.close)
 | |
|         return ssl_proto
 | |
| 
 | |
|     def connection_made(self, ssl_proto, *, do_handshake=None):
 | |
|         transport = mock.Mock()
 | |
|         sslobj = mock.Mock()
 | |
|         # emulate reading decompressed data
 | |
|         sslobj.read.side_effect = ssl.SSLWantReadError
 | |
|         sslobj.write.side_effect = ssl.SSLWantReadError
 | |
|         if do_handshake is not None:
 | |
|             sslobj.do_handshake = do_handshake
 | |
|         ssl_proto._sslobj = sslobj
 | |
|         ssl_proto.connection_made(transport)
 | |
|         return transport
 | |
| 
 | |
|     def test_handshake_timeout_zero(self):
 | |
|         sslcontext = test_utils.dummy_ssl_context()
 | |
|         app_proto = mock.Mock()
 | |
|         waiter = mock.Mock()
 | |
|         with self.assertRaisesRegex(ValueError, 'a positive number'):
 | |
|             sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
 | |
|                                  ssl_handshake_timeout=0)
 | |
| 
 | |
|     def test_handshake_timeout_negative(self):
 | |
|         sslcontext = test_utils.dummy_ssl_context()
 | |
|         app_proto = mock.Mock()
 | |
|         waiter = mock.Mock()
 | |
|         with self.assertRaisesRegex(ValueError, 'a positive number'):
 | |
|             sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
 | |
|                                  ssl_handshake_timeout=-10)
 | |
| 
 | |
|     def test_eof_received_waiter(self):
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         self.connection_made(
 | |
|             ssl_proto,
 | |
|             do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
 | |
|         )
 | |
|         ssl_proto.eof_received()
 | |
|         test_utils.run_briefly(self.loop)
 | |
|         self.assertIsInstance(waiter.exception(), ConnectionResetError)
 | |
| 
 | |
|     def test_fatal_error_no_name_error(self):
 | |
|         # From issue #363.
 | |
|         # _fatal_error() generates a NameError if sslproto.py
 | |
|         # does not import base_events.
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         # Temporarily turn off error logging so as not to spoil test output.
 | |
|         log_level = log.logger.getEffectiveLevel()
 | |
|         log.logger.setLevel(logging.FATAL)
 | |
|         try:
 | |
|             ssl_proto._fatal_error(None)
 | |
|         finally:
 | |
|             # Restore error logging.
 | |
|             log.logger.setLevel(log_level)
 | |
| 
 | |
|     def test_connection_lost(self):
 | |
|         # From issue #472.
 | |
|         # yield from waiter hang if lost_connection was called.
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         self.connection_made(
 | |
|             ssl_proto,
 | |
|             do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
 | |
|         )
 | |
|         ssl_proto.connection_lost(ConnectionAbortedError)
 | |
|         test_utils.run_briefly(self.loop)
 | |
|         self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
 | |
| 
 | |
|     def test_connection_lost_when_busy(self):
 | |
|         # gh-118950: SSLProtocol.connection_lost not being called when OSError
 | |
|         # is thrown on asyncio.write.
 | |
|         sock = mock.Mock()
 | |
|         sock.fileno = mock.Mock(return_value=12345)
 | |
|         sock.send = mock.Mock(side_effect=BrokenPipeError)
 | |
| 
 | |
|         # construct StreamWriter chain that contains loop dependant logic this emulates
 | |
|         # what _make_ssl_transport() does in BaseSelectorEventLoop
 | |
|         reader = asyncio.StreamReader(limit=2 ** 16, loop=self.loop)
 | |
|         protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
 | |
|         ssl_proto = self.ssl_protocol(proto=protocol)
 | |
| 
 | |
|         # emulate reading decompressed data
 | |
|         sslobj = mock.Mock()
 | |
|         sslobj.read.side_effect = ssl.SSLWantReadError
 | |
|         sslobj.write.side_effect = ssl.SSLWantReadError
 | |
|         ssl_proto._sslobj = sslobj
 | |
| 
 | |
|         # emulate outgoing data
 | |
|         data = b'An interesting message'
 | |
| 
 | |
|         outgoing = mock.Mock()
 | |
|         outgoing.read = mock.Mock(return_value=data)
 | |
|         outgoing.pending = len(data)
 | |
|         ssl_proto._outgoing = outgoing
 | |
| 
 | |
|         # use correct socket transport to initialize the SSLProtocol
 | |
|         self.loop._make_socket_transport(sock, ssl_proto)
 | |
| 
 | |
|         transport = ssl_proto._app_transport
 | |
|         writer = asyncio.StreamWriter(transport, protocol, reader, self.loop)
 | |
| 
 | |
|         async def main():
 | |
|             # writes data to transport
 | |
|             async def write():
 | |
|                 writer.write(data)
 | |
|                 await writer.drain()
 | |
| 
 | |
|             # try to write for the first time
 | |
|             await write()
 | |
|             # try to write for the second time, this raises as the connection_lost
 | |
|             # callback should be done with error
 | |
|             with self.assertRaises(ConnectionResetError):
 | |
|                 await write()
 | |
| 
 | |
|         self.loop.run_until_complete(main())
 | |
| 
 | |
|     def test_close_during_handshake(self):
 | |
|         # bpo-29743 Closing transport during handshake process leaks socket
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
| 
 | |
|         transport = self.connection_made(
 | |
|             ssl_proto,
 | |
|             do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
 | |
|         )
 | |
|         test_utils.run_briefly(self.loop)
 | |
| 
 | |
|         ssl_proto._app_transport.close()
 | |
|         self.assertTrue(transport._force_close.called)
 | |
| 
 | |
|     def test_close_during_ssl_over_ssl(self):
 | |
|         # gh-113214: passing exceptions from the inner wrapped SSL protocol to the
 | |
|         # shim transport provided by the outer SSL protocol should not raise
 | |
|         # attribute errors
 | |
|         outer = self.ssl_protocol(proto=self.ssl_protocol())
 | |
|         self.connection_made(outer)
 | |
|         # Closing the outer app transport should not raise an exception
 | |
|         messages = []
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 | |
|         outer._app_transport.close()
 | |
|         self.assertEqual(messages, [])
 | |
| 
 | |
|     def test_get_extra_info_on_closed_connection(self):
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         self.assertIsNone(ssl_proto._get_extra_info('socket'))
 | |
|         default = object()
 | |
|         self.assertIs(ssl_proto._get_extra_info('socket', default), default)
 | |
|         self.connection_made(ssl_proto)
 | |
|         self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
 | |
|         ssl_proto.connection_lost(None)
 | |
|         self.assertIsNone(ssl_proto._get_extra_info('socket'))
 | |
| 
 | |
|     def test_set_new_app_protocol(self):
 | |
|         waiter = self.loop.create_future()
 | |
|         ssl_proto = self.ssl_protocol(waiter=waiter)
 | |
|         new_app_proto = asyncio.Protocol()
 | |
|         ssl_proto._app_transport.set_protocol(new_app_proto)
 | |
|         self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
 | |
|         self.assertIs(ssl_proto._app_protocol, new_app_proto)
 | |
| 
 | |
|     def test_data_received_after_closing(self):
 | |
|         ssl_proto = self.ssl_protocol()
 | |
|         self.connection_made(ssl_proto)
 | |
|         transp = ssl_proto._app_transport
 | |
| 
 | |
|         transp.close()
 | |
| 
 | |
|         # should not raise
 | |
|         self.assertIsNone(ssl_proto.buffer_updated(5))
 | |
| 
 | |
|     def test_write_after_closing(self):
 | |
|         ssl_proto = self.ssl_protocol()
 | |
|         self.connection_made(ssl_proto)
 | |
|         transp = ssl_proto._app_transport
 | |
|         transp.close()
 | |
| 
 | |
|         # should not raise
 | |
|         self.assertIsNone(transp.write(b'data'))
 | |
| 
 | |
| 
 | |
| ##############################################################################
 | |
| # Start TLS Tests
 | |
| ##############################################################################
 | |
| 
 | |
| 
 | |
| class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
 | |
| 
 | |
|     PAYLOAD_SIZE = 1024 * 100
 | |
|     TIMEOUT = support.LONG_TIMEOUT
 | |
| 
 | |
|     def new_loop(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def test_buf_feed_data(self):
 | |
| 
 | |
|         class Proto(asyncio.BufferedProtocol):
 | |
| 
 | |
|             def __init__(self, bufsize, usemv):
 | |
|                 self.buf = bytearray(bufsize)
 | |
|                 self.mv = memoryview(self.buf)
 | |
|                 self.data = b''
 | |
|                 self.usemv = usemv
 | |
| 
 | |
|             def get_buffer(self, sizehint):
 | |
|                 if self.usemv:
 | |
|                     return self.mv
 | |
|                 else:
 | |
|                     return self.buf
 | |
| 
 | |
|             def buffer_updated(self, nsize):
 | |
|                 if self.usemv:
 | |
|                     self.data += self.mv[:nsize]
 | |
|                 else:
 | |
|                     self.data += self.buf[:nsize]
 | |
| 
 | |
|         for usemv in [False, True]:
 | |
|             proto = Proto(1, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'12345')
 | |
|             self.assertEqual(proto.data, b'12345')
 | |
| 
 | |
|             proto = Proto(2, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'12345')
 | |
|             self.assertEqual(proto.data, b'12345')
 | |
| 
 | |
|             proto = Proto(2, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'1234')
 | |
|             self.assertEqual(proto.data, b'1234')
 | |
| 
 | |
|             proto = Proto(4, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'1234')
 | |
|             self.assertEqual(proto.data, b'1234')
 | |
| 
 | |
|             proto = Proto(100, usemv)
 | |
|             protocols._feed_data_to_buffered_proto(proto, b'12345')
 | |
|             self.assertEqual(proto.data, b'12345')
 | |
| 
 | |
|             proto = Proto(0, usemv)
 | |
|             with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
 | |
|                 protocols._feed_data_to_buffered_proto(proto, b'12345')
 | |
| 
 | |
|     def test_start_tls_client_reg_proto_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.start_tls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.shutdown(socket.SHUT_RDWR)
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
|             new_tr = await self.loop.start_tls(tr, proto, client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data, b'O')
 | |
|             new_tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             new_tr.close()
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|         # No garbage is left if SSL is closed uncleanly
 | |
|         client_context = weakref.ref(client_context)
 | |
|         support.gc_collect()
 | |
|         self.assertIsNone(client_context())
 | |
| 
 | |
|     def test_create_connection_memory_leak(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             sock.start_tls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.shutdown(socket.SHUT_RDWR)
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 # XXX: We assume user stores the transport in protocol
 | |
|                 proto.tr = tr
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr,
 | |
|                 ssl=client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data, b'O')
 | |
|             tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             tr.close()
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|         # No garbage is left for SSL client from loop.create_connection, even
 | |
|         # if user stores the SSLTransport in corresponding protocol instance
 | |
|         client_context = weakref.ref(client_context)
 | |
|         support.gc_collect()
 | |
|         self.assertIsNone(client_context())
 | |
| 
 | |
|     @socket_helper.skip_if_tcp_blackhole
 | |
|     def test_start_tls_client_buf_proto_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
|         client_con_made_calls = 0
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.start_tls(server_context, server_side=True)
 | |
| 
 | |
|             sock.sendall(b'O')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.sendall(b'2')
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.shutdown(socket.SHUT_RDWR)
 | |
|             sock.close()
 | |
| 
 | |
|         class ClientProtoFirst(asyncio.BufferedProtocol):
 | |
|             def __init__(self, on_data):
 | |
|                 self.on_data = on_data
 | |
|                 self.buf = bytearray(1)
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 nonlocal client_con_made_calls
 | |
|                 client_con_made_calls += 1
 | |
| 
 | |
|             def get_buffer(self, sizehint):
 | |
|                 return self.buf
 | |
| 
 | |
|             def buffer_updated(slf, nsize):
 | |
|                 self.assertEqual(nsize, 1)
 | |
|                 slf.on_data.set_result(bytes(slf.buf[:nsize]))
 | |
| 
 | |
|         class ClientProtoSecond(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 nonlocal client_con_made_calls
 | |
|                 client_con_made_calls += 1
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data1 = self.loop.create_future()
 | |
|             on_data2 = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProtoFirst(on_data1), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
|             new_tr = await self.loop.start_tls(tr, proto, client_context)
 | |
| 
 | |
|             self.assertEqual(await on_data1, b'O')
 | |
|             new_tr.write(HELLO_MSG)
 | |
| 
 | |
|             new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
 | |
|             self.assertEqual(await on_data2, b'2')
 | |
|             new_tr.write(HELLO_MSG)
 | |
|             await on_eof
 | |
| 
 | |
|             new_tr.close()
 | |
| 
 | |
|             # connection_made() should be called only once -- when
 | |
|             # we establish connection for the first time. Start TLS
 | |
|             # doesn't call connection_made() on application protocols.
 | |
|             self.assertEqual(client_con_made_calls, 1)
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=self.TIMEOUT))
 | |
| 
 | |
|     def test_start_tls_slow_client_cancel(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
| 
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
|         server_waits_on_handshake = self.loop.create_future()
 | |
| 
 | |
|         def serve(sock):
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             try:
 | |
|                 self.loop.call_soon_threadsafe(
 | |
|                     server_waits_on_handshake.set_result, None)
 | |
|                 data = sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         class ClientProto(asyncio.Protocol):
 | |
|             def __init__(self, on_data, on_eof):
 | |
|                 self.on_data = on_data
 | |
|                 self.on_eof = on_eof
 | |
|                 self.con_made_cnt = 0
 | |
| 
 | |
|             def connection_made(proto, tr):
 | |
|                 proto.con_made_cnt += 1
 | |
|                 # Ensure connection_made gets called only once.
 | |
|                 self.assertEqual(proto.con_made_cnt, 1)
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.on_data.set_result(data)
 | |
| 
 | |
|             def eof_received(self):
 | |
|                 self.on_eof.set_result(True)
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.sleep(0.5)
 | |
| 
 | |
|             on_data = self.loop.create_future()
 | |
|             on_eof = self.loop.create_future()
 | |
| 
 | |
|             tr, proto = await self.loop.create_connection(
 | |
|                 lambda: ClientProto(on_data, on_eof), *addr)
 | |
| 
 | |
|             tr.write(HELLO_MSG)
 | |
| 
 | |
|             await server_waits_on_handshake
 | |
| 
 | |
|             with self.assertRaises(asyncio.TimeoutError):
 | |
|                 await asyncio.wait_for(
 | |
|                     self.loop.start_tls(tr, proto, client_context),
 | |
|                     0.5)
 | |
| 
 | |
|         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
 | |
|             self.loop.run_until_complete(
 | |
|                 asyncio.wait_for(client(srv.addr),
 | |
|                                  timeout=support.SHORT_TIMEOUT))
 | |
| 
 | |
|     @socket_helper.skip_if_tcp_blackhole
 | |
|     def test_start_tls_server_1(self):
 | |
|         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 | |
|         ANSWER = b'answer'
 | |
| 
 | |
|         server_context = test_utils.simple_server_sslcontext()
 | |
|         client_context = test_utils.simple_client_sslcontext()
 | |
|         answer = None
 | |
| 
 | |
|         def client(sock, addr):
 | |
|             nonlocal answer
 | |
|             sock.settimeout(self.TIMEOUT)
 | |
| 
 | |
|             sock.connect(addr)
 | |
|             data = sock.recv_all(len(HELLO_MSG))
 | |
|             self.assertEqual(len(data), len(HELLO_MSG))
 | |
| 
 | |
|             sock.start_tls(client_context)
 | |
|             sock.sendall(HELLO_MSG)
 | |
|             answer = sock.recv_all(len(ANSWER))
 | |
|             sock.close()
 | |
| 
 | |
|         class ServerProto(asyncio.Protocol):
 | |
|             def __init__(self, on_con, on_con_lost, on_got_hello):
 | |
|                 self.on_con = on_con
 | |
|                 self.on_con_lost = on_con_lost
 | |
|                 self.on_got_hello = on_got_hello
 | |
|                 self.data = b''
 | |
|                 self.transport = None
 | |
| 
 | |
|             def connection_made(self, tr):
 | |
|                 self.transport = tr
 | |
|                 self.on_con.set_result(tr)
 | |
| 
 | |
|             def replace_transport(self, tr):
 | |
|                 self.transport = tr
 | |
| 
 | |
|             def data_received(self, data):
 | |
|                 self.data += data
 | |
|                 if len(self.data) >= len(HELLO_MSG):
 | |
|                     self.on_got_hello.set_result(None)
 | |
| 
 | |
|             def connection_lost(self, exc):
 | |
|                 self.transport = None
 | |
|                 if exc is None:
 | |
|                     self.on_con_lost.set_result(None)
 | |
|                 else:
 | |
|                     self.on_con_lost.set_exception(exc)
 | |
| 
 | |
|         async def main(proto, on_con, on_con_lost, on_got_hello):
 | |
|             tr = await on_con
 | |
|             tr.write(HELLO_MSG)
 | |
| 
 | |
|             self.assertEqual(proto.data, b'')
 | |
| 
 | |
|             new_tr = await self.loop.start_tls(
 | |
|                 tr, proto, server_context,
 | |
|                 server_side=True,
 | |
|                 ssl_handshake_timeout=self.TIMEOUT)
 | |
|             proto.replace_transport(new_tr)
 | |
| 
 | |
|             await on_got_hello
 | |
|             new_tr.write(ANSWER)
 | |
| 
 | |
|             await on_con_lost
 | |
|             self.assertEqual(proto.data, HELLO_MSG)
 | |
|             new_tr.close()
 | |
| 
 | |
|         async def run_main():
 | |
|             on_con = self.loop.create_future()
 | |
|             on_con_lost = self.loop.create_future()
 | |
|             on_got_hello = self.loop.create_future()
 | |
|             proto = ServerProto(on_con, on_con_lost, on_got_hello)
 | |
| 
 | |
|             server = await self.loop.create_server(
 | |
|                 lambda: proto, '127.0.0.1', 0)
 | |
|             addr = server.sockets[0].getsockname()
 | |
| 
 | |
|             with self.tcp_client(lambda sock: client(sock, addr),
 | |
|                                  timeout=self.TIMEOUT):
 | |
|                 await asyncio.wait_for(
 | |
|                     main(proto, on_con, on_con_lost, on_got_hello),
 | |
|                     timeout=self.TIMEOUT)
 | |
| 
 | |
|             server.close()
 | |
|             await server.wait_closed()
 | |
|             self.assertEqual(answer, ANSWER)
 | |
| 
 | |
|         self.loop.run_until_complete(run_main())
 | |
| 
 | |
|     def test_start_tls_wrong_args(self):
 | |
|         async def main():
 | |
|             with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
 | |
|                 await self.loop.start_tls(None, None, None)
 | |
| 
 | |
|             sslctx = test_utils.simple_server_sslcontext()
 | |
|             with self.assertRaisesRegex(TypeError, 'is not supported'):
 | |
|                 await self.loop.start_tls(None, None, sslctx)
 | |
| 
 | |
|         self.loop.run_until_complete(main())
 | |
| 
 | |
|     def test_handshake_timeout(self):
 | |
|         # bpo-29970: Check that a connection is aborted if handshake is not
 | |
|         # completed in timeout period, instead of remaining open indefinitely
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         messages = []
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 | |
| 
 | |
|         server_side_aborted = False
 | |
| 
 | |
|         def server(sock):
 | |
|             nonlocal server_side_aborted
 | |
|             try:
 | |
|                 sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 server_side_aborted = True
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             await asyncio.wait_for(
 | |
|                 self.loop.create_connection(
 | |
|                     asyncio.Protocol,
 | |
|                     *addr,
 | |
|                     ssl=client_sslctx,
 | |
|                     server_hostname='',
 | |
|                     ssl_handshake_timeout=support.SHORT_TIMEOUT),
 | |
|                 0.5)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaises(asyncio.TimeoutError):
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         self.assertTrue(server_side_aborted)
 | |
| 
 | |
|         # Python issue #23197: cancelling a handshake must not raise an
 | |
|         # exception or log an error, even if the handshake failed
 | |
|         self.assertEqual(messages, [])
 | |
| 
 | |
|         # The 10s handshake timeout should be cancelled to free related
 | |
|         # objects without really waiting for 10s
 | |
|         client_sslctx = weakref.ref(client_sslctx)
 | |
|         support.gc_collect()
 | |
|         self.assertIsNone(client_sslctx())
 | |
| 
 | |
|     def test_create_connection_ssl_slow_handshake(self):
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         messages = []
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 | |
| 
 | |
|         def server(sock):
 | |
|             try:
 | |
|                 sock.recv_all(1024 * 1024)
 | |
|             except ConnectionAbortedError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='',
 | |
|                 ssl_handshake_timeout=1.0)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaisesRegex(
 | |
|                     ConnectionAbortedError,
 | |
|                     r'SSL handshake.*is taking longer'):
 | |
| 
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         self.assertEqual(messages, [])
 | |
| 
 | |
|     def test_create_connection_ssl_failed_certificate(self):
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: None)
 | |
| 
 | |
|         sslctx = test_utils.simple_server_sslcontext()
 | |
|         client_sslctx = test_utils.simple_client_sslcontext(
 | |
|             disable_verify=False)
 | |
| 
 | |
|         def server(sock):
 | |
|             try:
 | |
|                 sock.start_tls(
 | |
|                     sslctx,
 | |
|                     server_side=True)
 | |
|             except ssl.SSLError:
 | |
|                 pass
 | |
|             except OSError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='',
 | |
|                 ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             with self.assertRaises(ssl.SSLCertVerificationError):
 | |
|                 self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|     def test_start_tls_client_corrupted_ssl(self):
 | |
|         self.loop.set_exception_handler(lambda loop, ctx: None)
 | |
| 
 | |
|         sslctx = test_utils.simple_server_sslcontext()
 | |
|         client_sslctx = test_utils.simple_client_sslcontext()
 | |
| 
 | |
|         def server(sock):
 | |
|             orig_sock = sock.dup()
 | |
|             try:
 | |
|                 sock.start_tls(
 | |
|                     sslctx,
 | |
|                     server_side=True)
 | |
|                 sock.sendall(b'A\n')
 | |
|                 sock.recv_all(1)
 | |
|                 orig_sock.send(b'please corrupt the SSL connection')
 | |
|             except ssl.SSLError:
 | |
|                 pass
 | |
|             finally:
 | |
|                 orig_sock.close()
 | |
|                 sock.close()
 | |
| 
 | |
|         async def client(addr):
 | |
|             reader, writer = await asyncio.open_connection(
 | |
|                 *addr,
 | |
|                 ssl=client_sslctx,
 | |
|                 server_hostname='')
 | |
| 
 | |
|             self.assertEqual(await reader.readline(), b'A\n')
 | |
|             writer.write(b'B')
 | |
|             with self.assertRaises(ssl.SSLError):
 | |
|                 await reader.readline()
 | |
| 
 | |
|             writer.close()
 | |
|             return 'OK'
 | |
| 
 | |
|         with self.tcp_server(server,
 | |
|                              max_clients=1,
 | |
|                              backlog=1) as srv:
 | |
| 
 | |
|             res = self.loop.run_until_complete(client(srv.addr))
 | |
| 
 | |
|         self.assertEqual(res, 'OK')
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(ssl is None, 'No ssl module')
 | |
| class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
 | |
| 
 | |
|     def new_loop(self):
 | |
|         return asyncio.SelectorEventLoop()
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(ssl is None, 'No ssl module')
 | |
| @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
 | |
| class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
 | |
| 
 | |
|     def new_loop(self):
 | |
|         return asyncio.ProactorEventLoop()
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     unittest.main()
 |