| 
									
										
										
										
											2022-08-11 16:12:06 -07:00
										 |  |  | # Adapted with permission from the EdgeDB project; | 
					
						
							|  |  |  | # license: PSFL. | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | __all__ = ["TaskGroup"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from . import events | 
					
						
							|  |  |  | from . import exceptions | 
					
						
							|  |  |  | from . import tasks | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-20 12:07:00 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  | class TaskGroup: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-17 21:30:44 -08:00
										 |  |  |     def __init__(self): | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |         self._entered = False | 
					
						
							|  |  |  |         self._exiting = False | 
					
						
							|  |  |  |         self._aborting = False | 
					
						
							|  |  |  |         self._loop = None | 
					
						
							|  |  |  |         self._parent_task = None | 
					
						
							|  |  |  |         self._parent_cancel_requested = False | 
					
						
							| 
									
										
										
										
											2022-05-27 15:20:21 -07:00
										 |  |  |         self._tasks = set() | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |         self._errors = [] | 
					
						
							|  |  |  |         self._base_error = None | 
					
						
							|  |  |  |         self._on_completed_fut = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __repr__(self): | 
					
						
							| 
									
										
										
										
											2022-02-20 12:07:00 +02:00
										 |  |  |         info = [''] | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |         if self._tasks: | 
					
						
							| 
									
										
										
										
											2022-02-20 12:07:00 +02:00
										 |  |  |             info.append(f'tasks={len(self._tasks)}') | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |         if self._errors: | 
					
						
							| 
									
										
										
										
											2022-02-20 12:07:00 +02:00
										 |  |  |             info.append(f'errors={len(self._errors)}') | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |         if self._aborting: | 
					
						
							| 
									
										
										
										
											2022-02-20 12:07:00 +02:00
										 |  |  |             info.append('cancelling') | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |         elif self._entered: | 
					
						
							| 
									
										
										
										
											2022-02-20 12:07:00 +02:00
										 |  |  |             info.append('entered') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         info_str = ' '.join(info) | 
					
						
							|  |  |  |         return f'<TaskGroup{info_str}>' | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     async def __aenter__(self): | 
					
						
							|  |  |  |         if self._entered: | 
					
						
							|  |  |  |             raise RuntimeError( | 
					
						
							|  |  |  |                 f"TaskGroup {self!r} has been already entered") | 
					
						
							|  |  |  |         self._entered = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if self._loop is None: | 
					
						
							|  |  |  |             self._loop = events.get_running_loop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self._parent_task = tasks.current_task(self._loop) | 
					
						
							|  |  |  |         if self._parent_task is None: | 
					
						
							|  |  |  |             raise RuntimeError( | 
					
						
							|  |  |  |                 f'TaskGroup {self!r} cannot determine the parent task') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return self | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def __aexit__(self, et, exc, tb): | 
					
						
							|  |  |  |         self._exiting = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (exc is not None and | 
					
						
							|  |  |  |                 self._is_base_error(exc) and | 
					
						
							|  |  |  |                 self._base_error is None): | 
					
						
							|  |  |  |             self._base_error = exc | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-08-04 19:27:44 +05:30
										 |  |  |         propagate_cancellation_error = \ | 
					
						
							|  |  |  |             exc if et is exceptions.CancelledError else None | 
					
						
							|  |  |  |         if self._parent_cancel_requested: | 
					
						
							|  |  |  |             # If this flag is set we *must* call uncancel(). | 
					
						
							|  |  |  |             if self._parent_task.uncancel() == 0: | 
					
						
							|  |  |  |                 # If there are no pending cancellations left, | 
					
						
							|  |  |  |                 # don't propagate CancelledError. | 
					
						
							|  |  |  |                 propagate_cancellation_error = None | 
					
						
							| 
									
										
										
										
											2022-02-26 17:18:48 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-08-04 19:27:44 +05:30
										 |  |  |         if et is not None: | 
					
						
							| 
									
										
										
										
											2022-02-26 17:18:48 +01:00
										 |  |  |             if not self._aborting: | 
					
						
							|  |  |  |                 # Our parent task is being cancelled: | 
					
						
							|  |  |  |                 # | 
					
						
							|  |  |  |                 #    async with TaskGroup() as g: | 
					
						
							|  |  |  |                 #        g.create_task(...) | 
					
						
							|  |  |  |                 #        await ...  # <- CancelledError | 
					
						
							|  |  |  |                 # | 
					
						
							|  |  |  |                 # or there's an exception in "async with": | 
					
						
							|  |  |  |                 # | 
					
						
							|  |  |  |                 #    async with TaskGroup() as g: | 
					
						
							|  |  |  |                 #        g.create_task(...) | 
					
						
							|  |  |  |                 #        1 / 0 | 
					
						
							|  |  |  |                 # | 
					
						
							|  |  |  |                 self._abort() | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # We use while-loop here because "self._on_completed_fut" | 
					
						
							|  |  |  |         # can be cancelled multiple times if our parent task | 
					
						
							|  |  |  |         # is being cancelled repeatedly (or even once, when | 
					
						
							|  |  |  |         # our own cancellation is already in progress) | 
					
						
							| 
									
										
										
										
											2022-05-27 15:20:21 -07:00
										 |  |  |         while self._tasks: | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |             if self._on_completed_fut is None: | 
					
						
							|  |  |  |                 self._on_completed_fut = self._loop.create_future() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 await self._on_completed_fut | 
					
						
							|  |  |  |             except exceptions.CancelledError as ex: | 
					
						
							|  |  |  |                 if not self._aborting: | 
					
						
							|  |  |  |                     # Our parent task is being cancelled: | 
					
						
							|  |  |  |                     # | 
					
						
							|  |  |  |                     #    async def wrapper(): | 
					
						
							|  |  |  |                     #        async with TaskGroup() as g: | 
					
						
							|  |  |  |                     #            g.create_task(foo) | 
					
						
							|  |  |  |                     # | 
					
						
							|  |  |  |                     # "wrapper" is being cancelled while "foo" is | 
					
						
							|  |  |  |                     # still running. | 
					
						
							|  |  |  |                     propagate_cancellation_error = ex | 
					
						
							|  |  |  |                     self._abort() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             self._on_completed_fut = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-27 15:20:21 -07:00
										 |  |  |         assert not self._tasks | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if self._base_error is not None: | 
					
						
							|  |  |  |             raise self._base_error | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-08-16 18:23:06 -07:00
										 |  |  |         # Propagate CancelledError if there is one, except if there | 
					
						
							|  |  |  |         # are other errors -- those have priority. | 
					
						
							|  |  |  |         if propagate_cancellation_error and not self._errors: | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |             raise propagate_cancellation_error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if et is not None and et is not exceptions.CancelledError: | 
					
						
							|  |  |  |             self._errors.append(exc) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if self._errors: | 
					
						
							|  |  |  |             # Exceptions are heavy objects that can have object | 
					
						
							|  |  |  |             # cycles (bad for GC); let's not keep a reference to | 
					
						
							|  |  |  |             # a bunch of them. | 
					
						
							| 
									
										
										
										
											2022-10-22 21:35:11 +05:30
										 |  |  |             try: | 
					
						
							|  |  |  |                 me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors) | 
					
						
							|  |  |  |                 raise me from None | 
					
						
							|  |  |  |             finally: | 
					
						
							|  |  |  |                 self._errors = None | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-14 13:54:13 +02:00
										 |  |  |     def create_task(self, coro, *, name=None, context=None): | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |         if not self._entered: | 
					
						
							|  |  |  |             raise RuntimeError(f"TaskGroup {self!r} has not been entered") | 
					
						
							| 
									
										
										
										
											2022-05-27 15:20:21 -07:00
										 |  |  |         if self._exiting and not self._tasks: | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |             raise RuntimeError(f"TaskGroup {self!r} is finished") | 
					
						
							| 
									
										
										
										
											2022-06-30 10:10:46 -07:00
										 |  |  |         if self._aborting: | 
					
						
							|  |  |  |             raise RuntimeError(f"TaskGroup {self!r} is shutting down") | 
					
						
							| 
									
										
										
										
											2022-03-14 13:54:13 +02:00
										 |  |  |         if context is None: | 
					
						
							|  |  |  |             task = self._loop.create_task(coro) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             task = self._loop.create_task(coro, context=context) | 
					
						
							| 
									
										
										
										
											2022-02-17 21:30:44 -08:00
										 |  |  |         tasks._set_task_name(task, name) | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |         task.add_done_callback(self._on_task_done) | 
					
						
							|  |  |  |         self._tasks.add(task) | 
					
						
							|  |  |  |         return task | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Since Python 3.8 Tasks propagate all exceptions correctly, | 
					
						
							|  |  |  |     # except for KeyboardInterrupt and SystemExit which are | 
					
						
							|  |  |  |     # still considered special. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _is_base_error(self, exc: BaseException) -> bool: | 
					
						
							|  |  |  |         assert isinstance(exc, BaseException) | 
					
						
							|  |  |  |         return isinstance(exc, (SystemExit, KeyboardInterrupt)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _abort(self): | 
					
						
							|  |  |  |         self._aborting = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for t in self._tasks: | 
					
						
							|  |  |  |             if not t.done(): | 
					
						
							|  |  |  |                 t.cancel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _on_task_done(self, task): | 
					
						
							| 
									
										
										
										
											2022-05-27 15:20:21 -07:00
										 |  |  |         self._tasks.discard(task) | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-27 15:20:21 -07:00
										 |  |  |         if self._on_completed_fut is not None and not self._tasks: | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |             if not self._on_completed_fut.done(): | 
					
						
							|  |  |  |                 self._on_completed_fut.set_result(True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if task.cancelled(): | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         exc = task.exception() | 
					
						
							|  |  |  |         if exc is None: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self._errors.append(exc) | 
					
						
							|  |  |  |         if self._is_base_error(exc) and self._base_error is None: | 
					
						
							|  |  |  |             self._base_error = exc | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if self._parent_task.done(): | 
					
						
							|  |  |  |             # Not sure if this case is possible, but we want to handle | 
					
						
							|  |  |  |             # it anyways. | 
					
						
							|  |  |  |             self._loop.call_exception_handler({ | 
					
						
							|  |  |  |                 'message': f'Task {task!r} has errored out but its parent ' | 
					
						
							|  |  |  |                            f'task {self._parent_task} is already completed', | 
					
						
							|  |  |  |                 'exception': exc, | 
					
						
							|  |  |  |                 'task': task, | 
					
						
							|  |  |  |             }) | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-26 17:18:48 +01:00
										 |  |  |         if not self._aborting and not self._parent_cancel_requested: | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |             # If parent task *is not* being cancelled, it means that we want | 
					
						
							|  |  |  |             # to manually cancel it to abort whatever is being run right now | 
					
						
							|  |  |  |             # in the TaskGroup.  But we want to mark parent task as | 
					
						
							|  |  |  |             # "not cancelled" later in __aexit__.  Example situation that | 
					
						
							|  |  |  |             # we need to handle: | 
					
						
							|  |  |  |             # | 
					
						
							|  |  |  |             #    async def foo(): | 
					
						
							|  |  |  |             #        try: | 
					
						
							|  |  |  |             #            async with TaskGroup() as g: | 
					
						
							|  |  |  |             #                g.create_task(crash_soon()) | 
					
						
							|  |  |  |             #                await something  # <- this needs to be canceled | 
					
						
							|  |  |  |             #                                 #    by the TaskGroup, e.g. | 
					
						
							|  |  |  |             #                                 #    foo() needs to be cancelled | 
					
						
							|  |  |  |             #        except Exception: | 
					
						
							|  |  |  |             #            # Ignore any exceptions raised in the TaskGroup | 
					
						
							|  |  |  |             #            pass | 
					
						
							|  |  |  |             #        await something_else     # this line has to be called | 
					
						
							|  |  |  |             #                                 # after TaskGroup is finished. | 
					
						
							| 
									
										
										
										
											2022-02-26 17:18:48 +01:00
										 |  |  |             self._abort() | 
					
						
							| 
									
										
										
										
											2022-02-15 15:42:04 -08:00
										 |  |  |             self._parent_cancel_requested = True | 
					
						
							|  |  |  |             self._parent_task.cancel() |