mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 07:31:38 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			212 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			212 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import asyncio
 | 
						|
from contextlib import asynccontextmanager
 | 
						|
import functools
 | 
						|
from test import support
 | 
						|
import unittest
 | 
						|
 | 
						|
 | 
						|
def _async_test(func):
 | 
						|
    """Decorator to turn an async function into a test case."""
 | 
						|
    @functools.wraps(func)
 | 
						|
    def wrapper(*args, **kwargs):
 | 
						|
        coro = func(*args, **kwargs)
 | 
						|
        loop = asyncio.new_event_loop()
 | 
						|
        asyncio.set_event_loop(loop)
 | 
						|
        try:
 | 
						|
            return loop.run_until_complete(coro)
 | 
						|
        finally:
 | 
						|
            loop.close()
 | 
						|
            asyncio.set_event_loop(None)
 | 
						|
    return wrapper
 | 
						|
 | 
						|
 | 
						|
class AsyncContextManagerTestCase(unittest.TestCase):
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_plain(self):
 | 
						|
        state = []
 | 
						|
        @asynccontextmanager
 | 
						|
        async def woohoo():
 | 
						|
            state.append(1)
 | 
						|
            yield 42
 | 
						|
            state.append(999)
 | 
						|
        async with woohoo() as x:
 | 
						|
            self.assertEqual(state, [1])
 | 
						|
            self.assertEqual(x, 42)
 | 
						|
            state.append(x)
 | 
						|
        self.assertEqual(state, [1, 42, 999])
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_finally(self):
 | 
						|
        state = []
 | 
						|
        @asynccontextmanager
 | 
						|
        async def woohoo():
 | 
						|
            state.append(1)
 | 
						|
            try:
 | 
						|
                yield 42
 | 
						|
            finally:
 | 
						|
                state.append(999)
 | 
						|
        with self.assertRaises(ZeroDivisionError):
 | 
						|
            async with woohoo() as x:
 | 
						|
                self.assertEqual(state, [1])
 | 
						|
                self.assertEqual(x, 42)
 | 
						|
                state.append(x)
 | 
						|
                raise ZeroDivisionError()
 | 
						|
        self.assertEqual(state, [1, 42, 999])
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_no_reraise(self):
 | 
						|
        @asynccontextmanager
 | 
						|
        async def whee():
 | 
						|
            yield
 | 
						|
        ctx = whee()
 | 
						|
        await ctx.__aenter__()
 | 
						|
        # Calling __aexit__ should not result in an exception
 | 
						|
        self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_trap_yield_after_throw(self):
 | 
						|
        @asynccontextmanager
 | 
						|
        async def whoo():
 | 
						|
            try:
 | 
						|
                yield
 | 
						|
            except:
 | 
						|
                yield
 | 
						|
        ctx = whoo()
 | 
						|
        await ctx.__aenter__()
 | 
						|
        with self.assertRaises(RuntimeError):
 | 
						|
            await ctx.__aexit__(TypeError, TypeError('foo'), None)
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_trap_no_yield(self):
 | 
						|
        @asynccontextmanager
 | 
						|
        async def whoo():
 | 
						|
            if False:
 | 
						|
                yield
 | 
						|
        ctx = whoo()
 | 
						|
        with self.assertRaises(RuntimeError):
 | 
						|
            await ctx.__aenter__()
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_trap_second_yield(self):
 | 
						|
        @asynccontextmanager
 | 
						|
        async def whoo():
 | 
						|
            yield
 | 
						|
            yield
 | 
						|
        ctx = whoo()
 | 
						|
        await ctx.__aenter__()
 | 
						|
        with self.assertRaises(RuntimeError):
 | 
						|
            await ctx.__aexit__(None, None, None)
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_non_normalised(self):
 | 
						|
        @asynccontextmanager
 | 
						|
        async def whoo():
 | 
						|
            try:
 | 
						|
                yield
 | 
						|
            except RuntimeError:
 | 
						|
                raise SyntaxError
 | 
						|
 | 
						|
        ctx = whoo()
 | 
						|
        await ctx.__aenter__()
 | 
						|
        with self.assertRaises(SyntaxError):
 | 
						|
            await ctx.__aexit__(RuntimeError, None, None)
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_except(self):
 | 
						|
        state = []
 | 
						|
        @asynccontextmanager
 | 
						|
        async def woohoo():
 | 
						|
            state.append(1)
 | 
						|
            try:
 | 
						|
                yield 42
 | 
						|
            except ZeroDivisionError as e:
 | 
						|
                state.append(e.args[0])
 | 
						|
                self.assertEqual(state, [1, 42, 999])
 | 
						|
        async with woohoo() as x:
 | 
						|
            self.assertEqual(state, [1])
 | 
						|
            self.assertEqual(x, 42)
 | 
						|
            state.append(x)
 | 
						|
            raise ZeroDivisionError(999)
 | 
						|
        self.assertEqual(state, [1, 42, 999])
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_except_stopiter(self):
 | 
						|
        @asynccontextmanager
 | 
						|
        async def woohoo():
 | 
						|
            yield
 | 
						|
 | 
						|
        for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
 | 
						|
            with self.subTest(type=type(stop_exc)):
 | 
						|
                try:
 | 
						|
                    async with woohoo():
 | 
						|
                        raise stop_exc
 | 
						|
                except Exception as ex:
 | 
						|
                    self.assertIs(ex, stop_exc)
 | 
						|
                else:
 | 
						|
                    self.fail(f'{stop_exc} was suppressed')
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_contextmanager_wrap_runtimeerror(self):
 | 
						|
        @asynccontextmanager
 | 
						|
        async def woohoo():
 | 
						|
            try:
 | 
						|
                yield
 | 
						|
            except Exception as exc:
 | 
						|
                raise RuntimeError(f'caught {exc}') from exc
 | 
						|
 | 
						|
        with self.assertRaises(RuntimeError):
 | 
						|
            async with woohoo():
 | 
						|
                1 / 0
 | 
						|
 | 
						|
        # If the context manager wrapped StopAsyncIteration in a RuntimeError,
 | 
						|
        # we also unwrap it, because we can't tell whether the wrapping was
 | 
						|
        # done by the generator machinery or by the generator itself.
 | 
						|
        with self.assertRaises(StopAsyncIteration):
 | 
						|
            async with woohoo():
 | 
						|
                raise StopAsyncIteration
 | 
						|
 | 
						|
    def _create_contextmanager_attribs(self):
 | 
						|
        def attribs(**kw):
 | 
						|
            def decorate(func):
 | 
						|
                for k,v in kw.items():
 | 
						|
                    setattr(func,k,v)
 | 
						|
                return func
 | 
						|
            return decorate
 | 
						|
        @asynccontextmanager
 | 
						|
        @attribs(foo='bar')
 | 
						|
        async def baz(spam):
 | 
						|
            """Whee!"""
 | 
						|
            yield
 | 
						|
        return baz
 | 
						|
 | 
						|
    def test_contextmanager_attribs(self):
 | 
						|
        baz = self._create_contextmanager_attribs()
 | 
						|
        self.assertEqual(baz.__name__,'baz')
 | 
						|
        self.assertEqual(baz.foo, 'bar')
 | 
						|
 | 
						|
    @support.requires_docstrings
 | 
						|
    def test_contextmanager_doc_attrib(self):
 | 
						|
        baz = self._create_contextmanager_attribs()
 | 
						|
        self.assertEqual(baz.__doc__, "Whee!")
 | 
						|
 | 
						|
    @support.requires_docstrings
 | 
						|
    @_async_test
 | 
						|
    async def test_instance_docstring_given_cm_docstring(self):
 | 
						|
        baz = self._create_contextmanager_attribs()(None)
 | 
						|
        self.assertEqual(baz.__doc__, "Whee!")
 | 
						|
        async with baz:
 | 
						|
            pass  # suppress warning
 | 
						|
 | 
						|
    @_async_test
 | 
						|
    async def test_keywords(self):
 | 
						|
        # Ensure no keyword arguments are inhibited
 | 
						|
        @asynccontextmanager
 | 
						|
        async def woohoo(self, func, args, kwds):
 | 
						|
            yield (self, func, args, kwds)
 | 
						|
        async with woohoo(self=11, func=22, args=33, kwds=44) as target:
 | 
						|
            self.assertEqual(target, (11, 22, 33, 44))
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    unittest.main()
 |