mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 07:31:38 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			357 lines
		
	
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			357 lines
		
	
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Tests for base_events.py"""
 | 
						|
 | 
						|
import asyncio
 | 
						|
import contextvars
 | 
						|
import unittest
 | 
						|
 | 
						|
from unittest import mock
 | 
						|
from asyncio import tasks
 | 
						|
from test.test_asyncio import utils as test_utils
 | 
						|
import test.support
 | 
						|
from test.support.script_helper import assert_python_ok
 | 
						|
 | 
						|
MOCK_ANY = mock.ANY
 | 
						|
 | 
						|
 | 
						|
def tearDownModule():
 | 
						|
    asyncio.set_event_loop_policy(None)
 | 
						|
 | 
						|
 | 
						|
class EagerTaskFactoryLoopTests:
 | 
						|
 | 
						|
    Task = None
 | 
						|
 | 
						|
    def run_coro(self, coro):
 | 
						|
        """
 | 
						|
        Helper method to run the `coro` coroutine in the test event loop.
 | 
						|
        It helps with making sure the event loop is running before starting
 | 
						|
        to execute `coro`. This is important for testing the eager step
 | 
						|
        functionality, since an eager step is taken only if the event loop
 | 
						|
        is already running.
 | 
						|
        """
 | 
						|
 | 
						|
        async def coro_runner():
 | 
						|
            self.assertTrue(asyncio.get_event_loop().is_running())
 | 
						|
            return await coro
 | 
						|
 | 
						|
        return self.loop.run_until_complete(coro)
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        self.loop = asyncio.new_event_loop()
 | 
						|
        self.eager_task_factory = asyncio.create_eager_task_factory(self.Task)
 | 
						|
        self.loop.set_task_factory(self.eager_task_factory)
 | 
						|
        self.set_event_loop(self.loop)
 | 
						|
 | 
						|
    def test_eager_task_factory_set(self):
 | 
						|
        self.assertIsNotNone(self.eager_task_factory)
 | 
						|
        self.assertIs(self.loop.get_task_factory(), self.eager_task_factory)
 | 
						|
 | 
						|
        async def noop(): pass
 | 
						|
 | 
						|
        async def run():
 | 
						|
            t = self.loop.create_task(noop())
 | 
						|
            self.assertIsInstance(t, self.Task)
 | 
						|
            await t
 | 
						|
 | 
						|
        self.run_coro(run())
 | 
						|
 | 
						|
    def test_await_future_during_eager_step(self):
 | 
						|
 | 
						|
        async def set_result(fut, val):
 | 
						|
            fut.set_result(val)
 | 
						|
 | 
						|
        async def run():
 | 
						|
            fut = self.loop.create_future()
 | 
						|
            t = self.loop.create_task(set_result(fut, 'my message'))
 | 
						|
            # assert the eager step completed the task
 | 
						|
            self.assertTrue(t.done())
 | 
						|
            return await fut
 | 
						|
 | 
						|
        self.assertEqual(self.run_coro(run()), 'my message')
 | 
						|
 | 
						|
    def test_eager_completion(self):
 | 
						|
 | 
						|
        async def coro():
 | 
						|
            return 'hello'
 | 
						|
 | 
						|
        async def run():
 | 
						|
            t = self.loop.create_task(coro())
 | 
						|
            # assert the eager step completed the task
 | 
						|
            self.assertTrue(t.done())
 | 
						|
            return await t
 | 
						|
 | 
						|
        self.assertEqual(self.run_coro(run()), 'hello')
 | 
						|
 | 
						|
    def test_block_after_eager_step(self):
 | 
						|
 | 
						|
        async def coro():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            return 'finished after blocking'
 | 
						|
 | 
						|
        async def run():
 | 
						|
            t = self.loop.create_task(coro())
 | 
						|
            self.assertFalse(t.done())
 | 
						|
            result = await t
 | 
						|
            self.assertTrue(t.done())
 | 
						|
            return result
 | 
						|
 | 
						|
        self.assertEqual(self.run_coro(run()), 'finished after blocking')
 | 
						|
 | 
						|
    def test_cancellation_after_eager_completion(self):
 | 
						|
 | 
						|
        async def coro():
 | 
						|
            return 'finished without blocking'
 | 
						|
 | 
						|
        async def run():
 | 
						|
            t = self.loop.create_task(coro())
 | 
						|
            t.cancel()
 | 
						|
            result = await t
 | 
						|
            # finished task can't be cancelled
 | 
						|
            self.assertFalse(t.cancelled())
 | 
						|
            return result
 | 
						|
 | 
						|
        self.assertEqual(self.run_coro(run()), 'finished without blocking')
 | 
						|
 | 
						|
    def test_cancellation_after_eager_step_blocks(self):
 | 
						|
 | 
						|
        async def coro():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            return 'finished after blocking'
 | 
						|
 | 
						|
        async def run():
 | 
						|
            t = self.loop.create_task(coro())
 | 
						|
            t.cancel('cancellation message')
 | 
						|
            self.assertGreater(t.cancelling(), 0)
 | 
						|
            result = await t
 | 
						|
 | 
						|
        with self.assertRaises(asyncio.CancelledError) as cm:
 | 
						|
            self.run_coro(run())
 | 
						|
 | 
						|
        self.assertEqual('cancellation message', cm.exception.args[0])
 | 
						|
 | 
						|
    def test_current_task(self):
 | 
						|
        captured_current_task = None
 | 
						|
 | 
						|
        async def coro():
 | 
						|
            nonlocal captured_current_task
 | 
						|
            captured_current_task = asyncio.current_task()
 | 
						|
            # verify the task before and after blocking is identical
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            self.assertIs(asyncio.current_task(), captured_current_task)
 | 
						|
 | 
						|
        async def run():
 | 
						|
            t = self.loop.create_task(coro())
 | 
						|
            self.assertIs(captured_current_task, t)
 | 
						|
            await t
 | 
						|
 | 
						|
        self.run_coro(run())
 | 
						|
        captured_current_task = None
 | 
						|
 | 
						|
    def test_all_tasks_with_eager_completion(self):
 | 
						|
        captured_all_tasks = None
 | 
						|
 | 
						|
        async def coro():
 | 
						|
            nonlocal captured_all_tasks
 | 
						|
            captured_all_tasks = asyncio.all_tasks()
 | 
						|
 | 
						|
        async def run():
 | 
						|
            t = self.loop.create_task(coro())
 | 
						|
            self.assertIn(t, captured_all_tasks)
 | 
						|
            self.assertNotIn(t, asyncio.all_tasks())
 | 
						|
 | 
						|
        self.run_coro(run())
 | 
						|
 | 
						|
    def test_all_tasks_with_blocking(self):
 | 
						|
        captured_eager_all_tasks = None
 | 
						|
 | 
						|
        async def coro(fut1, fut2):
 | 
						|
            nonlocal captured_eager_all_tasks
 | 
						|
            captured_eager_all_tasks = asyncio.all_tasks()
 | 
						|
            await fut1
 | 
						|
            fut2.set_result(None)
 | 
						|
 | 
						|
        async def run():
 | 
						|
            fut1 = self.loop.create_future()
 | 
						|
            fut2 = self.loop.create_future()
 | 
						|
            t = self.loop.create_task(coro(fut1, fut2))
 | 
						|
            self.assertIn(t, captured_eager_all_tasks)
 | 
						|
            self.assertIn(t, asyncio.all_tasks())
 | 
						|
            fut1.set_result(None)
 | 
						|
            await fut2
 | 
						|
            self.assertNotIn(t, asyncio.all_tasks())
 | 
						|
 | 
						|
        self.run_coro(run())
 | 
						|
 | 
						|
    def test_context_vars(self):
 | 
						|
        cv = contextvars.ContextVar('cv', default=0)
 | 
						|
 | 
						|
        coro_first_step_ran = False
 | 
						|
        coro_second_step_ran = False
 | 
						|
 | 
						|
        async def coro():
 | 
						|
            nonlocal coro_first_step_ran
 | 
						|
            nonlocal coro_second_step_ran
 | 
						|
            self.assertEqual(cv.get(), 1)
 | 
						|
            cv.set(2)
 | 
						|
            self.assertEqual(cv.get(), 2)
 | 
						|
            coro_first_step_ran = True
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            self.assertEqual(cv.get(), 2)
 | 
						|
            cv.set(3)
 | 
						|
            self.assertEqual(cv.get(), 3)
 | 
						|
            coro_second_step_ran = True
 | 
						|
 | 
						|
        async def run():
 | 
						|
            cv.set(1)
 | 
						|
            t = self.loop.create_task(coro())
 | 
						|
            self.assertTrue(coro_first_step_ran)
 | 
						|
            self.assertFalse(coro_second_step_ran)
 | 
						|
            self.assertEqual(cv.get(), 1)
 | 
						|
            await t
 | 
						|
            self.assertTrue(coro_second_step_ran)
 | 
						|
            self.assertEqual(cv.get(), 1)
 | 
						|
 | 
						|
        self.run_coro(run())
 | 
						|
 | 
						|
 | 
						|
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
 | 
						|
    Task = tasks._PyTask
 | 
						|
 | 
						|
 | 
						|
@unittest.skipUnless(hasattr(tasks, '_CTask'),
 | 
						|
                     'requires the C _asyncio module')
 | 
						|
class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
 | 
						|
    Task = getattr(tasks, '_CTask', None)
 | 
						|
 | 
						|
    def test_issue105987(self):
 | 
						|
        code = """if 1:
 | 
						|
        from _asyncio import _swap_current_task
 | 
						|
 | 
						|
        class DummyTask:
 | 
						|
            pass
 | 
						|
 | 
						|
        class DummyLoop:
 | 
						|
            pass
 | 
						|
 | 
						|
        l = DummyLoop()
 | 
						|
        _swap_current_task(l, DummyTask())
 | 
						|
        t = _swap_current_task(l, None)
 | 
						|
        """
 | 
						|
 | 
						|
        _, out, err = assert_python_ok("-c", code)
 | 
						|
        self.assertFalse(err)
 | 
						|
 | 
						|
class AsyncTaskCounter:
 | 
						|
    def __init__(self, loop, *, task_class, eager):
 | 
						|
        self.suspense_count = 0
 | 
						|
        self.task_count = 0
 | 
						|
 | 
						|
        def CountingTask(*args, eager_start=False, **kwargs):
 | 
						|
            if not eager_start:
 | 
						|
                self.task_count += 1
 | 
						|
            kwargs["eager_start"] = eager_start
 | 
						|
            return task_class(*args, **kwargs)
 | 
						|
 | 
						|
        if eager:
 | 
						|
            factory = asyncio.create_eager_task_factory(CountingTask)
 | 
						|
        else:
 | 
						|
            def factory(loop, coro, **kwargs):
 | 
						|
                return CountingTask(coro, loop=loop, **kwargs)
 | 
						|
        loop.set_task_factory(factory)
 | 
						|
 | 
						|
    def get(self):
 | 
						|
        return self.task_count
 | 
						|
 | 
						|
 | 
						|
async def awaitable_chain(depth):
 | 
						|
    if depth == 0:
 | 
						|
        return 0
 | 
						|
    return 1 + await awaitable_chain(depth - 1)
 | 
						|
 | 
						|
 | 
						|
async def recursive_taskgroups(width, depth):
 | 
						|
    if depth == 0:
 | 
						|
        return
 | 
						|
 | 
						|
    async with asyncio.TaskGroup() as tg:
 | 
						|
        futures = [
 | 
						|
            tg.create_task(recursive_taskgroups(width, depth - 1))
 | 
						|
            for _ in range(width)
 | 
						|
        ]
 | 
						|
 | 
						|
 | 
						|
async def recursive_gather(width, depth):
 | 
						|
    if depth == 0:
 | 
						|
        return
 | 
						|
 | 
						|
    await asyncio.gather(
 | 
						|
        *[recursive_gather(width, depth - 1) for _ in range(width)]
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class BaseTaskCountingTests:
 | 
						|
 | 
						|
    Task = None
 | 
						|
    eager = None
 | 
						|
    expected_task_count = None
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        self.loop = asyncio.new_event_loop()
 | 
						|
        self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager)
 | 
						|
        self.set_event_loop(self.loop)
 | 
						|
 | 
						|
    def test_awaitables_chain(self):
 | 
						|
        observed_depth = self.loop.run_until_complete(awaitable_chain(100))
 | 
						|
        self.assertEqual(observed_depth, 100)
 | 
						|
        self.assertEqual(self.counter.get(), 0 if self.eager else 1)
 | 
						|
 | 
						|
    def test_recursive_taskgroups(self):
 | 
						|
        num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4))
 | 
						|
        self.assertEqual(self.counter.get(), self.expected_task_count)
 | 
						|
 | 
						|
    def test_recursive_gather(self):
 | 
						|
        self.loop.run_until_complete(recursive_gather(5, 4))
 | 
						|
        self.assertEqual(self.counter.get(), self.expected_task_count)
 | 
						|
 | 
						|
 | 
						|
class BaseNonEagerTaskFactoryTests(BaseTaskCountingTests):
 | 
						|
    eager = False
 | 
						|
    expected_task_count = 781  # 1 + 5 + 5^2 + 5^3 + 5^4
 | 
						|
 | 
						|
 | 
						|
class BaseEagerTaskFactoryTests(BaseTaskCountingTests):
 | 
						|
    eager = True
 | 
						|
    expected_task_count = 0
 | 
						|
 | 
						|
 | 
						|
class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
 | 
						|
    Task = asyncio.Task
 | 
						|
 | 
						|
 | 
						|
class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
 | 
						|
    Task = asyncio.Task
 | 
						|
 | 
						|
 | 
						|
class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
 | 
						|
    Task = tasks._PyTask
 | 
						|
 | 
						|
 | 
						|
class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
 | 
						|
    Task = tasks._PyTask
 | 
						|
 | 
						|
 | 
						|
@unittest.skipUnless(hasattr(tasks, '_CTask'),
 | 
						|
                     'requires the C _asyncio module')
 | 
						|
class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
 | 
						|
    Task = getattr(tasks, '_CTask', None)
 | 
						|
 | 
						|
 | 
						|
@unittest.skipUnless(hasattr(tasks, '_CTask'),
 | 
						|
                     'requires the C _asyncio module')
 | 
						|
class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
 | 
						|
    Task = getattr(tasks, '_CTask', None)
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    unittest.main()
 |