mirror of
				https://github.com/python/cpython.git
				synced 2025-11-01 06:01:29 +00:00 
			
		
		
		
	Issue 24342: Let wrapper set by sys.set_coroutine_wrapper fail gracefully
This commit is contained in:
		
							parent
							
								
									231d90609b
								
							
						
					
					
						commit
						aab3c4a211
					
				
					 6 changed files with 68 additions and 10 deletions
				
			
		|  | @ -1085,6 +1085,20 @@ always available. | |||
|    If called twice, the new wrapper replaces the previous one.  The function | ||||
|    is thread-specific. | ||||
| 
 | ||||
|    The *wrapper* callable cannot define new coroutines directly or indirectly:: | ||||
| 
 | ||||
|         def wrapper(coro): | ||||
|             async def wrap(coro): | ||||
|                 return await coro | ||||
|             return wrap(coro) | ||||
|         sys.set_coroutine_wrapper(wrapper) | ||||
| 
 | ||||
|         async def foo(): pass | ||||
| 
 | ||||
|         # The following line will fail with a RuntimeError, because | ||||
|         # `wrapper` creates a `wrap(coro)` coroutine: | ||||
|         foo() | ||||
| 
 | ||||
|    See also :func:`get_coroutine_wrapper`. | ||||
| 
 | ||||
|    .. versionadded:: 3.5 | ||||
|  |  | |||
|  | @ -23,8 +23,9 @@ PyAPI_FUNC(PyObject *) PyEval_CallMethod(PyObject *obj, | |||
| #ifndef Py_LIMITED_API | ||||
| PyAPI_FUNC(void) PyEval_SetProfile(Py_tracefunc, PyObject *); | ||||
| PyAPI_FUNC(void) PyEval_SetTrace(Py_tracefunc, PyObject *); | ||||
| PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *wrapper); | ||||
| PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *); | ||||
| PyAPI_FUNC(PyObject *) _PyEval_GetCoroutineWrapper(void); | ||||
| PyAPI_FUNC(PyObject *) _PyEval_ApplyCoroutineWrapper(PyObject *); | ||||
| #endif | ||||
| 
 | ||||
| struct _frame; /* Avoid including frameobject.h */ | ||||
|  |  | |||
|  | @ -135,6 +135,7 @@ typedef struct _ts { | |||
|     void *on_delete_data; | ||||
| 
 | ||||
|     PyObject *coroutine_wrapper; | ||||
|     int in_coroutine_wrapper; | ||||
| 
 | ||||
|     /* XXX signal handlers should also be here */ | ||||
| 
 | ||||
|  |  | |||
|  | @ -995,6 +995,26 @@ def test_set_wrapper_2(self): | |||
|             sys.set_coroutine_wrapper(1) | ||||
|         self.assertIsNone(sys.get_coroutine_wrapper()) | ||||
| 
 | ||||
|     def test_set_wrapper_3(self): | ||||
|         async def foo(): | ||||
|             return 'spam' | ||||
| 
 | ||||
|         def wrapper(coro): | ||||
|             async def wrap(coro): | ||||
|                 return await coro | ||||
|             return wrap(coro) | ||||
| 
 | ||||
|         sys.set_coroutine_wrapper(wrapper) | ||||
|         try: | ||||
|             with self.assertRaisesRegex( | ||||
|                 RuntimeError, | ||||
|                 "coroutine wrapper.*\.wrapper at 0x.*attempted to " | ||||
|                 "recursively wrap <coroutine.*\.wrap"): | ||||
| 
 | ||||
|                 foo() | ||||
|         finally: | ||||
|             sys.set_coroutine_wrapper(None) | ||||
| 
 | ||||
| 
 | ||||
| class CAPITest(unittest.TestCase): | ||||
| 
 | ||||
|  |  | |||
|  | @ -3921,7 +3921,6 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals, | |||
| 
 | ||||
|     if (co->co_flags & CO_GENERATOR) { | ||||
|         PyObject *gen; | ||||
|         PyObject *coroutine_wrapper; | ||||
| 
 | ||||
|         /* Don't need to keep the reference to f_back, it will be set
 | ||||
|          * when the generator is resumed. */ | ||||
|  | @ -3935,14 +3934,9 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals, | |||
|         if (gen == NULL) | ||||
|             return NULL; | ||||
| 
 | ||||
|         if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) { | ||||
|             coroutine_wrapper = _PyEval_GetCoroutineWrapper(); | ||||
|             if (coroutine_wrapper != NULL) { | ||||
|                 PyObject *wrapped = | ||||
|                             PyObject_CallFunction(coroutine_wrapper, "N", gen); | ||||
|                 gen = wrapped; | ||||
|             } | ||||
|         } | ||||
|         if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) | ||||
|             return _PyEval_ApplyCoroutineWrapper(gen); | ||||
| 
 | ||||
|         return gen; | ||||
|     } | ||||
| 
 | ||||
|  | @ -4407,6 +4401,33 @@ _PyEval_GetCoroutineWrapper(void) | |||
|     return tstate->coroutine_wrapper; | ||||
| } | ||||
| 
 | ||||
| PyObject * | ||||
| _PyEval_ApplyCoroutineWrapper(PyObject *gen) | ||||
| { | ||||
|     PyObject *wrapped; | ||||
|     PyThreadState *tstate = PyThreadState_GET(); | ||||
|     PyObject *wrapper = tstate->coroutine_wrapper; | ||||
| 
 | ||||
|     if (tstate->in_coroutine_wrapper) { | ||||
|         assert(wrapper != NULL); | ||||
|         PyErr_Format(PyExc_RuntimeError, | ||||
|                      "coroutine wrapper %.150R attempted " | ||||
|                      "to recursively wrap %.150R", | ||||
|                      wrapper, | ||||
|                      gen); | ||||
|         return NULL; | ||||
|     } | ||||
| 
 | ||||
|     if (wrapper == NULL) { | ||||
|         return gen; | ||||
|     } | ||||
| 
 | ||||
|     tstate->in_coroutine_wrapper = 1; | ||||
|     wrapped = PyObject_CallFunction(wrapper, "N", gen); | ||||
|     tstate->in_coroutine_wrapper = 0; | ||||
|     return wrapped; | ||||
| } | ||||
| 
 | ||||
| PyObject * | ||||
| PyEval_GetBuiltins(void) | ||||
| { | ||||
|  |  | |||
|  | @ -213,6 +213,7 @@ new_threadstate(PyInterpreterState *interp, int init) | |||
|         tstate->on_delete_data = NULL; | ||||
| 
 | ||||
|         tstate->coroutine_wrapper = NULL; | ||||
|         tstate->in_coroutine_wrapper = 0; | ||||
| 
 | ||||
|         if (init) | ||||
|             _PyThreadState_Init(tstate); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Yury Selivanov
						Yury Selivanov