mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 07:31:38 +00:00 
			
		
		
		
	gh-128552: fix refcycles in eager task creation (#128553)
(cherry picked from commit 61b9811ac6)
			
			
This commit is contained in:
		
							parent
							
								
									7e099c51b6
								
							
						
					
					
						commit
						13835888e6
					
				
					 4 changed files with 72 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -1,6 +1,8 @@
 | 
			
		|||
# Adapted with permission from the EdgeDB project;
 | 
			
		||||
# license: PSFL.
 | 
			
		||||
 | 
			
		||||
import weakref
 | 
			
		||||
import sys
 | 
			
		||||
import gc
 | 
			
		||||
import asyncio
 | 
			
		||||
import contextvars
 | 
			
		||||
| 
						 | 
				
			
			@ -28,7 +30,25 @@ def get_error_types(eg):
 | 
			
		|||
    return {type(exc) for exc in eg.exceptions}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
 | 
			
		||||
def set_gc_state(enabled):
 | 
			
		||||
    was_enabled = gc.isenabled()
 | 
			
		||||
    if enabled:
 | 
			
		||||
        gc.enable()
 | 
			
		||||
    else:
 | 
			
		||||
        gc.disable()
 | 
			
		||||
    return was_enabled
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@contextlib.contextmanager
 | 
			
		||||
def disable_gc():
 | 
			
		||||
    was_enabled = set_gc_state(enabled=False)
 | 
			
		||||
    try:
 | 
			
		||||
        yield
 | 
			
		||||
    finally:
 | 
			
		||||
        set_gc_state(enabled=was_enabled)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseTestTaskGroup:
 | 
			
		||||
 | 
			
		||||
    async def test_taskgroup_01(self):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -822,15 +842,15 @@ async def test_taskgroup_without_parent_task(self):
 | 
			
		|||
        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
 | 
			
		||||
            tg.create_task(coro)
 | 
			
		||||
 | 
			
		||||
    def test_coro_closed_when_tg_closed(self):
 | 
			
		||||
    async def test_coro_closed_when_tg_closed(self):
 | 
			
		||||
        async def run_coro_after_tg_closes():
 | 
			
		||||
            async with taskgroups.TaskGroup() as tg:
 | 
			
		||||
                pass
 | 
			
		||||
            coro = asyncio.sleep(0)
 | 
			
		||||
            with self.assertRaisesRegex(RuntimeError, "is finished"):
 | 
			
		||||
                tg.create_task(coro)
 | 
			
		||||
        loop = asyncio.get_event_loop()
 | 
			
		||||
        loop.run_until_complete(run_coro_after_tg_closes())
 | 
			
		||||
 | 
			
		||||
        await run_coro_after_tg_closes()
 | 
			
		||||
 | 
			
		||||
    async def test_cancelling_level_preserved(self):
 | 
			
		||||
        async def raise_after(t, e):
 | 
			
		||||
| 
						 | 
				
			
			@ -955,6 +975,30 @@ async def coro_fn():
 | 
			
		|||
        self.assertIsInstance(exc, _Done)
 | 
			
		||||
        self.assertListEqual(gc.get_referrers(exc), [])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async def test_exception_refcycles_parent_task_wr(self):
 | 
			
		||||
        """Test that TaskGroup deletes self._parent_task and create_task() deletes task"""
 | 
			
		||||
        tg = asyncio.TaskGroup()
 | 
			
		||||
        exc = None
 | 
			
		||||
 | 
			
		||||
        class _Done(Exception):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        async def coro_fn():
 | 
			
		||||
            async with tg:
 | 
			
		||||
                raise _Done
 | 
			
		||||
 | 
			
		||||
        with disable_gc():
 | 
			
		||||
            try:
 | 
			
		||||
                async with asyncio.TaskGroup() as tg2:
 | 
			
		||||
                    task_wr = weakref.ref(tg2.create_task(coro_fn()))
 | 
			
		||||
            except* _Done as excs:
 | 
			
		||||
                exc = excs.exceptions[0].exceptions[0]
 | 
			
		||||
 | 
			
		||||
        self.assertIsNone(task_wr())
 | 
			
		||||
        self.assertIsInstance(exc, _Done)
 | 
			
		||||
        self.assertListEqual(gc.get_referrers(exc), [])
 | 
			
		||||
 | 
			
		||||
    async def test_exception_refcycles_propagate_cancellation_error(self):
 | 
			
		||||
        """Test that TaskGroup deletes propagate_cancellation_error"""
 | 
			
		||||
        tg = asyncio.TaskGroup()
 | 
			
		||||
| 
						 | 
				
			
			@ -988,5 +1032,16 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
 | 
			
		|||
        self.assertListEqual(gc.get_referrers(exc), [])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
 | 
			
		||||
    loop_factory = asyncio.EventLoop
 | 
			
		||||
 | 
			
		||||
class TestEagerTaskTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def loop_factory():
 | 
			
		||||
        loop = asyncio.EventLoop()
 | 
			
		||||
        loop.set_task_factory(asyncio.eager_task_factory)
 | 
			
		||||
        return loop
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue