mirror of
				https://github.com/python/cpython.git
				synced 2025-10-30 13:11:29 +00:00 
			
		
		
		
	
		
			
	
	
		
			213 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			213 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import asyncio | ||
|  | from contextlib import asynccontextmanager | ||
|  | import functools | ||
|  | from test import support | ||
|  | import unittest | ||
|  | 
 | ||
|  | 
 | ||
|  | def _async_test(func): | ||
|  |     """Decorator to turn an async function into a test case.""" | ||
|  |     @functools.wraps(func) | ||
|  |     def wrapper(*args, **kwargs): | ||
|  |         coro = func(*args, **kwargs) | ||
|  |         loop = asyncio.new_event_loop() | ||
|  |         asyncio.set_event_loop(loop) | ||
|  |         try: | ||
|  |             return loop.run_until_complete(coro) | ||
|  |         finally: | ||
|  |             loop.close() | ||
|  |             asyncio.set_event_loop(None) | ||
|  |     return wrapper | ||
|  | 
 | ||
|  | 
 | ||
|  | class AsyncContextManagerTestCase(unittest.TestCase): | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_plain(self): | ||
|  |         state = [] | ||
|  |         @asynccontextmanager | ||
|  |         async def woohoo(): | ||
|  |             state.append(1) | ||
|  |             yield 42 | ||
|  |             state.append(999) | ||
|  |         async with woohoo() as x: | ||
|  |             self.assertEqual(state, [1]) | ||
|  |             self.assertEqual(x, 42) | ||
|  |             state.append(x) | ||
|  |         self.assertEqual(state, [1, 42, 999]) | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_finally(self): | ||
|  |         state = [] | ||
|  |         @asynccontextmanager | ||
|  |         async def woohoo(): | ||
|  |             state.append(1) | ||
|  |             try: | ||
|  |                 yield 42 | ||
|  |             finally: | ||
|  |                 state.append(999) | ||
|  |         with self.assertRaises(ZeroDivisionError): | ||
|  |             async with woohoo() as x: | ||
|  |                 self.assertEqual(state, [1]) | ||
|  |                 self.assertEqual(x, 42) | ||
|  |                 state.append(x) | ||
|  |                 raise ZeroDivisionError() | ||
|  |         self.assertEqual(state, [1, 42, 999]) | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_no_reraise(self): | ||
|  |         @asynccontextmanager | ||
|  |         async def whee(): | ||
|  |             yield | ||
|  |         ctx = whee() | ||
|  |         await ctx.__aenter__() | ||
|  |         # Calling __aexit__ should not result in an exception | ||
|  |         self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_trap_yield_after_throw(self): | ||
|  |         @asynccontextmanager | ||
|  |         async def whoo(): | ||
|  |             try: | ||
|  |                 yield | ||
|  |             except: | ||
|  |                 yield | ||
|  |         ctx = whoo() | ||
|  |         await ctx.__aenter__() | ||
|  |         with self.assertRaises(RuntimeError): | ||
|  |             await ctx.__aexit__(TypeError, TypeError('foo'), None) | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_trap_no_yield(self): | ||
|  |         @asynccontextmanager | ||
|  |         async def whoo(): | ||
|  |             if False: | ||
|  |                 yield | ||
|  |         ctx = whoo() | ||
|  |         with self.assertRaises(RuntimeError): | ||
|  |             await ctx.__aenter__() | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_trap_second_yield(self): | ||
|  |         @asynccontextmanager | ||
|  |         async def whoo(): | ||
|  |             yield | ||
|  |             yield | ||
|  |         ctx = whoo() | ||
|  |         await ctx.__aenter__() | ||
|  |         with self.assertRaises(RuntimeError): | ||
|  |             await ctx.__aexit__(None, None, None) | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_non_normalised(self): | ||
|  |         @asynccontextmanager | ||
|  |         async def whoo(): | ||
|  |             try: | ||
|  |                 yield | ||
|  |             except RuntimeError: | ||
|  |                 raise SyntaxError | ||
|  | 
 | ||
|  |         ctx = whoo() | ||
|  |         await ctx.__aenter__() | ||
|  |         with self.assertRaises(SyntaxError): | ||
|  |             await ctx.__aexit__(RuntimeError, None, None) | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_except(self): | ||
|  |         state = [] | ||
|  |         @asynccontextmanager | ||
|  |         async def woohoo(): | ||
|  |             state.append(1) | ||
|  |             try: | ||
|  |                 yield 42 | ||
|  |             except ZeroDivisionError as e: | ||
|  |                 state.append(e.args[0]) | ||
|  |                 self.assertEqual(state, [1, 42, 999]) | ||
|  |         async with woohoo() as x: | ||
|  |             self.assertEqual(state, [1]) | ||
|  |             self.assertEqual(x, 42) | ||
|  |             state.append(x) | ||
|  |             raise ZeroDivisionError(999) | ||
|  |         self.assertEqual(state, [1, 42, 999]) | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_except_stopiter(self): | ||
|  |         @asynccontextmanager | ||
|  |         async def woohoo(): | ||
|  |             yield | ||
|  | 
 | ||
|  |         for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): | ||
|  |             with self.subTest(type=type(stop_exc)): | ||
|  |                 try: | ||
|  |                     async with woohoo(): | ||
|  |                         raise stop_exc | ||
|  |                 except Exception as ex: | ||
|  |                     self.assertIs(ex, stop_exc) | ||
|  |                 else: | ||
|  |                     self.fail(f'{stop_exc} was suppressed') | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_contextmanager_wrap_runtimeerror(self): | ||
|  |         @asynccontextmanager | ||
|  |         async def woohoo(): | ||
|  |             try: | ||
|  |                 yield | ||
|  |             except Exception as exc: | ||
|  |                 raise RuntimeError(f'caught {exc}') from exc | ||
|  | 
 | ||
|  |         with self.assertRaises(RuntimeError): | ||
|  |             async with woohoo(): | ||
|  |                 1 / 0 | ||
|  | 
 | ||
|  |         # If the context manager wrapped StopAsyncIteration in a RuntimeError, | ||
|  |         # we also unwrap it, because we can't tell whether the wrapping was | ||
|  |         # done by the generator machinery or by the generator itself. | ||
|  |         with self.assertRaises(StopAsyncIteration): | ||
|  |             async with woohoo(): | ||
|  |                 raise StopAsyncIteration | ||
|  | 
 | ||
|  |     def _create_contextmanager_attribs(self): | ||
|  |         def attribs(**kw): | ||
|  |             def decorate(func): | ||
|  |                 for k,v in kw.items(): | ||
|  |                     setattr(func,k,v) | ||
|  |                 return func | ||
|  |             return decorate | ||
|  |         @asynccontextmanager | ||
|  |         @attribs(foo='bar') | ||
|  |         async def baz(spam): | ||
|  |             """Whee!""" | ||
|  |             yield | ||
|  |         return baz | ||
|  | 
 | ||
|  |     def test_contextmanager_attribs(self): | ||
|  |         baz = self._create_contextmanager_attribs() | ||
|  |         self.assertEqual(baz.__name__,'baz') | ||
|  |         self.assertEqual(baz.foo, 'bar') | ||
|  | 
 | ||
|  |     @support.requires_docstrings | ||
|  |     def test_contextmanager_doc_attrib(self): | ||
|  |         baz = self._create_contextmanager_attribs() | ||
|  |         self.assertEqual(baz.__doc__, "Whee!") | ||
|  | 
 | ||
|  |     @support.requires_docstrings | ||
|  |     @_async_test | ||
|  |     async def test_instance_docstring_given_cm_docstring(self): | ||
|  |         baz = self._create_contextmanager_attribs()(None) | ||
|  |         self.assertEqual(baz.__doc__, "Whee!") | ||
|  |         async with baz: | ||
|  |             pass  # suppress warning | ||
|  | 
 | ||
|  |     @_async_test | ||
|  |     async def test_keywords(self): | ||
|  |         # Ensure no keyword arguments are inhibited | ||
|  |         @asynccontextmanager | ||
|  |         async def woohoo(self, func, args, kwds): | ||
|  |             yield (self, func, args, kwds) | ||
|  |         async with woohoo(self=11, func=22, args=33, kwds=44) as target: | ||
|  |             self.assertEqual(target, (11, 22, 33, 44)) | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == '__main__': | ||
|  |     unittest.main() |