mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	asyncio: Refactor ssl transport ready loop (Nikolay Kim).
This commit is contained in:
		
							parent
							
								
									21c85a7124
								
							
						
					
					
						commit
						2b57016458
					
				
					 2 changed files with 136 additions and 92 deletions
				
			
		|  | @ -286,7 +286,7 @@ def _sock_connect(self, fut, registered, sock, address): | |||
|                 err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) | ||||
|                 if err != 0: | ||||
|                     # Jump to the except clause below. | ||||
|                     raise OSError(err, 'Connect call failed') | ||||
|                     raise OSError(err, 'Connect call failed %s' % (address,)) | ||||
|         except (BlockingIOError, InterruptedError): | ||||
|             self.add_writer(fd, self._sock_connect, fut, True, sock, address) | ||||
|         except Exception as exc: | ||||
|  | @ -413,7 +413,7 @@ def _maybe_pause_protocol(self): | |||
|             try: | ||||
|                 self._protocol.pause_writing() | ||||
|             except Exception: | ||||
|                 tulip_log.exception('pause_writing() failed') | ||||
|                 logger.exception('pause_writing() failed') | ||||
| 
 | ||||
|     def _maybe_resume_protocol(self): | ||||
|         if (self._protocol_paused and | ||||
|  | @ -422,7 +422,7 @@ def _maybe_resume_protocol(self): | |||
|             try: | ||||
|                 self._protocol.resume_writing() | ||||
|             except Exception: | ||||
|                 tulip_log.exception('resume_writing() failed') | ||||
|                 logger.exception('resume_writing() failed') | ||||
| 
 | ||||
|     def set_write_buffer_limits(self, high=None, low=None): | ||||
|         if high is None: | ||||
|  | @ -635,15 +635,16 @@ def _on_handshake(self): | |||
|                            compression=self._sock.compression(), | ||||
|                            ) | ||||
| 
 | ||||
|         self._loop.add_reader(self._sock_fd, self._on_ready) | ||||
|         self._loop.add_writer(self._sock_fd, self._on_ready) | ||||
|         self._read_wants_write = False | ||||
|         self._write_wants_read = False | ||||
|         self._loop.add_reader(self._sock_fd, self._read_ready) | ||||
|         self._loop.call_soon(self._protocol.connection_made, self) | ||||
|         if self._waiter is not None: | ||||
|             self._loop.call_soon(self._waiter.set_result, None) | ||||
| 
 | ||||
|     def pause_reading(self): | ||||
|         # XXX This is a bit icky, given the comment at the top of | ||||
|         # _on_ready().  Is it possible to evoke a deadlock?  I don't | ||||
|         # _read_ready().  Is it possible to evoke a deadlock?  I don't | ||||
|         # know, although it doesn't look like it; write() will still | ||||
|         # accept more data for the buffer and eventually the app will | ||||
|         # call resume_reading() again, and things will flow again. | ||||
|  | @ -658,41 +659,55 @@ def resume_reading(self): | |||
|         self._paused = False | ||||
|         if self._closing: | ||||
|             return | ||||
|         self._loop.add_reader(self._sock_fd, self._on_ready) | ||||
|         self._loop.add_reader(self._sock_fd, self._read_ready) | ||||
| 
 | ||||
|     def _on_ready(self): | ||||
|         # Because of renegotiations (?), there's no difference between | ||||
|         # readable and writable.  We just try both.  XXX This may be | ||||
|         # incorrect; we probably need to keep state about what we | ||||
|         # should do next. | ||||
|     def _read_ready(self): | ||||
|         if self._write_wants_read: | ||||
|             self._write_wants_read = False | ||||
|             self._write_ready() | ||||
| 
 | ||||
|         # First try reading. | ||||
|         if not self._closing and not self._paused: | ||||
|             try: | ||||
|                 data = self._sock.recv(self.max_size) | ||||
|             except (BlockingIOError, InterruptedError, | ||||
|                     ssl.SSLWantReadError, ssl.SSLWantWriteError): | ||||
|                 pass | ||||
|             except Exception as exc: | ||||
|                 self._fatal_error(exc) | ||||
|             if self._buffer: | ||||
|                 self._loop.add_writer(self._sock_fd, self._write_ready) | ||||
| 
 | ||||
|         try: | ||||
|             data = self._sock.recv(self.max_size) | ||||
|         except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): | ||||
|             pass | ||||
|         except ssl.SSLWantWriteError: | ||||
|             self._read_wants_write = True | ||||
|             self._loop.remove_reader(self._sock_fd) | ||||
|             self._loop.add_writer(self._sock_fd, self._write_ready) | ||||
|         except Exception as exc: | ||||
|             self._fatal_error(exc) | ||||
|         else: | ||||
|             if data: | ||||
|                 self._protocol.data_received(data) | ||||
|             else: | ||||
|                 if data: | ||||
|                     self._protocol.data_received(data) | ||||
|                 else: | ||||
|                     try: | ||||
|                         self._protocol.eof_received() | ||||
|                     finally: | ||||
|                         self.close() | ||||
|                 try: | ||||
|                     self._protocol.eof_received() | ||||
|                 finally: | ||||
|                     self.close() | ||||
| 
 | ||||
|     def _write_ready(self): | ||||
|         if self._read_wants_write: | ||||
|             self._read_wants_write = False | ||||
|             self._read_ready() | ||||
| 
 | ||||
|             if not (self._paused or self._closing): | ||||
|                 self._loop.add_reader(self._sock_fd, self._read_ready) | ||||
| 
 | ||||
|         # Now try writing, if there's anything to write. | ||||
|         if self._buffer: | ||||
|             data = b''.join(self._buffer) | ||||
|             self._buffer.clear() | ||||
|             try: | ||||
|                 n = self._sock.send(data) | ||||
|             except (BlockingIOError, InterruptedError, | ||||
|                     ssl.SSLWantReadError, ssl.SSLWantWriteError): | ||||
|                     ssl.SSLWantWriteError): | ||||
|                 n = 0 | ||||
|             except ssl.SSLWantReadError: | ||||
|                 n = 0 | ||||
|                 self._loop.remove_writer(self._sock_fd) | ||||
|                 self._write_wants_read = True | ||||
|             except Exception as exc: | ||||
|                 self._loop.remove_writer(self._sock_fd) | ||||
|                 self._fatal_error(exc) | ||||
|  | @ -701,11 +716,12 @@ def _on_ready(self): | |||
|             if n < len(data): | ||||
|                 self._buffer.append(data[n:]) | ||||
| 
 | ||||
|             self._maybe_resume_protocol()  # May append to buffer. | ||||
|         self._maybe_resume_protocol()  # May append to buffer. | ||||
| 
 | ||||
|         if self._closing and not self._buffer: | ||||
|         if not self._buffer: | ||||
|             self._loop.remove_writer(self._sock_fd) | ||||
|             self._call_connection_lost(None) | ||||
|             if self._closing: | ||||
|                 self._call_connection_lost(None) | ||||
| 
 | ||||
|     def write(self, data): | ||||
|         assert isinstance(data, bytes), repr(type(data)) | ||||
|  | @ -718,20 +734,16 @@ def write(self, data): | |||
|             self._conn_lost += 1 | ||||
|             return | ||||
| 
 | ||||
|         # We could optimize, but the callback can do this for now. | ||||
|         if not self._buffer: | ||||
|             self._loop.add_writer(self._sock_fd, self._write_ready) | ||||
| 
 | ||||
|         # Add it to the buffer. | ||||
|         self._buffer.append(data) | ||||
|         self._maybe_pause_protocol() | ||||
| 
 | ||||
|     def can_write_eof(self): | ||||
|         return False | ||||
| 
 | ||||
|     def close(self): | ||||
|         if self._closing: | ||||
|             return | ||||
|         self._closing = True | ||||
|         self._conn_lost += 1 | ||||
|         self._loop.remove_reader(self._sock_fd) | ||||
| 
 | ||||
| 
 | ||||
| class _SelectorDatagramTransport(_SelectorTransport): | ||||
| 
 | ||||
|  |  | |||
|  | @ -1003,8 +1003,7 @@ def test_on_handshake(self): | |||
|             self.loop, self.sock, self.protocol, self.sslcontext, | ||||
|             waiter=waiter) | ||||
|         self.assertTrue(self.sslsock.do_handshake.called) | ||||
|         self.loop.assert_reader(1, tr._on_ready) | ||||
|         self.loop.assert_writer(1, tr._on_ready) | ||||
|         self.loop.assert_reader(1, tr._read_ready) | ||||
|         test_utils.run_briefly(self.loop) | ||||
|         self.assertIsNone(waiter.result()) | ||||
| 
 | ||||
|  | @ -1047,13 +1046,13 @@ def test_on_handshake_base_exc(self): | |||
|     def test_pause_resume_reading(self): | ||||
|         tr = self._make_one() | ||||
|         self.assertFalse(tr._paused) | ||||
|         self.loop.assert_reader(1, tr._on_ready) | ||||
|         self.loop.assert_reader(1, tr._read_ready) | ||||
|         tr.pause_reading() | ||||
|         self.assertTrue(tr._paused) | ||||
|         self.assertFalse(1 in self.loop.readers) | ||||
|         tr.resume_reading() | ||||
|         self.assertFalse(tr._paused) | ||||
|         self.loop.assert_reader(1, tr._on_ready) | ||||
|         self.loop.assert_reader(1, tr._read_ready) | ||||
| 
 | ||||
|     def test_write_no_data(self): | ||||
|         transport = self._make_one() | ||||
|  | @ -1084,140 +1083,173 @@ def test_write_exception(self, m_log): | |||
|         transport.write(b'data') | ||||
|         m_log.warning.assert_called_with('socket.send() raised exception.') | ||||
| 
 | ||||
|     def test_on_ready_recv(self): | ||||
|     def test_read_ready_recv(self): | ||||
|         self.sslsock.recv.return_value = b'data' | ||||
|         transport = self._make_one() | ||||
|         transport._on_ready() | ||||
|         transport._read_ready() | ||||
|         self.assertTrue(self.sslsock.recv.called) | ||||
|         self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) | ||||
| 
 | ||||
|     def test_on_ready_recv_eof(self): | ||||
|     def test_read_ready_write_wants_read(self): | ||||
|         self.loop.add_writer = unittest.mock.Mock() | ||||
|         self.sslsock.recv.side_effect = BlockingIOError | ||||
|         transport = self._make_one() | ||||
|         transport._write_wants_read = True | ||||
|         transport._write_ready = unittest.mock.Mock() | ||||
|         transport._buffer.append(b'data') | ||||
|         transport._read_ready() | ||||
| 
 | ||||
|         self.assertFalse(transport._write_wants_read) | ||||
|         transport._write_ready.assert_called_with() | ||||
|         self.loop.add_writer.assert_called_with( | ||||
|             transport._sock_fd, transport._write_ready) | ||||
| 
 | ||||
|     def test_read_ready_recv_eof(self): | ||||
|         self.sslsock.recv.return_value = b'' | ||||
|         transport = self._make_one() | ||||
|         transport.close = unittest.mock.Mock() | ||||
|         transport._on_ready() | ||||
|         transport._read_ready() | ||||
|         transport.close.assert_called_with() | ||||
|         self.protocol.eof_received.assert_called_with() | ||||
| 
 | ||||
|     def test_on_ready_recv_conn_reset(self): | ||||
|     def test_read_ready_recv_conn_reset(self): | ||||
|         err = self.sslsock.recv.side_effect = ConnectionResetError() | ||||
|         transport = self._make_one() | ||||
|         transport._force_close = unittest.mock.Mock() | ||||
|         transport._on_ready() | ||||
|         transport._read_ready() | ||||
|         transport._force_close.assert_called_with(err) | ||||
| 
 | ||||
|     def test_on_ready_recv_retry(self): | ||||
|     def test_read_ready_recv_retry(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
|         transport = self._make_one() | ||||
|         transport._on_ready() | ||||
|         transport._read_ready() | ||||
|         self.assertTrue(self.sslsock.recv.called) | ||||
|         self.assertFalse(self.protocol.data_received.called) | ||||
| 
 | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantWriteError | ||||
|         transport._on_ready() | ||||
|         self.assertFalse(self.protocol.data_received.called) | ||||
| 
 | ||||
|         self.sslsock.recv.side_effect = BlockingIOError | ||||
|         transport._on_ready() | ||||
|         transport._read_ready() | ||||
|         self.assertFalse(self.protocol.data_received.called) | ||||
| 
 | ||||
|         self.sslsock.recv.side_effect = InterruptedError | ||||
|         transport._on_ready() | ||||
|         transport._read_ready() | ||||
|         self.assertFalse(self.protocol.data_received.called) | ||||
| 
 | ||||
|     def test_on_ready_recv_exc(self): | ||||
|     def test_read_ready_recv_write(self): | ||||
|         self.loop.remove_reader = unittest.mock.Mock() | ||||
|         self.loop.add_writer = unittest.mock.Mock() | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantWriteError | ||||
|         transport = self._make_one() | ||||
|         transport._read_ready() | ||||
|         self.assertFalse(self.protocol.data_received.called) | ||||
|         self.assertTrue(transport._read_wants_write) | ||||
| 
 | ||||
|         self.loop.remove_reader.assert_called_with(transport._sock_fd) | ||||
|         self.loop.add_writer.assert_called_with( | ||||
|             transport._sock_fd, transport._write_ready) | ||||
| 
 | ||||
|     def test_read_ready_recv_exc(self): | ||||
|         err = self.sslsock.recv.side_effect = OSError() | ||||
|         transport = self._make_one() | ||||
|         transport._fatal_error = unittest.mock.Mock() | ||||
|         transport._on_ready() | ||||
|         transport._read_ready() | ||||
|         transport._fatal_error.assert_called_with(err) | ||||
| 
 | ||||
|     def test_on_ready_send(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
|     def test_write_ready_send(self): | ||||
|         self.sslsock.send.return_value = 4 | ||||
|         transport = self._make_one() | ||||
|         transport._buffer = collections.deque([b'data']) | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         self.assertEqual(collections.deque(), transport._buffer) | ||||
|         self.assertTrue(self.sslsock.send.called) | ||||
| 
 | ||||
|     def test_on_ready_send_none(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
|     def test_write_ready_send_none(self): | ||||
|         self.sslsock.send.return_value = 0 | ||||
|         transport = self._make_one() | ||||
|         transport._buffer = collections.deque([b'data1', b'data2']) | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         self.assertTrue(self.sslsock.send.called) | ||||
|         self.assertEqual(collections.deque([b'data1data2']), transport._buffer) | ||||
| 
 | ||||
|     def test_on_ready_send_partial(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
|     def test_write_ready_send_partial(self): | ||||
|         self.sslsock.send.return_value = 2 | ||||
|         transport = self._make_one() | ||||
|         transport._buffer = collections.deque([b'data1', b'data2']) | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         self.assertTrue(self.sslsock.send.called) | ||||
|         self.assertEqual(collections.deque([b'ta1data2']), transport._buffer) | ||||
| 
 | ||||
|     def test_on_ready_send_closing_partial(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
|     def test_write_ready_send_closing_partial(self): | ||||
|         self.sslsock.send.return_value = 2 | ||||
|         transport = self._make_one() | ||||
|         transport._buffer = collections.deque([b'data1', b'data2']) | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         self.assertTrue(self.sslsock.send.called) | ||||
|         self.assertFalse(self.sslsock.close.called) | ||||
| 
 | ||||
|     def test_on_ready_send_closing(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
|     def test_write_ready_send_closing(self): | ||||
|         self.sslsock.send.return_value = 4 | ||||
|         transport = self._make_one() | ||||
|         transport.close() | ||||
|         transport._buffer = collections.deque([b'data']) | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         self.assertFalse(self.loop.writers) | ||||
|         self.protocol.connection_lost.assert_called_with(None) | ||||
| 
 | ||||
|     def test_on_ready_send_closing_empty_buffer(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
|     def test_write_ready_send_closing_empty_buffer(self): | ||||
|         self.sslsock.send.return_value = 4 | ||||
|         transport = self._make_one() | ||||
|         transport.close() | ||||
|         transport._buffer = collections.deque() | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         self.assertFalse(self.loop.writers) | ||||
|         self.protocol.connection_lost.assert_called_with(None) | ||||
| 
 | ||||
|     def test_on_ready_send_retry(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
| 
 | ||||
|     def test_write_ready_send_retry(self): | ||||
|         transport = self._make_one() | ||||
|         transport._buffer = collections.deque([b'data']) | ||||
| 
 | ||||
|         self.sslsock.send.side_effect = ssl.SSLWantReadError | ||||
|         transport._on_ready() | ||||
|         self.assertTrue(self.sslsock.send.called) | ||||
|         self.assertEqual(collections.deque([b'data']), transport._buffer) | ||||
| 
 | ||||
|         self.sslsock.send.side_effect = ssl.SSLWantWriteError | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         self.assertEqual(collections.deque([b'data']), transport._buffer) | ||||
| 
 | ||||
|         self.sslsock.send.side_effect = BlockingIOError() | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         self.assertEqual(collections.deque([b'data']), transport._buffer) | ||||
| 
 | ||||
|     def test_on_ready_send_exc(self): | ||||
|         self.sslsock.recv.side_effect = ssl.SSLWantReadError | ||||
|     def test_write_ready_send_read(self): | ||||
|         transport = self._make_one() | ||||
|         transport._buffer = collections.deque([b'data']) | ||||
| 
 | ||||
|         self.loop.remove_writer = unittest.mock.Mock() | ||||
|         self.sslsock.send.side_effect = ssl.SSLWantReadError | ||||
|         transport._write_ready() | ||||
|         self.assertFalse(self.protocol.data_received.called) | ||||
|         self.assertTrue(transport._write_wants_read) | ||||
|         self.loop.remove_writer.assert_called_with(transport._sock_fd) | ||||
| 
 | ||||
|     def test_write_ready_send_exc(self): | ||||
|         err = self.sslsock.send.side_effect = OSError() | ||||
| 
 | ||||
|         transport = self._make_one() | ||||
|         transport._buffer = collections.deque([b'data']) | ||||
|         transport._fatal_error = unittest.mock.Mock() | ||||
|         transport._on_ready() | ||||
|         transport._write_ready() | ||||
|         transport._fatal_error.assert_called_with(err) | ||||
|         self.assertEqual(collections.deque(), transport._buffer) | ||||
| 
 | ||||
|     def test_write_ready_read_wants_write(self): | ||||
|         self.loop.add_reader = unittest.mock.Mock() | ||||
|         self.sslsock.send.side_effect = BlockingIOError | ||||
|         transport = self._make_one() | ||||
|         transport._read_wants_write = True | ||||
|         transport._read_ready = unittest.mock.Mock() | ||||
|         transport._write_ready() | ||||
| 
 | ||||
|         self.assertFalse(transport._read_wants_write) | ||||
|         transport._read_ready.assert_called_with() | ||||
|         self.loop.add_reader.assert_called_with( | ||||
|             transport._sock_fd, transport._read_ready) | ||||
| 
 | ||||
|     def test_write_eof(self): | ||||
|         tr = self._make_one() | ||||
|         self.assertFalse(tr.can_write_eof()) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Guido van Rossum
						Guido van Rossum