mirror of
https://github.com/python/cpython.git
synced 2026-01-06 15:32:22 +00:00
bpo-34638: Store a weak reference to stream reader to break strong references loop (GH-9201)
Store a weak reference to stream readerfor breaking strong references It breaks the strong reference loop between reader and protocol and allows to detect and close the socket if the stream is deleted (garbage collected)
This commit is contained in:
parent
aca819fb49
commit
a5d1eb8d8b
4 changed files with 160 additions and 10 deletions
|
|
@ -3,6 +3,8 @@
|
|||
'open_connection', 'start_server')
|
||||
|
||||
import socket
|
||||
import sys
|
||||
import weakref
|
||||
|
||||
if hasattr(socket, 'AF_UNIX'):
|
||||
__all__ += ('open_unix_connection', 'start_unix_server')
|
||||
|
|
@ -10,6 +12,7 @@
|
|||
from . import coroutines
|
||||
from . import events
|
||||
from . import exceptions
|
||||
from . import format_helpers
|
||||
from . import protocols
|
||||
from .log import logger
|
||||
from .tasks import sleep
|
||||
|
|
@ -186,46 +189,106 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
|
|||
call inappropriate methods of the protocol.)
|
||||
"""
|
||||
|
||||
_source_traceback = None
|
||||
|
||||
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
||||
super().__init__(loop=loop)
|
||||
self._stream_reader = stream_reader
|
||||
if stream_reader is not None:
|
||||
self._stream_reader_wr = weakref.ref(stream_reader,
|
||||
self._on_reader_gc)
|
||||
self._source_traceback = stream_reader._source_traceback
|
||||
else:
|
||||
self._stream_reader_wr = None
|
||||
if client_connected_cb is not None:
|
||||
# This is a stream created by the `create_server()` function.
|
||||
# Keep a strong reference to the reader until a connection
|
||||
# is established.
|
||||
self._strong_reader = stream_reader
|
||||
self._reject_connection = False
|
||||
self._stream_writer = None
|
||||
self._transport = None
|
||||
self._client_connected_cb = client_connected_cb
|
||||
self._over_ssl = False
|
||||
self._closed = self._loop.create_future()
|
||||
|
||||
def _on_reader_gc(self, wr):
|
||||
transport = self._transport
|
||||
if transport is not None:
|
||||
# connection_made was called
|
||||
context = {
|
||||
'message': ('An open stream object is being garbage '
|
||||
'collected; call "stream.close()" explicitly.')
|
||||
}
|
||||
if self._source_traceback:
|
||||
context['source_traceback'] = self._source_traceback
|
||||
self._loop.call_exception_handler(context)
|
||||
transport.abort()
|
||||
else:
|
||||
self._reject_connection = True
|
||||
self._stream_reader_wr = None
|
||||
|
||||
def _untrack_reader(self):
|
||||
self._stream_reader_wr = None
|
||||
|
||||
@property
|
||||
def _stream_reader(self):
|
||||
if self._stream_reader_wr is None:
|
||||
return None
|
||||
return self._stream_reader_wr()
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._stream_reader.set_transport(transport)
|
||||
if self._reject_connection:
|
||||
context = {
|
||||
'message': ('An open stream was garbage collected prior to '
|
||||
'establishing network connection; '
|
||||
'call "stream.close()" explicitly.')
|
||||
}
|
||||
if self._source_traceback:
|
||||
context['source_traceback'] = self._source_traceback
|
||||
self._loop.call_exception_handler(context)
|
||||
transport.abort()
|
||||
return
|
||||
self._transport = transport
|
||||
reader = self._stream_reader
|
||||
if reader is not None:
|
||||
reader.set_transport(transport)
|
||||
self._over_ssl = transport.get_extra_info('sslcontext') is not None
|
||||
if self._client_connected_cb is not None:
|
||||
self._stream_writer = StreamWriter(transport, self,
|
||||
self._stream_reader,
|
||||
reader,
|
||||
self._loop)
|
||||
res = self._client_connected_cb(self._stream_reader,
|
||||
res = self._client_connected_cb(reader,
|
||||
self._stream_writer)
|
||||
if coroutines.iscoroutine(res):
|
||||
self._loop.create_task(res)
|
||||
self._strong_reader = None
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self._stream_reader is not None:
|
||||
reader = self._stream_reader
|
||||
if reader is not None:
|
||||
if exc is None:
|
||||
self._stream_reader.feed_eof()
|
||||
reader.feed_eof()
|
||||
else:
|
||||
self._stream_reader.set_exception(exc)
|
||||
reader.set_exception(exc)
|
||||
if not self._closed.done():
|
||||
if exc is None:
|
||||
self._closed.set_result(None)
|
||||
else:
|
||||
self._closed.set_exception(exc)
|
||||
super().connection_lost(exc)
|
||||
self._stream_reader = None
|
||||
self._stream_reader_wr = None
|
||||
self._stream_writer = None
|
||||
self._transport = None
|
||||
|
||||
def data_received(self, data):
|
||||
self._stream_reader.feed_data(data)
|
||||
reader = self._stream_reader
|
||||
if reader is not None:
|
||||
reader.feed_data(data)
|
||||
|
||||
def eof_received(self):
|
||||
self._stream_reader.feed_eof()
|
||||
reader = self._stream_reader
|
||||
if reader is not None:
|
||||
reader.feed_eof()
|
||||
if self._over_ssl:
|
||||
# Prevent a warning in SSLProtocol.eof_received:
|
||||
# "returning true from eof_received()
|
||||
|
|
@ -282,6 +345,9 @@ def can_write_eof(self):
|
|||
return self._transport.can_write_eof()
|
||||
|
||||
def close(self):
|
||||
# a reader can be garbage collected
|
||||
# after connection closing
|
||||
self._protocol._untrack_reader()
|
||||
return self._transport.close()
|
||||
|
||||
def is_closing(self):
|
||||
|
|
@ -318,6 +384,8 @@ async def drain(self):
|
|||
|
||||
class StreamReader:
|
||||
|
||||
_source_traceback = None
|
||||
|
||||
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
|
||||
# The line length limit is a security feature;
|
||||
# it also doubles as half the buffer limit.
|
||||
|
|
@ -336,6 +404,9 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
|
|||
self._exception = None
|
||||
self._transport = None
|
||||
self._paused = False
|
||||
if self._loop.get_debug():
|
||||
self._source_traceback = format_helpers.extract_stack(
|
||||
sys._getframe(1))
|
||||
|
||||
def __repr__(self):
|
||||
info = ['StreamReader']
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ def __repr__(self):
|
|||
info.append(f'stderr={self.stderr!r}')
|
||||
return '<{}>'.format(' '.join(info))
|
||||
|
||||
def _untrack_reader(self):
|
||||
# StreamWriter.close() expects the protocol
|
||||
# to have this method defined.
|
||||
pass
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ def test_ctor_global_loop(self, m_events):
|
|||
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
|
||||
|
||||
def _basetest_open_connection(self, open_connection_fut):
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
reader, writer = self.loop.run_until_complete(open_connection_fut)
|
||||
writer.write(b'GET / HTTP/1.0\r\n\r\n')
|
||||
f = reader.readline()
|
||||
|
|
@ -55,6 +57,7 @@ def _basetest_open_connection(self, open_connection_fut):
|
|||
data = self.loop.run_until_complete(f)
|
||||
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
|
||||
writer.close()
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
def test_open_connection(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
|
|
@ -70,6 +73,8 @@ def test_open_unix_connection(self):
|
|||
self._basetest_open_connection(conn_fut)
|
||||
|
||||
def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
try:
|
||||
reader, writer = self.loop.run_until_complete(open_connection_fut)
|
||||
finally:
|
||||
|
|
@ -80,6 +85,7 @@ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
|
|||
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
|
||||
|
||||
writer.close()
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_open_connection_no_loop_ssl(self):
|
||||
|
|
@ -104,6 +110,8 @@ def test_open_unix_connection_no_loop_ssl(self):
|
|||
self._basetest_open_connection_no_loop_ssl(conn_fut)
|
||||
|
||||
def _basetest_open_connection_error(self, open_connection_fut):
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
reader, writer = self.loop.run_until_complete(open_connection_fut)
|
||||
writer._protocol.connection_lost(ZeroDivisionError())
|
||||
f = reader.read()
|
||||
|
|
@ -111,6 +119,7 @@ def _basetest_open_connection_error(self, open_connection_fut):
|
|||
self.loop.run_until_complete(f)
|
||||
writer.close()
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
def test_open_connection_error(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
|
|
@ -621,6 +630,9 @@ async def client(addr):
|
|||
writer.close()
|
||||
return msgback
|
||||
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
|
||||
# test the server variant with a coroutine as client handler
|
||||
server = MyServer(self.loop)
|
||||
addr = server.start()
|
||||
|
|
@ -637,6 +649,8 @@ async def client(addr):
|
|||
server.stop()
|
||||
self.assertEqual(msg, b"hello world!\n")
|
||||
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
@support.skip_unless_bind_unix_socket
|
||||
def test_start_unix_server(self):
|
||||
|
||||
|
|
@ -685,6 +699,9 @@ async def client(path):
|
|||
writer.close()
|
||||
return msgback
|
||||
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
|
||||
# test the server variant with a coroutine as client handler
|
||||
with test_utils.unix_socket_path() as path:
|
||||
server = MyServer(self.loop, path)
|
||||
|
|
@ -703,6 +720,8 @@ async def client(path):
|
|||
server.stop()
|
||||
self.assertEqual(msg, b"hello world!\n")
|
||||
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
|
||||
def test_read_all_from_pipe_reader(self):
|
||||
# See asyncio issue 168. This test is derived from the example
|
||||
|
|
@ -893,6 +912,58 @@ def test_wait_closed_on_close_with_unread_data(self):
|
|||
wr.close()
|
||||
self.loop.run_until_complete(wr.wait_closed())
|
||||
|
||||
def test_del_stream_before_sock_closing(self):
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
|
||||
with test_utils.run_test_server() as httpd:
|
||||
rd, wr = self.loop.run_until_complete(
|
||||
asyncio.open_connection(*httpd.address, loop=self.loop))
|
||||
sock = wr.get_extra_info('socket')
|
||||
self.assertNotEqual(sock.fileno(), -1)
|
||||
|
||||
wr.write(b'GET / HTTP/1.0\r\n\r\n')
|
||||
f = rd.readline()
|
||||
data = self.loop.run_until_complete(f)
|
||||
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
|
||||
|
||||
# drop refs to reader/writer
|
||||
del rd
|
||||
del wr
|
||||
gc.collect()
|
||||
# make a chance to close the socket
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
self.assertEqual(1, len(messages))
|
||||
self.assertEqual(sock.fileno(), -1)
|
||||
|
||||
self.assertEqual(1, len(messages))
|
||||
self.assertEqual('An open stream object is being garbage '
|
||||
'collected; call "stream.close()" explicitly.',
|
||||
messages[0]['message'])
|
||||
|
||||
def test_del_stream_before_connection_made(self):
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
|
||||
with test_utils.run_test_server() as httpd:
|
||||
rd = asyncio.StreamReader(loop=self.loop)
|
||||
pr = asyncio.StreamReaderProtocol(rd, loop=self.loop)
|
||||
del rd
|
||||
gc.collect()
|
||||
tr, _ = self.loop.run_until_complete(
|
||||
self.loop.create_connection(
|
||||
lambda: pr, *httpd.address))
|
||||
|
||||
sock = tr.get_extra_info('socket')
|
||||
self.assertEqual(sock.fileno(), -1)
|
||||
|
||||
self.assertEqual(1, len(messages))
|
||||
self.assertEqual('An open stream was garbage collected prior to '
|
||||
'establishing network connection; '
|
||||
'call "stream.close()" explicitly.',
|
||||
messages[0]['message'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
Store a weak reference to stream reader to break strong references loop
|
||||
between reader and protocol. It allows to detect and close the socket if
|
||||
the stream is deleted (garbage collected) without ``close()`` call.
|
||||
Loading…
Add table
Add a link
Reference in a new issue