[Python-checkins] gh-95185: Check recursion depth in the AST constructor (#95186)

pablogsal webhook-mailer at python.org
Sun Jul 24 10:59:01 EDT 2022


https://github.com/python/cpython/commit/00474472944944b346d8409cfded84bb299f601a
commit: 00474472944944b346d8409cfded84bb299f601a
branch: main
author: Pablo Galindo Salgado <Pablogsal at gmail.com>
committer: pablogsal <Pablogsal at gmail.com>
date: 2022-07-24T15:58:52+01:00
summary:

gh-95185: Check recursion depth in the AST constructor (#95186)

Co-authored-by: Serhiy Storchaka <storchaka at gmail.com>

files:
A Misc/NEWS.d/next/Core and Builtins/2022-07-24-00-27-47.gh-issue-95185.ghYTZx.rst
M Include/internal/pycore_ast_state.h
M Lib/test/test_ast.py
M Parser/asdl_c.py
M Python/Python-ast.c

diff --git a/Include/internal/pycore_ast_state.h b/Include/internal/pycore_ast_state.h
index da78bba3b69bd..f15b4905eed14 100644
--- a/Include/internal/pycore_ast_state.h
+++ b/Include/internal/pycore_ast_state.h
@@ -12,6 +12,8 @@ extern "C" {
 
 struct ast_state {
     int initialized;
+    int recursion_depth;
+    int recursion_limit;
     PyObject *AST_type;
     PyObject *Add_singleton;
     PyObject *Add_type;
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 480089aa8af44..9734218c21be3 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -793,6 +793,27 @@ def next(self):
                     return self
         enum._test_simple_enum(_Precedence, ast._Precedence)
 
+    @support.cpython_only
+    def test_ast_recursion_limit(self):
+        fail_depth = sys.getrecursionlimit() * 3
+        crash_depth = sys.getrecursionlimit() * 300
+        success_depth = int(fail_depth * 0.75)
+
+        def check_limit(prefix, repeated):
+            expect_ok = prefix + repeated * success_depth
+            ast.parse(expect_ok)
+            for depth in (fail_depth, crash_depth):
+                broken = prefix + repeated * depth
+                details = "Compiling ({!r} + {!r} * {})".format(
+                            prefix, repeated, depth)
+                with self.assertRaises(RecursionError, msg=details):
+                    ast.parse(broken)
+
+        check_limit("a", "()")
+        check_limit("a", ".b")
+        check_limit("a", "[0]")
+        check_limit("a", "*a")
+
 
 class ASTHelpers_Test(unittest.TestCase):
     maxDiff = None
diff --git a/Misc/NEWS.d/next/Core and Builtins/2022-07-24-00-27-47.gh-issue-95185.ghYTZx.rst b/Misc/NEWS.d/next/Core and Builtins/2022-07-24-00-27-47.gh-issue-95185.ghYTZx.rst
new file mode 100644
index 0000000000000..de156bab2f51f
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2022-07-24-00-27-47.gh-issue-95185.ghYTZx.rst	
@@ -0,0 +1,3 @@
+Prevented crashes in the AST constructor when compiling some absurdly long
+expressions like ``"+0"*1000000``. :exc:`RecursionError` is now raised
+instead. Patch by Pablo Galindo
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py
index bf391a3ae1653..13dd44ca0cdc3 100755
--- a/Parser/asdl_c.py
+++ b/Parser/asdl_c.py
@@ -1112,6 +1112,8 @@ def visitModule(self, mod):
         for dfn in mod.dfns:
             self.visit(dfn)
         self.file.write(textwrap.dedent('''
+                state->recursion_depth = 0;
+                state->recursion_limit = 0;
                 state->initialized = 1;
                 return 1;
             }
@@ -1259,8 +1261,14 @@ def func_begin(self, name):
         self.emit('if (!o) {', 1)
         self.emit("Py_RETURN_NONE;", 2)
         self.emit("}", 1)
+        self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
+        self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
+        self.emit('"maximum recursion depth exceeded during ast construction");', 3)
+        self.emit("return 0;", 2)
+        self.emit("}", 1)
 
     def func_end(self):
+        self.emit("state->recursion_depth--;", 1)
         self.emit("return result;", 1)
         self.emit("failed:", 0)
         self.emit("Py_XDECREF(value);", 1)
@@ -1371,7 +1379,32 @@ class PartingShots(StaticVisitor):
     if (state == NULL) {
         return NULL;
     }
-    return ast2obj_mod(state, t);
+
+    int recursion_limit = Py_GetRecursionLimit();
+    int starting_recursion_depth;
+    /* Be careful here to prevent overflow. */
+    int COMPILER_STACK_FRAME_SCALE = 3;
+    PyThreadState *tstate = _PyThreadState_GET();
+    if (!tstate) {
+        return 0;
+    }
+    state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
+        recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
+    int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
+    starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
+        recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
+    state->recursion_depth = starting_recursion_depth;
+
+    PyObject *result = ast2obj_mod(state, t);
+
+    /* Check that the recursion depth counting balanced correctly */
+    if (result && state->recursion_depth != starting_recursion_depth) {
+        PyErr_Format(PyExc_SystemError,
+            "AST constructor recursion depth mismatch (before=%d, after=%d)",
+            starting_recursion_depth, state->recursion_depth);
+        return 0;
+    }
+    return result;
 }
 
 /* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
@@ -1437,6 +1470,8 @@ def visit(self, object):
 def generate_ast_state(module_state, f):
     f.write('struct ast_state {\n')
     f.write('    int initialized;\n')
+    f.write('    int recursion_depth;\n')
+    f.write('    int recursion_limit;\n')
     for s in module_state:
         f.write('    PyObject *' + s + ';\n')
     f.write('};')
diff --git a/Python/Python-ast.c b/Python/Python-ast.c
index e52a72d43bcbd..f485af675ccff 100644
--- a/Python/Python-ast.c
+++ b/Python/Python-ast.c
@@ -1851,6 +1851,8 @@ init_types(struct ast_state *state)
         "TypeIgnore(int lineno, string tag)");
     if (!state->TypeIgnore_type) return 0;
 
+    state->recursion_depth = 0;
+    state->recursion_limit = 0;
     state->initialized = 1;
     return 1;
 }
@@ -3610,6 +3612,11 @@ ast2obj_mod(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     switch (o->kind) {
     case Module_kind:
         tp = (PyTypeObject *)state->Module_type;
@@ -3665,6 +3672,7 @@ ast2obj_mod(struct ast_state *state, void* _o)
         Py_DECREF(value);
         break;
     }
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -3681,6 +3689,11 @@ ast2obj_stmt(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     switch (o->kind) {
     case FunctionDef_kind:
         tp = (PyTypeObject *)state->FunctionDef_type;
@@ -4224,6 +4237,7 @@ ast2obj_stmt(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->end_col_offset, value) < 0)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -4240,6 +4254,11 @@ ast2obj_expr(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     switch (o->kind) {
     case BoolOp_kind:
         tp = (PyTypeObject *)state->BoolOp_type;
@@ -4701,6 +4720,7 @@ ast2obj_expr(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->end_col_offset, value) < 0)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -4843,6 +4863,11 @@ ast2obj_comprehension(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     tp = (PyTypeObject *)state->comprehension_type;
     result = PyType_GenericNew(tp, NULL, NULL);
     if (!result) return NULL;
@@ -4866,6 +4891,7 @@ ast2obj_comprehension(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->is_async, value) == -1)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -4882,6 +4908,11 @@ ast2obj_excepthandler(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     switch (o->kind) {
     case ExceptHandler_kind:
         tp = (PyTypeObject *)state->ExceptHandler_type;
@@ -4925,6 +4956,7 @@ ast2obj_excepthandler(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->end_col_offset, value) < 0)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -4941,6 +4973,11 @@ ast2obj_arguments(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     tp = (PyTypeObject *)state->arguments_type;
     result = PyType_GenericNew(tp, NULL, NULL);
     if (!result) return NULL;
@@ -4979,6 +5016,7 @@ ast2obj_arguments(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->defaults, value) == -1)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -4995,6 +5033,11 @@ ast2obj_arg(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     tp = (PyTypeObject *)state->arg_type;
     result = PyType_GenericNew(tp, NULL, NULL);
     if (!result) return NULL;
@@ -5033,6 +5076,7 @@ ast2obj_arg(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->end_col_offset, value) < 0)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -5049,6 +5093,11 @@ ast2obj_keyword(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     tp = (PyTypeObject *)state->keyword_type;
     result = PyType_GenericNew(tp, NULL, NULL);
     if (!result) return NULL;
@@ -5082,6 +5131,7 @@ ast2obj_keyword(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->end_col_offset, value) < 0)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -5098,6 +5148,11 @@ ast2obj_alias(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     tp = (PyTypeObject *)state->alias_type;
     result = PyType_GenericNew(tp, NULL, NULL);
     if (!result) return NULL;
@@ -5131,6 +5186,7 @@ ast2obj_alias(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->end_col_offset, value) < 0)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -5147,6 +5203,11 @@ ast2obj_withitem(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     tp = (PyTypeObject *)state->withitem_type;
     result = PyType_GenericNew(tp, NULL, NULL);
     if (!result) return NULL;
@@ -5160,6 +5221,7 @@ ast2obj_withitem(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->optional_vars, value) == -1)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -5176,6 +5238,11 @@ ast2obj_match_case(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     tp = (PyTypeObject *)state->match_case_type;
     result = PyType_GenericNew(tp, NULL, NULL);
     if (!result) return NULL;
@@ -5194,6 +5261,7 @@ ast2obj_match_case(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->body, value) == -1)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -5210,6 +5278,11 @@ ast2obj_pattern(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     switch (o->kind) {
     case MatchValue_kind:
         tp = (PyTypeObject *)state->MatchValue_type;
@@ -5349,6 +5422,7 @@ ast2obj_pattern(struct ast_state *state, void* _o)
     if (PyObject_SetAttr(result, state->end_col_offset, value) < 0)
         goto failed;
     Py_DECREF(value);
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -5365,6 +5439,11 @@ ast2obj_type_ignore(struct ast_state *state, void* _o)
     if (!o) {
         Py_RETURN_NONE;
     }
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+            "maximum recursion depth exceeded during ast construction");
+        return 0;
+    }
     switch (o->kind) {
     case TypeIgnore_kind:
         tp = (PyTypeObject *)state->TypeIgnore_type;
@@ -5382,6 +5461,7 @@ ast2obj_type_ignore(struct ast_state *state, void* _o)
         Py_DECREF(value);
         break;
     }
+    state->recursion_depth--;
     return result;
 failed:
     Py_XDECREF(value);
@@ -12234,7 +12314,32 @@ PyObject* PyAST_mod2obj(mod_ty t)
     if (state == NULL) {
         return NULL;
     }
-    return ast2obj_mod(state, t);
+
+    int recursion_limit = Py_GetRecursionLimit();
+    int starting_recursion_depth;
+    /* Be careful here to prevent overflow. */
+    int COMPILER_STACK_FRAME_SCALE = 3;
+    PyThreadState *tstate = _PyThreadState_GET();
+    if (!tstate) {
+        return 0;
+    }
+    state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
+        recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
+    int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
+    starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
+        recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
+    state->recursion_depth = starting_recursion_depth;
+
+    PyObject *result = ast2obj_mod(state, t);
+
+    /* Check that the recursion depth counting balanced correctly */
+    if (result && state->recursion_depth != starting_recursion_depth) {
+        PyErr_Format(PyExc_SystemError,
+            "AST constructor recursion depth mismatch (before=%d, after=%d)",
+            starting_recursion_depth, state->recursion_depth);
+        return 0;
+    }
+    return result;
 }
 
 /* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */



More information about the Python-checkins mailing list