| 
									
										
										
										
											2020-05-18 23:03:28 -04:00
										 |  |  | """Tests for asyncio/threads.py""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import asyncio | 
					
						
							|  |  |  | import unittest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-21 01:20:43 -04:00
										 |  |  | from contextvars import ContextVar | 
					
						
							| 
									
										
										
										
											2020-05-18 23:03:28 -04:00
										 |  |  | from unittest import mock | 
					
						
							|  |  |  | from test.test_asyncio import utils as test_utils | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def tearDownModule(): | 
					
						
							|  |  |  |     asyncio.set_event_loop_policy(None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ToThreadTests(test_utils.TestCase): | 
					
						
							|  |  |  |     def setUp(self): | 
					
						
							|  |  |  |         super().setUp() | 
					
						
							|  |  |  |         self.loop = asyncio.new_event_loop() | 
					
						
							|  |  |  |         asyncio.set_event_loop(self.loop) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def tearDown(self): | 
					
						
							|  |  |  |         self.loop.run_until_complete( | 
					
						
							|  |  |  |             self.loop.shutdown_default_executor()) | 
					
						
							|  |  |  |         self.loop.close() | 
					
						
							|  |  |  |         asyncio.set_event_loop(None) | 
					
						
							|  |  |  |         self.loop = None | 
					
						
							|  |  |  |         super().tearDown() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_to_thread(self): | 
					
						
							|  |  |  |         async def main(): | 
					
						
							|  |  |  |             return await asyncio.to_thread(sum, [40, 2]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         result = self.loop.run_until_complete(main()) | 
					
						
							|  |  |  |         self.assertEqual(result, 42) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_to_thread_exception(self): | 
					
						
							|  |  |  |         def raise_runtime(): | 
					
						
							|  |  |  |             raise RuntimeError("test") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def main(): | 
					
						
							|  |  |  |             await asyncio.to_thread(raise_runtime) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with self.assertRaisesRegex(RuntimeError, "test"): | 
					
						
							|  |  |  |             self.loop.run_until_complete(main()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_to_thread_once(self): | 
					
						
							|  |  |  |         func = mock.Mock() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def main(): | 
					
						
							|  |  |  |             await asyncio.to_thread(func) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.loop.run_until_complete(main()) | 
					
						
							|  |  |  |         func.assert_called_once() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_to_thread_concurrent(self): | 
					
						
							|  |  |  |         func = mock.Mock() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def main(): | 
					
						
							|  |  |  |             futs = [] | 
					
						
							|  |  |  |             for _ in range(10): | 
					
						
							|  |  |  |                 fut = asyncio.to_thread(func) | 
					
						
							|  |  |  |                 futs.append(fut) | 
					
						
							|  |  |  |             await asyncio.gather(*futs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.loop.run_until_complete(main()) | 
					
						
							|  |  |  |         self.assertEqual(func.call_count, 10) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_to_thread_args_kwargs(self): | 
					
						
							|  |  |  |         # Unlike run_in_executor(), to_thread() should directly accept kwargs. | 
					
						
							|  |  |  |         func = mock.Mock() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def main(): | 
					
						
							|  |  |  |             await asyncio.to_thread(func, 'test', something=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.loop.run_until_complete(main()) | 
					
						
							|  |  |  |         func.assert_called_once_with('test', something=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-21 01:20:43 -04:00
										 |  |  |     def test_to_thread_contextvars(self): | 
					
						
							|  |  |  |         test_ctx = ContextVar('test_ctx') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def get_ctx(): | 
					
						
							|  |  |  |             return test_ctx.get() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         async def main(): | 
					
						
							|  |  |  |             test_ctx.set('parrot') | 
					
						
							|  |  |  |             return await asyncio.to_thread(get_ctx) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         result = self.loop.run_until_complete(main()) | 
					
						
							|  |  |  |         self.assertEqual(result, 'parrot') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-18 23:03:28 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     unittest.main() |