[Python-checkins] gh-98831: Support conditional effects; use for LOAD_ATTR (#101333)

gvanrossum webhook-mailer at python.org
Sun Jan 29 20:28:46 EST 2023


https://github.com/python/cpython/commit/f5a3d91b6c56ddff4644b5a5ac34d8c6d23d7c79
commit: f5a3d91b6c56ddff4644b5a5ac34d8c6d23d7c79
branch: main
author: Guido van Rossum <guido at python.org>
committer: gvanrossum <gvanrossum at gmail.com>
date: 2023-01-29T17:28:39-08:00
summary:

gh-98831: Support conditional effects; use for LOAD_ATTR (#101333)

files:
M Python/bytecodes.c
M Python/generated_cases.c.h
M Python/opcode_metadata.h
M Tools/cases_generator/generate_cases.py
M Tools/cases_generator/parser.py
M Tools/cases_generator/test_generator.py

diff --git a/Python/bytecodes.c b/Python/bytecodes.c
index e5769f61fc28..fb00b887732e 100644
--- a/Python/bytecodes.c
+++ b/Python/bytecodes.c
@@ -51,7 +51,7 @@
 
 // Dummy variables for stack effects.
 static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub;
-static PyObject *container, *start, *stop, *v, *lhs, *rhs;
+static PyObject *container, *start, *stop, *v, *lhs, *rhs, *res2;
 static PyObject *list, *tuple, *dict, *owner, *set, *str, *tup, *map, *keys;
 static PyObject *exit_func, *lasti, *val, *retval, *obj, *iter;
 static PyObject *aiter, *awaitable, *iterable, *w, *exc_value, *bc;
@@ -1438,13 +1438,11 @@ dummy_func(
             PREDICT(JUMP_BACKWARD);
         }
 
-        // error: LOAD_ATTR has irregular stack effect
-        inst(LOAD_ATTR) {
+        inst(LOAD_ATTR, (unused/9, owner -- res2 if (oparg & 1), res)) {
             #if ENABLE_SPECIALIZATION
             _PyAttrCache *cache = (_PyAttrCache *)next_instr;
             if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
                 assert(cframe.use_tracing == 0);
-                PyObject *owner = TOP();
                 PyObject *name = GETITEM(names, oparg>>1);
                 next_instr--;
                 _Py_Specialize_LoadAttr(owner, next_instr, name);
@@ -1454,26 +1452,18 @@ dummy_func(
             DECREMENT_ADAPTIVE_COUNTER(cache->counter);
             #endif  /* ENABLE_SPECIALIZATION */
             PyObject *name = GETITEM(names, oparg >> 1);
-            PyObject *owner = TOP();
             if (oparg & 1) {
-                /* Designed to work in tandem with CALL. */
+                /* Designed to work in tandem with CALL, pushes two values. */
                 PyObject* meth = NULL;
-
-                int meth_found = _PyObject_GetMethod(owner, name, &meth);
-
-                if (meth == NULL) {
-                    /* Most likely attribute wasn't found. */
-                    goto error;
-                }
-
-                if (meth_found) {
+                if (_PyObject_GetMethod(owner, name, &meth)) {
                     /* We can bypass temporary bound method object.
                        meth is unbound method and obj is self.
 
                        meth | self | arg1 | ... | argN
                      */
-                    SET_TOP(meth);
-                    PUSH(owner);  // self
+                    assert(meth != NULL);  // No errors on this branch
+                    res2 = meth;
+                    res = owner;  // Transfer ownership
                 }
                 else {
                     /* meth is not an unbound method (but a regular attr, or
@@ -1483,20 +1473,18 @@ dummy_func(
 
                        NULL | meth | arg1 | ... | argN
                     */
-                    SET_TOP(NULL);
                     Py_DECREF(owner);
-                    PUSH(meth);
+                    ERROR_IF(meth == NULL, error);
+                    res2 = NULL;
+                    res = meth;
                 }
             }
             else {
-                PyObject *res = PyObject_GetAttr(owner, name);
-                if (res == NULL) {
-                    goto error;
-                }
+                /* Classic, pushes one value. */
+                res = PyObject_GetAttr(owner, name);
                 Py_DECREF(owner);
-                SET_TOP(res);
+                ERROR_IF(res == NULL, error);
             }
-            JUMPBY(INLINE_CACHE_ENTRIES_LOAD_ATTR);
         }
 
         // error: LOAD_ATTR has irregular stack effect
diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h
index 287a1f1f0420..b5decf804ca6 100644
--- a/Python/generated_cases.c.h
+++ b/Python/generated_cases.c.h
@@ -1745,11 +1745,13 @@
 
         TARGET(LOAD_ATTR) {
             PREDICTED(LOAD_ATTR);
+            PyObject *owner = PEEK(1);
+            PyObject *res2 = NULL;
+            PyObject *res;
             #if ENABLE_SPECIALIZATION
             _PyAttrCache *cache = (_PyAttrCache *)next_instr;
             if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
                 assert(cframe.use_tracing == 0);
-                PyObject *owner = TOP();
                 PyObject *name = GETITEM(names, oparg>>1);
                 next_instr--;
                 _Py_Specialize_LoadAttr(owner, next_instr, name);
@@ -1759,26 +1761,18 @@
             DECREMENT_ADAPTIVE_COUNTER(cache->counter);
             #endif  /* ENABLE_SPECIALIZATION */
             PyObject *name = GETITEM(names, oparg >> 1);
-            PyObject *owner = TOP();
             if (oparg & 1) {
-                /* Designed to work in tandem with CALL. */
+                /* Designed to work in tandem with CALL, pushes two values. */
                 PyObject* meth = NULL;
-
-                int meth_found = _PyObject_GetMethod(owner, name, &meth);
-
-                if (meth == NULL) {
-                    /* Most likely attribute wasn't found. */
-                    goto error;
-                }
-
-                if (meth_found) {
+                if (_PyObject_GetMethod(owner, name, &meth)) {
                     /* We can bypass temporary bound method object.
                        meth is unbound method and obj is self.
 
                        meth | self | arg1 | ... | argN
                      */
-                    SET_TOP(meth);
-                    PUSH(owner);  // self
+                    assert(meth != NULL);  // No errors on this branch
+                    res2 = meth;
+                    res = owner;  // Transfer ownership
                 }
                 else {
                     /* meth is not an unbound method (but a regular attr, or
@@ -1788,20 +1782,22 @@
 
                        NULL | meth | arg1 | ... | argN
                     */
-                    SET_TOP(NULL);
                     Py_DECREF(owner);
-                    PUSH(meth);
+                    if (meth == NULL) goto pop_1_error;
+                    res2 = NULL;
+                    res = meth;
                 }
             }
             else {
-                PyObject *res = PyObject_GetAttr(owner, name);
-                if (res == NULL) {
-                    goto error;
-                }
+                /* Classic, pushes one value. */
+                res = PyObject_GetAttr(owner, name);
                 Py_DECREF(owner);
-                SET_TOP(res);
+                if (res == NULL) goto pop_1_error;
             }
-            JUMPBY(INLINE_CACHE_ENTRIES_LOAD_ATTR);
+            STACK_GROW(((oparg & 1) ? 1 : 0));
+            POKE(1, res);
+            if (oparg & 1) { POKE(1 + ((oparg & 1) ? 1 : 0), res2); }
+            JUMPBY(9);
             DISPATCH();
         }
 
diff --git a/Python/opcode_metadata.h b/Python/opcode_metadata.h
index cca86629e48d..e76ddda2f029 100644
--- a/Python/opcode_metadata.h
+++ b/Python/opcode_metadata.h
@@ -185,7 +185,7 @@ _PyOpcode_num_popped(int opcode, int oparg) {
         case MAP_ADD:
             return 2;
         case LOAD_ATTR:
-            return -1;
+            return 1;
         case LOAD_ATTR_INSTANCE_VALUE:
             return -1;
         case LOAD_ATTR_MODULE:
@@ -531,7 +531,7 @@ _PyOpcode_num_pushed(int opcode, int oparg) {
         case MAP_ADD:
             return 0;
         case LOAD_ATTR:
-            return -1;
+            return ((oparg & 1) ? 1 : 0) + 1;
         case LOAD_ATTR_INSTANCE_VALUE:
             return -1;
         case LOAD_ATTR_MODULE:
@@ -694,7 +694,7 @@ _PyOpcode_num_pushed(int opcode, int oparg) {
 }
 #endif
 enum Direction { DIR_NONE, DIR_READ, DIR_WRITE };
-enum InstructionFormat { INSTR_FMT_IB, INSTR_FMT_IBC, INSTR_FMT_IBC0, INSTR_FMT_IBC000, INSTR_FMT_IBIB, INSTR_FMT_IX, INSTR_FMT_IXC, INSTR_FMT_IXC000 };
+enum InstructionFormat { INSTR_FMT_IB, INSTR_FMT_IBC, INSTR_FMT_IBC0, INSTR_FMT_IBC000, INSTR_FMT_IBC00000000, INSTR_FMT_IBIB, INSTR_FMT_IX, INSTR_FMT_IXC, INSTR_FMT_IXC000 };
 struct opcode_metadata {
     enum Direction dir_op1;
     enum Direction dir_op2;
@@ -791,7 +791,7 @@ struct opcode_metadata {
     [DICT_UPDATE] = { DIR_NONE, DIR_NONE, DIR_NONE, true, INSTR_FMT_IB },
     [DICT_MERGE] = { DIR_NONE, DIR_NONE, DIR_NONE, true, INSTR_FMT_IB },
     [MAP_ADD] = { DIR_NONE, DIR_NONE, DIR_NONE, true, INSTR_FMT_IB },
-    [LOAD_ATTR] = { DIR_NONE, DIR_NONE, DIR_NONE, true, INSTR_FMT_IB },
+    [LOAD_ATTR] = { DIR_NONE, DIR_NONE, DIR_NONE, true, INSTR_FMT_IBC00000000 },
     [LOAD_ATTR_INSTANCE_VALUE] = { DIR_NONE, DIR_NONE, DIR_NONE, true, INSTR_FMT_IB },
     [LOAD_ATTR_MODULE] = { DIR_NONE, DIR_NONE, DIR_NONE, true, INSTR_FMT_IB },
     [LOAD_ATTR_WITH_HINT] = { DIR_NONE, DIR_NONE, DIR_NONE, true, INSTR_FMT_IB },
diff --git a/Tools/cases_generator/generate_cases.py b/Tools/cases_generator/generate_cases.py
index 9d894d2ff574..1d703a0a790e 100644
--- a/Tools/cases_generator/generate_cases.py
+++ b/Tools/cases_generator/generate_cases.py
@@ -26,7 +26,9 @@
 )
 BEGIN_MARKER = "// BEGIN BYTECODES //"
 END_MARKER = "// END BYTECODES //"
-RE_PREDICTED = r"^\s*(?:PREDICT\(|GO_TO_INSTRUCTION\(|DEOPT_IF\(.*?,\s*)(\w+)\);\s*(?://.*)?$"
+RE_PREDICTED = (
+    r"^\s*(?:PREDICT\(|GO_TO_INSTRUCTION\(|DEOPT_IF\(.*?,\s*)(\w+)\);\s*(?://.*)?$"
+)
 UNUSED = "unused"
 BITS_PER_CODE_UNIT = 16
 
@@ -59,7 +61,10 @@ def effect_size(effect: StackEffect) -> tuple[int, str]:
     At most one of these will be non-zero / non-empty.
     """
     if effect.size:
+        assert not effect.cond, "Array effects cannot have a condition"
         return 0, effect.size
+    elif effect.cond:
+        return 0, f"{maybe_parenthesize(effect.cond)} ? 1 : 0"
     else:
         return 1, ""
 
@@ -132,7 +137,12 @@ def block(self, head: str):
             yield
         self.emit("}")
 
-    def stack_adjust(self, diff: int, input_effects: list[StackEffect], output_effects: list[StackEffect]):
+    def stack_adjust(
+        self,
+        diff: int,
+        input_effects: list[StackEffect],
+        output_effects: list[StackEffect],
+    ):
         # TODO: Get rid of 'diff' parameter
         shrink, isym = list_effect_size(input_effects)
         grow, osym = list_effect_size(output_effects)
@@ -150,10 +160,13 @@ def declare(self, dst: StackEffect, src: StackEffect | None):
         if dst.name == UNUSED:
             return
         typ = f"{dst.type}" if dst.type else "PyObject *"
-        init = ""
         if src:
             cast = self.cast(dst, src)
             init = f" = {cast}{src.name}"
+        elif dst.cond:
+            init = " = NULL"
+        else:
+            init = ""
         sepa = "" if typ.endswith("*") else " "
         self.emit(f"{typ}{sepa}{dst.name}{init};")
 
@@ -162,7 +175,10 @@ def assign(self, dst: StackEffect, src: StackEffect):
             return
         cast = self.cast(dst, src)
         if m := re.match(r"^PEEK\((.*)\)$", dst.name):
-            self.emit(f"POKE({m.group(1)}, {cast}{src.name});")
+            stmt = f"POKE({m.group(1)}, {cast}{src.name});"
+            if src.cond:
+                stmt = f"if ({src.cond}) {{ {stmt} }}"
+            self.emit(stmt)
         elif m := re.match(r"^&PEEK\(.*\)$", dst.name):
             # NOTE: MOVE_ITEMS() does not actually exist.
             # The only supported output array forms are:
@@ -234,7 +250,7 @@ def __init__(self, inst: parser.InstDef):
         if self.register:
             num_regs = len(self.input_effects) + len(self.output_effects)
             num_dummies = (num_regs // 2) * 2 + 1 - num_regs
-            fmt = "I" + "B"*num_regs + "X"*num_dummies
+            fmt = "I" + "B" * num_regs + "X" * num_dummies
         else:
             if variable_used(inst, "oparg"):
                 fmt = "IB"
@@ -276,9 +292,13 @@ def write(self, out: Formatter) -> None:
             # Write input stack effect variable declarations and initializations
             ieffects = list(reversed(self.input_effects))
             for i, ieffect in enumerate(ieffects):
-                isize = string_effect_size(list_effect_size(ieffects[:i+1]))
+                isize = string_effect_size(
+                    list_effect_size([ieff for ieff in ieffects[: i + 1]])
+                )
                 if ieffect.size:
                     src = StackEffect(f"&PEEK({isize})", "PyObject **")
+                elif ieffect.cond:
+                    src = StackEffect(f"({ieffect.cond}) ? PEEK({isize}) : NULL", "")
                 else:
                     src = StackEffect(f"PEEK({isize})", "")
                 out.declare(ieffect, src)
@@ -304,14 +324,20 @@ def write(self, out: Formatter) -> None:
 
         if not self.register:
             # Write net stack growth/shrinkage
-            out.stack_adjust(0, self.input_effects, self.output_effects)
+            out.stack_adjust(
+                0,
+                [ieff for ieff in self.input_effects],
+                [oeff for oeff in self.output_effects],
+            )
 
             # Write output stack effect assignments
             oeffects = list(reversed(self.output_effects))
             for i, oeffect in enumerate(oeffects):
                 if oeffect.name in self.unmoved_names:
                     continue
-                osize = string_effect_size(list_effect_size(oeffects[:i+1]))
+                osize = string_effect_size(
+                    list_effect_size([oeff for oeff in oeffects[: i + 1]])
+                )
                 if oeffect.size:
                     dst = StackEffect(f"&PEEK({osize})", "PyObject **")
                 else:
@@ -438,6 +464,7 @@ class MacroInstruction(SuperOrMacroInstruction):
     parts: list[Component | parser.CacheEffect]
 
 
+AnyInstruction = Instruction | SuperInstruction | MacroInstruction
 INSTR_FMT_PREFIX = "INSTR_FMT_"
 
 
@@ -506,6 +533,7 @@ def parse(self) -> None:
         self.supers = {}
         self.macros = {}
         self.families = {}
+        thing: parser.InstDef | parser.Super | parser.Macro | parser.Family | None
         while thing := psr.definition():
             match thing:
                 case parser.InstDef(name=name):
@@ -631,7 +659,9 @@ def analyze_super(self, super: parser.Super) -> SuperInstruction:
             parts.append(part)
             format += instr.instr_fmt
         final_sp = sp
-        return SuperInstruction(super.name, stack, initial_sp, final_sp, format, super, parts)
+        return SuperInstruction(
+            super.name, stack, initial_sp, final_sp, format, super, parts
+        )
 
     def analyze_macro(self, macro: parser.Macro) -> MacroInstruction:
         components = self.check_macro_components(macro)
@@ -657,7 +687,9 @@ def analyze_macro(self, macro: parser.Macro) -> MacroInstruction:
                 case _:
                     typing.assert_never(component)
         final_sp = sp
-        return MacroInstruction(macro.name, stack, initial_sp, final_sp, format, macro, parts)
+        return MacroInstruction(
+            macro.name, stack, initial_sp, final_sp, format, macro, parts
+        )
 
     def analyze_instruction(
         self, instr: Instruction, stack: list[StackEffect], sp: int
@@ -710,7 +742,9 @@ def stack_analysis(
         for thing in components:
             match thing:
                 case Instruction() as instr:
-                    if any(eff.size for eff in instr.input_effects + instr.output_effects):
+                    if any(
+                        eff.size for eff in instr.input_effects + instr.output_effects
+                    ):
                         # TODO: Eventually this will be needed, at least for macros.
                         self.error(
                             f"Instruction {instr.name!r} has variable-sized stack effect, "
@@ -736,16 +770,16 @@ def stack_analysis(
 
     def get_stack_effect_info(
         self, thing: parser.InstDef | parser.Super | parser.Macro
-    ) -> tuple[Instruction|None, str, str]:
-
-        def effect_str(effect: list[StackEffect]) -> str:
-            if getattr(thing, 'kind', None) == 'legacy':
+    ) -> tuple[AnyInstruction | None, str, str]:
+        def effect_str(effects: list[StackEffect]) -> str:
+            if getattr(thing, "kind", None) == "legacy":
                 return str(-1)
-            n_effect, sym_effect = list_effect_size(effect)
+            n_effect, sym_effect = list_effect_size(effects)
             if sym_effect:
                 return f"{sym_effect} + {n_effect}" if n_effect else sym_effect
             return str(n_effect)
 
+        instr: AnyInstruction | None
         match thing:
             case parser.InstDef():
                 if thing.kind != "op":
@@ -754,34 +788,43 @@ def effect_str(effect: list[StackEffect]) -> str:
                     pushed = effect_str(instr.output_effects)
                 else:
                     instr = None
-                    popped = pushed = "", ""
+                    popped = ""
+                    pushed = ""
             case parser.Super():
                 instr = self.super_instrs[thing.name]
-                popped = '+'.join(effect_str(comp.instr.input_effects) for comp in instr.parts)
-                pushed = '+'.join(effect_str(comp.instr.output_effects) for comp in instr.parts)
+                popped = "+".join(
+                    effect_str(comp.instr.input_effects) for comp in instr.parts
+                )
+                pushed = "+".join(
+                    effect_str(comp.instr.output_effects) for comp in instr.parts
+                )
             case parser.Macro():
                 instr = self.macro_instrs[thing.name]
                 parts = [comp for comp in instr.parts if isinstance(comp, Component)]
-                popped = '+'.join(effect_str(comp.instr.input_effects) for comp in parts)
-                pushed = '+'.join(effect_str(comp.instr.output_effects) for comp in parts)
+                popped = "+".join(
+                    effect_str(comp.instr.input_effects) for comp in parts
+                )
+                pushed = "+".join(
+                    effect_str(comp.instr.output_effects) for comp in parts
+                )
             case _:
                 typing.assert_never(thing)
         return instr, popped, pushed
 
     def write_stack_effect_functions(self) -> None:
-        popped_data = []
-        pushed_data = []
+        popped_data: list[tuple[AnyInstruction, str]] = []
+        pushed_data: list[tuple[AnyInstruction, str]] = []
         for thing in self.everything:
             instr, popped, pushed = self.get_stack_effect_info(thing)
             if instr is not None:
-                popped_data.append( (instr, popped) )
-                pushed_data.append( (instr, pushed) )
+                popped_data.append((instr, popped))
+                pushed_data.append((instr, pushed))
 
-        def write_function(direction: str, data: list[tuple[Instruction, str]]) -> None:
-            self.out.emit("\n#ifndef NDEBUG");
-            self.out.emit("static int");
+        def write_function(direction: str, data: list[tuple[AnyInstruction, str]]) -> None:
+            self.out.emit("\n#ifndef NDEBUG")
+            self.out.emit("static int")
             self.out.emit(f"_PyOpcode_num_{direction}(int opcode, int oparg) {{")
-            self.out.emit("    switch(opcode) {");
+            self.out.emit("    switch(opcode) {")
             for instr, effect in data:
                 self.out.emit(f"        case {instr.name}:")
                 self.out.emit(f"            return {effect};")
@@ -789,10 +832,10 @@ def write_function(direction: str, data: list[tuple[Instruction, str]]) -> None:
             self.out.emit("            Py_UNREACHABLE();")
             self.out.emit("    }")
             self.out.emit("}")
-            self.out.emit("#endif");
+            self.out.emit("#endif")
 
-        write_function('popped', popped_data)
-        write_function('pushed', pushed_data)
+        write_function("popped", popped_data)
+        write_function("pushed", pushed_data)
 
     def write_metadata(self) -> None:
         """Write instruction metadata to output file."""
@@ -865,21 +908,21 @@ def write_metadata_for_inst(self, instr: Instruction) -> None:
                 directions.extend("DIR_NONE" for _ in range(3))
                 dir_op1, dir_op2, dir_op3 = directions[:3]
         self.out.emit(
-            f'    [{instr.name}] = {{ {dir_op1}, {dir_op2}, {dir_op3}, true, {INSTR_FMT_PREFIX}{instr.instr_fmt} }},'
+            f"    [{instr.name}] = {{ {dir_op1}, {dir_op2}, {dir_op3}, true, {INSTR_FMT_PREFIX}{instr.instr_fmt} }},"
         )
 
     def write_metadata_for_super(self, sup: SuperInstruction) -> None:
         """Write metadata for a super-instruction."""
         dir_op1 = dir_op2 = dir_op3 = "DIR_NONE"
         self.out.emit(
-            f'    [{sup.name}] = {{ {dir_op1}, {dir_op2}, {dir_op3}, true, {INSTR_FMT_PREFIX}{sup.instr_fmt} }},'
+            f"    [{sup.name}] = {{ {dir_op1}, {dir_op2}, {dir_op3}, true, {INSTR_FMT_PREFIX}{sup.instr_fmt} }},"
         )
 
     def write_metadata_for_macro(self, mac: MacroInstruction) -> None:
         """Write metadata for a macro-instruction."""
         dir_op1 = dir_op2 = dir_op3 = "DIR_NONE"
         self.out.emit(
-            f'    [{mac.name}] = {{ {dir_op1}, {dir_op2}, {dir_op3}, true, {INSTR_FMT_PREFIX}{mac.instr_fmt} }},'
+            f"    [{mac.name}] = {{ {dir_op1}, {dir_op2}, {dir_op3}, true, {INSTR_FMT_PREFIX}{mac.instr_fmt} }},"
         )
 
     def write_instructions(self) -> None:
@@ -1012,7 +1055,9 @@ def extract_block_text(block: parser.Block) -> tuple[list[str], list[str]]:
 
     # Separate PREDICT(...) macros from end
     predictions: list[str] = []
-    while blocklines and (m := re.match(r"^\s*PREDICT\((\w+)\);\s*(?://.*)?$", blocklines[-1])):
+    while blocklines and (
+        m := re.match(r"^\s*PREDICT\((\w+)\);\s*(?://.*)?$", blocklines[-1])
+    ):
         predictions.insert(0, m.group(1))
         blocklines.pop()
 
@@ -1029,13 +1074,22 @@ def always_exits(lines: list[str]) -> bool:
         return False
     line = line[12:]
     return line.startswith(
-        ("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()", "ERROR_IF(true, ")
+        (
+            "goto ",
+            "return ",
+            "DISPATCH",
+            "GO_TO_",
+            "Py_UNREACHABLE()",
+            "ERROR_IF(true, ",
+        )
     )
 
 
 def variable_used(node: parser.Node, name: str) -> bool:
     """Determine whether a variable with a given name is used in a node."""
-    return any(token.kind == "IDENTIFIER" and token.text == name for token in node.tokens)
+    return any(
+        token.kind == "IDENTIFIER" and token.text == name for token in node.tokens
+    )
 
 
 def main():
diff --git a/Tools/cases_generator/parser.py b/Tools/cases_generator/parser.py
index c2cebe96ccd6..ced66faee493 100644
--- a/Tools/cases_generator/parser.py
+++ b/Tools/cases_generator/parser.py
@@ -48,9 +48,6 @@ def to_text(self, dedent: int = 0) -> str:
         context = self.context
         if not context:
             return ""
-        tokens = context.owner.tokens
-        begin = context.begin
-        end = context.end
         return lx.to_text(self.tokens, dedent)
 
     @property
@@ -74,13 +71,13 @@ class Block(Node):
 class StackEffect(Node):
     name: str
     type: str = ""  # Optional `:type`
+    cond: str = ""  # Optional `if (cond)`
     size: str = ""  # Optional `[size]`
-    # Note: we can have type or size but not both
-    # TODO: condition (can be combined with type but not size)
+    # Note: size cannot be combined with type or cond
 
 
 @dataclass
-class Dimension(Node):
+class Expression(Node):
     size: str
 
 
@@ -239,31 +236,39 @@ def cache_effect(self) -> CacheEffect | None:
 
     @contextual
     def stack_effect(self) -> StackEffect | None:
-        # IDENTIFIER
-        #   | IDENTIFIER ':' IDENTIFIER
-        #   | IDENTIFIER '[' dimension ']'
-        # TODO: Conditions
+        #   IDENTIFIER [':' IDENTIFIER] ['if' '(' expression ')']
+        # | IDENTIFIER '[' expression ']'
         if tkn := self.expect(lx.IDENTIFIER):
+            type_text = ""
             if self.expect(lx.COLON):
-                typ = self.require(lx.IDENTIFIER)
-                return StackEffect(tkn.text, typ.text)
-            elif self.expect(lx.LBRACKET):
-                if not (dim := self.dimension()):
-                    raise self.make_syntax_error("Expected dimension")
+                type_text = self.require(lx.IDENTIFIER).text.strip()
+            cond_text = ""
+            if self.expect(lx.IF):
+                self.require(lx.LPAREN)
+                if not (cond := self.expression()):
+                    raise self.make_syntax_error("Expected condition")
+                self.require(lx.RPAREN)
+                cond_text = cond.text.strip()
+            size_text = ""
+            if self.expect(lx.LBRACKET):
+                if type_text or cond_text:
+                    raise self.make_syntax_error("Unexpected [")
+                if not (size := self.expression()):
+                    raise self.make_syntax_error("Expected expression")
                 self.require(lx.RBRACKET)
-                return StackEffect(tkn.text, "PyObject **", dim.text.strip())
-            else:
-                return StackEffect(tkn.text)
+                type_text = "PyObject **"
+                size_text = size.text.strip()
+            return StackEffect(tkn.text, type_text, cond_text, size_text)
 
     @contextual
-    def dimension(self) -> Dimension | None:
+    def expression(self) -> Expression | None:
         tokens: list[lx.Token] = []
-        while (tkn := self.peek()) and tkn.kind != lx.RBRACKET:
+        while (tkn := self.peek()) and tkn.kind not in (lx.RBRACKET, lx.RPAREN):
             tokens.append(tkn)
             self.next()
         if not tokens:
             return None
-        return Dimension(lx.to_text(tokens).strip())
+        return Expression(lx.to_text(tokens).strip())
 
     @contextual
     def super_def(self) -> Super | None:
@@ -366,7 +371,7 @@ def members(self) -> list[str] | None:
         return None
 
     @contextual
-    def block(self) -> Block:
+    def block(self) -> Block | None:
         if self.c_blob():
             return Block()
 
diff --git a/Tools/cases_generator/test_generator.py b/Tools/cases_generator/test_generator.py
index bd1b974399ab..6e6c60782d73 100644
--- a/Tools/cases_generator/test_generator.py
+++ b/Tools/cases_generator/test_generator.py
@@ -9,19 +9,19 @@
 
 def test_effect_sizes():
     input_effects = [
-        x := StackEffect("x", "", ""),
-        y := StackEffect("y", "", "oparg"),
-        z := StackEffect("z", "", "oparg*2"),
+        x := StackEffect("x", "", "", ""),
+        y := StackEffect("y", "", "", "oparg"),
+        z := StackEffect("z", "", "", "oparg*2"),
     ]
     output_effects = [
-        a := StackEffect("a", "", ""),
-        b := StackEffect("b", "", "oparg*4"),
-        c := StackEffect("c", "", ""),
+        StackEffect("a", "", "", ""),
+        StackEffect("b", "", "", "oparg*4"),
+        StackEffect("c", "", "", ""),
     ]
     other_effects = [
-        p := StackEffect("p", "", "oparg<<1"),
-        q := StackEffect("q", "", ""),
-        r := StackEffect("r", "", ""),
+        StackEffect("p", "", "", "oparg<<1"),
+        StackEffect("q", "", "", ""),
+        StackEffect("r", "", "", ""),
     ]
     assert generate_cases.effect_size(x) == (1, "")
     assert generate_cases.effect_size(y) == (0, "oparg")
@@ -54,6 +54,12 @@ def run_cases_test(input: str, expected: str):
     while lines and lines[0].startswith("// "):
         lines.pop(0)
     actual = "".join(lines)
+    # if actual.rstrip() != expected.rstrip():
+    #     print("Actual:")
+    #     print(actual)
+    #     print("Expected:")
+    #     print(expected)
+    #     print("End")
     assert actual.rstrip() == expected.rstrip()
 
 def test_legacy():
@@ -475,3 +481,28 @@ def test_register():
         }
     """
     run_cases_test(input, output)
+
+def test_cond_effect():
+    input = """
+        inst(OP, (aa, input if (oparg & 1), cc -- xx, output if (oparg & 2), zz)) {
+            output = spam(oparg, input);
+        }
+    """
+    output = """
+        TARGET(OP) {
+            PyObject *cc = PEEK(1);
+            PyObject *input = (oparg & 1) ? PEEK(1 + ((oparg & 1) ? 1 : 0)) : NULL;
+            PyObject *aa = PEEK(2 + ((oparg & 1) ? 1 : 0));
+            PyObject *xx;
+            PyObject *output = NULL;
+            PyObject *zz;
+            output = spam(oparg, input);
+            STACK_SHRINK(((oparg & 1) ? 1 : 0));
+            STACK_GROW(((oparg & 2) ? 1 : 0));
+            POKE(1, zz);
+            if (oparg & 2) { POKE(1 + ((oparg & 2) ? 1 : 0), output); }
+            POKE(2 + ((oparg & 2) ? 1 : 0), xx);
+            DISPATCH();
+        }
+    """
+    run_cases_test(input, output)



More information about the Python-checkins mailing list