mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 07:31:38 +00:00 
			
		
		
		
	When a task catches CancelledError and raises some other error, the other error should not silently be suppressed. Any scenario where a task crashes in cleanup upon cancellation will now result in an ExceptionGroup wrapping the crash(es) instead of propagating CancelledError and ignoring the side errors. NOTE: This represents a change in behavior (hence the need to change several tests). But it is only an edge case. Co-authored-by: Thomas Grainger <tagrain@gmail.com>
		
			
				
	
	
		
			784 lines
		
	
	
	
		
			22 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			784 lines
		
	
	
	
		
			22 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Adapted with permission from the EdgeDB project;
 | 
						|
# license: PSFL.
 | 
						|
 | 
						|
 | 
						|
import asyncio
 | 
						|
import contextvars
 | 
						|
import contextlib
 | 
						|
from asyncio import taskgroups
 | 
						|
import unittest
 | 
						|
 | 
						|
 | 
						|
# To prevent a warning "test altered the execution environment"
 | 
						|
def tearDownModule():
 | 
						|
    asyncio.set_event_loop_policy(None)
 | 
						|
 | 
						|
 | 
						|
class MyExc(Exception):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class MyBaseExc(BaseException):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
def get_error_types(eg):
 | 
						|
    return {type(exc) for exc in eg.exceptions}
 | 
						|
 | 
						|
 | 
						|
class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
 | 
						|
 | 
						|
    async def test_taskgroup_01(self):
 | 
						|
 | 
						|
        async def foo1():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            return 42
 | 
						|
 | 
						|
        async def foo2():
 | 
						|
            await asyncio.sleep(0.2)
 | 
						|
            return 11
 | 
						|
 | 
						|
        async with taskgroups.TaskGroup() as g:
 | 
						|
            t1 = g.create_task(foo1())
 | 
						|
            t2 = g.create_task(foo2())
 | 
						|
 | 
						|
        self.assertEqual(t1.result(), 42)
 | 
						|
        self.assertEqual(t2.result(), 11)
 | 
						|
 | 
						|
    async def test_taskgroup_02(self):
 | 
						|
 | 
						|
        async def foo1():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            return 42
 | 
						|
 | 
						|
        async def foo2():
 | 
						|
            await asyncio.sleep(0.2)
 | 
						|
            return 11
 | 
						|
 | 
						|
        async with taskgroups.TaskGroup() as g:
 | 
						|
            t1 = g.create_task(foo1())
 | 
						|
            await asyncio.sleep(0.15)
 | 
						|
            t2 = g.create_task(foo2())
 | 
						|
 | 
						|
        self.assertEqual(t1.result(), 42)
 | 
						|
        self.assertEqual(t2.result(), 11)
 | 
						|
 | 
						|
    async def test_taskgroup_03(self):
 | 
						|
 | 
						|
        async def foo1():
 | 
						|
            await asyncio.sleep(1)
 | 
						|
            return 42
 | 
						|
 | 
						|
        async def foo2():
 | 
						|
            await asyncio.sleep(0.2)
 | 
						|
            return 11
 | 
						|
 | 
						|
        async with taskgroups.TaskGroup() as g:
 | 
						|
            t1 = g.create_task(foo1())
 | 
						|
            await asyncio.sleep(0.15)
 | 
						|
            # cancel t1 explicitly, i.e. everything should continue
 | 
						|
            # working as expected.
 | 
						|
            t1.cancel()
 | 
						|
 | 
						|
            t2 = g.create_task(foo2())
 | 
						|
 | 
						|
        self.assertTrue(t1.cancelled())
 | 
						|
        self.assertEqual(t2.result(), 11)
 | 
						|
 | 
						|
    async def test_taskgroup_04(self):
 | 
						|
 | 
						|
        NUM = 0
 | 
						|
        t2_cancel = False
 | 
						|
        t2 = None
 | 
						|
 | 
						|
        async def foo1():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            1 / 0
 | 
						|
 | 
						|
        async def foo2():
 | 
						|
            nonlocal NUM, t2_cancel
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(1)
 | 
						|
            except asyncio.CancelledError:
 | 
						|
                t2_cancel = True
 | 
						|
                raise
 | 
						|
            NUM += 1
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            nonlocal NUM, t2
 | 
						|
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(foo1())
 | 
						|
                t2 = g.create_task(foo2())
 | 
						|
 | 
						|
            NUM += 10
 | 
						|
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await asyncio.create_task(runner())
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | 
						|
 | 
						|
        self.assertEqual(NUM, 0)
 | 
						|
        self.assertTrue(t2_cancel)
 | 
						|
        self.assertTrue(t2.cancelled())
 | 
						|
 | 
						|
    async def test_cancel_children_on_child_error(self):
 | 
						|
        # When a child task raises an error, the rest of the children
 | 
						|
        # are cancelled and the errors are gathered into an EG.
 | 
						|
 | 
						|
        NUM = 0
 | 
						|
        t2_cancel = False
 | 
						|
        runner_cancel = False
 | 
						|
 | 
						|
        async def foo1():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            1 / 0
 | 
						|
 | 
						|
        async def foo2():
 | 
						|
            nonlocal NUM, t2_cancel
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(5)
 | 
						|
            except asyncio.CancelledError:
 | 
						|
                t2_cancel = True
 | 
						|
                raise
 | 
						|
            NUM += 1
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            nonlocal NUM, runner_cancel
 | 
						|
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(foo1())
 | 
						|
                g.create_task(foo1())
 | 
						|
                g.create_task(foo1())
 | 
						|
                g.create_task(foo2())
 | 
						|
                try:
 | 
						|
                    await asyncio.sleep(10)
 | 
						|
                except asyncio.CancelledError:
 | 
						|
                    runner_cancel = True
 | 
						|
                    raise
 | 
						|
 | 
						|
            NUM += 10
 | 
						|
 | 
						|
        # The 3 foo1 sub tasks can be racy when the host is busy - if the
 | 
						|
        # cancellation happens in the middle, we'll see partial sub errors here
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await asyncio.create_task(runner())
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | 
						|
        self.assertEqual(NUM, 0)
 | 
						|
        self.assertTrue(t2_cancel)
 | 
						|
        self.assertTrue(runner_cancel)
 | 
						|
 | 
						|
    async def test_cancellation(self):
 | 
						|
 | 
						|
        NUM = 0
 | 
						|
 | 
						|
        async def foo():
 | 
						|
            nonlocal NUM
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(5)
 | 
						|
            except asyncio.CancelledError:
 | 
						|
                NUM += 1
 | 
						|
                raise
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                for _ in range(5):
 | 
						|
                    g.create_task(foo())
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
        with self.assertRaises(asyncio.CancelledError) as cm:
 | 
						|
            await r
 | 
						|
 | 
						|
        self.assertEqual(NUM, 5)
 | 
						|
 | 
						|
    async def test_taskgroup_07(self):
 | 
						|
 | 
						|
        NUM = 0
 | 
						|
 | 
						|
        async def foo():
 | 
						|
            nonlocal NUM
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(5)
 | 
						|
            except asyncio.CancelledError:
 | 
						|
                NUM += 1
 | 
						|
                raise
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            nonlocal NUM
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                for _ in range(5):
 | 
						|
                    g.create_task(foo())
 | 
						|
 | 
						|
                try:
 | 
						|
                    await asyncio.sleep(10)
 | 
						|
                except asyncio.CancelledError:
 | 
						|
                    NUM += 10
 | 
						|
                    raise
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
        with self.assertRaises(asyncio.CancelledError):
 | 
						|
            await r
 | 
						|
 | 
						|
        self.assertEqual(NUM, 15)
 | 
						|
 | 
						|
    async def test_taskgroup_08(self):
 | 
						|
 | 
						|
        async def foo():
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(10)
 | 
						|
            finally:
 | 
						|
                1 / 0
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                for _ in range(5):
 | 
						|
                    g.create_task(foo())
 | 
						|
 | 
						|
                await asyncio.sleep(10)
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await r
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | 
						|
 | 
						|
    async def test_taskgroup_09(self):
 | 
						|
 | 
						|
        t1 = t2 = None
 | 
						|
 | 
						|
        async def foo1():
 | 
						|
            await asyncio.sleep(1)
 | 
						|
            return 42
 | 
						|
 | 
						|
        async def foo2():
 | 
						|
            await asyncio.sleep(2)
 | 
						|
            return 11
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            nonlocal t1, t2
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                t1 = g.create_task(foo1())
 | 
						|
                t2 = g.create_task(foo2())
 | 
						|
                await asyncio.sleep(0.1)
 | 
						|
                1 / 0
 | 
						|
 | 
						|
        try:
 | 
						|
            await runner()
 | 
						|
        except ExceptionGroup as t:
 | 
						|
            self.assertEqual(get_error_types(t), {ZeroDivisionError})
 | 
						|
        else:
 | 
						|
            self.fail('ExceptionGroup was not raised')
 | 
						|
 | 
						|
        self.assertTrue(t1.cancelled())
 | 
						|
        self.assertTrue(t2.cancelled())
 | 
						|
 | 
						|
    async def test_taskgroup_10(self):
 | 
						|
 | 
						|
        t1 = t2 = None
 | 
						|
 | 
						|
        async def foo1():
 | 
						|
            await asyncio.sleep(1)
 | 
						|
            return 42
 | 
						|
 | 
						|
        async def foo2():
 | 
						|
            await asyncio.sleep(2)
 | 
						|
            return 11
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            nonlocal t1, t2
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                t1 = g.create_task(foo1())
 | 
						|
                t2 = g.create_task(foo2())
 | 
						|
                1 / 0
 | 
						|
 | 
						|
        try:
 | 
						|
            await runner()
 | 
						|
        except ExceptionGroup as t:
 | 
						|
            self.assertEqual(get_error_types(t), {ZeroDivisionError})
 | 
						|
        else:
 | 
						|
            self.fail('ExceptionGroup was not raised')
 | 
						|
 | 
						|
        self.assertTrue(t1.cancelled())
 | 
						|
        self.assertTrue(t2.cancelled())
 | 
						|
 | 
						|
    async def test_taskgroup_11(self):
 | 
						|
 | 
						|
        async def foo():
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(10)
 | 
						|
            finally:
 | 
						|
                1 / 0
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup():
 | 
						|
                async with taskgroups.TaskGroup() as g2:
 | 
						|
                    for _ in range(5):
 | 
						|
                        g2.create_task(foo())
 | 
						|
 | 
						|
                    await asyncio.sleep(10)
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await r
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
 | 
						|
        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
 | 
						|
 | 
						|
    async def test_taskgroup_12(self):
 | 
						|
 | 
						|
        async def foo():
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(10)
 | 
						|
            finally:
 | 
						|
                1 / 0
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g1:
 | 
						|
                g1.create_task(asyncio.sleep(10))
 | 
						|
 | 
						|
                async with taskgroups.TaskGroup() as g2:
 | 
						|
                    for _ in range(5):
 | 
						|
                        g2.create_task(foo())
 | 
						|
 | 
						|
                    await asyncio.sleep(10)
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await r
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
 | 
						|
        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
 | 
						|
 | 
						|
    async def test_taskgroup_13(self):
 | 
						|
 | 
						|
        async def crash_after(t):
 | 
						|
            await asyncio.sleep(t)
 | 
						|
            raise ValueError(t)
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g1:
 | 
						|
                g1.create_task(crash_after(0.1))
 | 
						|
 | 
						|
                async with taskgroups.TaskGroup() as g2:
 | 
						|
                    g2.create_task(crash_after(10))
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await r
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ValueError})
 | 
						|
 | 
						|
    async def test_taskgroup_14(self):
 | 
						|
 | 
						|
        async def crash_after(t):
 | 
						|
            await asyncio.sleep(t)
 | 
						|
            raise ValueError(t)
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g1:
 | 
						|
                g1.create_task(crash_after(10))
 | 
						|
 | 
						|
                async with taskgroups.TaskGroup() as g2:
 | 
						|
                    g2.create_task(crash_after(0.1))
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await r
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
 | 
						|
        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
 | 
						|
 | 
						|
    async def test_taskgroup_15(self):
 | 
						|
 | 
						|
        async def crash_soon():
 | 
						|
            await asyncio.sleep(0.3)
 | 
						|
            1 / 0
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g1:
 | 
						|
                g1.create_task(crash_soon())
 | 
						|
                try:
 | 
						|
                    await asyncio.sleep(10)
 | 
						|
                except asyncio.CancelledError:
 | 
						|
                    await asyncio.sleep(0.5)
 | 
						|
                    raise
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await r
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | 
						|
 | 
						|
    async def test_taskgroup_16(self):
 | 
						|
 | 
						|
        async def crash_soon():
 | 
						|
            await asyncio.sleep(0.3)
 | 
						|
            1 / 0
 | 
						|
 | 
						|
        async def nested_runner():
 | 
						|
            async with taskgroups.TaskGroup() as g1:
 | 
						|
                g1.create_task(crash_soon())
 | 
						|
                try:
 | 
						|
                    await asyncio.sleep(10)
 | 
						|
                except asyncio.CancelledError:
 | 
						|
                    await asyncio.sleep(0.5)
 | 
						|
                    raise
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            t = asyncio.create_task(nested_runner())
 | 
						|
            await t
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await r
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | 
						|
 | 
						|
    async def test_taskgroup_17(self):
 | 
						|
        NUM = 0
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            nonlocal NUM
 | 
						|
            async with taskgroups.TaskGroup():
 | 
						|
                try:
 | 
						|
                    await asyncio.sleep(10)
 | 
						|
                except asyncio.CancelledError:
 | 
						|
                    NUM += 10
 | 
						|
                    raise
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
        with self.assertRaises(asyncio.CancelledError):
 | 
						|
            await r
 | 
						|
 | 
						|
        self.assertEqual(NUM, 10)
 | 
						|
 | 
						|
    async def test_taskgroup_18(self):
 | 
						|
        NUM = 0
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            nonlocal NUM
 | 
						|
            async with taskgroups.TaskGroup():
 | 
						|
                try:
 | 
						|
                    await asyncio.sleep(10)
 | 
						|
                except asyncio.CancelledError:
 | 
						|
                    NUM += 10
 | 
						|
                    # This isn't a good idea, but we have to support
 | 
						|
                    # this weird case.
 | 
						|
                    raise MyExc
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.1)
 | 
						|
 | 
						|
        self.assertFalse(r.done())
 | 
						|
        r.cancel()
 | 
						|
 | 
						|
        try:
 | 
						|
            await r
 | 
						|
        except ExceptionGroup as t:
 | 
						|
            self.assertEqual(get_error_types(t),{MyExc})
 | 
						|
        else:
 | 
						|
            self.fail('ExceptionGroup was not raised')
 | 
						|
 | 
						|
        self.assertEqual(NUM, 10)
 | 
						|
 | 
						|
    async def test_taskgroup_19(self):
 | 
						|
        async def crash_soon():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            1 / 0
 | 
						|
 | 
						|
        async def nested():
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(10)
 | 
						|
            finally:
 | 
						|
                raise MyExc
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(crash_soon())
 | 
						|
                await nested()
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        try:
 | 
						|
            await r
 | 
						|
        except ExceptionGroup as t:
 | 
						|
            self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
 | 
						|
        else:
 | 
						|
            self.fail('TasgGroupError was not raised')
 | 
						|
 | 
						|
    async def test_taskgroup_20(self):
 | 
						|
        async def crash_soon():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            1 / 0
 | 
						|
 | 
						|
        async def nested():
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(10)
 | 
						|
            finally:
 | 
						|
                raise KeyboardInterrupt
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(crash_soon())
 | 
						|
                await nested()
 | 
						|
 | 
						|
        with self.assertRaises(KeyboardInterrupt):
 | 
						|
            await runner()
 | 
						|
 | 
						|
    async def test_taskgroup_20a(self):
 | 
						|
        async def crash_soon():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            1 / 0
 | 
						|
 | 
						|
        async def nested():
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(10)
 | 
						|
            finally:
 | 
						|
                raise MyBaseExc
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(crash_soon())
 | 
						|
                await nested()
 | 
						|
 | 
						|
        with self.assertRaises(BaseExceptionGroup) as cm:
 | 
						|
            await runner()
 | 
						|
 | 
						|
        self.assertEqual(
 | 
						|
            get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
 | 
						|
        )
 | 
						|
 | 
						|
    async def _test_taskgroup_21(self):
 | 
						|
        # This test doesn't work as asyncio, currently, doesn't
 | 
						|
        # correctly propagate KeyboardInterrupt (or SystemExit) --
 | 
						|
        # those cause the event loop itself to crash.
 | 
						|
        # (Compare to the previous (passing) test -- that one raises
 | 
						|
        # a plain exception but raises KeyboardInterrupt in nested();
 | 
						|
        # this test does it the other way around.)
 | 
						|
 | 
						|
        async def crash_soon():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            raise KeyboardInterrupt
 | 
						|
 | 
						|
        async def nested():
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(10)
 | 
						|
            finally:
 | 
						|
                raise TypeError
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(crash_soon())
 | 
						|
                await nested()
 | 
						|
 | 
						|
        with self.assertRaises(KeyboardInterrupt):
 | 
						|
            await runner()
 | 
						|
 | 
						|
    async def test_taskgroup_21a(self):
 | 
						|
 | 
						|
        async def crash_soon():
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            raise MyBaseExc
 | 
						|
 | 
						|
        async def nested():
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(10)
 | 
						|
            finally:
 | 
						|
                raise TypeError
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(crash_soon())
 | 
						|
                await nested()
 | 
						|
 | 
						|
        with self.assertRaises(BaseExceptionGroup) as cm:
 | 
						|
            await runner()
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
 | 
						|
 | 
						|
    async def test_taskgroup_22(self):
 | 
						|
 | 
						|
        async def foo1():
 | 
						|
            await asyncio.sleep(1)
 | 
						|
            return 42
 | 
						|
 | 
						|
        async def foo2():
 | 
						|
            await asyncio.sleep(2)
 | 
						|
            return 11
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(foo1())
 | 
						|
                g.create_task(foo2())
 | 
						|
 | 
						|
        r = asyncio.create_task(runner())
 | 
						|
        await asyncio.sleep(0.05)
 | 
						|
        r.cancel()
 | 
						|
 | 
						|
        with self.assertRaises(asyncio.CancelledError):
 | 
						|
            await r
 | 
						|
 | 
						|
    async def test_taskgroup_23(self):
 | 
						|
 | 
						|
        async def do_job(delay):
 | 
						|
            await asyncio.sleep(delay)
 | 
						|
 | 
						|
        async with taskgroups.TaskGroup() as g:
 | 
						|
            for count in range(10):
 | 
						|
                await asyncio.sleep(0.1)
 | 
						|
                g.create_task(do_job(0.3))
 | 
						|
                if count == 5:
 | 
						|
                    self.assertLess(len(g._tasks), 5)
 | 
						|
            await asyncio.sleep(1.35)
 | 
						|
            self.assertEqual(len(g._tasks), 0)
 | 
						|
 | 
						|
    async def test_taskgroup_24(self):
 | 
						|
 | 
						|
        async def root(g):
 | 
						|
            await asyncio.sleep(0.1)
 | 
						|
            g.create_task(coro1(0.1))
 | 
						|
            g.create_task(coro1(0.2))
 | 
						|
 | 
						|
        async def coro1(delay):
 | 
						|
            await asyncio.sleep(delay)
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(root(g))
 | 
						|
 | 
						|
        await runner()
 | 
						|
 | 
						|
    async def test_taskgroup_25(self):
 | 
						|
        nhydras = 0
 | 
						|
 | 
						|
        async def hydra(g):
 | 
						|
            nonlocal nhydras
 | 
						|
            nhydras += 1
 | 
						|
            await asyncio.sleep(0.01)
 | 
						|
            g.create_task(hydra(g))
 | 
						|
            g.create_task(hydra(g))
 | 
						|
 | 
						|
        async def hercules():
 | 
						|
            while nhydras < 10:
 | 
						|
                await asyncio.sleep(0.015)
 | 
						|
            1 / 0
 | 
						|
 | 
						|
        async def runner():
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(hydra(g))
 | 
						|
                g.create_task(hercules())
 | 
						|
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            await runner()
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | 
						|
        self.assertGreaterEqual(nhydras, 10)
 | 
						|
 | 
						|
    async def test_taskgroup_task_name(self):
 | 
						|
        async def coro():
 | 
						|
            await asyncio.sleep(0)
 | 
						|
        async with taskgroups.TaskGroup() as g:
 | 
						|
            t = g.create_task(coro(), name="yolo")
 | 
						|
            self.assertEqual(t.get_name(), "yolo")
 | 
						|
 | 
						|
    async def test_taskgroup_task_context(self):
 | 
						|
        cvar = contextvars.ContextVar('cvar')
 | 
						|
 | 
						|
        async def coro(val):
 | 
						|
            await asyncio.sleep(0)
 | 
						|
            cvar.set(val)
 | 
						|
 | 
						|
        async with taskgroups.TaskGroup() as g:
 | 
						|
            ctx = contextvars.copy_context()
 | 
						|
            self.assertIsNone(ctx.get(cvar))
 | 
						|
            t1 = g.create_task(coro(1), context=ctx)
 | 
						|
            await t1
 | 
						|
            self.assertEqual(1, ctx.get(cvar))
 | 
						|
            t2 = g.create_task(coro(2), context=ctx)
 | 
						|
            await t2
 | 
						|
            self.assertEqual(2, ctx.get(cvar))
 | 
						|
 | 
						|
    async def test_taskgroup_no_create_task_after_failure(self):
 | 
						|
        async def coro1():
 | 
						|
            await asyncio.sleep(0.001)
 | 
						|
            1 / 0
 | 
						|
        async def coro2(g):
 | 
						|
            try:
 | 
						|
                await asyncio.sleep(1)
 | 
						|
            except asyncio.CancelledError:
 | 
						|
                with self.assertRaises(RuntimeError):
 | 
						|
                    g.create_task(c1 := coro1())
 | 
						|
                # We still have to await c1 to avoid a warning
 | 
						|
                with self.assertRaises(ZeroDivisionError):
 | 
						|
                    await c1
 | 
						|
 | 
						|
        with self.assertRaises(ExceptionGroup) as cm:
 | 
						|
            async with taskgroups.TaskGroup() as g:
 | 
						|
                g.create_task(coro1())
 | 
						|
                g.create_task(coro2(g))
 | 
						|
 | 
						|
        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | 
						|
 | 
						|
    async def test_taskgroup_context_manager_exit_raises(self):
 | 
						|
        # See https://github.com/python/cpython/issues/95289
 | 
						|
        class CustomException(Exception):
 | 
						|
            pass
 | 
						|
 | 
						|
        async def raise_exc():
 | 
						|
            raise CustomException
 | 
						|
 | 
						|
        @contextlib.asynccontextmanager
 | 
						|
        async def database():
 | 
						|
            try:
 | 
						|
                yield
 | 
						|
            finally:
 | 
						|
                raise CustomException
 | 
						|
 | 
						|
        async def main():
 | 
						|
            task = asyncio.current_task()
 | 
						|
            try:
 | 
						|
                async with taskgroups.TaskGroup() as tg:
 | 
						|
                    async with database():
 | 
						|
                        tg.create_task(raise_exc())
 | 
						|
                        await asyncio.sleep(1)
 | 
						|
            except* CustomException as err:
 | 
						|
                self.assertEqual(task.cancelling(), 0)
 | 
						|
                self.assertEqual(len(err.exceptions), 2)
 | 
						|
 | 
						|
            else:
 | 
						|
                self.fail('CustomException not raised')
 | 
						|
 | 
						|
        await asyncio.create_task(main())
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    unittest.main()
 |