import asyncio import contextvars import unittest import sys from unittest import TestCase try: import ssl except ImportError: ssl = None from test.test_asyncio import utils as test_utils def tearDownModule(): asyncio.events._set_event_loop_policy(None) class ServerContextvarsTestCase: loop_factory = None # To be defined in subclasses server_ssl_context = None # To be defined in subclasses for SSL tests client_ssl_context = None # To be defined in subclasses for SSL tests def run_coro(self, coro): return asyncio.run(coro, loop_factory=self.loop_factory) def test_start_server1(self): # Test that asyncio.start_server captures the context at the time of server creation async def test(): var = contextvars.ContextVar("var", default="default") async def handle_client(reader, writer): value = var.get() writer.write(value.encode()) await writer.drain() writer.close() server = await asyncio.start_server(handle_client, '127.0.0.1', 0, ssl=self.server_ssl_context) # change the value var.set("after_server") async def client(addr): reader, writer = await asyncio.open_connection(*addr, ssl=self.client_ssl_context) data = await reader.read(100) writer.close() await writer.wait_closed() return data.decode() async with server: addr = server.sockets[0].getsockname() self.assertEqual(await client(addr), "default") self.assertEqual(var.get(), "after_server") self.run_coro(test()) def test_start_server2(self): # Test that mutations to the context in one handler don't affect other handlers or the server's context async def test(): var = contextvars.ContextVar("var", default="default") async def handle_client(reader, writer): value = var.get() writer.write(value.encode()) var.set("in_handler") await writer.drain() writer.close() server = await asyncio.start_server(handle_client, '127.0.0.1', 0, ssl=self.server_ssl_context) var.set("after_server") async def client(addr): reader, writer = await asyncio.open_connection(*addr, ssl=self.client_ssl_context) data = await reader.read(100) writer.close() await writer.wait_closed() return data.decode() async with server: addr = server.sockets[0].getsockname() self.assertEqual(await client(addr), "default") self.assertEqual(await client(addr), "default") self.assertEqual(await client(addr), "default") self.assertEqual(var.get(), "after_server") self.run_coro(test()) def test_start_server3(self): # Test that mutations to context in concurrent handlers don't affect each other or the server's context async def test(): var = contextvars.ContextVar("var", default="default") var.set("before_server") async def handle_client(reader, writer): writer.write(var.get().encode()) await writer.drain() writer.close() server = await asyncio.start_server(handle_client, '127.0.0.1', 0, ssl=self.server_ssl_context) var.set("after_server") async def client(addr): reader, writer = await asyncio.open_connection(*addr, ssl=self.client_ssl_context) data = await reader.read(100) self.assertEqual(data.decode(), "before_server") writer.close() await writer.wait_closed() async with server: addr = server.sockets[0].getsockname() async with asyncio.TaskGroup() as tg: for _ in range(100): tg.create_task(client(addr)) self.assertEqual(var.get(), "after_server") self.run_coro(test()) def test_create_server1(self): # Test that loop.create_server captures the context at the time of server creation # and that mutations to the context in protocol callbacks don't affect the server's context async def test(): var = contextvars.ContextVar("var", default="default") class EchoProtocol(asyncio.Protocol): def connection_made(self, transport): self.transport = transport value = var.get() var.set("in_handler") self.transport.write(value.encode()) self.transport.close() server = await asyncio.get_running_loop().create_server( lambda: EchoProtocol(), '127.0.0.1', 0, ssl=self.server_ssl_context) var.set("after_server") async def client(addr): reader, writer = await asyncio.open_connection(*addr, ssl=self.client_ssl_context) data = await reader.read(100) self.assertEqual(data.decode(), "default") writer.close() await writer.wait_closed() async with server: addr = server.sockets[0].getsockname() await client(addr) self.assertEqual(var.get(), "after_server") self.run_coro(test()) def test_create_server2(self): # Test that mutations to context in one protocol instance don't affect other instances or the server's context async def test(): var = contextvars.ContextVar("var", default="default") class EchoProtocol(asyncio.Protocol): def __init__(self): super().__init__() assert var.get() == "default", var.get() def connection_made(self, transport): self.transport = transport value = var.get() var.set("in_handler") self.transport.write(value.encode()) self.transport.close() server = await asyncio.get_running_loop().create_server( lambda: EchoProtocol(), '127.0.0.1', 0, ssl=self.server_ssl_context) var.set("after_server") async def client(addr, expected): reader, writer = await asyncio.open_connection(*addr, ssl=self.client_ssl_context) data = await reader.read(100) self.assertEqual(data.decode(), expected) writer.close() await writer.wait_closed() async with server: addr = server.sockets[0].getsockname() await client(addr, "default") await client(addr, "default") self.assertEqual(var.get(), "after_server") self.run_coro(test()) def test_gh140947(self): # See https://github.com/python/cpython/issues/140947 cvar1 = contextvars.ContextVar("cvar1") cvar2 = contextvars.ContextVar("cvar2") cvar3 = contextvars.ContextVar("cvar3") results = {} is_ssl = self.server_ssl_context is not None def capture_context(meth): result = [] for k,v in contextvars.copy_context().items(): if k.name.startswith("cvar"): result.append((k.name, v)) results[meth] = sorted(result) class DemoProtocol(asyncio.Protocol): def __init__(self, on_conn_lost): self.transport = None self.on_conn_lost = on_conn_lost self.tasks = set() def connection_made(self, transport): capture_context("connection_made") self.transport = transport def data_received(self, data): capture_context("data_received") task = asyncio.create_task(self.asgi()) self.tasks.add(task) task.add_done_callback(self.tasks.discard) self.transport.pause_reading() def connection_lost(self, exc): capture_context("connection_lost") if not self.on_conn_lost.done(): self.on_conn_lost.set_result(True) async def asgi(self): capture_context("asgi start") cvar1.set(True) # make sure that we only resume after the pause # otherwise the resume does nothing if is_ssl: while not self.transport._ssl_protocol._app_reading_paused: await asyncio.sleep(0.01) else: while not self.transport._paused: await asyncio.sleep(0.01) cvar2.set(True) self.transport.resume_reading() cvar3.set(True) capture_context("asgi end") async def main(): loop = asyncio.get_running_loop() on_conn_lost = loop.create_future() server = await loop.create_server( lambda: DemoProtocol(on_conn_lost), '127.0.0.1', 0, ssl=self.server_ssl_context) async with server: addr = server.sockets[0].getsockname() reader, writer = await asyncio.open_connection(*addr, ssl=self.client_ssl_context) writer.write(b"anything") await writer.drain() writer.close() await writer.wait_closed() await on_conn_lost self.run_coro(main()) self.assertDictEqual(results, { "connection_made": [], "data_received": [], "asgi start": [], "asgi end": [("cvar1", True), ("cvar2", True), ("cvar3", True)], "connection_lost": [], }) class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase): loop_factory = staticmethod(asyncio.new_event_loop) @unittest.skipUnless(ssl, "SSL not available") class AsyncioEventLoopSSLTests(AsyncioEventLoopTests): def setUp(self): super().setUp() self.server_ssl_context = test_utils.simple_server_sslcontext() self.client_ssl_context = test_utils.simple_client_sslcontext() if sys.platform == "win32": class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase): loop_factory = asyncio.ProactorEventLoop class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase): loop_factory = asyncio.SelectorEventLoop @unittest.skipUnless(ssl, "SSL not available") class AsyncioProactorEventLoopSSLTests(AsyncioProactorEventLoopTests): def setUp(self): super().setUp() self.server_ssl_context = test_utils.simple_server_sslcontext() self.client_ssl_context = test_utils.simple_client_sslcontext() @unittest.skipUnless(ssl, "SSL not available") class AsyncioSelectorEventLoopSSLTests(AsyncioSelectorEventLoopTests): def setUp(self): super().setUp() self.server_ssl_context = test_utils.simple_server_sslcontext() self.client_ssl_context = test_utils.simple_client_sslcontext() if __name__ == "__main__": unittest.main()