[3.12] gh-124594: Create and reuse the same context for the entire asyncio REPL session (GH-124595) (#124849)

* gh-124594: Create and reuse the same context for the entire asyncio REPL session (GH-124595)
(cherry picked from commit 67e01a430f)

Co-authored-by: Bartosz Sławecki <bartoszpiotrslawecki@gmail.com>
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>


---------

Co-authored-by: Bartosz Sławecki <bartoszpiotrslawecki@gmail.com>
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
Miss Islington (bot) 2024-10-28 15:25:00 +01:00 committed by GitHub
parent 1e01dcf429
commit d89283b3e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 42 additions and 2 deletions

View file

@ -2,6 +2,7 @@
import asyncio
import code
import concurrent.futures
import contextvars
import inspect
import sys
import threading
@ -17,6 +18,7 @@ def __init__(self, locals, loop):
super().__init__(locals)
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
self.loop = loop
self.context = contextvars.copy_context()
def runcode(self, code):
future = concurrent.futures.Future()
@ -46,12 +48,12 @@ def callback():
return
try:
repl_future = self.loop.create_task(coro)
repl_future = self.loop.create_task(coro, context=self.context)
futures._chain_future(repl_future, future)
except BaseException as exc:
future.set_exception(exc)
loop.call_soon_threadsafe(callback)
loop.call_soon_threadsafe(callback, context=self.context)
try:
return future.result()

View file

@ -146,5 +146,42 @@ def f():
self.assertEqual(traceback_lines, expected_lines)
class TestAsyncioREPLContextVars(unittest.TestCase):
def test_toplevel_contextvars_sync(self):
user_input = dedent("""\
from contextvars import ContextVar
var = ContextVar("var", default="failed")
var.set("ok")
""")
p = spawn_repl("-m", "asyncio")
p.stdin.write(user_input)
user_input2 = dedent("""
print(f"toplevel contextvar test: {var.get()}")
""")
p.stdin.write(user_input2)
output = kill_python(p)
self.assertEqual(p.returncode, 0)
expected = "toplevel contextvar test: ok"
self.assertIn(expected, output, expected)
def test_toplevel_contextvars_async(self):
user_input = dedent("""\
from contextvars import ContextVar
var = ContextVar('var', default='failed')
""")
p = spawn_repl("-m", "asyncio")
p.stdin.write(user_input+"\n")
user_input2 = "async def set_var(): var.set('ok')\n"
p.stdin.write(user_input2+"\n")
user_input3 = "await set_var()\n"
p.stdin.write(user_input3+"\n")
user_input4 = "print(f'toplevel contextvar test: {var.get()}')\n"
p.stdin.write(user_input4+"\n")
output = kill_python(p)
self.assertEqual(p.returncode, 0)
expected = "toplevel contextvar test: ok"
self.assertIn(expected, output, expected)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1 @@
All :mod:`asyncio` REPL prompts run in the same :class:`context <contextvars.Context>`. Contributed by Bartosz Sławecki.