mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 21:51:50 +00:00 
			
		
		
		
	bpo-42609: Check recursion depth in the AST validator and optimizer (GH-23744)
This commit is contained in:
		
							parent
							
								
									b5adc8a7e5
								
							
						
					
					
						commit
						face87c94e
					
				
					 5 changed files with 309 additions and 149 deletions
				
			
		|  | @ -28,6 +28,9 @@ extern PyObject* _Py_Mangle(PyObject *p, PyObject *name); | |||
| typedef struct { | ||||
|     int optimize; | ||||
|     int ff_features; | ||||
| 
 | ||||
|     int recursion_depth;            /* current recursion depth */ | ||||
|     int recursion_limit;            /* recursion limit */ | ||||
| } _PyASTOptimizeState; | ||||
| 
 | ||||
| extern int _PyAST_Optimize( | ||||
|  |  | |||
|  | @ -543,21 +543,26 @@ def test_compiler_recursion_limit(self): | |||
|         # XXX (ncoghlan): duplicating the scaling factor here is a little | ||||
|         # ugly. Perhaps it should be exposed somewhere... | ||||
|         fail_depth = sys.getrecursionlimit() * 3 | ||||
|         crash_depth = sys.getrecursionlimit() * 300 | ||||
|         success_depth = int(fail_depth * 0.75) | ||||
| 
 | ||||
|         def check_limit(prefix, repeated): | ||||
|         def check_limit(prefix, repeated, mode="single"): | ||||
|             expect_ok = prefix + repeated * success_depth | ||||
|             self.compile_single(expect_ok) | ||||
|             broken = prefix + repeated * fail_depth | ||||
|             details = "Compiling ({!r} + {!r} * {})".format( | ||||
|                          prefix, repeated, fail_depth) | ||||
|             with self.assertRaises(RecursionError, msg=details): | ||||
|                 self.compile_single(broken) | ||||
|             compile(expect_ok, '<test>', mode) | ||||
|             for depth in (fail_depth, crash_depth): | ||||
|                 broken = prefix + repeated * depth | ||||
|                 details = "Compiling ({!r} + {!r} * {})".format( | ||||
|                             prefix, repeated, depth) | ||||
|                 with self.assertRaises(RecursionError, msg=details): | ||||
|                     compile(broken, '<test>', mode) | ||||
| 
 | ||||
|         check_limit("a", "()") | ||||
|         check_limit("a", ".b") | ||||
|         check_limit("a", "[0]") | ||||
|         check_limit("a", "*a") | ||||
|         # XXX Crashes in the parser. | ||||
|         # check_limit("a", " if a else a") | ||||
|         # check_limit("if a: pass", "\nelif a: pass", mode="exec") | ||||
| 
 | ||||
|     def test_null_terminated(self): | ||||
|         # The source code is null-terminated internally, but bytes-like | ||||
|  |  | |||
|  | @ -0,0 +1,3 @@ | |||
| Prevented crashes in the AST validator and optimizer when compiling some | ||||
| absurdly long expressions like ``"+0"*1000000``. :exc:`RecursionError` is | ||||
| now raised instead. | ||||
							
								
								
									
										391
									
								
								Python/ast.c
									
										
									
									
									
								
							
							
						
						
									
										391
									
								
								Python/ast.c
									
										
									
									
									
								
							|  | @ -4,14 +4,20 @@ | |||
|  */ | ||||
| #include "Python.h" | ||||
| #include "pycore_ast.h"           // asdl_stmt_seq | ||||
| #include "pycore_pystate.h"       // _PyThreadState_GET() | ||||
| 
 | ||||
| #include <assert.h> | ||||
| 
 | ||||
| static int validate_stmts(asdl_stmt_seq *); | ||||
| static int validate_exprs(asdl_expr_seq*, expr_context_ty, int); | ||||
| struct validator { | ||||
|     int recursion_depth;            /* current recursion depth */ | ||||
|     int recursion_limit;            /* recursion limit */ | ||||
| }; | ||||
| 
 | ||||
| static int validate_stmts(struct validator *, asdl_stmt_seq *); | ||||
| static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, int); | ||||
| static int _validate_nonempty_seq(asdl_seq *, const char *, const char *); | ||||
| static int validate_stmt(stmt_ty); | ||||
| static int validate_expr(expr_ty, expr_context_ty); | ||||
| static int validate_stmt(struct validator *, stmt_ty); | ||||
| static int validate_expr(struct validator *, expr_ty, expr_context_ty); | ||||
| 
 | ||||
| static int | ||||
| validate_name(PyObject *name) | ||||
|  | @ -33,7 +39,7 @@ validate_name(PyObject *name) | |||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_comprehension(asdl_comprehension_seq *gens) | ||||
| validate_comprehension(struct validator *state, asdl_comprehension_seq *gens) | ||||
| { | ||||
|     Py_ssize_t i; | ||||
|     if (!asdl_seq_LEN(gens)) { | ||||
|  | @ -42,31 +48,31 @@ validate_comprehension(asdl_comprehension_seq *gens) | |||
|     } | ||||
|     for (i = 0; i < asdl_seq_LEN(gens); i++) { | ||||
|         comprehension_ty comp = asdl_seq_GET(gens, i); | ||||
|         if (!validate_expr(comp->target, Store) || | ||||
|             !validate_expr(comp->iter, Load) || | ||||
|             !validate_exprs(comp->ifs, Load, 0)) | ||||
|         if (!validate_expr(state, comp->target, Store) || | ||||
|             !validate_expr(state, comp->iter, Load) || | ||||
|             !validate_exprs(state, comp->ifs, Load, 0)) | ||||
|             return 0; | ||||
|     } | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_keywords(asdl_keyword_seq *keywords) | ||||
| validate_keywords(struct validator *state, asdl_keyword_seq *keywords) | ||||
| { | ||||
|     Py_ssize_t i; | ||||
|     for (i = 0; i < asdl_seq_LEN(keywords); i++) | ||||
|         if (!validate_expr((asdl_seq_GET(keywords, i))->value, Load)) | ||||
|         if (!validate_expr(state, (asdl_seq_GET(keywords, i))->value, Load)) | ||||
|             return 0; | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_args(asdl_arg_seq *args) | ||||
| validate_args(struct validator *state, asdl_arg_seq *args) | ||||
| { | ||||
|     Py_ssize_t i; | ||||
|     for (i = 0; i < asdl_seq_LEN(args); i++) { | ||||
|         arg_ty arg = asdl_seq_GET(args, i); | ||||
|         if (arg->annotation && !validate_expr(arg->annotation, Load)) | ||||
|         if (arg->annotation && !validate_expr(state, arg->annotation, Load)) | ||||
|             return 0; | ||||
|     } | ||||
|     return 1; | ||||
|  | @ -88,19 +94,19 @@ expr_context_name(expr_context_ty ctx) | |||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_arguments(arguments_ty args) | ||||
| validate_arguments(struct validator *state, arguments_ty args) | ||||
| { | ||||
|     if (!validate_args(args->posonlyargs) || !validate_args(args->args)) { | ||||
|     if (!validate_args(state, args->posonlyargs) || !validate_args(state, args->args)) { | ||||
|         return 0; | ||||
|     } | ||||
|     if (args->vararg && args->vararg->annotation | ||||
|         && !validate_expr(args->vararg->annotation, Load)) { | ||||
|         && !validate_expr(state, args->vararg->annotation, Load)) { | ||||
|             return 0; | ||||
|     } | ||||
|     if (!validate_args(args->kwonlyargs)) | ||||
|     if (!validate_args(state, args->kwonlyargs)) | ||||
|         return 0; | ||||
|     if (args->kwarg && args->kwarg->annotation | ||||
|         && !validate_expr(args->kwarg->annotation, Load)) { | ||||
|         && !validate_expr(state, args->kwarg->annotation, Load)) { | ||||
|             return 0; | ||||
|     } | ||||
|     if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) { | ||||
|  | @ -112,11 +118,11 @@ validate_arguments(arguments_ty args) | |||
|                         "kw_defaults on arguments"); | ||||
|         return 0; | ||||
|     } | ||||
|     return validate_exprs(args->defaults, Load, 0) && validate_exprs(args->kw_defaults, Load, 1); | ||||
|     return validate_exprs(state, args->defaults, Load, 0) && validate_exprs(state, args->kw_defaults, Load, 1); | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_constant(PyObject *value) | ||||
| validate_constant(struct validator *state, PyObject *value) | ||||
| { | ||||
|     if (value == Py_None || value == Py_Ellipsis) | ||||
|         return 1; | ||||
|  | @ -130,9 +136,13 @@ validate_constant(PyObject *value) | |||
|         return 1; | ||||
| 
 | ||||
|     if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) { | ||||
|         PyObject *it; | ||||
|         if (++state->recursion_depth > state->recursion_limit) { | ||||
|             PyErr_SetString(PyExc_RecursionError, | ||||
|                             "maximum recursion depth exceeded during compilation"); | ||||
|             return 0; | ||||
|         } | ||||
| 
 | ||||
|         it = PyObject_GetIter(value); | ||||
|         PyObject *it = PyObject_GetIter(value); | ||||
|         if (it == NULL) | ||||
|             return 0; | ||||
| 
 | ||||
|  | @ -146,7 +156,7 @@ validate_constant(PyObject *value) | |||
|                 break; | ||||
|             } | ||||
| 
 | ||||
|             if (!validate_constant(item)) { | ||||
|             if (!validate_constant(state, item)) { | ||||
|                 Py_DECREF(it); | ||||
|                 Py_DECREF(item); | ||||
|                 return 0; | ||||
|  | @ -155,6 +165,7 @@ validate_constant(PyObject *value) | |||
|         } | ||||
| 
 | ||||
|         Py_DECREF(it); | ||||
|         --state->recursion_depth; | ||||
|         return 1; | ||||
|     } | ||||
| 
 | ||||
|  | @ -167,8 +178,14 @@ validate_constant(PyObject *value) | |||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_expr(expr_ty exp, expr_context_ty ctx) | ||||
| validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx) | ||||
| { | ||||
|     int ret; | ||||
|     if (++state->recursion_depth > state->recursion_limit) { | ||||
|         PyErr_SetString(PyExc_RecursionError, | ||||
|                         "maximum recursion depth exceeded during compilation"); | ||||
|         return 0; | ||||
|     } | ||||
|     int check_ctx = 1; | ||||
|     expr_context_ty actual_ctx; | ||||
| 
 | ||||
|  | @ -218,19 +235,24 @@ validate_expr(expr_ty exp, expr_context_ty ctx) | |||
|             PyErr_SetString(PyExc_ValueError, "BoolOp with less than 2 values"); | ||||
|             return 0; | ||||
|         } | ||||
|         return validate_exprs(exp->v.BoolOp.values, Load, 0); | ||||
|         ret = validate_exprs(state, exp->v.BoolOp.values, Load, 0); | ||||
|         break; | ||||
|     case BinOp_kind: | ||||
|         return validate_expr(exp->v.BinOp.left, Load) && | ||||
|             validate_expr(exp->v.BinOp.right, Load); | ||||
|         ret = validate_expr(state, exp->v.BinOp.left, Load) && | ||||
|             validate_expr(state, exp->v.BinOp.right, Load); | ||||
|         break; | ||||
|     case UnaryOp_kind: | ||||
|         return validate_expr(exp->v.UnaryOp.operand, Load); | ||||
|         ret = validate_expr(state, exp->v.UnaryOp.operand, Load); | ||||
|         break; | ||||
|     case Lambda_kind: | ||||
|         return validate_arguments(exp->v.Lambda.args) && | ||||
|             validate_expr(exp->v.Lambda.body, Load); | ||||
|         ret = validate_arguments(state, exp->v.Lambda.args) && | ||||
|             validate_expr(state, exp->v.Lambda.body, Load); | ||||
|         break; | ||||
|     case IfExp_kind: | ||||
|         return validate_expr(exp->v.IfExp.test, Load) && | ||||
|             validate_expr(exp->v.IfExp.body, Load) && | ||||
|             validate_expr(exp->v.IfExp.orelse, Load); | ||||
|         ret = validate_expr(state, exp->v.IfExp.test, Load) && | ||||
|             validate_expr(state, exp->v.IfExp.body, Load) && | ||||
|             validate_expr(state, exp->v.IfExp.orelse, Load); | ||||
|         break; | ||||
|     case Dict_kind: | ||||
|         if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) { | ||||
|             PyErr_SetString(PyExc_ValueError, | ||||
|  | @ -239,28 +261,35 @@ validate_expr(expr_ty exp, expr_context_ty ctx) | |||
|         } | ||||
|         /* null_ok=1 for keys expressions to allow dict unpacking to work in
 | ||||
|            dict literals, i.e. ``{**{a:b}}`` */ | ||||
|         return validate_exprs(exp->v.Dict.keys, Load, /*null_ok=*/ 1) && | ||||
|             validate_exprs(exp->v.Dict.values, Load, /*null_ok=*/ 0); | ||||
|         ret = validate_exprs(state, exp->v.Dict.keys, Load, /*null_ok=*/ 1) && | ||||
|             validate_exprs(state, exp->v.Dict.values, Load, /*null_ok=*/ 0); | ||||
|         break; | ||||
|     case Set_kind: | ||||
|         return validate_exprs(exp->v.Set.elts, Load, 0); | ||||
|         ret = validate_exprs(state, exp->v.Set.elts, Load, 0); | ||||
|         break; | ||||
| #define COMP(NAME) \ | ||||
|         case NAME ## _kind: \ | ||||
|             return validate_comprehension(exp->v.NAME.generators) && \ | ||||
|                 validate_expr(exp->v.NAME.elt, Load); | ||||
|             ret = validate_comprehension(state, exp->v.NAME.generators) && \ | ||||
|                 validate_expr(state, exp->v.NAME.elt, Load); \ | ||||
|             break; | ||||
|     COMP(ListComp) | ||||
|     COMP(SetComp) | ||||
|     COMP(GeneratorExp) | ||||
| #undef COMP | ||||
|     case DictComp_kind: | ||||
|         return validate_comprehension(exp->v.DictComp.generators) && | ||||
|             validate_expr(exp->v.DictComp.key, Load) && | ||||
|             validate_expr(exp->v.DictComp.value, Load); | ||||
|         ret = validate_comprehension(state, exp->v.DictComp.generators) && | ||||
|             validate_expr(state, exp->v.DictComp.key, Load) && | ||||
|             validate_expr(state, exp->v.DictComp.value, Load); | ||||
|         break; | ||||
|     case Yield_kind: | ||||
|         return !exp->v.Yield.value || validate_expr(exp->v.Yield.value, Load); | ||||
|         ret = !exp->v.Yield.value || validate_expr(state, exp->v.Yield.value, Load); | ||||
|         break; | ||||
|     case YieldFrom_kind: | ||||
|         return validate_expr(exp->v.YieldFrom.value, Load); | ||||
|         ret = validate_expr(state, exp->v.YieldFrom.value, Load); | ||||
|         break; | ||||
|     case Await_kind: | ||||
|         return validate_expr(exp->v.Await.value, Load); | ||||
|         ret = validate_expr(state, exp->v.Await.value, Load); | ||||
|         break; | ||||
|     case Compare_kind: | ||||
|         if (!asdl_seq_LEN(exp->v.Compare.comparators)) { | ||||
|             PyErr_SetString(PyExc_ValueError, "Compare with no comparators"); | ||||
|  | @ -272,42 +301,56 @@ validate_expr(expr_ty exp, expr_context_ty ctx) | |||
|                             "of comparators and operands"); | ||||
|             return 0; | ||||
|         } | ||||
|         return validate_exprs(exp->v.Compare.comparators, Load, 0) && | ||||
|             validate_expr(exp->v.Compare.left, Load); | ||||
|         ret = validate_exprs(state, exp->v.Compare.comparators, Load, 0) && | ||||
|             validate_expr(state, exp->v.Compare.left, Load); | ||||
|         break; | ||||
|     case Call_kind: | ||||
|         return validate_expr(exp->v.Call.func, Load) && | ||||
|             validate_exprs(exp->v.Call.args, Load, 0) && | ||||
|             validate_keywords(exp->v.Call.keywords); | ||||
|         ret = validate_expr(state, exp->v.Call.func, Load) && | ||||
|             validate_exprs(state, exp->v.Call.args, Load, 0) && | ||||
|             validate_keywords(state, exp->v.Call.keywords); | ||||
|         break; | ||||
|     case Constant_kind: | ||||
|         if (!validate_constant(exp->v.Constant.value)) { | ||||
|         if (!validate_constant(state, exp->v.Constant.value)) { | ||||
|             return 0; | ||||
|         } | ||||
|         return 1; | ||||
|         ret = 1; | ||||
|         break; | ||||
|     case JoinedStr_kind: | ||||
|         return validate_exprs(exp->v.JoinedStr.values, Load, 0); | ||||
|         ret = validate_exprs(state, exp->v.JoinedStr.values, Load, 0); | ||||
|         break; | ||||
|     case FormattedValue_kind: | ||||
|         if (validate_expr(exp->v.FormattedValue.value, Load) == 0) | ||||
|         if (validate_expr(state, exp->v.FormattedValue.value, Load) == 0) | ||||
|             return 0; | ||||
|         if (exp->v.FormattedValue.format_spec) | ||||
|             return validate_expr(exp->v.FormattedValue.format_spec, Load); | ||||
|         return 1; | ||||
|         if (exp->v.FormattedValue.format_spec) { | ||||
|             ret = validate_expr(state, exp->v.FormattedValue.format_spec, Load); | ||||
|             break; | ||||
|         } | ||||
|         ret = 1; | ||||
|         break; | ||||
|     case Attribute_kind: | ||||
|         return validate_expr(exp->v.Attribute.value, Load); | ||||
|         ret = validate_expr(state, exp->v.Attribute.value, Load); | ||||
|         break; | ||||
|     case Subscript_kind: | ||||
|         return validate_expr(exp->v.Subscript.slice, Load) && | ||||
|             validate_expr(exp->v.Subscript.value, Load); | ||||
|         ret = validate_expr(state, exp->v.Subscript.slice, Load) && | ||||
|             validate_expr(state, exp->v.Subscript.value, Load); | ||||
|         break; | ||||
|     case Starred_kind: | ||||
|         return validate_expr(exp->v.Starred.value, ctx); | ||||
|         ret = validate_expr(state, exp->v.Starred.value, ctx); | ||||
|         break; | ||||
|     case Slice_kind: | ||||
|         return (!exp->v.Slice.lower || validate_expr(exp->v.Slice.lower, Load)) && | ||||
|             (!exp->v.Slice.upper || validate_expr(exp->v.Slice.upper, Load)) && | ||||
|             (!exp->v.Slice.step || validate_expr(exp->v.Slice.step, Load)); | ||||
|         ret = (!exp->v.Slice.lower || validate_expr(state, exp->v.Slice.lower, Load)) && | ||||
|             (!exp->v.Slice.upper || validate_expr(state, exp->v.Slice.upper, Load)) && | ||||
|             (!exp->v.Slice.step || validate_expr(state, exp->v.Slice.step, Load)); | ||||
|         break; | ||||
|     case List_kind: | ||||
|         return validate_exprs(exp->v.List.elts, ctx, 0); | ||||
|         ret = validate_exprs(state, exp->v.List.elts, ctx, 0); | ||||
|         break; | ||||
|     case Tuple_kind: | ||||
|         return validate_exprs(exp->v.Tuple.elts, ctx, 0); | ||||
|         ret = validate_exprs(state, exp->v.Tuple.elts, ctx, 0); | ||||
|         break; | ||||
|     case NamedExpr_kind: | ||||
|         return validate_expr(exp->v.NamedExpr.value, Load); | ||||
|         ret = validate_expr(state, exp->v.NamedExpr.value, Load); | ||||
|         break; | ||||
|     case MatchAs_kind: | ||||
|         PyErr_SetString(PyExc_ValueError, | ||||
|                         "MatchAs is only valid in match_case patterns"); | ||||
|  | @ -318,10 +361,14 @@ validate_expr(expr_ty exp, expr_context_ty ctx) | |||
|         return 0; | ||||
|     /* This last case doesn't have any checking. */ | ||||
|     case Name_kind: | ||||
|         return 1; | ||||
|         ret = 1; | ||||
|         break; | ||||
|     default: | ||||
|         PyErr_SetString(PyExc_SystemError, "unexpected expression"); | ||||
|         return 0; | ||||
|     } | ||||
|     PyErr_SetString(PyExc_SystemError, "unexpected expression"); | ||||
|     return 0; | ||||
|     state->recursion_depth--; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| static int | ||||
|  | @ -342,44 +389,56 @@ _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner) | |||
| #define validate_nonempty_seq(seq, what, owner) _validate_nonempty_seq((asdl_seq*)seq, what, owner) | ||||
| 
 | ||||
| static int | ||||
| validate_assignlist(asdl_expr_seq *targets, expr_context_ty ctx) | ||||
| validate_assignlist(struct validator *state, asdl_expr_seq *targets, expr_context_ty ctx) | ||||
| { | ||||
|     return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") && | ||||
|         validate_exprs(targets, ctx, 0); | ||||
|         validate_exprs(state, targets, ctx, 0); | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_body(asdl_stmt_seq *body, const char *owner) | ||||
| validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner) | ||||
| { | ||||
|     return validate_nonempty_seq(body, "body", owner) && validate_stmts(body); | ||||
|     return validate_nonempty_seq(body, "body", owner) && validate_stmts(state, body); | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_stmt(stmt_ty stmt) | ||||
| validate_stmt(struct validator *state, stmt_ty stmt) | ||||
| { | ||||
|     int ret; | ||||
|     Py_ssize_t i; | ||||
|     if (++state->recursion_depth > state->recursion_limit) { | ||||
|         PyErr_SetString(PyExc_RecursionError, | ||||
|                         "maximum recursion depth exceeded during compilation"); | ||||
|         return 0; | ||||
|     } | ||||
|     switch (stmt->kind) { | ||||
|     case FunctionDef_kind: | ||||
|         return validate_body(stmt->v.FunctionDef.body, "FunctionDef") && | ||||
|             validate_arguments(stmt->v.FunctionDef.args) && | ||||
|             validate_exprs(stmt->v.FunctionDef.decorator_list, Load, 0) && | ||||
|         ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") && | ||||
|             validate_arguments(state, stmt->v.FunctionDef.args) && | ||||
|             validate_exprs(state, stmt->v.FunctionDef.decorator_list, Load, 0) && | ||||
|             (!stmt->v.FunctionDef.returns || | ||||
|              validate_expr(stmt->v.FunctionDef.returns, Load)); | ||||
|              validate_expr(state, stmt->v.FunctionDef.returns, Load)); | ||||
|         break; | ||||
|     case ClassDef_kind: | ||||
|         return validate_body(stmt->v.ClassDef.body, "ClassDef") && | ||||
|             validate_exprs(stmt->v.ClassDef.bases, Load, 0) && | ||||
|             validate_keywords(stmt->v.ClassDef.keywords) && | ||||
|             validate_exprs(stmt->v.ClassDef.decorator_list, Load, 0); | ||||
|         ret = validate_body(state, stmt->v.ClassDef.body, "ClassDef") && | ||||
|             validate_exprs(state, stmt->v.ClassDef.bases, Load, 0) && | ||||
|             validate_keywords(state, stmt->v.ClassDef.keywords) && | ||||
|             validate_exprs(state, stmt->v.ClassDef.decorator_list, Load, 0); | ||||
|         break; | ||||
|     case Return_kind: | ||||
|         return !stmt->v.Return.value || validate_expr(stmt->v.Return.value, Load); | ||||
|         ret = !stmt->v.Return.value || validate_expr(state, stmt->v.Return.value, Load); | ||||
|         break; | ||||
|     case Delete_kind: | ||||
|         return validate_assignlist(stmt->v.Delete.targets, Del); | ||||
|         ret = validate_assignlist(state, stmt->v.Delete.targets, Del); | ||||
|         break; | ||||
|     case Assign_kind: | ||||
|         return validate_assignlist(stmt->v.Assign.targets, Store) && | ||||
|             validate_expr(stmt->v.Assign.value, Load); | ||||
|         ret = validate_assignlist(state, stmt->v.Assign.targets, Store) && | ||||
|             validate_expr(state, stmt->v.Assign.value, Load); | ||||
|         break; | ||||
|     case AugAssign_kind: | ||||
|         return validate_expr(stmt->v.AugAssign.target, Store) && | ||||
|             validate_expr(stmt->v.AugAssign.value, Load); | ||||
|         ret = validate_expr(state, stmt->v.AugAssign.target, Store) && | ||||
|             validate_expr(state, stmt->v.AugAssign.value, Load); | ||||
|         break; | ||||
|     case AnnAssign_kind: | ||||
|         if (stmt->v.AnnAssign.target->kind != Name_kind && | ||||
|             stmt->v.AnnAssign.simple) { | ||||
|  | @ -387,74 +446,84 @@ validate_stmt(stmt_ty stmt) | |||
|                             "AnnAssign with simple non-Name target"); | ||||
|             return 0; | ||||
|         } | ||||
|         return validate_expr(stmt->v.AnnAssign.target, Store) && | ||||
|         ret = validate_expr(state, stmt->v.AnnAssign.target, Store) && | ||||
|                (!stmt->v.AnnAssign.value || | ||||
|                 validate_expr(stmt->v.AnnAssign.value, Load)) && | ||||
|                validate_expr(stmt->v.AnnAssign.annotation, Load); | ||||
|                 validate_expr(state, stmt->v.AnnAssign.value, Load)) && | ||||
|                validate_expr(state, stmt->v.AnnAssign.annotation, Load); | ||||
|         break; | ||||
|     case For_kind: | ||||
|         return validate_expr(stmt->v.For.target, Store) && | ||||
|             validate_expr(stmt->v.For.iter, Load) && | ||||
|             validate_body(stmt->v.For.body, "For") && | ||||
|             validate_stmts(stmt->v.For.orelse); | ||||
|         ret = validate_expr(state, stmt->v.For.target, Store) && | ||||
|             validate_expr(state, stmt->v.For.iter, Load) && | ||||
|             validate_body(state, stmt->v.For.body, "For") && | ||||
|             validate_stmts(state, stmt->v.For.orelse); | ||||
|         break; | ||||
|     case AsyncFor_kind: | ||||
|         return validate_expr(stmt->v.AsyncFor.target, Store) && | ||||
|             validate_expr(stmt->v.AsyncFor.iter, Load) && | ||||
|             validate_body(stmt->v.AsyncFor.body, "AsyncFor") && | ||||
|             validate_stmts(stmt->v.AsyncFor.orelse); | ||||
|         ret = validate_expr(state, stmt->v.AsyncFor.target, Store) && | ||||
|             validate_expr(state, stmt->v.AsyncFor.iter, Load) && | ||||
|             validate_body(state, stmt->v.AsyncFor.body, "AsyncFor") && | ||||
|             validate_stmts(state, stmt->v.AsyncFor.orelse); | ||||
|         break; | ||||
|     case While_kind: | ||||
|         return validate_expr(stmt->v.While.test, Load) && | ||||
|             validate_body(stmt->v.While.body, "While") && | ||||
|             validate_stmts(stmt->v.While.orelse); | ||||
|         ret = validate_expr(state, stmt->v.While.test, Load) && | ||||
|             validate_body(state, stmt->v.While.body, "While") && | ||||
|             validate_stmts(state, stmt->v.While.orelse); | ||||
|         break; | ||||
|     case If_kind: | ||||
|         return validate_expr(stmt->v.If.test, Load) && | ||||
|             validate_body(stmt->v.If.body, "If") && | ||||
|             validate_stmts(stmt->v.If.orelse); | ||||
|         ret = validate_expr(state, stmt->v.If.test, Load) && | ||||
|             validate_body(state, stmt->v.If.body, "If") && | ||||
|             validate_stmts(state, stmt->v.If.orelse); | ||||
|         break; | ||||
|     case With_kind: | ||||
|         if (!validate_nonempty_seq(stmt->v.With.items, "items", "With")) | ||||
|             return 0; | ||||
|         for (i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) { | ||||
|             withitem_ty item = asdl_seq_GET(stmt->v.With.items, i); | ||||
|             if (!validate_expr(item->context_expr, Load) || | ||||
|                 (item->optional_vars && !validate_expr(item->optional_vars, Store))) | ||||
|             if (!validate_expr(state, item->context_expr, Load) || | ||||
|                 (item->optional_vars && !validate_expr(state, item->optional_vars, Store))) | ||||
|                 return 0; | ||||
|         } | ||||
|         return validate_body(stmt->v.With.body, "With"); | ||||
|         ret = validate_body(state, stmt->v.With.body, "With"); | ||||
|         break; | ||||
|     case AsyncWith_kind: | ||||
|         if (!validate_nonempty_seq(stmt->v.AsyncWith.items, "items", "AsyncWith")) | ||||
|             return 0; | ||||
|         for (i = 0; i < asdl_seq_LEN(stmt->v.AsyncWith.items); i++) { | ||||
|             withitem_ty item = asdl_seq_GET(stmt->v.AsyncWith.items, i); | ||||
|             if (!validate_expr(item->context_expr, Load) || | ||||
|                 (item->optional_vars && !validate_expr(item->optional_vars, Store))) | ||||
|             if (!validate_expr(state, item->context_expr, Load) || | ||||
|                 (item->optional_vars && !validate_expr(state, item->optional_vars, Store))) | ||||
|                 return 0; | ||||
|         } | ||||
|         return validate_body(stmt->v.AsyncWith.body, "AsyncWith"); | ||||
|         ret = validate_body(state, stmt->v.AsyncWith.body, "AsyncWith"); | ||||
|         break; | ||||
|     case Match_kind: | ||||
|         if (!validate_expr(stmt->v.Match.subject, Load) | ||||
|         if (!validate_expr(state, stmt->v.Match.subject, Load) | ||||
|             || !validate_nonempty_seq(stmt->v.Match.cases, "cases", "Match")) { | ||||
|             return 0; | ||||
|         } | ||||
|         for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) { | ||||
|             match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i); | ||||
|             if (!validate_pattern(m->pattern) | ||||
|                 || (m->guard && !validate_expr(m->guard, Load)) | ||||
|                 || !validate_body(m->body, "match_case")) { | ||||
|                 || (m->guard && !validate_expr(state, m->guard, Load)) | ||||
|                 || !validate_body(state, m->body, "match_case")) { | ||||
|                 return 0; | ||||
|             } | ||||
|         } | ||||
|         return 1; | ||||
|         ret = 1; | ||||
|         break; | ||||
|     case Raise_kind: | ||||
|         if (stmt->v.Raise.exc) { | ||||
|             return validate_expr(stmt->v.Raise.exc, Load) && | ||||
|                 (!stmt->v.Raise.cause || validate_expr(stmt->v.Raise.cause, Load)); | ||||
|             ret = validate_expr(state, stmt->v.Raise.exc, Load) && | ||||
|                 (!stmt->v.Raise.cause || validate_expr(state, stmt->v.Raise.cause, Load)); | ||||
|             break; | ||||
|         } | ||||
|         if (stmt->v.Raise.cause) { | ||||
|             PyErr_SetString(PyExc_ValueError, "Raise with cause but no exception"); | ||||
|             return 0; | ||||
|         } | ||||
|         return 1; | ||||
|         ret = 1; | ||||
|         break; | ||||
|     case Try_kind: | ||||
|         if (!validate_body(stmt->v.Try.body, "Try")) | ||||
|         if (!validate_body(state, stmt->v.Try.body, "Try")) | ||||
|             return 0; | ||||
|         if (!asdl_seq_LEN(stmt->v.Try.handlers) && | ||||
|             !asdl_seq_LEN(stmt->v.Try.finalbody)) { | ||||
|  | @ -469,55 +538,66 @@ validate_stmt(stmt_ty stmt) | |||
|         for (i = 0; i < asdl_seq_LEN(stmt->v.Try.handlers); i++) { | ||||
|             excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i); | ||||
|             if ((handler->v.ExceptHandler.type && | ||||
|                  !validate_expr(handler->v.ExceptHandler.type, Load)) || | ||||
|                 !validate_body(handler->v.ExceptHandler.body, "ExceptHandler")) | ||||
|                  !validate_expr(state, handler->v.ExceptHandler.type, Load)) || | ||||
|                 !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler")) | ||||
|                 return 0; | ||||
|         } | ||||
|         return (!asdl_seq_LEN(stmt->v.Try.finalbody) || | ||||
|                 validate_stmts(stmt->v.Try.finalbody)) && | ||||
|         ret = (!asdl_seq_LEN(stmt->v.Try.finalbody) || | ||||
|                 validate_stmts(state, stmt->v.Try.finalbody)) && | ||||
|             (!asdl_seq_LEN(stmt->v.Try.orelse) || | ||||
|              validate_stmts(stmt->v.Try.orelse)); | ||||
|              validate_stmts(state, stmt->v.Try.orelse)); | ||||
|         break; | ||||
|     case Assert_kind: | ||||
|         return validate_expr(stmt->v.Assert.test, Load) && | ||||
|             (!stmt->v.Assert.msg || validate_expr(stmt->v.Assert.msg, Load)); | ||||
|         ret = validate_expr(state, stmt->v.Assert.test, Load) && | ||||
|             (!stmt->v.Assert.msg || validate_expr(state, stmt->v.Assert.msg, Load)); | ||||
|         break; | ||||
|     case Import_kind: | ||||
|         return validate_nonempty_seq(stmt->v.Import.names, "names", "Import"); | ||||
|         ret = validate_nonempty_seq(stmt->v.Import.names, "names", "Import"); | ||||
|         break; | ||||
|     case ImportFrom_kind: | ||||
|         if (stmt->v.ImportFrom.level < 0) { | ||||
|             PyErr_SetString(PyExc_ValueError, "Negative ImportFrom level"); | ||||
|             return 0; | ||||
|         } | ||||
|         return validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom"); | ||||
|         ret = validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom"); | ||||
|         break; | ||||
|     case Global_kind: | ||||
|         return validate_nonempty_seq(stmt->v.Global.names, "names", "Global"); | ||||
|         ret = validate_nonempty_seq(stmt->v.Global.names, "names", "Global"); | ||||
|         break; | ||||
|     case Nonlocal_kind: | ||||
|         return validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal"); | ||||
|         ret = validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal"); | ||||
|         break; | ||||
|     case Expr_kind: | ||||
|         return validate_expr(stmt->v.Expr.value, Load); | ||||
|         ret = validate_expr(state, stmt->v.Expr.value, Load); | ||||
|         break; | ||||
|     case AsyncFunctionDef_kind: | ||||
|         return validate_body(stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") && | ||||
|             validate_arguments(stmt->v.AsyncFunctionDef.args) && | ||||
|             validate_exprs(stmt->v.AsyncFunctionDef.decorator_list, Load, 0) && | ||||
|         ret = validate_body(state, stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") && | ||||
|             validate_arguments(state, stmt->v.AsyncFunctionDef.args) && | ||||
|             validate_exprs(state, stmt->v.AsyncFunctionDef.decorator_list, Load, 0) && | ||||
|             (!stmt->v.AsyncFunctionDef.returns || | ||||
|              validate_expr(stmt->v.AsyncFunctionDef.returns, Load)); | ||||
|              validate_expr(state, stmt->v.AsyncFunctionDef.returns, Load)); | ||||
|         break; | ||||
|     case Pass_kind: | ||||
|     case Break_kind: | ||||
|     case Continue_kind: | ||||
|         return 1; | ||||
|         ret = 1; | ||||
|         break; | ||||
|     default: | ||||
|         PyErr_SetString(PyExc_SystemError, "unexpected statement"); | ||||
|         return 0; | ||||
|     } | ||||
|     state->recursion_depth--; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_stmts(asdl_stmt_seq *seq) | ||||
| validate_stmts(struct validator *state, asdl_stmt_seq *seq) | ||||
| { | ||||
|     Py_ssize_t i; | ||||
|     for (i = 0; i < asdl_seq_LEN(seq); i++) { | ||||
|         stmt_ty stmt = asdl_seq_GET(seq, i); | ||||
|         if (stmt) { | ||||
|             if (!validate_stmt(stmt)) | ||||
|             if (!validate_stmt(state, stmt)) | ||||
|                 return 0; | ||||
|         } | ||||
|         else { | ||||
|  | @ -530,13 +610,13 @@ validate_stmts(asdl_stmt_seq *seq) | |||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_exprs(asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok) | ||||
| validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok) | ||||
| { | ||||
|     Py_ssize_t i; | ||||
|     for (i = 0; i < asdl_seq_LEN(exprs); i++) { | ||||
|         expr_ty expr = asdl_seq_GET(exprs, i); | ||||
|         if (expr) { | ||||
|             if (!validate_expr(expr, ctx)) | ||||
|             if (!validate_expr(state, expr, ctx)) | ||||
|                 return 0; | ||||
|         } | ||||
|         else if (!null_ok) { | ||||
|  | @ -549,26 +629,53 @@ validate_exprs(asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok) | |||
|     return 1; | ||||
| } | ||||
| 
 | ||||
| /* See comments in symtable.c. */ | ||||
| #define COMPILER_STACK_FRAME_SCALE 3 | ||||
| 
 | ||||
| int | ||||
| _PyAST_Validate(mod_ty mod) | ||||
| { | ||||
|     int res = 0; | ||||
|     struct validator state; | ||||
|     PyThreadState *tstate; | ||||
|     int recursion_limit = Py_GetRecursionLimit(); | ||||
|     int starting_recursion_depth; | ||||
| 
 | ||||
|     /* Setup recursion depth check counters */ | ||||
|     tstate = _PyThreadState_GET(); | ||||
|     if (!tstate) { | ||||
|         return 0; | ||||
|     } | ||||
|     /* Be careful here to prevent overflow. */ | ||||
|     starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? | ||||
|         tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth; | ||||
|     state.recursion_depth = starting_recursion_depth; | ||||
|     state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? | ||||
|         recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit; | ||||
| 
 | ||||
|     switch (mod->kind) { | ||||
|     case Module_kind: | ||||
|         res = validate_stmts(mod->v.Module.body); | ||||
|         res = validate_stmts(&state, mod->v.Module.body); | ||||
|         break; | ||||
|     case Interactive_kind: | ||||
|         res = validate_stmts(mod->v.Interactive.body); | ||||
|         res = validate_stmts(&state, mod->v.Interactive.body); | ||||
|         break; | ||||
|     case Expression_kind: | ||||
|         res = validate_expr(mod->v.Expression.body, Load); | ||||
|         res = validate_expr(&state, mod->v.Expression.body, Load); | ||||
|         break; | ||||
|     default: | ||||
|         PyErr_SetString(PyExc_SystemError, "impossible module node"); | ||||
|         res = 0; | ||||
|         break; | ||||
|     } | ||||
| 
 | ||||
|     /* Check that the recursion depth counting balanced correctly */ | ||||
|     if (res && state.recursion_depth != starting_recursion_depth) { | ||||
|         PyErr_Format(PyExc_SystemError, | ||||
|             "AST validator recursion depth mismatch (before=%d, after=%d)", | ||||
|             starting_recursion_depth, state.recursion_depth); | ||||
|         return 0; | ||||
|     } | ||||
|     return res; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ | |||
| #include "Python.h" | ||||
| #include "pycore_ast.h"           // _PyAST_GetDocString() | ||||
| #include "pycore_compile.h"       // _PyASTOptimizeState | ||||
| #include "pycore_pystate.h"       // _PyThreadState_GET() | ||||
| 
 | ||||
| 
 | ||||
| static int | ||||
|  | @ -488,6 +489,11 @@ astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) | |||
| static int | ||||
| astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) | ||||
| { | ||||
|     if (++state->recursion_depth > state->recursion_limit) { | ||||
|         PyErr_SetString(PyExc_RecursionError, | ||||
|                         "maximum recursion depth exceeded during compilation"); | ||||
|         return 0; | ||||
|     } | ||||
|     switch (node_->kind) { | ||||
|     case BoolOp_kind: | ||||
|         CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values); | ||||
|  | @ -586,6 +592,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) | |||
|     case Name_kind: | ||||
|         if (node_->v.Name.ctx == Load && | ||||
|                 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) { | ||||
|             state->recursion_depth--; | ||||
|             return make_const(node_, PyBool_FromLong(!state->optimize), ctx_); | ||||
|         } | ||||
|         break; | ||||
|  | @ -602,6 +609,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) | |||
|     // No default case, so the compiler will emit a warning if new expression
 | ||||
|     // kinds are added without being handled here
 | ||||
|     } | ||||
|     state->recursion_depth--; | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
|  | @ -648,6 +656,11 @@ astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) | |||
| static int | ||||
| astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) | ||||
| { | ||||
|     if (++state->recursion_depth > state->recursion_limit) { | ||||
|         PyErr_SetString(PyExc_RecursionError, | ||||
|                         "maximum recursion depth exceeded during compilation"); | ||||
|         return 0; | ||||
|     } | ||||
|     switch (node_->kind) { | ||||
|     case FunctionDef_kind: | ||||
|         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args); | ||||
|  | @ -757,6 +770,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) | |||
|     // No default case, so the compiler will emit a warning if new statement
 | ||||
|     // kinds are added without being handled here
 | ||||
|     } | ||||
|     state->recursion_depth--; | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
|  | @ -906,10 +920,38 @@ astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat | |||
| #undef CALL_SEQ | ||||
| #undef CALL_INT_SEQ | ||||
| 
 | ||||
| /* See comments in symtable.c. */ | ||||
| #define COMPILER_STACK_FRAME_SCALE 3 | ||||
| 
 | ||||
| int | ||||
| _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state) | ||||
| { | ||||
|     PyThreadState *tstate; | ||||
|     int recursion_limit = Py_GetRecursionLimit(); | ||||
|     int starting_recursion_depth; | ||||
| 
 | ||||
|     /* Setup recursion depth check counters */ | ||||
|     tstate = _PyThreadState_GET(); | ||||
|     if (!tstate) { | ||||
|         return 0; | ||||
|     } | ||||
|     /* Be careful here to prevent overflow. */ | ||||
|     starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? | ||||
|         tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth; | ||||
|     state->recursion_depth = starting_recursion_depth; | ||||
|     state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? | ||||
|         recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit; | ||||
| 
 | ||||
|     int ret = astfold_mod(mod, arena, state); | ||||
|     assert(ret || PyErr_Occurred()); | ||||
| 
 | ||||
|     /* Check that the recursion depth counting balanced correctly */ | ||||
|     if (ret && state->recursion_depth != starting_recursion_depth) { | ||||
|         PyErr_Format(PyExc_SystemError, | ||||
|             "AST optimizer recursion depth mismatch (before=%d, after=%d)", | ||||
|             starting_recursion_depth, state->recursion_depth); | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     return ret; | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Serhiy Storchaka
						Serhiy Storchaka