[Python-checkins] bpo-43897: AST validation for pattern matching nodes (GH24771)

brandtbucher webhook-mailer at python.org
Wed Jul 28 13:14:58 EDT 2021


https://github.com/python/cpython/commit/31bec6f1b178dadec3cb43353274b4e958a8f015
commit: 31bec6f1b178dadec3cb43353274b4e958a8f015
branch: main
author: Batuhan Taskaya <batuhan at python.org>
committer: brandtbucher <brandtbucher at gmail.com>
date: 2021-07-28T10:14:45-07:00
summary:

bpo-43897: AST validation for pattern matching nodes (GH24771)

files:
M Lib/test/test_ast.py
M Python/ast.c

diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 5f1ee75c8bddc..925bb883d63e4 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -696,7 +696,7 @@ def test_constant_as_name(self):
         for constant in "True", "False", "None":
             expr = ast.Expression(ast.Name(constant, ast.Load()))
             ast.fix_missing_locations(expr)
-            with self.assertRaisesRegex(ValueError, f"Name node can't be used with '{constant}' constant"):
+            with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
                 compile(expr, "<test>", "eval")
 
     def test_precedence_enum(self):
@@ -1507,6 +1507,147 @@ def test_stdlib_validates(self):
                 mod = ast.parse(source, fn)
                 compile(mod, fn, "exec")
 
+    constant_1 = ast.Constant(1)
+    pattern_1 = ast.MatchValue(constant_1)
+
+    constant_x = ast.Constant('x')
+    pattern_x = ast.MatchValue(constant_x)
+
+    constant_true = ast.Constant(True)
+    pattern_true = ast.MatchSingleton(True)
+
+    name_carter = ast.Name('carter', ast.Load())
+
+    _MATCH_PATTERNS = [
+        ast.MatchValue(
+            ast.Attribute(
+                ast.Attribute(
+                    ast.Name('x', ast.Store()),
+                    'y', ast.Load()
+                ),
+                'z', ast.Load()
+            )
+        ),
+        ast.MatchValue(
+            ast.Attribute(
+                ast.Attribute(
+                    ast.Name('x', ast.Load()),
+                    'y', ast.Store()
+                ),
+                'z', ast.Load()
+            )
+        ),
+        ast.MatchValue(
+            ast.Constant(...)
+        ),
+        ast.MatchValue(
+            ast.Constant(True)
+        ),
+        ast.MatchValue(
+            ast.Constant((1,2,3))
+        ),
+        ast.MatchSingleton('string'),
+        ast.MatchSequence([
+          ast.MatchSingleton('string')
+        ]),
+        ast.MatchSequence(
+            [
+                ast.MatchSequence(
+                    [
+                        ast.MatchSingleton('string')
+                    ]
+                )
+            ]
+        ),
+        ast.MatchMapping(
+            [constant_1, constant_true],
+            [pattern_x]
+        ),
+        ast.MatchMapping(
+            [constant_true, constant_1],
+            [pattern_x, pattern_1],
+            rest='True'
+        ),
+        ast.MatchMapping(
+            [constant_true, ast.Starred(ast.Name('lol', ast.Load()), ast.Load())],
+            [pattern_x, pattern_1],
+            rest='legit'
+        ),
+        ast.MatchClass(
+            ast.Attribute(
+                ast.Attribute(
+                    constant_x,
+                    'y', ast.Load()),
+                'z', ast.Load()),
+            patterns=[], kwd_attrs=[], kwd_patterns=[]
+        ),
+        ast.MatchClass(
+            name_carter,
+            patterns=[],
+            kwd_attrs=['True'],
+            kwd_patterns=[pattern_1]
+        ),
+        ast.MatchClass(
+            name_carter,
+            patterns=[],
+            kwd_attrs=[],
+            kwd_patterns=[pattern_1]
+        ),
+        ast.MatchClass(
+            name_carter,
+            patterns=[ast.MatchSingleton('string')],
+            kwd_attrs=[],
+            kwd_patterns=[]
+        ),
+        ast.MatchClass(
+            name_carter,
+            patterns=[ast.MatchStar()],
+            kwd_attrs=[],
+            kwd_patterns=[]
+        ),
+        ast.MatchClass(
+            name_carter,
+            patterns=[],
+            kwd_attrs=[],
+            kwd_patterns=[ast.MatchStar()]
+        ),
+        ast.MatchSequence(
+            [
+                ast.MatchStar("True")
+            ]
+        ),
+        ast.MatchAs(
+            name='False'
+        ),
+        ast.MatchOr(
+            []
+        ),
+        ast.MatchOr(
+            [pattern_1]
+        ),
+        ast.MatchOr(
+            [pattern_1, pattern_x, ast.MatchSingleton('xxx')]
+        )
+    ]
+
+    def test_match_validation_pattern(self):
+        name_x = ast.Name('x', ast.Load())
+        for pattern in self._MATCH_PATTERNS:
+            with self.subTest(ast.dump(pattern, indent=4)):
+                node = ast.Match(
+                    subject=name_x,
+                    cases = [
+                        ast.match_case(
+                            pattern=pattern,
+                            body = [ast.Pass()]
+                        )
+                    ]
+                )
+                node = ast.fix_missing_locations(node)
+                module = ast.Module([node], [])
+                with self.assertRaises(ValueError):
+                    compile(module, "<test>", "exec")
+
 
 class ConstantTests(unittest.TestCase):
     """Tests on the ast.Constant node type."""
diff --git a/Python/ast.c b/Python/ast.c
index 1fc83f6301962..0a306c0b0579d 100644
--- a/Python/ast.c
+++ b/Python/ast.c
@@ -15,7 +15,8 @@ struct validator {
 };
 
 static int validate_stmts(struct validator *, asdl_stmt_seq *);
-static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, int);
+static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
+static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
 static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
 static int validate_stmt(struct validator *, stmt_ty);
 static int validate_expr(struct validator *, expr_ty, expr_context_ty);
@@ -33,7 +34,7 @@ validate_name(PyObject *name)
     };
     for (int i = 0; forbidden[i] != NULL; i++) {
         if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
-            PyErr_Format(PyExc_ValueError, "Name node can't be used with '%s' constant", forbidden[i]);
+            PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
             return 0;
         }
     }
@@ -448,6 +449,21 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
     switch (exp->kind)
     {
         case Constant_kind:
+            /* Ellipsis and immutable sequences are not allowed.
+               For True, False and None, MatchSingleton() should
+               be used */
+            if (!validate_expr(state, exp, Load)) {
+                return 0;
+            }
+            PyObject *literal = exp->v.Constant.value;
+            if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
+                PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
+                PyUnicode_CheckExact(literal)) {
+                return 1;
+            }
+            PyErr_SetString(PyExc_ValueError,
+                            "unexpected constant inside of a literal pattern");
+            return 0;
         case Attribute_kind:
             // Constants and attribute lookups are always permitted
             return 1;
@@ -465,11 +481,14 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
                 return 1;
             }
             break;
+        case JoinedStr_kind:
+            // Handled in the later stages
+            return 1;
         default:
             break;
     }
-    PyErr_SetString(PyExc_SyntaxError,
-        "patterns may only match literals and attribute lookups");
+    PyErr_SetString(PyExc_ValueError,
+                    "patterns may only match literals and attribute lookups");
     return 0;
 }
 
@@ -489,51 +508,101 @@ validate_pattern(struct validator *state, pattern_ty p)
             ret = validate_pattern_match_value(state, p->v.MatchValue.value);
             break;
         case MatchSingleton_kind:
-            // TODO: Check constant is specifically None, True, or False
-            ret = validate_constant(state, p->v.MatchSingleton.value);
+            ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
+            if (!ret) {
+                PyErr_SetString(PyExc_ValueError,
+                                "MatchSingleton can only contain True, False and None");
+            }
             break;
         case MatchSequence_kind:
-            // TODO: Validate all subpatterns
-            // return validate_patterns(state, p->v.MatchSequence.patterns);
-            ret = 1;
+            ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
             break;
         case MatchMapping_kind:
-            // TODO: check "rest" target name is valid
             if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
                 PyErr_SetString(PyExc_ValueError,
                                 "MatchMapping doesn't have the same number of keys as patterns");
-                return 0;
+                ret = 0;
+                break;
             }
-            // null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
-            // TODO: replace with more restrictive expression validator, as per MatchValue above
-            if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) {
-                return 0;
+
+            if (p->v.MatchMapping.rest && !validate_name(p->v.MatchMapping.rest)) {
+                ret = 0;
+                break;
             }
-            // TODO: Validate all subpatterns
-            // ret = validate_patterns(state, p->v.MatchMapping.patterns);
-            ret = 1;
+
+            asdl_expr_seq *keys = p->v.MatchMapping.keys;
+            for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
+                expr_ty key = asdl_seq_GET(keys, i);
+                if (key->kind == Constant_kind) {
+                    PyObject *literal = key->v.Constant.value;
+                    if (literal == Py_None || PyBool_Check(literal)) {
+                        /* validate_pattern_match_value will ensure the key
+                           doesn't contain True, False and None but it is
+                           syntactically valid, so we will pass those on in
+                           a special case. */
+                        continue;
+                    }
+                }
+                if (!validate_pattern_match_value(state, key)) {
+                    ret = 0;
+                    break;
+                }
+            }
+
+            ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
             break;
         case MatchClass_kind:
             if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
                 PyErr_SetString(PyExc_ValueError,
                                 "MatchClass doesn't have the same number of keyword attributes as patterns");
-                return 0;
+                ret = 0;
+                break;
             }
-            // TODO: Restrict cls lookup to being a name or attribute
             if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
-                return 0;
+                ret = 0;
+                break;
             }
-            // TODO: Validate all subpatterns
-            // return validate_patterns(state, p->v.MatchClass.patterns) &&
-            //        validate_patterns(state, p->v.MatchClass.kwd_patterns);
-            ret = 1;
+
+            expr_ty cls = p->v.MatchClass.cls;
+            while (1) {
+                if (cls->kind == Name_kind) {
+                    break;
+                }
+                else if (cls->kind == Attribute_kind) {
+                    cls = cls->v.Attribute.value;
+                    continue;
+                }
+                else {
+                    PyErr_SetString(PyExc_ValueError,
+                                    "MatchClass cls field can only contain Name or Attribute nodes.");
+                    state->recursion_depth--;
+                    return 0;
+                }
+            }
+
+            for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
+                PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
+                if (!validate_name(identifier)) {
+                    state->recursion_depth--;
+                    return 0;
+                }
+            }
+
+            if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
+                ret = 0;
+                break;
+            }
+
+            ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
             break;
         case MatchStar_kind:
-            // TODO: check target name is valid
-            ret = 1;
+            ret = p->v.MatchStar.name == NULL || validate_name(p->v.MatchStar.name);
             break;
         case MatchAs_kind:
-            // TODO: check target name is valid
+            if (p->v.MatchAs.name && !validate_name(p->v.MatchAs.name)) {
+                ret = 0;
+                break;
+            }
             if (p->v.MatchAs.pattern == NULL) {
                 ret = 1;
             }
@@ -547,9 +616,13 @@ validate_pattern(struct validator *state, pattern_ty p)
             }
             break;
         case MatchOr_kind:
-            // TODO: Validate all subpatterns
-            // return validate_patterns(state, p->v.MatchOr.patterns);
-            ret = 1;
+            if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
+                PyErr_SetString(PyExc_ValueError,
+                                "MatchOr requires at least 2 patterns");
+                ret = 0;
+                break;
+            }
+            ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
             break;
     // No default case, so the compiler will emit a warning if new pattern
     // kinds are added without being handled here
@@ -815,6 +888,25 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
     return 1;
 }
 
+static int
+validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
+{
+    Py_ssize_t i;
+    for (i = 0; i < asdl_seq_LEN(patterns); i++) {
+        pattern_ty pattern = asdl_seq_GET(patterns, i);
+        if (pattern->kind == MatchStar_kind && !star_ok) {
+            PyErr_SetString(PyExc_ValueError,
+                            "Can't use MatchStar within this sequence of patterns");
+            return 0;
+        }
+        if (!validate_pattern(state, pattern)) {
+            return 0;
+        }
+    }
+    return 1;
+}
+
+
 /* See comments in symtable.c. */
 #define COMPILER_STACK_FRAME_SCALE 3
 



More information about the Python-checkins mailing list