mirror of
				https://github.com/python/cpython.git
				synced 2025-10-25 10:44:55 +00:00 
			
		
		
		
	bpo-43892: Make match patterns explicit in the AST (GH-25585)
Co-authored-by: Brandt Bucher <brandtbucher@gmail.com>
This commit is contained in:
		
							parent
							
								
									e52ab42ced
								
							
						
					
					
						commit
						1e7b858575
					
				
					 20 changed files with 3460 additions and 1377 deletions
				
			
		
							
								
								
									
										238
									
								
								Python/ast.c
									
										
									
									
									
								
							
							
						
						
									
										238
									
								
								Python/ast.c
									
										
									
									
									
								
							|  | @ -7,6 +7,7 @@ | |||
| #include "pycore_pystate.h"       // _PyThreadState_GET() | ||||
| 
 | ||||
| #include <assert.h> | ||||
| #include <stdbool.h> | ||||
| 
 | ||||
| struct validator { | ||||
|     int recursion_depth;            /* current recursion depth */ | ||||
|  | @ -18,6 +19,7 @@ static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, i | |||
| static int _validate_nonempty_seq(asdl_seq *, const char *, const char *); | ||||
| static int validate_stmt(struct validator *, stmt_ty); | ||||
| static int validate_expr(struct validator *, expr_ty, expr_context_ty); | ||||
| static int validate_pattern(struct validator *, pattern_ty); | ||||
| 
 | ||||
| static int | ||||
| validate_name(PyObject *name) | ||||
|  | @ -88,9 +90,9 @@ expr_context_name(expr_context_ty ctx) | |||
|         return "Store"; | ||||
|     case Del: | ||||
|         return "Del"; | ||||
|     default: | ||||
|         Py_UNREACHABLE(); | ||||
|     // No default case so compiler emits warning for unhandled cases
 | ||||
|     } | ||||
|     Py_UNREACHABLE(); | ||||
| } | ||||
| 
 | ||||
| static int | ||||
|  | @ -180,7 +182,7 @@ validate_constant(struct validator *state, PyObject *value) | |||
| static int | ||||
| validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx) | ||||
| { | ||||
|     int ret; | ||||
|     int ret = -1; | ||||
|     if (++state->recursion_depth > state->recursion_limit) { | ||||
|         PyErr_SetString(PyExc_RecursionError, | ||||
|                         "maximum recursion depth exceeded during compilation"); | ||||
|  | @ -351,33 +353,215 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx) | |||
|     case NamedExpr_kind: | ||||
|         ret = validate_expr(state, exp->v.NamedExpr.value, Load); | ||||
|         break; | ||||
|     case MatchAs_kind: | ||||
|         PyErr_SetString(PyExc_ValueError, | ||||
|                         "MatchAs is only valid in match_case patterns"); | ||||
|         return 0; | ||||
|     case MatchOr_kind: | ||||
|         PyErr_SetString(PyExc_ValueError, | ||||
|                         "MatchOr is only valid in match_case patterns"); | ||||
|         return 0; | ||||
|     /* This last case doesn't have any checking. */ | ||||
|     case Name_kind: | ||||
|         ret = 1; | ||||
|         break; | ||||
|     default: | ||||
|     // No default case so compiler emits warning for unhandled cases
 | ||||
|     } | ||||
|     if (ret < 0) { | ||||
|         PyErr_SetString(PyExc_SystemError, "unexpected expression"); | ||||
|         return 0; | ||||
|         ret = 0; | ||||
|     } | ||||
|     state->recursion_depth--; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // Note: the ensure_literal_* functions are only used to validate a restricted
 | ||||
| //       set of non-recursive literals that have already been checked with
 | ||||
| //       validate_expr, so they don't accept the validator state
 | ||||
| static int | ||||
| validate_pattern(expr_ty p) | ||||
| ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary) | ||||
| { | ||||
|     // Coming soon (thanks Batuhan)!
 | ||||
|     assert(exp->kind == Constant_kind); | ||||
|     PyObject *value = exp->v.Constant.value; | ||||
|     return (allow_real && PyFloat_CheckExact(value)) || | ||||
|            (allow_real && PyLong_CheckExact(value)) || | ||||
|            (allow_imaginary && PyComplex_CheckExact(value)); | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary) | ||||
| { | ||||
|     assert(exp->kind == UnaryOp_kind); | ||||
|     // Must be negation ...
 | ||||
|     if (exp->v.UnaryOp.op != USub) { | ||||
|         return 0; | ||||
|     } | ||||
|     // ... of a constant ...
 | ||||
|     expr_ty operand = exp->v.UnaryOp.operand; | ||||
|     if (operand->kind != Constant_kind) { | ||||
|         return 0; | ||||
|     } | ||||
|     // ... number
 | ||||
|     return ensure_literal_number(operand, allow_real, allow_imaginary); | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| ensure_literal_complex(expr_ty exp) | ||||
| { | ||||
|     assert(exp->kind == BinOp_kind); | ||||
|     expr_ty left = exp->v.BinOp.left; | ||||
|     expr_ty right = exp->v.BinOp.right; | ||||
|     // Ensure op is addition or subtraction
 | ||||
|     if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) { | ||||
|         return 0; | ||||
|     } | ||||
|     // Check LHS is a real number (potentially signed)
 | ||||
|     switch (left->kind) | ||||
|     { | ||||
|         case Constant_kind: | ||||
|             if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) { | ||||
|                 return 0; | ||||
|             } | ||||
|             break; | ||||
|         case UnaryOp_kind: | ||||
|             if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) { | ||||
|                 return 0; | ||||
|             } | ||||
|             break; | ||||
|         default: | ||||
|             return 0; | ||||
|     } | ||||
|     // Check RHS is an imaginary number (no separate sign allowed)
 | ||||
|     switch (right->kind) | ||||
|     { | ||||
|         case Constant_kind: | ||||
|             if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) { | ||||
|                 return 0; | ||||
|             } | ||||
|             break; | ||||
|         default: | ||||
|             return 0; | ||||
|     } | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_pattern_match_value(struct validator *state, expr_ty exp) | ||||
| { | ||||
|     if (!validate_expr(state, exp, Load)) { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     switch (exp->kind) | ||||
|     { | ||||
|         case Constant_kind: | ||||
|         case Attribute_kind: | ||||
|             // Constants and attribute lookups are always permitted
 | ||||
|             return 1; | ||||
|         case UnaryOp_kind: | ||||
|             // Negated numbers are permitted (whether real or imaginary)
 | ||||
|             // Compiler will complain if AST folding doesn't create a constant
 | ||||
|             if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) { | ||||
|                 return 1; | ||||
|             } | ||||
|             break; | ||||
|         case BinOp_kind: | ||||
|             // Complex literals are permitted
 | ||||
|             // Compiler will complain if AST folding doesn't create a constant
 | ||||
|             if (ensure_literal_complex(exp)) { | ||||
|                 return 1; | ||||
|             } | ||||
|             break; | ||||
|         default: | ||||
|             break; | ||||
|     } | ||||
|     PyErr_SetString(PyExc_SyntaxError, | ||||
|         "patterns may only match literals and attribute lookups"); | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| validate_pattern(struct validator *state, pattern_ty p) | ||||
| { | ||||
|     int ret = -1; | ||||
|     if (++state->recursion_depth > state->recursion_limit) { | ||||
|         PyErr_SetString(PyExc_RecursionError, | ||||
|                         "maximum recursion depth exceeded during compilation"); | ||||
|         return 0; | ||||
|     } | ||||
|     // Coming soon: https://bugs.python.org/issue43897 (thanks Batuhan)!
 | ||||
|     // TODO: Ensure no subnodes use "_" as an ordinary identifier
 | ||||
|     switch (p->kind) { | ||||
|         case MatchValue_kind: | ||||
|             ret = validate_pattern_match_value(state, p->v.MatchValue.value); | ||||
|             break; | ||||
|         case MatchSingleton_kind: | ||||
|             // TODO: Check constant is specifically None, True, or False
 | ||||
|             ret = validate_constant(state, p->v.MatchSingleton.value); | ||||
|             break; | ||||
|         case MatchSequence_kind: | ||||
|             // TODO: Validate all subpatterns
 | ||||
|             // return validate_patterns(state, p->v.MatchSequence.patterns);
 | ||||
|             ret = 1; | ||||
|             break; | ||||
|         case MatchMapping_kind: | ||||
|             // TODO: check "rest" target name is valid
 | ||||
|             if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) { | ||||
|                 PyErr_SetString(PyExc_ValueError, | ||||
|                                 "MatchMapping doesn't have the same number of keys as patterns"); | ||||
|                 return 0; | ||||
|             } | ||||
|             // null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
 | ||||
|             // TODO: replace with more restrictive expression validator, as per MatchValue above
 | ||||
|             if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) { | ||||
|                 return 0; | ||||
|             } | ||||
|             // TODO: Validate all subpatterns
 | ||||
|             // ret = validate_patterns(state, p->v.MatchMapping.patterns);
 | ||||
|             ret = 1; | ||||
|             break; | ||||
|         case MatchClass_kind: | ||||
|             if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) { | ||||
|                 PyErr_SetString(PyExc_ValueError, | ||||
|                                 "MatchClass doesn't have the same number of keyword attributes as patterns"); | ||||
|                 return 0; | ||||
|             } | ||||
|             // TODO: Restrict cls lookup to being a name or attribute
 | ||||
|             if (!validate_expr(state, p->v.MatchClass.cls, Load)) { | ||||
|                 return 0; | ||||
|             } | ||||
|             // TODO: Validate all subpatterns
 | ||||
|             // return validate_patterns(state, p->v.MatchClass.patterns) &&
 | ||||
|             //        validate_patterns(state, p->v.MatchClass.kwd_patterns);
 | ||||
|             ret = 1; | ||||
|             break; | ||||
|         case MatchStar_kind: | ||||
|             // TODO: check target name is valid
 | ||||
|             ret = 1; | ||||
|             break; | ||||
|         case MatchAs_kind: | ||||
|             // TODO: check target name is valid
 | ||||
|             if (p->v.MatchAs.pattern == NULL) { | ||||
|                 ret = 1; | ||||
|             } | ||||
|             else if (p->v.MatchAs.name == NULL) { | ||||
|                 PyErr_SetString(PyExc_ValueError, | ||||
|                                 "MatchAs must specify a target name if a pattern is given"); | ||||
|                 return 0; | ||||
|             } | ||||
|             else { | ||||
|                 ret = validate_pattern(state, p->v.MatchAs.pattern); | ||||
|             } | ||||
|             break; | ||||
|         case MatchOr_kind: | ||||
|             // TODO: Validate all subpatterns
 | ||||
|             // return validate_patterns(state, p->v.MatchOr.patterns);
 | ||||
|             ret = 1; | ||||
|             break; | ||||
|     // No default case, so the compiler will emit a warning if new pattern
 | ||||
|     // kinds are added without being handled here
 | ||||
|     } | ||||
|     if (ret < 0) { | ||||
|         PyErr_SetString(PyExc_SystemError, "unexpected pattern"); | ||||
|         ret = 0; | ||||
|     } | ||||
|     state->recursion_depth--; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| static int | ||||
| _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner) | ||||
| { | ||||
|  | @ -404,7 +588,7 @@ validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner) | |||
| static int | ||||
| validate_stmt(struct validator *state, stmt_ty stmt) | ||||
| { | ||||
|     int ret; | ||||
|     int ret = -1; | ||||
|     Py_ssize_t i; | ||||
|     if (++state->recursion_depth > state->recursion_limit) { | ||||
|         PyErr_SetString(PyExc_RecursionError, | ||||
|  | @ -502,7 +686,7 @@ validate_stmt(struct validator *state, stmt_ty stmt) | |||
|         } | ||||
|         for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) { | ||||
|             match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i); | ||||
|             if (!validate_pattern(m->pattern) | ||||
|             if (!validate_pattern(state, m->pattern) | ||||
|                 || (m->guard && !validate_expr(state, m->guard, Load)) | ||||
|                 || !validate_body(state, m->body, "match_case")) { | ||||
|                 return 0; | ||||
|  | @ -582,9 +766,11 @@ validate_stmt(struct validator *state, stmt_ty stmt) | |||
|     case Continue_kind: | ||||
|         ret = 1; | ||||
|         break; | ||||
|     default: | ||||
|     // No default case so compiler emits warning for unhandled cases
 | ||||
|     } | ||||
|     if (ret < 0) { | ||||
|         PyErr_SetString(PyExc_SystemError, "unexpected statement"); | ||||
|         return 0; | ||||
|         ret = 0; | ||||
|     } | ||||
|     state->recursion_depth--; | ||||
|     return ret; | ||||
|  | @ -635,7 +821,7 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct | |||
| int | ||||
| _PyAST_Validate(mod_ty mod) | ||||
| { | ||||
|     int res = 0; | ||||
|     int res = -1; | ||||
|     struct validator state; | ||||
|     PyThreadState *tstate; | ||||
|     int recursion_limit = Py_GetRecursionLimit(); | ||||
|  | @ -663,10 +849,16 @@ _PyAST_Validate(mod_ty mod) | |||
|     case Expression_kind: | ||||
|         res = validate_expr(&state, mod->v.Expression.body, Load); | ||||
|         break; | ||||
|     default: | ||||
|         PyErr_SetString(PyExc_SystemError, "impossible module node"); | ||||
|         res = 0; | ||||
|     case FunctionType_kind: | ||||
|         res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) && | ||||
|               validate_expr(&state, mod->v.FunctionType.returns, Load); | ||||
|         break; | ||||
|     // No default case so compiler emits warning for unhandled cases
 | ||||
|     } | ||||
| 
 | ||||
|     if (res < 0) { | ||||
|         PyErr_SetString(PyExc_SystemError, "impossible module node"); | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     /* Check that the recursion depth counting balanced correctly */ | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Nick Coghlan
						Nick Coghlan