mirror of
https://github.com/python/cpython.git
synced 2026-04-14 07:41:00 +00:00
[3.14] gh-142352: Fix asyncio start_tls() to transfer buffered data from StreamReader (GH-142354) (#145363)
gh-142352: Fix `asyncio` `start_tls()` to transfer buffered data from StreamReader (GH-142354)
(cherry picked from commit 0598f4a899)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
Co-authored-by: Maksym Kasimov <39828623+kasimov-maxim@users.noreply.github.com>
This commit is contained in:
parent
d76c56e958
commit
fc2a6cf0d3
3 changed files with 57 additions and 0 deletions
|
|
@ -1345,6 +1345,17 @@ async def start_tls(self, transport, protocol, sslcontext, *,
|
|||
# have a chance to get called before "ssl_protocol.connection_made()".
|
||||
transport.pause_reading()
|
||||
|
||||
# gh-142352: move buffered StreamReader data to SSLProtocol
|
||||
if server_side:
|
||||
from .streams import StreamReaderProtocol
|
||||
if isinstance(protocol, StreamReaderProtocol):
|
||||
stream_reader = getattr(protocol, '_stream_reader', None)
|
||||
if stream_reader is not None:
|
||||
buffer = stream_reader._buffer
|
||||
if buffer:
|
||||
ssl_protocol._incoming.write(buffer)
|
||||
buffer.clear()
|
||||
|
||||
transport.set_protocol(ssl_protocol)
|
||||
conmade_cb = self.call_soon(ssl_protocol.connection_made, transport)
|
||||
resume_cb = self.call_soon(transport.resume_reading)
|
||||
|
|
|
|||
|
|
@ -819,6 +819,48 @@ async def client(addr):
|
|||
self.assertEqual(msg1, b"hello world 1!\n")
|
||||
self.assertEqual(msg2, b"hello world 2!\n")
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_start_tls_buffered_data(self):
|
||||
# gh-142352: test start_tls() with buffered data
|
||||
|
||||
async def server_handler(client_reader, client_writer):
|
||||
# Wait for TLS ClientHello to be buffered before start_tls().
|
||||
await client_reader._wait_for_data('test_start_tls_buffered_data'),
|
||||
self.assertTrue(client_reader._buffer)
|
||||
await client_writer.start_tls(test_utils.simple_server_sslcontext())
|
||||
|
||||
line = await client_reader.readline()
|
||||
self.assertEqual(line, b"ping\n")
|
||||
client_writer.write(b"pong\n")
|
||||
await client_writer.drain()
|
||||
client_writer.close()
|
||||
await client_writer.wait_closed()
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(*addr)
|
||||
await writer.start_tls(test_utils.simple_client_sslcontext())
|
||||
|
||||
writer.write(b"ping\n")
|
||||
await writer.drain()
|
||||
line = await reader.readline()
|
||||
self.assertEqual(line, b"pong\n")
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async def run_test():
|
||||
server = await asyncio.start_server(
|
||||
server_handler, socket_helper.HOSTv4, 0)
|
||||
server_addr = server.sockets[0].getsockname()
|
||||
|
||||
await client(server_addr)
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
self.loop.run_until_complete(run_test())
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
def test_streamreader_constructor_without_loop(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
|
||||
asyncio.StreamReader()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
Fix :meth:`asyncio.StreamWriter.start_tls` to transfer buffered data from
|
||||
:class:`~asyncio.StreamReader` to the SSL layer, preventing data loss when
|
||||
upgrading a connection to TLS mid-stream (e.g., when implementing PROXY
|
||||
protocol support).
|
||||
Loading…
Add table
Add a link
Reference in a new issue