gh-144386: Add support for descriptors in ExitStack and AsyncExitStack (#144420)

__enter__(), __exit__(), __aenter__(), and __aexit__() can now be
arbitrary descriptors, not only normal methods, for consistency with the
"with" and "async with" statements.
This commit is contained in:
Serhiy Storchaka 2026-02-04 13:20:18 +02:00 committed by GitHub
parent 34e5a63f14
commit f73d2e7003
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 209 additions and 58 deletions

View file

@ -564,6 +564,10 @@ Functions and classes provided:
Raises :exc:`TypeError` instead of :exc:`AttributeError` if *cm*
is not a context manager.
.. versionchanged:: next
Added support for arbitrary descriptors :meth:`!__enter__` and
:meth:`!__exit__`.
.. method:: push(exit)
Adds a context manager's :meth:`~object.__exit__` method to the callback stack.
@ -582,6 +586,9 @@ Functions and classes provided:
The passed in object is returned from the function, allowing this
method to be used as a function decorator.
.. versionchanged:: next
Added support for arbitrary descriptors :meth:`!__exit__`.
.. method:: callback(callback, /, *args, **kwds)
Accepts an arbitrary callback function and arguments and adds it to
@ -639,11 +646,17 @@ Functions and classes provided:
Raises :exc:`TypeError` instead of :exc:`AttributeError` if *cm*
is not an asynchronous context manager.
.. versionchanged:: next
Added support for arbitrary descriptors :meth:`!__aenter__` and :meth:`!__aexit__`.
.. method:: push_async_exit(exit)
Similar to :meth:`ExitStack.push` but expects either an asynchronous context manager
or a coroutine function.
.. versionchanged:: next
Added support for arbitrary descriptors :meth:`!__aexit__`.
.. method:: push_async_callback(callback, /, *args, **kwds)
Similar to :meth:`ExitStack.callback` but expects a coroutine function.

View file

@ -548,6 +548,16 @@ concurrent.futures
(Contributed by Jonathan Berg in :gh:`139486`.)
contextlib
----------
* Added support for arbitrary descriptors :meth:`!__enter__`,
:meth:`!__exit__`, :meth:`!__aenter__`, and :meth:`!__aexit__` in
:class:`~contextlib.ExitStack` and :class:`contextlib.AsyncExitStack`, for
consistency with the :keyword:`with` and :keyword:`async with` statements.
(Contributed by Serhiy Storchaka in :gh:`144386`.)
dataclasses
-----------

View file

@ -5,7 +5,7 @@
import _collections_abc
from collections import deque
from functools import wraps
from types import MethodType, GenericAlias
from types import GenericAlias
__all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext",
"AbstractContextManager", "AbstractAsyncContextManager",
@ -469,13 +469,23 @@ def __exit__(self, exctype, excinst, exctb):
return False
def _lookup_special(obj, name, default):
# Follow the standard lookup behaviour for special methods.
from inspect import getattr_static, _descriptor_get
cls = type(obj)
try:
descr = getattr_static(cls, name)
except AttributeError:
return default
return _descriptor_get(descr, obj)
_sentinel = ['SENTINEL']
class _BaseExitStack:
"""A base class for ExitStack and AsyncExitStack."""
@staticmethod
def _create_exit_wrapper(cm, cm_exit):
return MethodType(cm_exit, cm)
@staticmethod
def _create_cb_wrapper(callback, /, *args, **kwds):
def _exit_wrapper(exc_type, exc, tb):
@ -499,17 +509,8 @@ def push(self, exit):
Also accepts any object with an __exit__ method (registering a call
to the method instead of the object itself).
"""
# We use an unbound method rather than a bound method to follow
# the standard lookup behaviour for special methods.
_cb_type = type(exit)
try:
exit_method = _cb_type.__exit__
except AttributeError:
# Not a context manager, so assume it's a callable.
self._push_exit_callback(exit)
else:
self._push_cm_exit(exit, exit_method)
exit_method = _lookup_special(exit, '__exit__', exit)
self._push_exit_callback(exit_method)
return exit # Allow use as a decorator.
def enter_context(self, cm):
@ -518,17 +519,18 @@ def enter_context(self, cm):
If successful, also pushes its __exit__ method as a callback and
returns the result of the __enter__ method.
"""
# We look up the special methods on the type to match the with
# statement.
cls = type(cm)
try:
_enter = cls.__enter__
_exit = cls.__exit__
except AttributeError:
_enter = _lookup_special(cm, '__enter__', _sentinel)
if _enter is _sentinel:
cls = type(cm)
raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
f"not support the context manager protocol") from None
result = _enter(cm)
self._push_cm_exit(cm, _exit)
f"not support the context manager protocol")
_exit = _lookup_special(cm, '__exit__', _sentinel)
if _exit is _sentinel:
cls = type(cm)
raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
f"not support the context manager protocol")
result = _enter()
self._push_exit_callback(_exit)
return result
def callback(self, callback, /, *args, **kwds):
@ -544,11 +546,6 @@ def callback(self, callback, /, *args, **kwds):
self._push_exit_callback(_exit_wrapper)
return callback # Allow use as a decorator
def _push_cm_exit(self, cm, cm_exit):
"""Helper to correctly register callbacks to __exit__ methods."""
_exit_wrapper = self._create_exit_wrapper(cm, cm_exit)
self._push_exit_callback(_exit_wrapper, True)
def _push_exit_callback(self, callback, is_sync=True):
self._exit_callbacks.append((is_sync, callback))
@ -641,10 +638,6 @@ class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager):
# connection later in the list raise an exception.
"""
@staticmethod
def _create_async_exit_wrapper(cm, cm_exit):
return MethodType(cm_exit, cm)
@staticmethod
def _create_async_cb_wrapper(callback, /, *args, **kwds):
async def _exit_wrapper(exc_type, exc, tb):
@ -657,16 +650,18 @@ async def enter_async_context(self, cm):
If successful, also pushes its __aexit__ method as a callback and
returns the result of the __aenter__ method.
"""
cls = type(cm)
try:
_enter = cls.__aenter__
_exit = cls.__aexit__
except AttributeError:
_enter = _lookup_special(cm, '__aenter__', _sentinel)
if _enter is _sentinel:
cls = type(cm)
raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
f"not support the asynchronous context manager protocol"
) from None
result = await _enter(cm)
self._push_async_cm_exit(cm, _exit)
f"not support the asynchronous context manager protocol")
_exit = _lookup_special(cm, '__aexit__', _sentinel)
if _exit is _sentinel:
cls = type(cm)
raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
f"not support the asynchronous context manager protocol")
result = await _enter()
self._push_exit_callback(_exit, False)
return result
def push_async_exit(self, exit):
@ -677,14 +672,8 @@ def push_async_exit(self, exit):
Also accepts any object with an __aexit__ method (registering a call
to the method instead of the object itself).
"""
_cb_type = type(exit)
try:
exit_method = _cb_type.__aexit__
except AttributeError:
# Not an async context manager, so assume it's a coroutine function
self._push_exit_callback(exit, False)
else:
self._push_async_cm_exit(exit, exit_method)
exit_method = _lookup_special(exit, '__aexit__', exit)
self._push_exit_callback(exit_method, False)
return exit # Allow use as a decorator
def push_async_callback(self, callback, /, *args, **kwds):
@ -704,12 +693,6 @@ async def aclose(self):
"""Immediately unwind the context stack."""
await self.__aexit__(None, None, None)
def _push_async_cm_exit(self, cm, cm_exit):
"""Helper to correctly register coroutine function to __aexit__
method."""
_exit_wrapper = self._create_async_exit_wrapper(cm, cm_exit)
self._push_exit_callback(_exit_wrapper, False)
async def __aenter__(self):
return self

View file

@ -788,6 +788,75 @@ def _exit():
result.append(2)
self.assertEqual(result, [1, 2, 3, 4])
def test_enter_context_classmethod(self):
class TestCM:
@classmethod
def __enter__(cls):
result.append(('enter', cls))
@classmethod
def __exit__(cls, *exc_details):
result.append(('exit', cls, *exc_details))
cm = TestCM()
result = []
with self.exit_stack() as stack:
stack.enter_context(cm)
self.assertEqual(result, [('enter', TestCM)])
self.assertEqual(result, [('enter', TestCM),
('exit', TestCM, None, None, None)])
result = []
with self.exit_stack() as stack:
stack.push(cm)
self.assertEqual(result, [])
self.assertEqual(result, [('exit', TestCM, None, None, None)])
def test_enter_context_staticmethod(self):
class TestCM:
@staticmethod
def __enter__():
result.append('enter')
@staticmethod
def __exit__(*exc_details):
result.append(('exit', *exc_details))
cm = TestCM()
result = []
with self.exit_stack() as stack:
stack.enter_context(cm)
self.assertEqual(result, ['enter'])
self.assertEqual(result, ['enter', ('exit', None, None, None)])
result = []
with self.exit_stack() as stack:
stack.push(cm)
self.assertEqual(result, [])
self.assertEqual(result, [('exit', None, None, None)])
def test_enter_context_slots(self):
class TestCM:
__slots__ = ('__enter__', '__exit__')
def __init__(self):
def enter():
result.append('enter')
def exit(*exc_details):
result.append(('exit', *exc_details))
self.__enter__ = enter
self.__exit__ = exit
cm = TestCM()
result = []
with self.exit_stack() as stack:
stack.enter_context(cm)
self.assertEqual(result, ['enter'])
self.assertEqual(result, ['enter', ('exit', None, None, None)])
result = []
with self.exit_stack() as stack:
stack.push(cm)
self.assertEqual(result, [])
self.assertEqual(result, [('exit', None, None, None)])
def test_enter_context_errors(self):
class LacksEnterAndExit:
pass

View file

@ -641,6 +641,78 @@ async def _exit():
self.assertEqual(result, [1, 2, 3, 4])
@_async_test
async def test_enter_async_context_classmethod(self):
class TestCM:
@classmethod
async def __aenter__(cls):
result.append(('enter', cls))
@classmethod
async def __aexit__(cls, *exc_details):
result.append(('exit', cls, *exc_details))
cm = TestCM()
result = []
async with self.exit_stack() as stack:
await stack.enter_async_context(cm)
self.assertEqual(result, [('enter', TestCM)])
self.assertEqual(result, [('enter', TestCM),
('exit', TestCM, None, None, None)])
result = []
async with self.exit_stack() as stack:
stack.push_async_exit(cm)
self.assertEqual(result, [])
self.assertEqual(result, [('exit', TestCM, None, None, None)])
@_async_test
async def test_enter_async_context_staticmethod(self):
class TestCM:
@staticmethod
async def __aenter__():
result.append('enter')
@staticmethod
async def __aexit__(*exc_details):
result.append(('exit', *exc_details))
cm = TestCM()
result = []
async with self.exit_stack() as stack:
await stack.enter_async_context(cm)
self.assertEqual(result, ['enter'])
self.assertEqual(result, ['enter', ('exit', None, None, None)])
result = []
async with self.exit_stack() as stack:
stack.push_async_exit(cm)
self.assertEqual(result, [])
self.assertEqual(result, [('exit', None, None, None)])
@_async_test
async def test_enter_async_context_slots(self):
class TestCM:
__slots__ = ('__aenter__', '__aexit__')
def __init__(self):
async def enter():
result.append('enter')
async def exit(*exc_details):
result.append(('exit', *exc_details))
self.__aenter__ = enter
self.__aexit__ = exit
cm = TestCM()
result = []
async with self.exit_stack() as stack:
await stack.enter_async_context(cm)
self.assertEqual(result, ['enter'])
self.assertEqual(result, ['enter', ('exit', None, None, None)])
result = []
async with self.exit_stack() as stack:
stack.push_async_exit(cm)
self.assertEqual(result, [])
self.assertEqual(result, [('exit', None, None, None)])
@_async_test
async def test_enter_async_context_errors(self):
class LacksEnterAndExit:

View file

@ -0,0 +1,4 @@
Add support for arbitrary descriptors :meth:`!__enter__`, :meth:`!__exit__`,
:meth:`!__aenter__`, and :meth:`!__aexit__` in :class:`contextlib.ExitStack`
and :class:`contextlib.AsyncExitStack`, for consistency with the
:keyword:`with` and :keyword:`async with` statements.