| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | """Tests for asyncio/sslproto.py.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-06-28 10:55:36 -04:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | import socket | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | import unittest | 
					
						
							| 
									
										
										
										
											2019-03-17 17:51:10 -05:00
										 |  |  | import weakref | 
					
						
							| 
									
										
										
										
											2022-03-10 21:36:22 +00:00
										 |  |  | from test import support | 
					
						
							| 
									
										
										
										
											2023-09-07 01:58:03 +02:00
										 |  |  | from test.support import socket_helper | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | from unittest import mock | 
					
						
							| 
									
										
										
										
											2015-01-28 00:30:40 +01:00
										 |  |  | try: | 
					
						
							|  |  |  |     import ssl | 
					
						
							|  |  |  | except ImportError: | 
					
						
							|  |  |  |     ssl = None | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | import asyncio | 
					
						
							| 
									
										
										
										
											2016-06-28 10:55:36 -04:00
										 |  |  | from asyncio import log | 
					
						
							| 
									
										
										
										
											2018-06-08 00:25:52 +02:00
										 |  |  | from asyncio import protocols | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | from asyncio import sslproto | 
					
						
							| 
									
										
										
										
											2017-12-11 10:04:40 -05:00
										 |  |  | from test.test_asyncio import utils as test_utils | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | from test.test_asyncio import functional as func_tests | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-01 20:34:09 -07:00
										 |  |  | def tearDownModule(): | 
					
						
							|  |  |  |     asyncio.set_event_loop_policy(None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-01-29 00:35:56 +01:00
										 |  |  | @unittest.skipIf(ssl is None, 'No ssl module') | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | class SslProtoHandshakeTests(test_utils.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def setUp(self): | 
					
						
							| 
									
										
										
										
											2016-11-04 14:29:28 -04:00
										 |  |  |         super().setUp() | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  |         self.loop = asyncio.new_event_loop() | 
					
						
							|  |  |  |         self.set_event_loop(self.loop) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |     def ssl_protocol(self, *, waiter=None, proto=None): | 
					
						
							| 
									
										
										
										
											2015-01-29 00:35:56 +01:00
										 |  |  |         sslcontext = test_utils.dummy_ssl_context() | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |         if proto is None:  # app protocol | 
					
						
							|  |  |  |             proto = asyncio.Protocol() | 
					
						
							|  |  |  |         ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter, | 
					
						
							|  |  |  |                                          ssl_handshake_timeout=0.1) | 
					
						
							|  |  |  |         self.assertIs(ssl_proto._app_transport.get_protocol(), proto) | 
					
						
							|  |  |  |         self.addCleanup(ssl_proto._app_transport.close) | 
					
						
							|  |  |  |         return ssl_proto | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def connection_made(self, ssl_proto, *, do_handshake=None): | 
					
						
							| 
									
										
										
										
											2015-01-29 00:35:56 +01:00
										 |  |  |         transport = mock.Mock() | 
					
						
							| 
									
										
										
										
											2022-02-15 18:34:00 +05:30
										 |  |  |         sslobj = mock.Mock() | 
					
						
							|  |  |  |         # emulate reading decompressed data | 
					
						
							|  |  |  |         sslobj.read.side_effect = ssl.SSLWantReadError | 
					
						
							| 
									
										
										
										
											2023-12-20 23:09:01 +00:00
										 |  |  |         sslobj.write.side_effect = ssl.SSLWantReadError | 
					
						
							| 
									
										
										
										
											2022-02-15 18:34:00 +05:30
										 |  |  |         if do_handshake is not None: | 
					
						
							|  |  |  |             sslobj.do_handshake = do_handshake | 
					
						
							|  |  |  |         ssl_proto._sslobj = sslobj | 
					
						
							|  |  |  |         ssl_proto.connection_made(transport) | 
					
						
							| 
									
										
										
										
											2017-06-09 14:46:14 -07:00
										 |  |  |         return transport | 
					
						
							| 
									
										
										
										
											2015-01-29 00:35:56 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-20 20:24:43 +02:00
										 |  |  |     def test_handshake_timeout_zero(self): | 
					
						
							|  |  |  |         sslcontext = test_utils.dummy_ssl_context() | 
					
						
							|  |  |  |         app_proto = mock.Mock() | 
					
						
							|  |  |  |         waiter = mock.Mock() | 
					
						
							|  |  |  |         with self.assertRaisesRegex(ValueError, 'a positive number'): | 
					
						
							|  |  |  |             sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, | 
					
						
							|  |  |  |                                  ssl_handshake_timeout=0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_handshake_timeout_negative(self): | 
					
						
							|  |  |  |         sslcontext = test_utils.dummy_ssl_context() | 
					
						
							|  |  |  |         app_proto = mock.Mock() | 
					
						
							|  |  |  |         waiter = mock.Mock() | 
					
						
							|  |  |  |         with self.assertRaisesRegex(ValueError, 'a positive number'): | 
					
						
							|  |  |  |             sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, | 
					
						
							|  |  |  |                                  ssl_handshake_timeout=-10) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-01-29 00:35:56 +01:00
										 |  |  |     def test_eof_received_waiter(self): | 
					
						
							| 
									
										
										
										
											2019-09-11 16:07:37 +03:00
										 |  |  |         waiter = self.loop.create_future() | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |         ssl_proto = self.ssl_protocol(waiter=waiter) | 
					
						
							| 
									
										
										
										
											2022-02-15 18:34:00 +05:30
										 |  |  |         self.connection_made( | 
					
						
							|  |  |  |             ssl_proto, | 
					
						
							|  |  |  |             do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2015-01-29 00:35:56 +01:00
										 |  |  |         ssl_proto.eof_received() | 
					
						
							|  |  |  |         test_utils.run_briefly(self.loop) | 
					
						
							|  |  |  |         self.assertIsInstance(waiter.exception(), ConnectionResetError) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-06-28 10:55:36 -04:00
										 |  |  |     def test_fatal_error_no_name_error(self): | 
					
						
							|  |  |  |         # From issue #363. | 
					
						
							|  |  |  |         # _fatal_error() generates a NameError if sslproto.py | 
					
						
							|  |  |  |         # does not import base_events. | 
					
						
							| 
									
										
										
										
											2019-09-11 16:07:37 +03:00
										 |  |  |         waiter = self.loop.create_future() | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |         ssl_proto = self.ssl_protocol(waiter=waiter) | 
					
						
							| 
									
										
										
										
											2016-06-28 10:55:36 -04:00
										 |  |  |         # Temporarily turn off error logging so as not to spoil test output. | 
					
						
							|  |  |  |         log_level = log.logger.getEffectiveLevel() | 
					
						
							|  |  |  |         log.logger.setLevel(logging.FATAL) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             ssl_proto._fatal_error(None) | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             # Restore error logging. | 
					
						
							|  |  |  |             log.logger.setLevel(log_level) | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-12-16 11:50:41 -05:00
										 |  |  |     def test_connection_lost(self): | 
					
						
							|  |  |  |         # From issue #472. | 
					
						
							|  |  |  |         # yield from waiter hang if lost_connection was called. | 
					
						
							| 
									
										
										
										
											2019-09-11 16:07:37 +03:00
										 |  |  |         waiter = self.loop.create_future() | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |         ssl_proto = self.ssl_protocol(waiter=waiter) | 
					
						
							| 
									
										
										
										
											2022-02-15 18:34:00 +05:30
										 |  |  |         self.connection_made( | 
					
						
							|  |  |  |             ssl_proto, | 
					
						
							|  |  |  |             do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2016-12-16 11:50:41 -05:00
										 |  |  |         ssl_proto.connection_lost(ConnectionAbortedError) | 
					
						
							|  |  |  |         test_utils.run_briefly(self.loop) | 
					
						
							|  |  |  |         self.assertIsInstance(waiter.exception(), ConnectionAbortedError) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-06-09 14:46:14 -07:00
										 |  |  |     def test_close_during_handshake(self): | 
					
						
							|  |  |  |         # bpo-29743 Closing transport during handshake process leaks socket | 
					
						
							| 
									
										
										
										
											2019-09-11 16:07:37 +03:00
										 |  |  |         waiter = self.loop.create_future() | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |         ssl_proto = self.ssl_protocol(waiter=waiter) | 
					
						
							| 
									
										
										
										
											2017-06-09 14:46:14 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-15 18:34:00 +05:30
										 |  |  |         transport = self.connection_made( | 
					
						
							|  |  |  |             ssl_proto, | 
					
						
							|  |  |  |             do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2017-06-09 14:46:14 -07:00
										 |  |  |         test_utils.run_briefly(self.loop) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ssl_proto._app_transport.close() | 
					
						
							| 
									
										
										
										
											2023-12-20 23:09:01 +00:00
										 |  |  |         self.assertTrue(transport._force_close.called) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_close_during_ssl_over_ssl(self): | 
					
						
							|  |  |  |         # gh-113214: passing exceptions from the inner wrapped SSL protocol to the | 
					
						
							|  |  |  |         # shim transport provided by the outer SSL protocol should not raise | 
					
						
							|  |  |  |         # attribute errors | 
					
						
							|  |  |  |         outer = self.ssl_protocol(proto=self.ssl_protocol()) | 
					
						
							|  |  |  |         self.connection_made(outer) | 
					
						
							|  |  |  |         # Closing the outer app transport should not raise an exception | 
					
						
							|  |  |  |         messages = [] | 
					
						
							|  |  |  |         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) | 
					
						
							|  |  |  |         outer._app_transport.close() | 
					
						
							|  |  |  |         self.assertEqual(messages, []) | 
					
						
							| 
									
										
										
										
											2017-06-09 14:46:14 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-03-12 12:23:30 -07:00
										 |  |  |     def test_get_extra_info_on_closed_connection(self): | 
					
						
							| 
									
										
										
										
											2019-09-11 16:07:37 +03:00
										 |  |  |         waiter = self.loop.create_future() | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |         ssl_proto = self.ssl_protocol(waiter=waiter) | 
					
						
							| 
									
										
										
										
											2017-03-12 12:23:30 -07:00
										 |  |  |         self.assertIsNone(ssl_proto._get_extra_info('socket')) | 
					
						
							|  |  |  |         default = object() | 
					
						
							|  |  |  |         self.assertIs(ssl_proto._get_extra_info('socket', default), default) | 
					
						
							|  |  |  |         self.connection_made(ssl_proto) | 
					
						
							|  |  |  |         self.assertIsNotNone(ssl_proto._get_extra_info('socket')) | 
					
						
							|  |  |  |         ssl_proto.connection_lost(None) | 
					
						
							|  |  |  |         self.assertIsNone(ssl_proto._get_extra_info('socket')) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-10-19 19:49:57 +02:00
										 |  |  |     def test_set_new_app_protocol(self): | 
					
						
							| 
									
										
										
										
											2019-09-11 16:07:37 +03:00
										 |  |  |         waiter = self.loop.create_future() | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |         ssl_proto = self.ssl_protocol(waiter=waiter) | 
					
						
							| 
									
										
										
										
											2017-10-19 19:49:57 +02:00
										 |  |  |         new_app_proto = asyncio.Protocol() | 
					
						
							|  |  |  |         ssl_proto._app_transport.set_protocol(new_app_proto) | 
					
						
							|  |  |  |         self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto) | 
					
						
							|  |  |  |         self.assertIs(ssl_proto._app_protocol, new_app_proto) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  |     def test_data_received_after_closing(self): | 
					
						
							|  |  |  |         ssl_proto = self.ssl_protocol() | 
					
						
							|  |  |  |         self.connection_made(ssl_proto) | 
					
						
							|  |  |  |         transp = ssl_proto._app_transport | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         transp.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # should not raise | 
					
						
							| 
									
										
										
										
											2022-02-15 18:34:00 +05:30
										 |  |  |         self.assertIsNone(ssl_proto.buffer_updated(5)) | 
					
						
							| 
									
										
										
										
											2018-03-10 17:48:35 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_write_after_closing(self): | 
					
						
							|  |  |  |         ssl_proto = self.ssl_protocol() | 
					
						
							|  |  |  |         self.connection_made(ssl_proto) | 
					
						
							|  |  |  |         transp = ssl_proto._app_transport | 
					
						
							|  |  |  |         transp.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # should not raise | 
					
						
							|  |  |  |         self.assertIsNone(transp.write(b'data')) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-03-12 12:23:30 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | ############################################################################## | 
					
						
							|  |  |  | # Start TLS Tests | 
					
						
							|  |  |  | ############################################################################## | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BaseStartTLS(func_tests.FunctionalTestCaseMixin): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |     PAYLOAD_SIZE = 1024 * 100 | 
					
						
							| 
									
										
										
										
											2019-12-10 21:12:26 +01:00
										 |  |  |     TIMEOUT = support.LONG_TIMEOUT | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |     def new_loop(self): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |     def test_buf_feed_data(self): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         class Proto(asyncio.BufferedProtocol): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def __init__(self, bufsize, usemv): | 
					
						
							|  |  |  |                 self.buf = bytearray(bufsize) | 
					
						
							|  |  |  |                 self.mv = memoryview(self.buf) | 
					
						
							|  |  |  |                 self.data = b'' | 
					
						
							|  |  |  |                 self.usemv = usemv | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def get_buffer(self, sizehint): | 
					
						
							|  |  |  |                 if self.usemv: | 
					
						
							|  |  |  |                     return self.mv | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     return self.buf | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def buffer_updated(self, nsize): | 
					
						
							|  |  |  |                 if self.usemv: | 
					
						
							|  |  |  |                     self.data += self.mv[:nsize] | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     self.data += self.buf[:nsize] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for usemv in [False, True]: | 
					
						
							|  |  |  |             proto = Proto(1, usemv) | 
					
						
							| 
									
										
										
										
											2018-06-08 10:32:06 +02:00
										 |  |  |             protocols._feed_data_to_buffered_proto(proto, b'12345') | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             self.assertEqual(proto.data, b'12345') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             proto = Proto(2, usemv) | 
					
						
							| 
									
										
										
										
											2018-06-08 10:32:06 +02:00
										 |  |  |             protocols._feed_data_to_buffered_proto(proto, b'12345') | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             self.assertEqual(proto.data, b'12345') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             proto = Proto(2, usemv) | 
					
						
							| 
									
										
										
										
											2018-06-08 10:32:06 +02:00
										 |  |  |             protocols._feed_data_to_buffered_proto(proto, b'1234') | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             self.assertEqual(proto.data, b'1234') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             proto = Proto(4, usemv) | 
					
						
							| 
									
										
										
										
											2018-06-08 10:32:06 +02:00
										 |  |  |             protocols._feed_data_to_buffered_proto(proto, b'1234') | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             self.assertEqual(proto.data, b'1234') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             proto = Proto(100, usemv) | 
					
						
							| 
									
										
										
										
											2018-06-08 10:32:06 +02:00
										 |  |  |             protocols._feed_data_to_buffered_proto(proto, b'12345') | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             self.assertEqual(proto.data, b'12345') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             proto = Proto(0, usemv) | 
					
						
							|  |  |  |             with self.assertRaisesRegex(RuntimeError, 'empty buffer'): | 
					
						
							| 
									
										
										
										
											2018-06-08 10:32:06 +02:00
										 |  |  |                 protocols._feed_data_to_buffered_proto(proto, b'12345') | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_start_tls_client_reg_proto_1(self): | 
					
						
							|  |  |  |         HELLO_MSG = b'1' * self.PAYLOAD_SIZE | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         server_context = test_utils.simple_server_sslcontext() | 
					
						
							|  |  |  |         client_context = test_utils.simple_client_sslcontext() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def serve(sock): | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             sock.settimeout(self.TIMEOUT) | 
					
						
							| 
									
										
										
										
											2018-01-29 00:25:05 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             data = sock.recv_all(len(HELLO_MSG)) | 
					
						
							|  |  |  |             self.assertEqual(len(data), len(HELLO_MSG)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sock.start_tls(server_context, server_side=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sock.sendall(b'O') | 
					
						
							|  |  |  |             data = sock.recv_all(len(HELLO_MSG)) | 
					
						
							|  |  |  |             self.assertEqual(len(data), len(HELLO_MSG)) | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |             sock.shutdown(socket.SHUT_RDWR) | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         class ClientProto(asyncio.Protocol): | 
					
						
							|  |  |  |             def __init__(self, on_data, on_eof): | 
					
						
							|  |  |  |                 self.on_data = on_data | 
					
						
							|  |  |  |                 self.on_eof = on_eof | 
					
						
							|  |  |  |                 self.con_made_cnt = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def connection_made(proto, tr): | 
					
						
							|  |  |  |                 proto.con_made_cnt += 1 | 
					
						
							|  |  |  |                 # Ensure connection_made gets called only once. | 
					
						
							|  |  |  |                 self.assertEqual(proto.con_made_cnt, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def data_received(self, data): | 
					
						
							|  |  |  |                 self.on_data.set_result(data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def eof_received(self): | 
					
						
							|  |  |  |                 self.on_eof.set_result(True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							| 
									
										
										
										
											2018-10-02 13:53:06 -04:00
										 |  |  |             await asyncio.sleep(0.5) | 
					
						
							| 
									
										
										
										
											2018-01-29 00:25:05 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             on_data = self.loop.create_future() | 
					
						
							|  |  |  |             on_eof = self.loop.create_future() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             tr, proto = await self.loop.create_connection( | 
					
						
							|  |  |  |                 lambda: ClientProto(on_data, on_eof), *addr) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             tr.write(HELLO_MSG) | 
					
						
							|  |  |  |             new_tr = await self.loop.start_tls(tr, proto, client_context) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             self.assertEqual(await on_data, b'O') | 
					
						
							|  |  |  |             new_tr.write(HELLO_MSG) | 
					
						
							|  |  |  |             await on_eof | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             new_tr.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 15:48:59 -04:00
										 |  |  |         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             self.loop.run_until_complete( | 
					
						
							| 
									
										
										
										
											2019-12-11 11:30:03 +01:00
										 |  |  |                 asyncio.wait_for(client(srv.addr), | 
					
						
							|  |  |  |                                  timeout=support.SHORT_TIMEOUT)) | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-03-17 17:51:10 -05:00
										 |  |  |         # No garbage is left if SSL is closed uncleanly | 
					
						
							|  |  |  |         client_context = weakref.ref(client_context) | 
					
						
							| 
									
										
										
										
											2021-08-26 15:48:24 +03:00
										 |  |  |         support.gc_collect() | 
					
						
							| 
									
										
										
										
											2019-03-17 17:51:10 -05:00
										 |  |  |         self.assertIsNone(client_context()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_create_connection_memory_leak(self): | 
					
						
							|  |  |  |         HELLO_MSG = b'1' * self.PAYLOAD_SIZE | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         server_context = test_utils.simple_server_sslcontext() | 
					
						
							|  |  |  |         client_context = test_utils.simple_client_sslcontext() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def serve(sock): | 
					
						
							|  |  |  |             sock.settimeout(self.TIMEOUT) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sock.start_tls(server_context, server_side=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sock.sendall(b'O') | 
					
						
							|  |  |  |             data = sock.recv_all(len(HELLO_MSG)) | 
					
						
							|  |  |  |             self.assertEqual(len(data), len(HELLO_MSG)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sock.shutdown(socket.SHUT_RDWR) | 
					
						
							|  |  |  |             sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         class ClientProto(asyncio.Protocol): | 
					
						
							|  |  |  |             def __init__(self, on_data, on_eof): | 
					
						
							|  |  |  |                 self.on_data = on_data | 
					
						
							|  |  |  |                 self.on_eof = on_eof | 
					
						
							|  |  |  |                 self.con_made_cnt = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def connection_made(proto, tr): | 
					
						
							|  |  |  |                 # XXX: We assume user stores the transport in protocol | 
					
						
							|  |  |  |                 proto.tr = tr | 
					
						
							|  |  |  |                 proto.con_made_cnt += 1 | 
					
						
							|  |  |  |                 # Ensure connection_made gets called only once. | 
					
						
							|  |  |  |                 self.assertEqual(proto.con_made_cnt, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def data_received(self, data): | 
					
						
							|  |  |  |                 self.on_data.set_result(data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def eof_received(self): | 
					
						
							|  |  |  |                 self.on_eof.set_result(True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							|  |  |  |             await asyncio.sleep(0.5) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             on_data = self.loop.create_future() | 
					
						
							|  |  |  |             on_eof = self.loop.create_future() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             tr, proto = await self.loop.create_connection( | 
					
						
							|  |  |  |                 lambda: ClientProto(on_data, on_eof), *addr, | 
					
						
							|  |  |  |                 ssl=client_context) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             self.assertEqual(await on_data, b'O') | 
					
						
							|  |  |  |             tr.write(HELLO_MSG) | 
					
						
							|  |  |  |             await on_eof | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             tr.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: | 
					
						
							|  |  |  |             self.loop.run_until_complete( | 
					
						
							| 
									
										
										
										
											2019-12-11 11:30:03 +01:00
										 |  |  |                 asyncio.wait_for(client(srv.addr), | 
					
						
							|  |  |  |                                  timeout=support.SHORT_TIMEOUT)) | 
					
						
							| 
									
										
										
										
											2019-03-17 17:51:10 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # No garbage is left for SSL client from loop.create_connection, even | 
					
						
							|  |  |  |         # if user stores the SSLTransport in corresponding protocol instance | 
					
						
							|  |  |  |         client_context = weakref.ref(client_context) | 
					
						
							| 
									
										
										
										
											2021-08-26 15:48:24 +03:00
										 |  |  |         support.gc_collect() | 
					
						
							| 
									
										
										
										
											2019-03-17 17:51:10 -05:00
										 |  |  |         self.assertIsNone(client_context()) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-07 01:58:03 +02:00
										 |  |  |     @socket_helper.skip_if_tcp_blackhole | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |     def test_start_tls_client_buf_proto_1(self): | 
					
						
							|  |  |  |         HELLO_MSG = b'1' * self.PAYLOAD_SIZE | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         server_context = test_utils.simple_server_sslcontext() | 
					
						
							|  |  |  |         client_context = test_utils.simple_client_sslcontext() | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |         client_con_made_calls = 0 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         def serve(sock): | 
					
						
							|  |  |  |             sock.settimeout(self.TIMEOUT) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             data = sock.recv_all(len(HELLO_MSG)) | 
					
						
							|  |  |  |             self.assertEqual(len(data), len(HELLO_MSG)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sock.start_tls(server_context, server_side=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sock.sendall(b'O') | 
					
						
							|  |  |  |             data = sock.recv_all(len(HELLO_MSG)) | 
					
						
							|  |  |  |             self.assertEqual(len(data), len(HELLO_MSG)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |             sock.sendall(b'2') | 
					
						
							|  |  |  |             data = sock.recv_all(len(HELLO_MSG)) | 
					
						
							|  |  |  |             self.assertEqual(len(data), len(HELLO_MSG)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             sock.shutdown(socket.SHUT_RDWR) | 
					
						
							|  |  |  |             sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |         class ClientProtoFirst(asyncio.BufferedProtocol): | 
					
						
							|  |  |  |             def __init__(self, on_data): | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |                 self.on_data = on_data | 
					
						
							|  |  |  |                 self.buf = bytearray(1) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |             def connection_made(self, tr): | 
					
						
							|  |  |  |                 nonlocal client_con_made_calls | 
					
						
							|  |  |  |                 client_con_made_calls += 1 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |             def get_buffer(self, sizehint): | 
					
						
							|  |  |  |                 return self.buf | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-20 12:23:05 +02:00
										 |  |  |             def buffer_updated(slf, nsize): | 
					
						
							|  |  |  |                 self.assertEqual(nsize, 1) | 
					
						
							|  |  |  |                 slf.on_data.set_result(bytes(slf.buf[:nsize])) | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |         class ClientProtoSecond(asyncio.Protocol): | 
					
						
							|  |  |  |             def __init__(self, on_data, on_eof): | 
					
						
							|  |  |  |                 self.on_data = on_data | 
					
						
							|  |  |  |                 self.on_eof = on_eof | 
					
						
							|  |  |  |                 self.con_made_cnt = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def connection_made(self, tr): | 
					
						
							|  |  |  |                 nonlocal client_con_made_calls | 
					
						
							|  |  |  |                 client_con_made_calls += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def data_received(self, data): | 
					
						
							|  |  |  |                 self.on_data.set_result(data) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             def eof_received(self): | 
					
						
							|  |  |  |                 self.on_eof.set_result(True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							| 
									
										
										
										
											2018-10-02 13:53:06 -04:00
										 |  |  |             await asyncio.sleep(0.5) | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |             on_data1 = self.loop.create_future() | 
					
						
							|  |  |  |             on_data2 = self.loop.create_future() | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             on_eof = self.loop.create_future() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             tr, proto = await self.loop.create_connection( | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |                 lambda: ClientProtoFirst(on_data1), *addr) | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |             tr.write(HELLO_MSG) | 
					
						
							|  |  |  |             new_tr = await self.loop.start_tls(tr, proto, client_context) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |             self.assertEqual(await on_data1, b'O') | 
					
						
							|  |  |  |             new_tr.write(HELLO_MSG) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) | 
					
						
							|  |  |  |             self.assertEqual(await on_data2, b'2') | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             new_tr.write(HELLO_MSG) | 
					
						
							|  |  |  |             await on_eof | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             new_tr.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-29 05:02:40 -04:00
										 |  |  |             # connection_made() should be called only once -- when | 
					
						
							|  |  |  |             # we establish connection for the first time. Start TLS | 
					
						
							|  |  |  |             # doesn't call connection_made() on application protocols. | 
					
						
							|  |  |  |             self.assertEqual(client_con_made_calls, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 15:48:59 -04:00
										 |  |  |         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             self.loop.run_until_complete( | 
					
						
							|  |  |  |                 asyncio.wait_for(client(srv.addr), | 
					
						
							| 
									
										
										
										
											2018-10-02 13:53:06 -04:00
										 |  |  |                                  timeout=self.TIMEOUT)) | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  |     def test_start_tls_slow_client_cancel(self): | 
					
						
							|  |  |  |         HELLO_MSG = b'1' * self.PAYLOAD_SIZE | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         client_context = test_utils.simple_client_sslcontext() | 
					
						
							|  |  |  |         server_waits_on_handshake = self.loop.create_future() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def serve(sock): | 
					
						
							|  |  |  |             sock.settimeout(self.TIMEOUT) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             data = sock.recv_all(len(HELLO_MSG)) | 
					
						
							|  |  |  |             self.assertEqual(len(data), len(HELLO_MSG)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 self.loop.call_soon_threadsafe( | 
					
						
							|  |  |  |                     server_waits_on_handshake.set_result, None) | 
					
						
							|  |  |  |                 data = sock.recv_all(1024 * 1024) | 
					
						
							|  |  |  |             except ConnectionAbortedError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  |             finally: | 
					
						
							|  |  |  |                 sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         class ClientProto(asyncio.Protocol): | 
					
						
							|  |  |  |             def __init__(self, on_data, on_eof): | 
					
						
							|  |  |  |                 self.on_data = on_data | 
					
						
							|  |  |  |                 self.on_eof = on_eof | 
					
						
							|  |  |  |                 self.con_made_cnt = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def connection_made(proto, tr): | 
					
						
							|  |  |  |                 proto.con_made_cnt += 1 | 
					
						
							|  |  |  |                 # Ensure connection_made gets called only once. | 
					
						
							|  |  |  |                 self.assertEqual(proto.con_made_cnt, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def data_received(self, data): | 
					
						
							|  |  |  |                 self.on_data.set_result(data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def eof_received(self): | 
					
						
							|  |  |  |                 self.on_eof.set_result(True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							| 
									
										
										
										
											2018-10-02 13:53:06 -04:00
										 |  |  |             await asyncio.sleep(0.5) | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |             on_data = self.loop.create_future() | 
					
						
							|  |  |  |             on_eof = self.loop.create_future() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             tr, proto = await self.loop.create_connection( | 
					
						
							|  |  |  |                 lambda: ClientProto(on_data, on_eof), *addr) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             tr.write(HELLO_MSG) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             await server_waits_on_handshake | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             with self.assertRaises(asyncio.TimeoutError): | 
					
						
							|  |  |  |                 await asyncio.wait_for( | 
					
						
							|  |  |  |                     self.loop.start_tls(tr, proto, client_context), | 
					
						
							| 
									
										
										
										
											2018-10-02 13:53:06 -04:00
										 |  |  |                     0.5) | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: | 
					
						
							|  |  |  |             self.loop.run_until_complete( | 
					
						
							| 
									
										
										
										
											2019-12-11 11:30:03 +01:00
										 |  |  |                 asyncio.wait_for(client(srv.addr), | 
					
						
							|  |  |  |                                  timeout=support.SHORT_TIMEOUT)) | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-07 01:58:03 +02:00
										 |  |  |     @socket_helper.skip_if_tcp_blackhole | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |     def test_start_tls_server_1(self): | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |         HELLO_MSG = b'1' * self.PAYLOAD_SIZE | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |         ANSWER = b'answer' | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         server_context = test_utils.simple_server_sslcontext() | 
					
						
							|  |  |  |         client_context = test_utils.simple_client_sslcontext() | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |         answer = None | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         def client(sock, addr): | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |             nonlocal answer | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             sock.settimeout(self.TIMEOUT) | 
					
						
							| 
									
										
										
										
											2018-01-29 00:25:05 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             sock.connect(addr) | 
					
						
							|  |  |  |             data = sock.recv_all(len(HELLO_MSG)) | 
					
						
							|  |  |  |             self.assertEqual(len(data), len(HELLO_MSG)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sock.start_tls(client_context) | 
					
						
							|  |  |  |             sock.sendall(HELLO_MSG) | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |             answer = sock.recv_all(len(ANSWER)) | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         class ServerProto(asyncio.Protocol): | 
					
						
							| 
									
										
										
										
											2019-10-16 02:36:42 +02:00
										 |  |  |             def __init__(self, on_con, on_con_lost, on_got_hello): | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |                 self.on_con = on_con | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |                 self.on_con_lost = on_con_lost | 
					
						
							| 
									
										
										
										
											2019-10-16 02:36:42 +02:00
										 |  |  |                 self.on_got_hello = on_got_hello | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |                 self.data = b'' | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |                 self.transport = None | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |             def connection_made(self, tr): | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |                 self.transport = tr | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |                 self.on_con.set_result(tr) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |             def replace_transport(self, tr): | 
					
						
							|  |  |  |                 self.transport = tr | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             def data_received(self, data): | 
					
						
							|  |  |  |                 self.data += data | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |                 if len(self.data) >= len(HELLO_MSG): | 
					
						
							| 
									
										
										
										
											2019-10-16 02:36:42 +02:00
										 |  |  |                     self.on_got_hello.set_result(None) | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             def connection_lost(self, exc): | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |                 self.transport = None | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |                 if exc is None: | 
					
						
							|  |  |  |                     self.on_con_lost.set_result(None) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     self.on_con_lost.set_exception(exc) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-16 02:36:42 +02:00
										 |  |  |         async def main(proto, on_con, on_con_lost, on_got_hello): | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             tr = await on_con | 
					
						
							|  |  |  |             tr.write(HELLO_MSG) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             self.assertEqual(proto.data, b'') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             new_tr = await self.loop.start_tls( | 
					
						
							|  |  |  |                 tr, proto, server_context, | 
					
						
							| 
									
										
										
										
											2018-05-28 15:48:59 -04:00
										 |  |  |                 server_side=True, | 
					
						
							|  |  |  |                 ssl_handshake_timeout=self.TIMEOUT) | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |             proto.replace_transport(new_tr) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-16 02:36:42 +02:00
										 |  |  |             await on_got_hello | 
					
						
							|  |  |  |             new_tr.write(ANSWER) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             await on_con_lost | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  |             self.assertEqual(proto.data, HELLO_MSG) | 
					
						
							|  |  |  |             new_tr.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |         async def run_main(): | 
					
						
							|  |  |  |             on_con = self.loop.create_future() | 
					
						
							|  |  |  |             on_con_lost = self.loop.create_future() | 
					
						
							| 
									
										
										
										
											2019-10-16 02:36:42 +02:00
										 |  |  |             on_got_hello = self.loop.create_future() | 
					
						
							|  |  |  |             proto = ServerProto(on_con, on_con_lost, on_got_hello) | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             server = await self.loop.create_server( | 
					
						
							|  |  |  |                 lambda: proto, '127.0.0.1', 0) | 
					
						
							|  |  |  |             addr = server.sockets[0].getsockname() | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 15:48:59 -04:00
										 |  |  |             with self.tcp_client(lambda sock: client(sock, addr), | 
					
						
							|  |  |  |                                  timeout=self.TIMEOUT): | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |                 await asyncio.wait_for( | 
					
						
							| 
									
										
										
										
											2019-10-16 02:36:42 +02:00
										 |  |  |                     main(proto, on_con, on_con_lost, on_got_hello), | 
					
						
							| 
									
										
										
										
											2018-10-02 13:53:06 -04:00
										 |  |  |                     timeout=self.TIMEOUT) | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  |             server.close() | 
					
						
							|  |  |  |             await server.wait_closed() | 
					
						
							| 
									
										
										
										
											2019-06-14 18:26:24 +03:00
										 |  |  |             self.assertEqual(answer, ANSWER) | 
					
						
							| 
									
										
										
										
											2018-05-28 14:31:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         self.loop.run_until_complete(run_main()) | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_start_tls_wrong_args(self): | 
					
						
							|  |  |  |         async def main(): | 
					
						
							|  |  |  |             with self.assertRaisesRegex(TypeError, 'SSLContext, got'): | 
					
						
							|  |  |  |                 await self.loop.start_tls(None, None, None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sslctx = test_utils.simple_server_sslcontext() | 
					
						
							|  |  |  |             with self.assertRaisesRegex(TypeError, 'is not supported'): | 
					
						
							|  |  |  |                 await self.loop.start_tls(None, None, sslctx) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.loop.run_until_complete(main()) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  |     def test_handshake_timeout(self): | 
					
						
							|  |  |  |         # bpo-29970: Check that a connection is aborted if handshake is not | 
					
						
							|  |  |  |         # completed in timeout period, instead of remaining open indefinitely | 
					
						
							|  |  |  |         client_sslctx = test_utils.simple_client_sslcontext() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         messages = [] | 
					
						
							|  |  |  |         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         server_side_aborted = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def server(sock): | 
					
						
							|  |  |  |             nonlocal server_side_aborted | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 sock.recv_all(1024 * 1024) | 
					
						
							|  |  |  |             except ConnectionAbortedError: | 
					
						
							|  |  |  |                 server_side_aborted = True | 
					
						
							|  |  |  |             finally: | 
					
						
							|  |  |  |                 sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							|  |  |  |             await asyncio.wait_for( | 
					
						
							|  |  |  |                 self.loop.create_connection( | 
					
						
							|  |  |  |                     asyncio.Protocol, | 
					
						
							|  |  |  |                     *addr, | 
					
						
							|  |  |  |                     ssl=client_sslctx, | 
					
						
							|  |  |  |                     server_hostname='', | 
					
						
							| 
									
										
										
										
											2019-12-11 11:30:03 +01:00
										 |  |  |                     ssl_handshake_timeout=support.SHORT_TIMEOUT), | 
					
						
							| 
									
										
										
										
											2018-10-02 13:53:06 -04:00
										 |  |  |                 0.5) | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         with self.tcp_server(server, | 
					
						
							|  |  |  |                              max_clients=1, | 
					
						
							|  |  |  |                              backlog=1) as srv: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             with self.assertRaises(asyncio.TimeoutError): | 
					
						
							|  |  |  |                 self.loop.run_until_complete(client(srv.addr)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertTrue(server_side_aborted) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Python issue #23197: cancelling a handshake must not raise an | 
					
						
							|  |  |  |         # exception or log an error, even if the handshake failed | 
					
						
							|  |  |  |         self.assertEqual(messages, []) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-03-17 17:51:10 -05:00
										 |  |  |         # The 10s handshake timeout should be cancelled to free related | 
					
						
							|  |  |  |         # objects without really waiting for 10s | 
					
						
							|  |  |  |         client_sslctx = weakref.ref(client_sslctx) | 
					
						
							| 
									
										
										
										
											2021-08-26 15:48:24 +03:00
										 |  |  |         support.gc_collect() | 
					
						
							| 
									
										
										
										
											2019-03-17 17:51:10 -05:00
										 |  |  |         self.assertIsNone(client_sslctx()) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  |     def test_create_connection_ssl_slow_handshake(self): | 
					
						
							|  |  |  |         client_sslctx = test_utils.simple_client_sslcontext() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         messages = [] | 
					
						
							|  |  |  |         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def server(sock): | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 sock.recv_all(1024 * 1024) | 
					
						
							|  |  |  |             except ConnectionAbortedError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  |             finally: | 
					
						
							|  |  |  |                 sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							| 
									
										
										
										
											2020-11-26 09:36:37 +02:00
										 |  |  |             reader, writer = await asyncio.open_connection( | 
					
						
							|  |  |  |                 *addr, | 
					
						
							|  |  |  |                 ssl=client_sslctx, | 
					
						
							|  |  |  |                 server_hostname='', | 
					
						
							|  |  |  |                 ssl_handshake_timeout=1.0) | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         with self.tcp_server(server, | 
					
						
							|  |  |  |                              max_clients=1, | 
					
						
							|  |  |  |                              backlog=1) as srv: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             with self.assertRaisesRegex( | 
					
						
							|  |  |  |                     ConnectionAbortedError, | 
					
						
							|  |  |  |                     r'SSL handshake.*is taking longer'): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 self.loop.run_until_complete(client(srv.addr)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertEqual(messages, []) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_create_connection_ssl_failed_certificate(self): | 
					
						
							|  |  |  |         self.loop.set_exception_handler(lambda loop, ctx: None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         sslctx = test_utils.simple_server_sslcontext() | 
					
						
							|  |  |  |         client_sslctx = test_utils.simple_client_sslcontext( | 
					
						
							|  |  |  |             disable_verify=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def server(sock): | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 sock.start_tls( | 
					
						
							|  |  |  |                     sslctx, | 
					
						
							|  |  |  |                     server_side=True) | 
					
						
							|  |  |  |             except ssl.SSLError: | 
					
						
							|  |  |  |                 pass | 
					
						
							| 
									
										
										
										
											2018-06-07 01:12:38 +02:00
										 |  |  |             except OSError: | 
					
						
							|  |  |  |                 pass | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  |             finally: | 
					
						
							|  |  |  |                 sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							| 
									
										
										
										
											2020-11-26 09:36:37 +02:00
										 |  |  |             reader, writer = await asyncio.open_connection( | 
					
						
							|  |  |  |                 *addr, | 
					
						
							|  |  |  |                 ssl=client_sslctx, | 
					
						
							|  |  |  |                 server_hostname='', | 
					
						
							|  |  |  |                 ssl_handshake_timeout=support.LOOPBACK_TIMEOUT) | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         with self.tcp_server(server, | 
					
						
							|  |  |  |                              max_clients=1, | 
					
						
							|  |  |  |                              backlog=1) as srv: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             with self.assertRaises(ssl.SSLCertVerificationError): | 
					
						
							|  |  |  |                 self.loop.run_until_complete(client(srv.addr)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_start_tls_client_corrupted_ssl(self): | 
					
						
							|  |  |  |         self.loop.set_exception_handler(lambda loop, ctx: None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         sslctx = test_utils.simple_server_sslcontext() | 
					
						
							|  |  |  |         client_sslctx = test_utils.simple_client_sslcontext() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def server(sock): | 
					
						
							|  |  |  |             orig_sock = sock.dup() | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 sock.start_tls( | 
					
						
							|  |  |  |                     sslctx, | 
					
						
							|  |  |  |                     server_side=True) | 
					
						
							|  |  |  |                 sock.sendall(b'A\n') | 
					
						
							|  |  |  |                 sock.recv_all(1) | 
					
						
							|  |  |  |                 orig_sock.send(b'please corrupt the SSL connection') | 
					
						
							|  |  |  |             except ssl.SSLError: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  |             finally: | 
					
						
							| 
									
										
										
										
											2018-06-07 01:12:38 +02:00
										 |  |  |                 orig_sock.close() | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  |                 sock.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def client(addr): | 
					
						
							| 
									
										
										
										
											2020-11-26 09:36:37 +02:00
										 |  |  |             reader, writer = await asyncio.open_connection( | 
					
						
							|  |  |  |                 *addr, | 
					
						
							|  |  |  |                 ssl=client_sslctx, | 
					
						
							|  |  |  |                 server_hostname='') | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |             self.assertEqual(await reader.readline(), b'A\n') | 
					
						
							|  |  |  |             writer.write(b'B') | 
					
						
							|  |  |  |             with self.assertRaises(ssl.SSLError): | 
					
						
							|  |  |  |                 await reader.readline() | 
					
						
							| 
									
										
										
										
											2018-06-07 01:12:38 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |             writer.close() | 
					
						
							| 
									
										
										
										
											2018-06-04 11:32:35 -04:00
										 |  |  |             return 'OK' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with self.tcp_server(server, | 
					
						
							|  |  |  |                              max_clients=1, | 
					
						
							|  |  |  |                              backlog=1) as srv: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             res = self.loop.run_until_complete(client(srv.addr)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertEqual(res, 'OK') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | @unittest.skipIf(ssl is None, 'No ssl module') | 
					
						
							| 
									
										
										
										
											2017-12-30 15:40:20 -05:00
										 |  |  | class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase): | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def new_loop(self): | 
					
						
							|  |  |  |         return asyncio.SelectorEventLoop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @unittest.skipIf(ssl is None, 'No ssl module') | 
					
						
							|  |  |  | @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') | 
					
						
							| 
									
										
										
										
											2017-12-30 15:40:20 -05:00
										 |  |  | class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase): | 
					
						
							| 
									
										
										
										
											2017-12-30 00:35:36 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def new_loop(self): | 
					
						
							|  |  |  |         return asyncio.ProactorEventLoop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-01-14 16:56:20 +01:00
										 |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     unittest.main() |