[Python-checkins] GH-98831: Support cache effects in super- and macro instructions (#99601)

gvanrossum webhook-mailer at python.org
Fri Dec 2 22:57:48 EST 2022


https://github.com/python/cpython/commit/acf9184e6b68714cf7a756edefd02372dccd988b
commit: acf9184e6b68714cf7a756edefd02372dccd988b
branch: main
author: Guido van Rossum <guido at python.org>
committer: gvanrossum <gvanrossum at gmail.com>
date: 2022-12-02T19:57:30-08:00
summary:

GH-98831: Support cache effects in super- and macro instructions (#99601)

files:
M Python/generated_cases.c.h
M Tools/cases_generator/generate_cases.py
M Tools/cases_generator/lexer.py
M Tools/cases_generator/parser.py

diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h
index 3af60b83d84e..3a403824b499 100644
--- a/Python/generated_cases.c.h
+++ b/Python/generated_cases.c.h
@@ -436,10 +436,10 @@
         }
 
         TARGET(BINARY_SUBSCR_GETITEM) {
-            uint32_t type_version = read_u32(next_instr + 1);
-            uint16_t func_version = read_u16(next_instr + 3);
             PyObject *sub = PEEK(1);
             PyObject *container = PEEK(2);
+            uint32_t type_version = read_u32(next_instr + 1);
+            uint16_t func_version = read_u16(next_instr + 3);
             PyTypeObject *tp = Py_TYPE(container);
             DEOPT_IF(tp->tp_version_tag != type_version, BINARY_SUBSCR);
             assert(tp->tp_flags & Py_TPFLAGS_HEAPTYPE);
diff --git a/Tools/cases_generator/generate_cases.py b/Tools/cases_generator/generate_cases.py
index 424b15ede2aa..2952634a3cda 100644
--- a/Tools/cases_generator/generate_cases.py
+++ b/Tools/cases_generator/generate_cases.py
@@ -14,23 +14,76 @@
 
 import parser
 
-DEFAULT_INPUT = "Python/bytecodes.c"
-DEFAULT_OUTPUT = "Python/generated_cases.c.h"
+DEFAULT_INPUT = os.path.relpath(
+    os.path.join(os.path.dirname(__file__), "../../Python/bytecodes.c")
+)
+DEFAULT_OUTPUT = os.path.relpath(
+    os.path.join(os.path.dirname(__file__), "../../Python/generated_cases.c.h")
+)
 BEGIN_MARKER = "// BEGIN BYTECODES //"
 END_MARKER = "// END BYTECODES //"
 RE_PREDICTED = r"(?s)(?:PREDICT\(|GO_TO_INSTRUCTION\(|DEOPT_IF\(.*?,\s*)(\w+)\);"
 UNUSED = "unused"
 BITS_PER_CODE_UNIT = 16
 
-arg_parser = argparse.ArgumentParser()
-arg_parser.add_argument("-i", "--input", type=str, default=DEFAULT_INPUT)
-arg_parser.add_argument("-o", "--output", type=str, default=DEFAULT_OUTPUT)
+arg_parser = argparse.ArgumentParser(
+    description="Generate the code for the interpreter switch.",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+)
+arg_parser.add_argument(
+    "-i", "--input", type=str, help="Instruction definitions", default=DEFAULT_INPUT
+)
+arg_parser.add_argument(
+    "-o", "--output", type=str, help="Generated code", default=DEFAULT_OUTPUT
+)
 
 
-# This is not a data class
-class Instruction(parser.InstDef):
+class Formatter:
+    """Wraps an output stream with the ability to indent etc."""
+
+    stream: typing.TextIO
+    prefix: str
+
+    def __init__(self, stream: typing.TextIO, indent: int) -> None:
+        self.stream = stream
+        self.prefix = " " * indent
+
+    def write_raw(self, s: str) -> None:
+        self.stream.write(s)
+
+    def emit(self, arg: str) -> None:
+        if arg:
+            self.write_raw(f"{self.prefix}{arg}\n")
+        else:
+            self.write_raw("\n")
+
+    @contextlib.contextmanager
+    def indent(self):
+        self.prefix += "    "
+        yield
+        self.prefix = self.prefix[:-4]
+
+    @contextlib.contextmanager
+    def block(self, head: str):
+        if head:
+            self.emit(head + " {")
+        else:
+            self.emit("{")
+        with self.indent():
+            yield
+        self.emit("}")
+
+
+ at dataclasses.dataclass
+class Instruction:
     """An instruction with additional data and code."""
 
+    # Parts of the underlying instruction definition
+    inst: parser.InstDef
+    kind: typing.Literal["inst", "op"]
+    name: str
+    block: parser.Block
+
     # Computed by constructor
     always_exits: bool
     cache_offset: int
@@ -43,65 +96,44 @@ class Instruction(parser.InstDef):
     predicted: bool = False
 
     def __init__(self, inst: parser.InstDef):
-        super().__init__(inst.header, inst.block)
-        self.context = inst.context
+        self.inst = inst
+        self.kind = inst.kind
+        self.name = inst.name
+        self.block = inst.block
         self.always_exits = always_exits(self.block)
         self.cache_effects = [
-            effect for effect in self.inputs if isinstance(effect, parser.CacheEffect)
+            effect for effect in inst.inputs if isinstance(effect, parser.CacheEffect)
         ]
         self.cache_offset = sum(c.size for c in self.cache_effects)
         self.input_effects = [
-            effect for effect in self.inputs if isinstance(effect, parser.StackEffect)
+            effect for effect in inst.inputs if isinstance(effect, parser.StackEffect)
         ]
-        self.output_effects = self.outputs  # For consistency/completeness
+        self.output_effects = inst.outputs  # For consistency/completeness
 
-    def write(self, f: typing.TextIO, indent: str, dedent: int = 0) -> None:
+    def write(self, out: Formatter) -> None:
         """Write one instruction, sans prologue and epilogue."""
-        if dedent < 0:
-            indent += " " * -dedent  # DO WE NEED THIS?
-
-        # Get cache offset and maybe assert that it is correct
+        # Write a static assertion that a family's cache size is correct
         if family := self.family:
             if self.name == family.members[0]:
                 if cache_size := family.size:
-                    f.write(
-                        f"{indent}    static_assert({cache_size} == "
-                        f'{self.cache_offset}, "incorrect cache size");\n'
-                    )
-
-        # Write cache effect variable declarations
-        cache_offset = 0
-        for ceffect in self.cache_effects:
-            if ceffect.name != UNUSED:
-                bits = ceffect.size * BITS_PER_CODE_UNIT
-                if bits == 64:
-                    # NOTE: We assume that 64-bit data in the cache
-                    # is always an object pointer.
-                    # If this becomes false, we need a way to specify
-                    # syntactically what type the cache data is.
-                    f.write(
-                        f"{indent}    PyObject *{ceffect.name} = "
-                        f"read_obj(next_instr + {cache_offset});\n"
+                    out.emit(
+                        f"static_assert({cache_size} == "
+                        f'{self.cache_offset}, "incorrect cache size");'
                     )
-                else:
-                    f.write(f"{indent}    uint{bits}_t {ceffect.name} = "
-                        f"read_u{bits}(next_instr + {cache_offset});\n")
-            cache_offset += ceffect.size
-        assert cache_offset == self.cache_offset
 
         # Write input stack effect variable declarations and initializations
         for i, seffect in enumerate(reversed(self.input_effects), 1):
             if seffect.name != UNUSED:
-                f.write(f"{indent}    PyObject *{seffect.name} = PEEK({i});\n")
+                out.emit(f"PyObject *{seffect.name} = PEEK({i});")
 
         # Write output stack effect variable declarations
         input_names = {seffect.name for seffect in self.input_effects}
         input_names.add(UNUSED)
         for seffect in self.output_effects:
             if seffect.name not in input_names:
-                f.write(f"{indent}    PyObject *{seffect.name};\n")
+                out.emit(f"PyObject *{seffect.name};")
 
-        self.write_body(f, indent, dedent)
+        self.write_body(out, 0)
 
         # Skip the rest if the block always exits
         if always_exits(self.block):
@@ -110,9 +142,9 @@ def write(self, f: typing.TextIO, indent: str, dedent: int = 0) -> None:
         # Write net stack growth/shrinkage
         diff = len(self.output_effects) - len(self.input_effects)
         if diff > 0:
-            f.write(f"{indent}    STACK_GROW({diff});\n")
+            out.emit(f"STACK_GROW({diff});")
         elif diff < 0:
-            f.write(f"{indent}    STACK_SHRINK({-diff});\n")
+            out.emit(f"STACK_SHRINK({-diff});")
 
         # Write output stack effect assignments
         unmoved_names = {UNUSED}
@@ -121,14 +153,32 @@ def write(self, f: typing.TextIO, indent: str, dedent: int = 0) -> None:
                 unmoved_names.add(ieffect.name)
         for i, seffect in enumerate(reversed(self.output_effects)):
             if seffect.name not in unmoved_names:
-                f.write(f"{indent}    POKE({i+1}, {seffect.name});\n")
+                out.emit(f"POKE({i+1}, {seffect.name});")
 
         # Write cache effect
         if self.cache_offset:
-            f.write(f"{indent}    next_instr += {self.cache_offset};\n")
+            out.emit(f"next_instr += {self.cache_offset};")
 
-    def write_body(self, f: typing.TextIO, ndent: str, dedent: int) -> None:
+    def write_body(self, out: Formatter, dedent: int, cache_adjust: int = 0) -> None:
         """Write the instruction body."""
+        # Write cache effect variable declarations and initializations
+        cache_offset = cache_adjust
+        for ceffect in self.cache_effects:
+            if ceffect.name != UNUSED:
+                bits = ceffect.size * BITS_PER_CODE_UNIT
+                if bits == 64:
+                    # NOTE: We assume that 64-bit data in the cache
+                    # is always an object pointer.
+                    # If this becomes false, we need a way to specify
+                    # syntactically what type the cache data is.
+                    type = "PyObject *"
+                    func = "read_obj"
+                else:
+                    type = f"uint{bits}_t "
+                    func = f"read_u{bits}"
+                out.emit(f"{type}{ceffect.name} = {func}(next_instr + {cache_offset});")
+            cache_offset += ceffect.size
+        assert cache_offset == self.cache_offset + cache_adjust
 
         # Get lines of text with proper dedent
         blocklines = self.block.to_text(dedent=dedent).splitlines(True)
@@ -165,122 +215,101 @@ def write_body(self, f: typing.TextIO, ndent: str, dedent: int) -> None:
                     else:
                         break
                 if ninputs:
-                    f.write(f"{space}if ({cond}) goto pop_{ninputs}_{label};\n")
+                    out.write_raw(f"{space}if ({cond}) goto pop_{ninputs}_{label};\n")
                 else:
-                    f.write(f"{space}if ({cond}) goto {label};\n")
+                    out.write_raw(f"{space}if ({cond}) goto {label};\n")
             else:
-                f.write(line)
+                out.write_raw(line)
+
+
+InstructionOrCacheEffect = Instruction | parser.CacheEffect
 
 
 @dataclasses.dataclass
-class SuperComponent:
+class Component:
     instr: Instruction
     input_mapping: dict[str, parser.StackEffect]
     output_mapping: dict[str, parser.StackEffect]
 
+    def write_body(self, out: Formatter, cache_adjust: int) -> None:
+        with out.block(""):
+            for var, ieffect in self.input_mapping.items():
+                out.emit(f"PyObject *{ieffect.name} = {var};")
+            for oeffect in self.output_mapping.values():
+                out.emit(f"PyObject *{oeffect.name};")
+            self.instr.write_body(out, dedent=-4, cache_adjust=cache_adjust)
+            for var, oeffect in self.output_mapping.items():
+                out.emit(f"{var} = {oeffect.name};")
+
 
-class SuperInstruction(parser.Super):
+# TODO: Use a common base class for {Super,Macro}Instruction
 
+
+ at dataclasses.dataclass
+class SuperOrMacroInstruction:
+    """Common fields for super- and macro instructions."""
+
+    name: str
     stack: list[str]
     initial_sp: int
     final_sp: int
-    parts: list[SuperComponent]
-
-    def __init__(self, sup: parser.Super):
-        super().__init__(sup.kind, sup.name, sup.ops)
-        self.context = sup.context
-
-    def analyze(self, a: "Analyzer") -> None:
-        components = self.check_components(a)
-        self.stack, self.initial_sp = self.super_macro_analysis(a, components)
-        sp = self.initial_sp
-        self.parts = []
-        for instr in components:
-            input_mapping = {}
-            for ieffect in reversed(instr.input_effects):
-                sp -= 1
-                if ieffect.name != UNUSED:
-                    input_mapping[self.stack[sp]] = ieffect
-            output_mapping = {}
-            for oeffect in instr.output_effects:
-                if oeffect.name != UNUSED:
-                    output_mapping[self.stack[sp]] = oeffect
-                sp += 1
-            self.parts.append(SuperComponent(instr, input_mapping, output_mapping))
-        self.final_sp = sp
-
-    def check_components(self, a: "Analyzer") -> list[Instruction]:
-        components: list[Instruction] = []
-        if not self.ops:
-            a.error(f"{self.kind.capitalize()}-instruction has no operands", self)
-        for name in self.ops:
-            if name not in a.instrs:
-                a.error(f"Unknown instruction {name!r}", self)
-            else:
-                instr = a.instrs[name]
-                if self.kind == "super" and instr.kind != "inst":
-                    a.error(f"Super-instruction operand {instr.name} must be inst, not op", instr)
-                components.append(instr)
-        return components
 
-    def super_macro_analysis(
-        self, a: "Analyzer", components: list[Instruction]
-    ) -> tuple[list[str], int]:
-        """Analyze a super-instruction or macro.
 
-        Print an error if there's a cache effect (which we don't support yet).
+ at dataclasses.dataclass
+class SuperInstruction(SuperOrMacroInstruction):
+    """A super-instruction."""
 
-        Return the list of variable names and the initial stack pointer.
-        """
-        lowest = current = highest = 0
-        for instr in components:
-            if instr.cache_effects:
-                a.error(
-                    f"Super-instruction {self.name!r} has cache effects in {instr.name!r}",
-                    instr,
-                )
-            current -= len(instr.input_effects)
-            lowest = min(lowest, current)
-            current += len(instr.output_effects)
-            highest = max(highest, current)
-        # At this point, 'current' is the net stack effect,
-        # and 'lowest' and 'highest' are the extremes.
-        # Note that 'lowest' may be negative.
-        stack = [f"_tmp_{i+1}" for i in range(highest - lowest)]
-        return stack, -lowest
+    super: parser.Super
+    parts: list[Component]
+
+
+ at dataclasses.dataclass
+class MacroInstruction(SuperOrMacroInstruction):
+    """A macro instruction."""
+
+    macro: parser.Macro
+    parts: list[Component | parser.CacheEffect]
 
 
 class Analyzer:
     """Parse input, analyze it, and write to output."""
 
     filename: str
+    output_filename: str
     src: str
     errors: int = 0
 
+    def __init__(self, filename: str, output_filename: str):
+        """Read the input file."""
+        self.filename = filename
+        self.output_filename = output_filename
+        with open(filename) as f:
+            self.src = f.read()
+
     def error(self, msg: str, node: parser.Node) -> None:
         lineno = 0
         if context := node.context:
             # Use line number of first non-comment in the node
-            for token in context.owner.tokens[context.begin :  context.end]:
+            for token in context.owner.tokens[context.begin : context.end]:
                 lineno = token.line
                 if token.kind != "COMMENT":
                     break
         print(f"{self.filename}:{lineno}: {msg}", file=sys.stderr)
         self.errors += 1
 
-    def __init__(self, filename: str):
-        """Read the input file."""
-        self.filename = filename
-        with open(filename) as f:
-            self.src = f.read()
-
     instrs: dict[str, Instruction]  # Includes ops
-    supers: dict[str, parser.Super]  # Includes macros
+    supers: dict[str, parser.Super]
     super_instrs: dict[str, SuperInstruction]
+    macros: dict[str, parser.Macro]
+    macro_instrs: dict[str, MacroInstruction]
     families: dict[str, parser.Family]
 
     def parse(self) -> None:
-        """Parse the source text."""
+        """Parse the source text.
+
+        We only want the parser to see the stuff between the
+        begin and end markers.
+        """
         psr = parser.Parser(self.src, filename=self.filename)
 
         # Skip until begin marker
@@ -291,24 +320,38 @@ def parse(self) -> None:
             raise psr.make_syntax_error(
                 f"Couldn't find {BEGIN_MARKER!r} in {psr.filename}"
             )
+        start = psr.getpos()
 
-        # Parse until end marker
+        # Find end marker, then delete everything after it
+        while tkn := psr.next(raw=True):
+            if tkn.text == END_MARKER:
+                break
+        del psr.tokens[psr.getpos() - 1 :]
+
+        # Parse from start
+        psr.setpos(start)
         self.instrs = {}
         self.supers = {}
+        self.macros = {}
         self.families = {}
-        while (tkn := psr.peek(raw=True)) and tkn.text != END_MARKER:
-            if inst := psr.inst_def():
-                self.instrs[inst.name] = instr = Instruction(inst)
-            elif super := psr.super_def():
-                self.supers[super.name] = super
-            elif family := psr.family_def():
-                self.families[family.name] = family
-            else:
-                raise psr.make_syntax_error(f"Unexpected token")
+        while thing := psr.definition():
+            match thing:
+                case parser.InstDef(name=name):
+                    self.instrs[name] = Instruction(thing)
+                case parser.Super(name):
+                    self.supers[name] = thing
+                case parser.Macro(name):
+                    self.macros[name] = thing
+                case parser.Family(name):
+                    self.families[name] = thing
+                case _:
+                    typing.assert_never(thing)
+        if not psr.eof():
+            raise psr.make_syntax_error("Extra stuff at the end")
 
         print(
-            f"Read {len(self.instrs)} instructions, "
-            f"{len(self.supers)} supers/macros, "
+            f"Read {len(self.instrs)} instructions/ops, "
+            f"{len(self.supers)} supers, {len(self.macros)} macros, "
             f"and {len(self.families)} families from {self.filename}",
             file=sys.stderr,
         )
@@ -321,7 +364,7 @@ def analyze(self) -> None:
         self.find_predictions()
         self.map_families()
         self.check_families()
-        self.analyze_supers()
+        self.analyze_supers_and_macros()
 
     def find_predictions(self) -> None:
         """Find the instructions that need PREDICTED() labels."""
@@ -332,7 +375,7 @@ def find_predictions(self) -> None:
                 else:
                     self.error(
                         f"Unknown instruction {target!r} predicted in {instr.name!r}",
-                        instr,  # TODO: Use better location
+                        instr.inst,  # TODO: Use better location
                     )
 
     def map_families(self) -> None:
@@ -360,7 +403,9 @@ def check_families(self) -> None:
             members = [member for member in family.members if member in self.instrs]
             if members != family.members:
                 unknown = set(family.members) - set(members)
-                self.error(f"Family {family.name!r} has unknown members: {unknown}", family)
+                self.error(
+                    f"Family {family.name!r} has unknown members: {unknown}", family
+                )
             if len(members) < 2:
                 continue
             head = self.instrs[members[0]]
@@ -381,105 +426,211 @@ def check_families(self) -> None:
                         family,
                     )
 
-    def analyze_supers(self) -> None:
-        """Analyze each super instruction."""
+    def analyze_supers_and_macros(self) -> None:
+        """Analyze each super- and macro instruction."""
         self.super_instrs = {}
-        for name, sup in self.supers.items():
-            dup = SuperInstruction(sup)
-            dup.analyze(self)
-            self.super_instrs[name] = dup
+        self.macro_instrs = {}
+        for name, super in self.supers.items():
+            self.super_instrs[name] = self.analyze_super(super)
+        for name, macro in self.macros.items():
+            self.macro_instrs[name] = self.analyze_macro(macro)
+
+    def analyze_super(self, super: parser.Super) -> SuperInstruction:
+        components = self.check_super_components(super)
+        stack, initial_sp = self.stack_analysis(components)
+        sp = initial_sp
+        parts: list[Component] = []
+        for component in components:
+            match component:
+                case parser.CacheEffect() as ceffect:
+                    parts.append(ceffect)
+                case Instruction() as instr:
+                    input_mapping = {}
+                    for ieffect in reversed(instr.input_effects):
+                        sp -= 1
+                        if ieffect.name != UNUSED:
+                            input_mapping[stack[sp]] = ieffect
+                    output_mapping = {}
+                    for oeffect in instr.output_effects:
+                        if oeffect.name != UNUSED:
+                            output_mapping[stack[sp]] = oeffect
+                        sp += 1
+                    parts.append(Component(instr, input_mapping, output_mapping))
+                case _:
+                    typing.assert_never(component)
+        final_sp = sp
+        return SuperInstruction(super.name, stack, initial_sp, final_sp, super, parts)
+
+    def analyze_macro(self, macro: parser.Macro) -> MacroInstruction:
+        components = self.check_macro_components(macro)
+        stack, initial_sp = self.stack_analysis(components)
+        sp = initial_sp
+        parts: list[Component | parser.CacheEffect] = []
+        for component in components:
+            match component:
+                case parser.CacheEffect() as ceffect:
+                    parts.append(ceffect)
+                case Instruction() as instr:
+                    input_mapping = {}
+                    for ieffect in reversed(instr.input_effects):
+                        sp -= 1
+                        if ieffect.name != UNUSED:
+                            input_mapping[stack[sp]] = ieffect
+                    output_mapping = {}
+                    for oeffect in instr.output_effects:
+                        if oeffect.name != UNUSED:
+                            output_mapping[stack[sp]] = oeffect
+                        sp += 1
+                    parts.append(Component(instr, input_mapping, output_mapping))
+                case _:
+                    typing.assert_never(component)
+        final_sp = sp
+        return MacroInstruction(macro.name, stack, initial_sp, final_sp, macro, parts)
+
+    def check_super_components(self, super: parser.Super) -> list[Instruction]:
+        components: list[Instruction] = []
+        for op in super.ops:
+            if op.name not in self.instrs:
+                self.error(f"Unknown instruction {op.name!r}", super)
+            else:
+                components.append(self.instrs[op.name])
+        return components
 
-    def write_instructions(self, filename: str) -> None:
+    def check_macro_components(
+        self, macro: parser.Macro
+    ) -> list[InstructionOrCacheEffect]:
+        components: list[InstructionOrCacheEffect] = []
+        for uop in macro.uops:
+            match uop:
+                case parser.OpName(name):
+                    if name not in self.instrs:
+                        self.error(f"Unknown instruction {name!r}", macro)
+                    components.append(self.instrs[name])
+                case parser.CacheEffect():
+                    components.append(uop)
+                case _:
+                    typing.assert_never(uop)
+        return components
+
+    def stack_analysis(
+        self, components: typing.Iterable[InstructionOrCacheEffect]
+    ) -> tuple[list[str], int]:
+        """Analyze a super-instruction or macro.
+
+        Print an error if there's a cache effect (which we don't support yet).
+
+        Return the list of variable names and the initial stack pointer.
+        """
+        lowest = current = highest = 0
+        for thing in components:
+            match thing:
+                case Instruction() as instr:
+                    current -= len(instr.input_effects)
+                    lowest = min(lowest, current)
+                    current += len(instr.output_effects)
+                    highest = max(highest, current)
+                case parser.CacheEffect():
+                    pass
+                case _:
+                    typing.assert_never(thing)
+        # At this point, 'current' is the net stack effect,
+        # and 'lowest' and 'highest' are the extremes.
+        # Note that 'lowest' may be negative.
+        stack = [f"_tmp_{i+1}" for i in range(highest - lowest)]
+        return stack, -lowest
+
+    def write_instructions(self) -> None:
         """Write instructions to output file."""
-        indent = " " * 8
-        with open(filename, "w") as f:
+        with open(self.output_filename, "w") as f:
             # Write provenance header
             f.write(f"// This file is generated by {os.path.relpath(__file__)}\n")
             f.write(f"// from {os.path.relpath(self.filename)}\n")
             f.write(f"// Do not edit!\n")
 
-            # Write regular instructions
+            # Create formatter; the rest of the code uses this.
+            self.out = Formatter(f, 8)
+
+            # Write and count regular instructions
             n_instrs = 0
             for name, instr in self.instrs.items():
                 if instr.kind != "inst":
                     continue  # ops are not real instructions
                 n_instrs += 1
-                f.write(f"\n{indent}TARGET({name}) {{\n")
-                if instr.predicted:
-                    f.write(f"{indent}    PREDICTED({name});\n")
-                instr.write(f, indent)
-                if not always_exits(instr.block):
-                    f.write(f"{indent}    DISPATCH();\n")
-                f.write(f"{indent}}}\n")
-
-            # Write super-instructions and macros
+                self.out.emit("")
+                with self.out.block(f"TARGET({name})"):
+                    if instr.predicted:
+                        self.out.emit(f"PREDICTED({name});")
+                    instr.write(self.out)
+                    if not always_exits(instr.block):
+                        self.out.emit(f"DISPATCH();")
+
+            # Write and count super-instructions
             n_supers = 0
-            n_macros = 0
             for sup in self.super_instrs.values():
-                if sup.kind == "super":
-                    n_supers += 1
-                elif sup.kind == "macro":
-                    n_macros += 1
-                self.write_super_macro(f, sup, indent)
-
-            print(
-                f"Wrote {n_instrs} instructions, {n_supers} supers, "
-                f"and {n_macros} macros to {filename}",
-                file=sys.stderr,
-            )
+                n_supers += 1
+                self.write_super(sup)
 
-    def write_super_macro(
-        self, f: typing.TextIO, sup: SuperInstruction, indent: str = ""
-    ) -> None:
+            # Write and count macro instructions
+            n_macros = 0
+            for macro in self.macro_instrs.values():
+                n_macros += 1
+                self.write_macro(macro)
 
-        # TODO: Make write() and block() methods of some Formatter class
-        def write(arg: str) -> None:
-            if arg:
-                f.write(f"{indent}{arg}\n")
-            else:
-                f.write("\n")
+        print(
+            f"Wrote {n_instrs} instructions, {n_supers} supers, "
+            f"and {n_macros} macros to {self.output_filename}",
+            file=sys.stderr,
+        )
 
-        @contextlib.contextmanager
-        def block(head: str):
-            if head:
-                write(head + " {")
-            else:
-                write("{")
-            nonlocal indent
-            indent += "    "
-            yield
-            indent = indent[:-4]
-            write("}")
-
-        write("")
-        with block(f"TARGET({sup.name})"):
-            for i, var in enumerate(sup.stack):
-                if i < sup.initial_sp:
-                    write(f"PyObject *{var} = PEEK({sup.initial_sp - i});")
+    def write_super(self, sup: SuperInstruction) -> None:
+        """Write code for a super-instruction."""
+        with self.wrap_super_or_macro(sup):
+            first = True
+            for comp in sup.parts:
+                if not first:
+                    self.out.emit("NEXTOPARG();")
+                    self.out.emit("next_instr++;")
+                first = False
+                comp.write_body(self.out, 0)
+                if comp.instr.cache_offset:
+                    self.out.emit(f"next_instr += {comp.instr.cache_offset};")
+
+    def write_macro(self, mac: MacroInstruction) -> None:
+        """Write code for a macro instruction."""
+        with self.wrap_super_or_macro(mac):
+            cache_adjust = 0
+            for part in mac.parts:
+                match part:
+                    case parser.CacheEffect(size=size):
+                        cache_adjust += size
+                    case Component() as comp:
+                        comp.write_body(self.out, cache_adjust)
+                        cache_adjust += comp.instr.cache_offset
+
+            if cache_adjust:
+                self.out.emit(f"next_instr += {cache_adjust};")
+
+    @contextlib.contextmanager
+    def wrap_super_or_macro(self, up: SuperOrMacroInstruction):
+        """Shared boilerplate for super- and macro instructions."""
+        self.out.emit("")
+        with self.out.block(f"TARGET({up.name})"):
+            for i, var in enumerate(up.stack):
+                if i < up.initial_sp:
+                    self.out.emit(f"PyObject *{var} = PEEK({up.initial_sp - i});")
                 else:
-                    write(f"PyObject *{var};")
-
-            for i, comp in enumerate(sup.parts):
-                if i > 0 and sup.kind == "super":
-                    write("NEXTOPARG();")
-                    write("next_instr++;")
-
-                with block(""):
-                    for var, ieffect in comp.input_mapping.items():
-                        write(f"PyObject *{ieffect.name} = {var};")
-                    for oeffect in comp.output_mapping.values():
-                        write(f"PyObject *{oeffect.name};")
-                    comp.instr.write_body(f, indent, dedent=-4)
-                    for var, oeffect in comp.output_mapping.items():
-                        write(f"{var} = {oeffect.name};")
-
-            if sup.final_sp > sup.initial_sp:
-                write(f"STACK_GROW({sup.final_sp - sup.initial_sp});")
-            elif sup.final_sp < sup.initial_sp:
-                write(f"STACK_SHRINK({sup.initial_sp - sup.final_sp});")
-            for i, var in enumerate(reversed(sup.stack[:sup.final_sp]), 1):
-                write(f"POKE({i}, {var});")
-            write("DISPATCH();")
+                    self.out.emit(f"PyObject *{var};")
+
+            yield
+
+            if up.final_sp > up.initial_sp:
+                self.out.emit(f"STACK_GROW({up.final_sp - up.initial_sp});")
+            elif up.final_sp < up.initial_sp:
+                self.out.emit(f"STACK_SHRINK({up.initial_sp - up.final_sp});")
+            for i, var in enumerate(reversed(up.stack[: up.final_sp]), 1):
+                self.out.emit(f"POKE({i}, {var});")
+
+            self.out.emit(f"DISPATCH();")
 
 
 def always_exits(block: parser.Block) -> bool:
@@ -506,13 +657,12 @@ def always_exits(block: parser.Block) -> bool:
 def main():
     """Parse command line, parse input, analyze, write output."""
     args = arg_parser.parse_args()  # Prints message and sys.exit(2) on error
-    a = Analyzer(args.input)  # Raises OSError if file not found
+    a = Analyzer(args.input, args.output)  # Raises OSError if input unreadable
     a.parse()  # Raises SyntaxError on failure
-    a.analyze()  # Prints messages and raises SystemExit on failure
+    a.analyze()  # Prints messages and sets a.errors on failure
     if a.errors:
         sys.exit(f"Found {a.errors} errors")
-
-    a.write_instructions(args.output)  # Raises OSError if file can't be written
+    a.write_instructions()  # Raises OSError if output can't be written
 
 
 if __name__ == "__main__":
diff --git a/Tools/cases_generator/lexer.py b/Tools/cases_generator/lexer.py
index 980c920bf357..39b6a212a67b 100644
--- a/Tools/cases_generator/lexer.py
+++ b/Tools/cases_generator/lexer.py
@@ -240,7 +240,12 @@ def to_text(tkns: list[Token], dedent: int = 0) -> str:
             res.append('\n')
             col = 1+dedent
         res.append(' '*(c-col))
-        res.append(tkn.text)
+        text = tkn.text
+        if dedent != 0 and tkn.kind == 'COMMENT' and '\n' in text:
+            if dedent < 0:
+                text = text.replace('\n', '\n' + ' '*-dedent)
+            # TODO: dedent > 0
+        res.append(text)
         line, col = tkn.end
     return ''.join(res)
 
diff --git a/Tools/cases_generator/parser.py b/Tools/cases_generator/parser.py
index ae5ef1e26ea1..02a7834d2215 100644
--- a/Tools/cases_generator/parser.py
+++ b/Tools/cases_generator/parser.py
@@ -9,10 +9,12 @@
 
 P = TypeVar("P", bound="Parser")
 N = TypeVar("N", bound="Node")
-def contextual(func: Callable[[P], N|None]) -> Callable[[P], N|None]:
+
+
+def contextual(func: Callable[[P], N | None]) -> Callable[[P], N | None]:
     # Decorator to wrap grammar methods.
     # Resets position if `func` returns None.
-    def contextual_wrapper(self: P) -> N|None:
+    def contextual_wrapper(self: P) -> N | None:
         begin = self.getpos()
         res = func(self)
         if res is None:
@@ -21,6 +23,7 @@ def contextual_wrapper(self: P) -> N|None:
         end = self.getpos()
         res.context = Context(begin, end, self)
         return res
+
     return contextual_wrapper
 
 
@@ -35,7 +38,7 @@ def __repr__(self):
 
 @dataclass
 class Node:
-    context: Context|None = field(init=False, default=None)
+    context: Context | None = field(init=False, default=None)
 
     @property
     def text(self) -> str:
@@ -68,8 +71,14 @@ class CacheEffect(Node):
     size: int
 
 
+ at dataclass
+class OpName(Node):
+    name: str
+
+
 InputEffect = StackEffect | CacheEffect
 OutputEffect = StackEffect
+UOp = OpName | CacheEffect
 
 
 @dataclass
@@ -82,32 +91,23 @@ class InstHeader(Node):
 
 @dataclass
 class InstDef(Node):
-    # TODO: Merge InstHeader and InstDef
-    header: InstHeader
+    kind: Literal["inst", "op"]
+    name: str
+    inputs: list[InputEffect]
+    outputs: list[OutputEffect]
     block: Block
 
-    @property
-    def kind(self) -> str:
-        return self.header.kind
-
-    @property
-    def name(self) -> str:
-        return self.header.name
 
-    @property
-    def inputs(self) -> list[InputEffect]:
-        return self.header.inputs
-
-    @property
-    def outputs(self) -> list[OutputEffect]:
-        return self.header.outputs
+ at dataclass
+class Super(Node):
+    name: str
+    ops: list[OpName]
 
 
 @dataclass
-class Super(Node):
-    kind: Literal["macro", "super"]
+class Macro(Node):
     name: str
-    ops: list[str]
+    uops: list[UOp]
 
 
 @dataclass
@@ -118,12 +118,22 @@ class Family(Node):
 
 
 class Parser(PLexer):
+    @contextual
+    def definition(self) -> InstDef | Super | Macro | Family | None:
+        if inst := self.inst_def():
+            return inst
+        if super := self.super_def():
+            return super
+        if macro := self.macro_def():
+            return macro
+        if family := self.family_def():
+            return family
 
     @contextual
     def inst_def(self) -> InstDef | None:
-        if header := self.inst_header():
+        if hdr := self.inst_header():
             if block := self.block():
-                return InstDef(header, block)
+                return InstDef(hdr.kind, hdr.name, hdr.inputs, hdr.outputs, block)
             raise self.make_syntax_error("Expected block")
         return None
 
@@ -132,17 +142,14 @@ def inst_header(self) -> InstHeader | None:
         # inst(NAME)
         #   | inst(NAME, (inputs -- outputs))
         #   | op(NAME, (inputs -- outputs))
-        # TODO: Error out when there is something unexpected.
         # TODO: Make INST a keyword in the lexer.
         if (tkn := self.expect(lx.IDENTIFIER)) and (kind := tkn.text) in ("inst", "op"):
-            if (self.expect(lx.LPAREN)
-                    and (tkn := self.expect(lx.IDENTIFIER))):
+            if self.expect(lx.LPAREN) and (tkn := self.expect(lx.IDENTIFIER)):
                 name = tkn.text
                 if self.expect(lx.COMMA):
                     inp, outp = self.stack_effect()
                     if self.expect(lx.RPAREN):
-                        if ((tkn := self.peek())
-                                and tkn.kind == lx.LBRACE):
+                        if (tkn := self.peek()) and tkn.kind == lx.LBRACE:
                             return InstHeader(kind, name, inp, outp)
                 elif self.expect(lx.RPAREN) and kind == "inst":
                     # No legacy stack effect if kind is "op".
@@ -176,18 +183,20 @@ def inputs(self) -> list[InputEffect] | None:
     def input(self) -> InputEffect | None:
         # IDENTIFIER '/' INTEGER (CacheEffect)
         # IDENTIFIER (StackEffect)
-        if (tkn := self.expect(lx.IDENTIFIER)):
+        if tkn := self.expect(lx.IDENTIFIER):
             if self.expect(lx.DIVIDE):
                 if num := self.expect(lx.NUMBER):
                     try:
                         size = int(num.text)
                     except ValueError:
                         raise self.make_syntax_error(
-                            f"Expected integer, got {num.text!r}")
+                            f"Expected integer, got {num.text!r}"
+                        )
                     else:
                         return CacheEffect(tkn.text, size)
                 raise self.make_syntax_error("Expected integer")
             else:
+                # TODO: Arrays, conditions
                 return StackEffect(tkn.text)
 
     def outputs(self) -> list[OutputEffect] | None:
@@ -205,46 +214,91 @@ def outputs(self) -> list[OutputEffect] | None:
 
     @contextual
     def output(self) -> OutputEffect | None:
-        if (tkn := self.expect(lx.IDENTIFIER)):
+        if tkn := self.expect(lx.IDENTIFIER):
             return StackEffect(tkn.text)
 
     @contextual
     def super_def(self) -> Super | None:
-        if (tkn := self.expect(lx.IDENTIFIER)) and (kind := tkn.text) in ("super", "macro"):
+        if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "super":
             if self.expect(lx.LPAREN):
-                if (tkn := self.expect(lx.IDENTIFIER)):
+                if tkn := self.expect(lx.IDENTIFIER):
                     if self.expect(lx.RPAREN):
                         if self.expect(lx.EQUALS):
                             if ops := self.ops():
-                                res = Super(kind, tkn.text, ops)
+                                self.require(lx.SEMI)
+                                res = Super(tkn.text, ops)
                                 return res
 
-    def ops(self) -> list[str] | None:
-        if tkn := self.expect(lx.IDENTIFIER):
-            ops = [tkn.text]
+    def ops(self) -> list[OpName] | None:
+        if op := self.op():
+            ops = [op]
             while self.expect(lx.PLUS):
-                if tkn := self.require(lx.IDENTIFIER):
-                    ops.append(tkn.text)
-            self.require(lx.SEMI)
+                if op := self.op():
+                    ops.append(op)
             return ops
 
+    @contextual
+    def op(self) -> OpName | None:
+        if tkn := self.expect(lx.IDENTIFIER):
+            return OpName(tkn.text)
+
+    @contextual
+    def macro_def(self) -> Macro | None:
+        if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "macro":
+            if self.expect(lx.LPAREN):
+                if tkn := self.expect(lx.IDENTIFIER):
+                    if self.expect(lx.RPAREN):
+                        if self.expect(lx.EQUALS):
+                            if uops := self.uops():
+                                self.require(lx.SEMI)
+                                res = Macro(tkn.text, uops)
+                                return res
+
+    def uops(self) -> list[UOp] | None:
+        if uop := self.uop():
+            uops = [uop]
+            while self.expect(lx.PLUS):
+                if uop := self.uop():
+                    uops.append(uop)
+                else:
+                    raise self.make_syntax_error("Expected op name or cache effect")
+            return uops
+
+    @contextual
+    def uop(self) -> UOp | None:
+        if tkn := self.expect(lx.IDENTIFIER):
+            if self.expect(lx.DIVIDE):
+                if num := self.expect(lx.NUMBER):
+                    try:
+                        size = int(num.text)
+                    except ValueError:
+                        raise self.make_syntax_error(
+                            f"Expected integer, got {num.text!r}"
+                        )
+                    else:
+                        return CacheEffect(tkn.text, size)
+                raise self.make_syntax_error("Expected integer")
+            else:
+                return OpName(tkn.text)
+
     @contextual
     def family_def(self) -> Family | None:
         if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family":
             size = None
             if self.expect(lx.LPAREN):
-                if (tkn := self.expect(lx.IDENTIFIER)):
+                if tkn := self.expect(lx.IDENTIFIER):
                     if self.expect(lx.COMMA):
                         if not (size := self.expect(lx.IDENTIFIER)):
-                            raise self.make_syntax_error(
-                                "Expected identifier")
+                            raise self.make_syntax_error("Expected identifier")
                     if self.expect(lx.RPAREN):
                         if self.expect(lx.EQUALS):
                             if not self.expect(lx.LBRACE):
                                 raise self.make_syntax_error("Expected {")
                             if members := self.members():
                                 if self.expect(lx.RBRACE) and self.expect(lx.SEMI):
-                                    return Family(tkn.text, size.text if size else "", members)
+                                    return Family(
+                                        tkn.text, size.text if size else "", members
+                                    )
         return None
 
     def members(self) -> list[str] | None:
@@ -284,6 +338,7 @@ def c_blob(self) -> list[lx.Token]:
 
 if __name__ == "__main__":
     import sys
+
     if sys.argv[1:]:
         filename = sys.argv[1]
         if filename == "-c" and sys.argv[2:]:
@@ -295,10 +350,10 @@ def c_blob(self) -> list[lx.Token]:
             srclines = src.splitlines()
             begin = srclines.index("// BEGIN BYTECODES //")
             end = srclines.index("// END BYTECODES //")
-            src = "\n".join(srclines[begin+1 : end])
+            src = "\n".join(srclines[begin + 1 : end])
     else:
         filename = "<default>"
         src = "if (x) { x.foo; // comment\n}"
     parser = Parser(src, filename)
-    x = parser.inst_def() or parser.super_def() or parser.family_def()
+    x = parser.definition()
     print(x)



More information about the Python-checkins mailing list