mirror of
https://github.com/python/cpython.git
synced 2026-04-13 23:31:02 +00:00
gh-140947: fix contextvars handling for server tasks in asyncio (#141158)
This commit is contained in:
parent
897fa231a7
commit
60fbc20ef9
7 changed files with 380 additions and 54 deletions
|
|
@ -14,6 +14,7 @@
|
|||
"""
|
||||
|
||||
import collections
|
||||
import contextvars
|
||||
import collections.abc
|
||||
import concurrent.futures
|
||||
import errno
|
||||
|
|
@ -290,6 +291,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
|
|||
self._ssl_shutdown_timeout = ssl_shutdown_timeout
|
||||
self._serving = False
|
||||
self._serving_forever_fut = None
|
||||
self._context = contextvars.copy_context()
|
||||
|
||||
def __repr__(self):
|
||||
return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
|
||||
|
|
@ -319,7 +321,7 @@ def _start_serving(self):
|
|||
self._loop._start_serving(
|
||||
self._protocol_factory, sock, self._ssl_context,
|
||||
self, self._backlog, self._ssl_handshake_timeout,
|
||||
self._ssl_shutdown_timeout)
|
||||
self._ssl_shutdown_timeout, context=self._context)
|
||||
|
||||
def get_loop(self):
|
||||
return self._loop
|
||||
|
|
@ -509,7 +511,8 @@ def _make_ssl_transport(
|
|||
extra=None, server=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None,
|
||||
call_connection_made=True):
|
||||
call_connection_made=True,
|
||||
context=None):
|
||||
"""Create SSL transport."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -1213,9 +1216,10 @@ async def _create_connection_transport(
|
|||
self, sock, protocol_factory, ssl,
|
||||
server_hostname, server_side=False,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
ssl_shutdown_timeout=None, context=None):
|
||||
|
||||
sock.setblocking(False)
|
||||
context = context if context is not None else contextvars.copy_context()
|
||||
|
||||
protocol = protocol_factory()
|
||||
waiter = self.create_future()
|
||||
|
|
@ -1225,9 +1229,10 @@ async def _create_connection_transport(
|
|||
sock, protocol, sslcontext, waiter,
|
||||
server_side=server_side, server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout,
|
||||
context=context)
|
||||
else:
|
||||
transport = self._make_socket_transport(sock, protocol, waiter)
|
||||
transport = self._make_socket_transport(sock, protocol, waiter, context=context)
|
||||
|
||||
try:
|
||||
await waiter
|
||||
|
|
|
|||
|
|
@ -642,7 +642,7 @@ def __init__(self, proactor):
|
|||
signal.set_wakeup_fd(self._csock.fileno())
|
||||
|
||||
def _make_socket_transport(self, sock, protocol, waiter=None,
|
||||
extra=None, server=None):
|
||||
extra=None, server=None, context=None):
|
||||
return _ProactorSocketTransport(self, sock, protocol, waiter,
|
||||
extra, server)
|
||||
|
||||
|
|
@ -651,7 +651,7 @@ def _make_ssl_transport(
|
|||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
ssl_shutdown_timeout=None, context=None):
|
||||
ssl_protocol = sslproto.SSLProtocol(
|
||||
self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname,
|
||||
|
|
@ -837,7 +837,7 @@ def _write_to_self(self):
|
|||
def _start_serving(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=None,
|
||||
ssl_shutdown_timeout=None):
|
||||
ssl_shutdown_timeout=None, context=None):
|
||||
|
||||
def loop(f=None):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -67,10 +67,10 @@ def __init__(self, selector=None):
|
|||
self._transports = weakref.WeakValueDictionary()
|
||||
|
||||
def _make_socket_transport(self, sock, protocol, waiter=None, *,
|
||||
extra=None, server=None):
|
||||
extra=None, server=None, context=None):
|
||||
self._ensure_fd_no_transport(sock)
|
||||
return _SelectorSocketTransport(self, sock, protocol, waiter,
|
||||
extra, server)
|
||||
extra, server, context=context)
|
||||
|
||||
def _make_ssl_transport(
|
||||
self, rawsock, protocol, sslcontext, waiter=None,
|
||||
|
|
@ -78,16 +78,17 @@ def _make_ssl_transport(
|
|||
extra=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
|
||||
context=None,
|
||||
):
|
||||
self._ensure_fd_no_transport(rawsock)
|
||||
ssl_protocol = sslproto.SSLProtocol(
|
||||
self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout,
|
||||
)
|
||||
_SelectorSocketTransport(self, rawsock, ssl_protocol,
|
||||
extra=extra, server=server)
|
||||
extra=extra, server=server, context=context)
|
||||
return ssl_protocol._app_transport
|
||||
|
||||
def _make_datagram_transport(self, sock, protocol,
|
||||
|
|
@ -159,16 +160,16 @@ def _write_to_self(self):
|
|||
def _start_serving(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
|
||||
self._add_reader(sock.fileno(), self._accept_connection,
|
||||
protocol_factory, sock, sslcontext, server, backlog,
|
||||
ssl_handshake_timeout, ssl_shutdown_timeout)
|
||||
ssl_handshake_timeout, ssl_shutdown_timeout, context)
|
||||
|
||||
def _accept_connection(
|
||||
self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
|
||||
# This method is only called once for each event loop tick where the
|
||||
# listening socket has triggered an EVENT_READ. There may be multiple
|
||||
# connections waiting for an .accept() so it is called in a loop.
|
||||
|
|
@ -204,21 +205,22 @@ def _accept_connection(
|
|||
self._start_serving,
|
||||
protocol_factory, sock, sslcontext, server,
|
||||
backlog, ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout)
|
||||
ssl_shutdown_timeout, context)
|
||||
else:
|
||||
raise # The event loop will catch, log and ignore it.
|
||||
else:
|
||||
extra = {'peername': addr}
|
||||
conn_context = context.copy() if context is not None else None
|
||||
accept = self._accept_connection2(
|
||||
protocol_factory, conn, extra, sslcontext, server,
|
||||
ssl_handshake_timeout, ssl_shutdown_timeout)
|
||||
self.create_task(accept)
|
||||
ssl_handshake_timeout, ssl_shutdown_timeout, context=conn_context)
|
||||
self.create_task(accept, context=conn_context)
|
||||
|
||||
async def _accept_connection2(
|
||||
self, protocol_factory, conn, extra,
|
||||
sslcontext=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
|
||||
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
|
||||
protocol = None
|
||||
transport = None
|
||||
try:
|
||||
|
|
@ -229,11 +231,12 @@ async def _accept_connection2(
|
|||
conn, protocol, sslcontext, waiter=waiter,
|
||||
server_side=True, extra=extra, server=server,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout)
|
||||
ssl_shutdown_timeout=ssl_shutdown_timeout,
|
||||
context=context)
|
||||
else:
|
||||
transport = self._make_socket_transport(
|
||||
conn, protocol, waiter=waiter, extra=extra,
|
||||
server=server)
|
||||
server=server, context=context)
|
||||
|
||||
try:
|
||||
await waiter
|
||||
|
|
@ -275,9 +278,9 @@ def _ensure_fd_no_transport(self, fd):
|
|||
f'File descriptor {fd!r} is used by transport '
|
||||
f'{transport!r}')
|
||||
|
||||
def _add_reader(self, fd, callback, *args):
|
||||
def _add_reader(self, fd, callback, *args, context=None):
|
||||
self._check_closed()
|
||||
handle = events.Handle(callback, args, self, None)
|
||||
handle = events.Handle(callback, args, self, context=context)
|
||||
key = self._selector.get_map().get(fd)
|
||||
if key is None:
|
||||
self._selector.register(fd, selectors.EVENT_READ,
|
||||
|
|
@ -309,9 +312,9 @@ def _remove_reader(self, fd):
|
|||
else:
|
||||
return False
|
||||
|
||||
def _add_writer(self, fd, callback, *args):
|
||||
def _add_writer(self, fd, callback, *args, context=None):
|
||||
self._check_closed()
|
||||
handle = events.Handle(callback, args, self, None)
|
||||
handle = events.Handle(callback, args, self, context=context)
|
||||
key = self._selector.get_map().get(fd)
|
||||
if key is None:
|
||||
self._selector.register(fd, selectors.EVENT_WRITE,
|
||||
|
|
@ -770,7 +773,7 @@ class _SelectorTransport(transports._FlowControlMixin,
|
|||
# exception)
|
||||
_sock = None
|
||||
|
||||
def __init__(self, loop, sock, protocol, extra=None, server=None):
|
||||
def __init__(self, loop, sock, protocol, extra=None, server=None, context=None):
|
||||
super().__init__(extra, loop)
|
||||
self._extra['socket'] = trsock.TransportSocket(sock)
|
||||
try:
|
||||
|
|
@ -784,7 +787,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
|
|||
self._extra['peername'] = None
|
||||
self._sock = sock
|
||||
self._sock_fd = sock.fileno()
|
||||
|
||||
self._context = context
|
||||
self._protocol_connected = False
|
||||
self.set_protocol(protocol)
|
||||
|
||||
|
|
@ -866,7 +869,7 @@ def close(self):
|
|||
if not self._buffer:
|
||||
self._conn_lost += 1
|
||||
self._loop._remove_writer(self._sock_fd)
|
||||
self._loop.call_soon(self._call_connection_lost, None)
|
||||
self._call_soon(self._call_connection_lost, None)
|
||||
|
||||
def __del__(self, _warn=warnings.warn):
|
||||
if self._sock is not None:
|
||||
|
|
@ -899,7 +902,7 @@ def _force_close(self, exc):
|
|||
self._closing = True
|
||||
self._loop._remove_reader(self._sock_fd)
|
||||
self._conn_lost += 1
|
||||
self._loop.call_soon(self._call_connection_lost, exc)
|
||||
self._call_soon(self._call_connection_lost, exc)
|
||||
|
||||
def _call_connection_lost(self, exc):
|
||||
try:
|
||||
|
|
@ -921,8 +924,13 @@ def get_write_buffer_size(self):
|
|||
def _add_reader(self, fd, callback, *args):
|
||||
if not self.is_reading():
|
||||
return
|
||||
self._loop._add_reader(fd, callback, *args)
|
||||
self._loop._add_reader(fd, callback, *args, context=self._context)
|
||||
|
||||
def _add_writer(self, fd, callback, *args):
|
||||
self._loop._add_writer(fd, callback, *args, context=self._context)
|
||||
|
||||
def _call_soon(self, callback, *args):
|
||||
self._loop.call_soon(callback, *args, context=self._context)
|
||||
|
||||
class _SelectorSocketTransport(_SelectorTransport):
|
||||
|
||||
|
|
@ -930,10 +938,9 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
_sendfile_compatible = constants._SendfileMode.TRY_NATIVE
|
||||
|
||||
def __init__(self, loop, sock, protocol, waiter=None,
|
||||
extra=None, server=None):
|
||||
|
||||
extra=None, server=None, context=None):
|
||||
self._read_ready_cb = None
|
||||
super().__init__(loop, sock, protocol, extra, server)
|
||||
super().__init__(loop, sock, protocol, extra, server, context)
|
||||
self._eof = False
|
||||
self._empty_waiter = None
|
||||
if _HAS_SENDMSG:
|
||||
|
|
@ -945,14 +952,12 @@ def __init__(self, loop, sock, protocol, waiter=None,
|
|||
# decreases the latency (in some cases significantly.)
|
||||
base_events._set_nodelay(self._sock)
|
||||
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
self._call_soon(self._protocol.connection_made, self)
|
||||
# only start reading when connection_made() has been called
|
||||
self._loop.call_soon(self._add_reader,
|
||||
self._sock_fd, self._read_ready)
|
||||
self._call_soon(self._add_reader, self._sock_fd, self._read_ready)
|
||||
if waiter is not None:
|
||||
# only wake up the waiter when connection_made() has been called
|
||||
self._loop.call_soon(futures._set_result_unless_cancelled,
|
||||
waiter, None)
|
||||
self._call_soon(futures._set_result_unless_cancelled, waiter, None)
|
||||
|
||||
def set_protocol(self, protocol):
|
||||
if isinstance(protocol, protocols.BufferedProtocol):
|
||||
|
|
@ -1081,7 +1086,7 @@ def write(self, data):
|
|||
if not data:
|
||||
return
|
||||
# Not all was written; register write handler.
|
||||
self._loop._add_writer(self._sock_fd, self._write_ready)
|
||||
self._add_writer(self._sock_fd, self._write_ready)
|
||||
|
||||
# Add it to the buffer.
|
||||
self._buffer.append(data)
|
||||
|
|
@ -1185,7 +1190,7 @@ def writelines(self, list_of_data):
|
|||
self._write_ready()
|
||||
# If the entire buffer couldn't be written, register a write handler
|
||||
if self._buffer:
|
||||
self._loop._add_writer(self._sock_fd, self._write_ready)
|
||||
self._add_writer(self._sock_fd, self._write_ready)
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def can_write_eof(self):
|
||||
|
|
@ -1226,14 +1231,12 @@ def __init__(self, loop, sock, protocol, address=None,
|
|||
super().__init__(loop, sock, protocol, extra)
|
||||
self._address = address
|
||||
self._buffer_size = 0
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
self._call_soon(self._protocol.connection_made, self)
|
||||
# only start reading when connection_made() has been called
|
||||
self._loop.call_soon(self._add_reader,
|
||||
self._sock_fd, self._read_ready)
|
||||
self._call_soon(self._add_reader, self._sock_fd, self._read_ready)
|
||||
if waiter is not None:
|
||||
# only wake up the waiter when connection_made() has been called
|
||||
self._loop.call_soon(futures._set_result_unless_cancelled,
|
||||
waiter, None)
|
||||
self._call_soon(futures._set_result_unless_cancelled, waiter, None)
|
||||
|
||||
def get_write_buffer_size(self):
|
||||
return self._buffer_size
|
||||
|
|
@ -1280,7 +1283,7 @@ def sendto(self, data, addr=None):
|
|||
self._sock.sendto(data, addr)
|
||||
return
|
||||
except (BlockingIOError, InterruptedError):
|
||||
self._loop._add_writer(self._sock_fd, self._sendto_ready)
|
||||
self._add_writer(self._sock_fd, self._sendto_ready)
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1696,7 +1696,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
|
|||
server_side=False,
|
||||
server_hostname='python.org',
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
ssl_shutdown_timeout=shutdown_timeout,
|
||||
context=ANY)
|
||||
# Next try an explicit server_hostname.
|
||||
self.loop._make_ssl_transport.reset_mock()
|
||||
coro = self.loop.create_connection(
|
||||
|
|
@ -1711,7 +1712,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
|
|||
server_side=False,
|
||||
server_hostname='perl.com',
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
ssl_shutdown_timeout=shutdown_timeout,
|
||||
context=ANY)
|
||||
# Finally try an explicit empty server_hostname.
|
||||
self.loop._make_ssl_transport.reset_mock()
|
||||
coro = self.loop.create_connection(
|
||||
|
|
@ -1726,7 +1728,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
|
|||
server_side=False,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=handshake_timeout,
|
||||
ssl_shutdown_timeout=shutdown_timeout)
|
||||
ssl_shutdown_timeout=shutdown_timeout,
|
||||
context=ANY)
|
||||
|
||||
def test_create_connection_no_ssl_server_hostname_errors(self):
|
||||
# When not using ssl, server_hostname must be None.
|
||||
|
|
@ -2104,7 +2107,7 @@ def test_accept_connection_exception(self, m_log):
|
|||
constants.ACCEPT_RETRY_DELAY,
|
||||
# self.loop._start_serving
|
||||
mock.ANY,
|
||||
MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)
|
||||
MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY, mock.ANY)
|
||||
|
||||
def test_call_coroutine(self):
|
||||
async def simple_coroutine():
|
||||
|
|
|
|||
314
Lib/test/test_asyncio/test_server_context.py
Normal file
314
Lib/test/test_asyncio/test_server_context.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
import asyncio
|
||||
import contextvars
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None
|
||||
|
||||
from test.test_asyncio import utils as test_utils
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.events._set_event_loop_policy(None)
|
||||
|
||||
class ServerContextvarsTestCase:
|
||||
loop_factory = None # To be defined in subclasses
|
||||
server_ssl_context = None # To be defined in subclasses for SSL tests
|
||||
client_ssl_context = None # To be defined in subclasses for SSL tests
|
||||
|
||||
def run_coro(self, coro):
|
||||
return asyncio.run(coro, loop_factory=self.loop_factory)
|
||||
|
||||
def test_start_server1(self):
|
||||
# Test that asyncio.start_server captures the context at the time of server creation
|
||||
async def test():
|
||||
var = contextvars.ContextVar("var", default="default")
|
||||
|
||||
async def handle_client(reader, writer):
|
||||
value = var.get()
|
||||
writer.write(value.encode())
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
|
||||
ssl=self.server_ssl_context)
|
||||
# change the value
|
||||
var.set("after_server")
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(*addr,
|
||||
ssl=self.client_ssl_context)
|
||||
data = await reader.read(100)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
return data.decode()
|
||||
|
||||
async with server:
|
||||
addr = server.sockets[0].getsockname()
|
||||
self.assertEqual(await client(addr), "default")
|
||||
|
||||
self.assertEqual(var.get(), "after_server")
|
||||
|
||||
self.run_coro(test())
|
||||
|
||||
def test_start_server2(self):
|
||||
# Test that mutations to the context in one handler don't affect other handlers or the server's context
|
||||
async def test():
|
||||
var = contextvars.ContextVar("var", default="default")
|
||||
|
||||
async def handle_client(reader, writer):
|
||||
value = var.get()
|
||||
writer.write(value.encode())
|
||||
var.set("in_handler")
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
|
||||
ssl=self.server_ssl_context)
|
||||
var.set("after_server")
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(*addr,
|
||||
ssl=self.client_ssl_context)
|
||||
data = await reader.read(100)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
return data.decode()
|
||||
|
||||
async with server:
|
||||
addr = server.sockets[0].getsockname()
|
||||
self.assertEqual(await client(addr), "default")
|
||||
self.assertEqual(await client(addr), "default")
|
||||
self.assertEqual(await client(addr), "default")
|
||||
|
||||
self.assertEqual(var.get(), "after_server")
|
||||
|
||||
self.run_coro(test())
|
||||
|
||||
def test_start_server3(self):
|
||||
# Test that mutations to context in concurrent handlers don't affect each other or the server's context
|
||||
async def test():
|
||||
var = contextvars.ContextVar("var", default="default")
|
||||
var.set("before_server")
|
||||
|
||||
async def handle_client(reader, writer):
|
||||
writer.write(var.get().encode())
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
|
||||
ssl=self.server_ssl_context)
|
||||
var.set("after_server")
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(*addr,
|
||||
ssl=self.client_ssl_context)
|
||||
data = await reader.read(100)
|
||||
self.assertEqual(data.decode(), "before_server")
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async with server:
|
||||
addr = server.sockets[0].getsockname()
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for _ in range(100):
|
||||
tg.create_task(client(addr))
|
||||
|
||||
self.assertEqual(var.get(), "after_server")
|
||||
|
||||
self.run_coro(test())
|
||||
|
||||
def test_create_server1(self):
|
||||
# Test that loop.create_server captures the context at the time of server creation
|
||||
# and that mutations to the context in protocol callbacks don't affect the server's context
|
||||
async def test():
|
||||
var = contextvars.ContextVar("var", default="default")
|
||||
|
||||
class EchoProtocol(asyncio.Protocol):
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
value = var.get()
|
||||
var.set("in_handler")
|
||||
self.transport.write(value.encode())
|
||||
self.transport.close()
|
||||
|
||||
server = await asyncio.get_running_loop().create_server(
|
||||
lambda: EchoProtocol(), '127.0.0.1', 0,
|
||||
ssl=self.server_ssl_context)
|
||||
var.set("after_server")
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(*addr,
|
||||
ssl=self.client_ssl_context)
|
||||
data = await reader.read(100)
|
||||
self.assertEqual(data.decode(), "default")
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async with server:
|
||||
addr = server.sockets[0].getsockname()
|
||||
await client(addr)
|
||||
|
||||
self.assertEqual(var.get(), "after_server")
|
||||
|
||||
self.run_coro(test())
|
||||
|
||||
def test_create_server2(self):
|
||||
# Test that mutations to context in one protocol instance don't affect other instances or the server's context
|
||||
async def test():
|
||||
var = contextvars.ContextVar("var", default="default")
|
||||
|
||||
class EchoProtocol(asyncio.Protocol):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
assert var.get() == "default", var.get()
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
value = var.get()
|
||||
var.set("in_handler")
|
||||
self.transport.write(value.encode())
|
||||
self.transport.close()
|
||||
|
||||
server = await asyncio.get_running_loop().create_server(
|
||||
lambda: EchoProtocol(), '127.0.0.1', 0,
|
||||
ssl=self.server_ssl_context)
|
||||
|
||||
var.set("after_server")
|
||||
|
||||
async def client(addr, expected):
|
||||
reader, writer = await asyncio.open_connection(*addr,
|
||||
ssl=self.client_ssl_context)
|
||||
data = await reader.read(100)
|
||||
self.assertEqual(data.decode(), expected)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async with server:
|
||||
addr = server.sockets[0].getsockname()
|
||||
await client(addr, "default")
|
||||
await client(addr, "default")
|
||||
|
||||
self.assertEqual(var.get(), "after_server")
|
||||
|
||||
self.run_coro(test())
|
||||
|
||||
def test_gh140947(self):
|
||||
# See https://github.com/python/cpython/issues/140947
|
||||
|
||||
cvar1 = contextvars.ContextVar("cvar1")
|
||||
cvar2 = contextvars.ContextVar("cvar2")
|
||||
cvar3 = contextvars.ContextVar("cvar3")
|
||||
results = {}
|
||||
is_ssl = self.server_ssl_context is not None
|
||||
|
||||
def capture_context(meth):
|
||||
result = []
|
||||
for k,v in contextvars.copy_context().items():
|
||||
if k.name.startswith("cvar"):
|
||||
result.append((k.name, v))
|
||||
results[meth] = sorted(result)
|
||||
|
||||
class DemoProtocol(asyncio.Protocol):
|
||||
def __init__(self, on_conn_lost):
|
||||
self.transport = None
|
||||
self.on_conn_lost = on_conn_lost
|
||||
self.tasks = set()
|
||||
|
||||
def connection_made(self, transport):
|
||||
capture_context("connection_made")
|
||||
self.transport = transport
|
||||
|
||||
def data_received(self, data):
|
||||
capture_context("data_received")
|
||||
|
||||
task = asyncio.create_task(self.asgi())
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
self.transport.pause_reading()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
capture_context("connection_lost")
|
||||
if not self.on_conn_lost.done():
|
||||
self.on_conn_lost.set_result(True)
|
||||
|
||||
async def asgi(self):
|
||||
capture_context("asgi start")
|
||||
cvar1.set(True)
|
||||
# make sure that we only resume after the pause
|
||||
# otherwise the resume does nothing
|
||||
if is_ssl:
|
||||
while not self.transport._ssl_protocol._app_reading_paused:
|
||||
await asyncio.sleep(0.01)
|
||||
else:
|
||||
while not self.transport._paused:
|
||||
await asyncio.sleep(0.01)
|
||||
cvar2.set(True)
|
||||
self.transport.resume_reading()
|
||||
cvar3.set(True)
|
||||
capture_context("asgi end")
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
on_conn_lost = loop.create_future()
|
||||
|
||||
server = await loop.create_server(
|
||||
lambda: DemoProtocol(on_conn_lost), '127.0.0.1', 0,
|
||||
ssl=self.server_ssl_context)
|
||||
async with server:
|
||||
addr = server.sockets[0].getsockname()
|
||||
reader, writer = await asyncio.open_connection(*addr,
|
||||
ssl=self.client_ssl_context)
|
||||
writer.write(b"anything")
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
await on_conn_lost
|
||||
|
||||
self.run_coro(main())
|
||||
self.assertDictEqual(results, {
|
||||
"connection_made": [],
|
||||
"data_received": [],
|
||||
"asgi start": [],
|
||||
"asgi end": [("cvar1", True), ("cvar2", True), ("cvar3", True)],
|
||||
"connection_lost": [],
|
||||
})
|
||||
|
||||
|
||||
class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase):
|
||||
loop_factory = staticmethod(asyncio.new_event_loop)
|
||||
|
||||
@unittest.skipUnless(ssl, "SSL not available")
|
||||
class AsyncioEventLoopSSLTests(AsyncioEventLoopTests):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.server_ssl_context = test_utils.simple_server_sslcontext()
|
||||
self.client_ssl_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
if sys.platform == "win32":
|
||||
class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase):
|
||||
loop_factory = asyncio.ProactorEventLoop
|
||||
|
||||
class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase):
|
||||
loop_factory = asyncio.SelectorEventLoop
|
||||
|
||||
@unittest.skipUnless(ssl, "SSL not available")
|
||||
class AsyncioProactorEventLoopSSLTests(AsyncioProactorEventLoopTests):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.server_ssl_context = test_utils.simple_server_sslcontext()
|
||||
self.client_ssl_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
@unittest.skipUnless(ssl, "SSL not available")
|
||||
class AsyncioSelectorEventLoopSSLTests(AsyncioSelectorEventLoopTests):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.server_ssl_context = test_utils.simple_server_sslcontext()
|
||||
self.client_ssl_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -388,8 +388,8 @@ def close(self):
|
|||
else: # pragma: no cover
|
||||
raise AssertionError("Time generator is not finished")
|
||||
|
||||
def _add_reader(self, fd, callback, *args):
|
||||
self.readers[fd] = events.Handle(callback, args, self, None)
|
||||
def _add_reader(self, fd, callback, *args, context=None):
|
||||
self.readers[fd] = events.Handle(callback, args, self, context)
|
||||
|
||||
def _remove_reader(self, fd):
|
||||
self.remove_reader_count[fd] += 1
|
||||
|
|
@ -414,8 +414,8 @@ def assert_no_reader(self, fd):
|
|||
if fd in self.readers:
|
||||
raise AssertionError(f'fd {fd} is registered')
|
||||
|
||||
def _add_writer(self, fd, callback, *args):
|
||||
self.writers[fd] = events.Handle(callback, args, self, None)
|
||||
def _add_writer(self, fd, callback, *args, context=None):
|
||||
self.writers[fd] = events.Handle(callback, args, self, context)
|
||||
|
||||
def _remove_writer(self, fd):
|
||||
self.remove_writer_count[fd] += 1
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
Fix incorrect contextvars handling in server tasks created by :mod:`asyncio`. Patch by Kumar Aditya.
|
||||
Loading…
Add table
Add a link
Reference in a new issue