[Python-checkins] GH-98831: Refactor and fix cases generator (#99526)

gvanrossum webhook-mailer at python.org
Thu Nov 17 20:06:16 EST 2022


https://github.com/python/cpython/commit/4f5e1cb00a914692895c1c16e446c8d2ab3efb7e
commit: 4f5e1cb00a914692895c1c16e446c8d2ab3efb7e
branch: main
author: Guido van Rossum <guido at python.org>
committer: gvanrossum <gvanrossum at gmail.com>
date: 2022-11-17T17:06:07-08:00
summary:

GH-98831: Refactor and fix cases generator (#99526)

Also complete cache effects for BINARY_SUBSCR family.

files:
M Python/bytecodes.c
M Python/generated_cases.c.h
M Tools/cases_generator/generate_cases.py
M Tools/cases_generator/lexer.py
M Tools/cases_generator/parser.py
M Tools/cases_generator/plexer.py

diff --git a/Python/bytecodes.c b/Python/bytecodes.c
index a3e02674c290..78f7d4ac0616 100644
--- a/Python/bytecodes.c
+++ b/Python/bytecodes.c
@@ -71,7 +71,7 @@ do { \
 
 #define inst(name, ...) case name:
 #define super(name) static int SUPER_##name
-#define family(name) static int family_##name
+#define family(name, ...) static int family_##name
 
 #define NAME_ERROR_MSG \
     "name '%.200s' is not defined"
@@ -79,6 +79,7 @@ do { \
 // Dummy variables for stack effects.
 static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub;
 static PyObject *container, *start, *stop, *v, *lhs, *rhs;
+static PyObject *list, *tuple, *dict;
 
 static PyObject *
 dummy_func(
@@ -322,7 +323,15 @@ dummy_func(
             ERROR_IF(sum == NULL, error);
         }
 
-        inst(BINARY_SUBSCR, (container, sub -- res)) {
+        family(binary_subscr, INLINE_CACHE_ENTRIES_BINARY_SUBSCR) = {
+            BINARY_SUBSCR,
+            BINARY_SUBSCR_DICT,
+            BINARY_SUBSCR_GETITEM,
+            BINARY_SUBSCR_LIST_INT,
+            BINARY_SUBSCR_TUPLE_INT,
+        };
+
+        inst(BINARY_SUBSCR, (container, sub, unused/4 -- res)) {
             _PyBinarySubscrCache *cache = (_PyBinarySubscrCache *)next_instr;
             if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
                 assert(cframe.use_tracing == 0);
@@ -336,7 +345,6 @@ dummy_func(
             Py_DECREF(container);
             Py_DECREF(sub);
             ERROR_IF(res == NULL, error);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_SUBSCR);
         }
 
         inst(BINARY_SLICE, (container, start, stop -- res)) {
@@ -369,11 +377,8 @@ dummy_func(
             ERROR_IF(err, error);
         }
 
-        // stack effect: (__0 -- )
-        inst(BINARY_SUBSCR_LIST_INT) {
+        inst(BINARY_SUBSCR_LIST_INT, (list, sub, unused/4 -- res)) {
             assert(cframe.use_tracing == 0);
-            PyObject *sub = TOP();
-            PyObject *list = SECOND();
             DEOPT_IF(!PyLong_CheckExact(sub), BINARY_SUBSCR);
             DEOPT_IF(!PyList_CheckExact(list), BINARY_SUBSCR);
 
@@ -384,21 +389,15 @@ dummy_func(
             Py_ssize_t index = ((PyLongObject*)sub)->ob_digit[0];
             DEOPT_IF(index >= PyList_GET_SIZE(list), BINARY_SUBSCR);
             STAT_INC(BINARY_SUBSCR, hit);
-            PyObject *res = PyList_GET_ITEM(list, index);
+            res = PyList_GET_ITEM(list, index);
             assert(res != NULL);
             Py_INCREF(res);
-            STACK_SHRINK(1);
             _Py_DECREF_SPECIALIZED(sub, (destructor)PyObject_Free);
-            SET_TOP(res);
             Py_DECREF(list);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_SUBSCR);
         }
 
-        // stack effect: (__0 -- )
-        inst(BINARY_SUBSCR_TUPLE_INT) {
+        inst(BINARY_SUBSCR_TUPLE_INT, (tuple, sub, unused/4 -- res)) {
             assert(cframe.use_tracing == 0);
-            PyObject *sub = TOP();
-            PyObject *tuple = SECOND();
             DEOPT_IF(!PyLong_CheckExact(sub), BINARY_SUBSCR);
             DEOPT_IF(!PyTuple_CheckExact(tuple), BINARY_SUBSCR);
 
@@ -409,51 +408,39 @@ dummy_func(
             Py_ssize_t index = ((PyLongObject*)sub)->ob_digit[0];
             DEOPT_IF(index >= PyTuple_GET_SIZE(tuple), BINARY_SUBSCR);
             STAT_INC(BINARY_SUBSCR, hit);
-            PyObject *res = PyTuple_GET_ITEM(tuple, index);
+            res = PyTuple_GET_ITEM(tuple, index);
             assert(res != NULL);
             Py_INCREF(res);
-            STACK_SHRINK(1);
             _Py_DECREF_SPECIALIZED(sub, (destructor)PyObject_Free);
-            SET_TOP(res);
             Py_DECREF(tuple);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_SUBSCR);
         }
 
-        // stack effect: (__0 -- )
-        inst(BINARY_SUBSCR_DICT) {
+        inst(BINARY_SUBSCR_DICT, (dict, sub, unused/4 -- res)) {
             assert(cframe.use_tracing == 0);
-            PyObject *dict = SECOND();
-            DEOPT_IF(!PyDict_CheckExact(SECOND()), BINARY_SUBSCR);
+            DEOPT_IF(!PyDict_CheckExact(dict), BINARY_SUBSCR);
             STAT_INC(BINARY_SUBSCR, hit);
-            PyObject *sub = TOP();
-            PyObject *res = PyDict_GetItemWithError(dict, sub);
+            res = PyDict_GetItemWithError(dict, sub);
             if (res == NULL) {
                 if (!_PyErr_Occurred(tstate)) {
                     _PyErr_SetKeyError(sub);
                 }
-                goto error;
+                Py_DECREF(dict);
+                Py_DECREF(sub);
+                ERROR_IF(1, error);
             }
-            Py_INCREF(res);
-            STACK_SHRINK(1);
-            Py_DECREF(sub);
-            SET_TOP(res);
+            Py_INCREF(res);  // Do this before DECREF'ing dict, sub
             Py_DECREF(dict);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_SUBSCR);
+            Py_DECREF(sub);
         }
 
-        // stack effect: (__0 -- )
-        inst(BINARY_SUBSCR_GETITEM) {
-            PyObject *sub = TOP();
-            PyObject *container = SECOND();
-            _PyBinarySubscrCache *cache = (_PyBinarySubscrCache *)next_instr;
-            uint32_t type_version = read_u32(cache->type_version);
+        inst(BINARY_SUBSCR_GETITEM, (container, sub, unused/1, type_version/2, func_version/1 -- unused)) {
             PyTypeObject *tp = Py_TYPE(container);
             DEOPT_IF(tp->tp_version_tag != type_version, BINARY_SUBSCR);
             assert(tp->tp_flags & Py_TPFLAGS_HEAPTYPE);
             PyObject *cached = ((PyHeapTypeObject *)tp)->_spec_cache.getitem;
             assert(PyFunction_Check(cached));
             PyFunctionObject *getitem = (PyFunctionObject *)cached;
-            DEOPT_IF(getitem->func_version != cache->func_version, BINARY_SUBSCR);
+            DEOPT_IF(getitem->func_version != func_version, BINARY_SUBSCR);
             PyCodeObject *code = (PyCodeObject *)getitem->func_code;
             assert(code->co_argcount == 2);
             DEOPT_IF(!_PyThreadState_HasStackSpace(tstate, code->co_framesize), BINARY_SUBSCR);
diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h
index ba00203da301..2c6333f8e615 100644
--- a/Python/generated_cases.c.h
+++ b/Python/generated_cases.c.h
@@ -1,4 +1,5 @@
 // This file is generated by Tools/cases_generator/generate_cases.py
+// from Python/bytecodes.c
 // Do not edit!
 
         TARGET(NOP) {
@@ -300,6 +301,7 @@
 
         TARGET(BINARY_SUBSCR) {
             PREDICTED(BINARY_SUBSCR);
+            static_assert(INLINE_CACHE_ENTRIES_BINARY_SUBSCR == 4, "incorrect cache size");
             PyObject *sub = PEEK(1);
             PyObject *container = PEEK(2);
             PyObject *res;
@@ -316,9 +318,9 @@
             Py_DECREF(container);
             Py_DECREF(sub);
             if (res == NULL) goto pop_2_error;
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_SUBSCR);
             STACK_SHRINK(1);
             POKE(1, res);
+            next_instr += 4;
             DISPATCH();
         }
 
@@ -366,9 +368,10 @@
         }
 
         TARGET(BINARY_SUBSCR_LIST_INT) {
+            PyObject *sub = PEEK(1);
+            PyObject *list = PEEK(2);
+            PyObject *res;
             assert(cframe.use_tracing == 0);
-            PyObject *sub = TOP();
-            PyObject *list = SECOND();
             DEOPT_IF(!PyLong_CheckExact(sub), BINARY_SUBSCR);
             DEOPT_IF(!PyList_CheckExact(list), BINARY_SUBSCR);
 
@@ -379,21 +382,22 @@
             Py_ssize_t index = ((PyLongObject*)sub)->ob_digit[0];
             DEOPT_IF(index >= PyList_GET_SIZE(list), BINARY_SUBSCR);
             STAT_INC(BINARY_SUBSCR, hit);
-            PyObject *res = PyList_GET_ITEM(list, index);
+            res = PyList_GET_ITEM(list, index);
             assert(res != NULL);
             Py_INCREF(res);
-            STACK_SHRINK(1);
             _Py_DECREF_SPECIALIZED(sub, (destructor)PyObject_Free);
-            SET_TOP(res);
             Py_DECREF(list);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_SUBSCR);
+            STACK_SHRINK(1);
+            POKE(1, res);
+            next_instr += 4;
             DISPATCH();
         }
 
         TARGET(BINARY_SUBSCR_TUPLE_INT) {
+            PyObject *sub = PEEK(1);
+            PyObject *tuple = PEEK(2);
+            PyObject *res;
             assert(cframe.use_tracing == 0);
-            PyObject *sub = TOP();
-            PyObject *tuple = SECOND();
             DEOPT_IF(!PyLong_CheckExact(sub), BINARY_SUBSCR);
             DEOPT_IF(!PyTuple_CheckExact(tuple), BINARY_SUBSCR);
 
@@ -404,51 +408,54 @@
             Py_ssize_t index = ((PyLongObject*)sub)->ob_digit[0];
             DEOPT_IF(index >= PyTuple_GET_SIZE(tuple), BINARY_SUBSCR);
             STAT_INC(BINARY_SUBSCR, hit);
-            PyObject *res = PyTuple_GET_ITEM(tuple, index);
+            res = PyTuple_GET_ITEM(tuple, index);
             assert(res != NULL);
             Py_INCREF(res);
-            STACK_SHRINK(1);
             _Py_DECREF_SPECIALIZED(sub, (destructor)PyObject_Free);
-            SET_TOP(res);
             Py_DECREF(tuple);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_SUBSCR);
+            STACK_SHRINK(1);
+            POKE(1, res);
+            next_instr += 4;
             DISPATCH();
         }
 
         TARGET(BINARY_SUBSCR_DICT) {
+            PyObject *sub = PEEK(1);
+            PyObject *dict = PEEK(2);
+            PyObject *res;
             assert(cframe.use_tracing == 0);
-            PyObject *dict = SECOND();
-            DEOPT_IF(!PyDict_CheckExact(SECOND()), BINARY_SUBSCR);
+            DEOPT_IF(!PyDict_CheckExact(dict), BINARY_SUBSCR);
             STAT_INC(BINARY_SUBSCR, hit);
-            PyObject *sub = TOP();
-            PyObject *res = PyDict_GetItemWithError(dict, sub);
+            res = PyDict_GetItemWithError(dict, sub);
             if (res == NULL) {
                 if (!_PyErr_Occurred(tstate)) {
                     _PyErr_SetKeyError(sub);
                 }
-                goto error;
+                Py_DECREF(dict);
+                Py_DECREF(sub);
+                if (1) goto pop_2_error;
             }
-            Py_INCREF(res);
-            STACK_SHRINK(1);
-            Py_DECREF(sub);
-            SET_TOP(res);
+            Py_INCREF(res);  // Do this before DECREF'ing dict, sub
             Py_DECREF(dict);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_SUBSCR);
+            Py_DECREF(sub);
+            STACK_SHRINK(1);
+            POKE(1, res);
+            next_instr += 4;
             DISPATCH();
         }
 
         TARGET(BINARY_SUBSCR_GETITEM) {
-            PyObject *sub = TOP();
-            PyObject *container = SECOND();
-            _PyBinarySubscrCache *cache = (_PyBinarySubscrCache *)next_instr;
-            uint32_t type_version = read_u32(cache->type_version);
+            uint32_t type_version = read_u32(next_instr + 1);
+            uint16_t func_version = *(next_instr + 3);
+            PyObject *sub = PEEK(1);
+            PyObject *container = PEEK(2);
             PyTypeObject *tp = Py_TYPE(container);
             DEOPT_IF(tp->tp_version_tag != type_version, BINARY_SUBSCR);
             assert(tp->tp_flags & Py_TPFLAGS_HEAPTYPE);
             PyObject *cached = ((PyHeapTypeObject *)tp)->_spec_cache.getitem;
             assert(PyFunction_Check(cached));
             PyFunctionObject *getitem = (PyFunctionObject *)cached;
-            DEOPT_IF(getitem->func_version != cache->func_version, BINARY_SUBSCR);
+            DEOPT_IF(getitem->func_version != func_version, BINARY_SUBSCR);
             PyCodeObject *code = (PyCodeObject *)getitem->func_code;
             assert(code->co_argcount == 2);
             DEOPT_IF(!_PyThreadState_HasStackSpace(tstate, code->co_framesize), BINARY_SUBSCR);
@@ -3656,7 +3663,7 @@
 
         TARGET(BINARY_OP) {
             PREDICTED(BINARY_OP);
-            assert(INLINE_CACHE_ENTRIES_BINARY_OP == 1);
+            static_assert(INLINE_CACHE_ENTRIES_BINARY_OP == 1, "incorrect cache size");
             PyObject *rhs = PEEK(1);
             PyObject *lhs = PEEK(2);
             PyObject *res;
diff --git a/Tools/cases_generator/generate_cases.py b/Tools/cases_generator/generate_cases.py
index d01653175091..e11d0c77e99d 100644
--- a/Tools/cases_generator/generate_cases.py
+++ b/Tools/cases_generator/generate_cases.py
@@ -1,55 +1,326 @@
-"""Generate the main interpreter switch."""
+"""Generate the main interpreter switch.
 
-# Write the cases to generated_cases.c.h, which is #included in ceval.c.
-
-# TODO: Reuse C generation framework from deepfreeze.py?
+Reads the instruction definitions from bytecodes.c.
+Writes the cases to generated_cases.c.h, which is #included in ceval.c.
+"""
 
 import argparse
 import os
 import re
 import sys
-from typing import TextIO, cast
+import typing
 
 import parser
-from parser import InstDef  # TODO: Use parser.InstDef
 
+DEFAULT_INPUT = "Python/bytecodes.c"
+DEFAULT_OUTPUT = "Python/generated_cases.c.h"
+BEGIN_MARKER = "// BEGIN BYTECODES //"
+END_MARKER = "// END BYTECODES //"
 RE_PREDICTED = r"(?s)(?:PREDICT\(|GO_TO_INSTRUCTION\(|DEOPT_IF\(.*?,\s*)(\w+)\);"
 
 arg_parser = argparse.ArgumentParser()
-arg_parser.add_argument("-i", "--input", type=str, default="Python/bytecodes.c")
-arg_parser.add_argument("-o", "--output", type=str, default="Python/generated_cases.c.h")
-arg_parser.add_argument("-q", "--quiet", action="store_true")
+arg_parser.add_argument("-i", "--input", type=str, default=DEFAULT_INPUT)
+arg_parser.add_argument("-o", "--output", type=str, default=DEFAULT_OUTPUT)
 
 
-def eopen(filename: str, mode: str = "r") -> TextIO:
-    if filename == "-":
-        if "r" in mode:
-            return sys.stdin
-        else:
-            return sys.stdout
-    return cast(TextIO, open(filename, mode))
-
-
-def parse_cases(
-    src: str, filename: str|None = None
-) -> tuple[list[InstDef], list[parser.Super], list[parser.Family]]:
-    psr = parser.Parser(src, filename=filename)
-    instrs: list[InstDef] = []
-    supers: list[parser.Super] = []
-    families: list[parser.Family] = []
-    while not psr.eof():
-        if inst := psr.inst_def():
-            instrs.append(inst)
-        elif sup := psr.super_def():
-            supers.append(sup)
-        elif fam := psr.family_def():
-            families.append(fam)
+# This is not a data class
+class Instruction(parser.InstDef):
+    """An instruction with additional data and code."""
+
+    # Computed by constructor
+    always_exits: bool
+    cache_offset: int
+    cache_effects: list[parser.CacheEffect]
+    input_effects: list[parser.StackEffect]
+    output_effects: list[parser.StackEffect]
+
+    # Set later
+    family: parser.Family | None = None
+    predicted: bool = False
+
+    def __init__(self, inst: parser.InstDef):
+        super().__init__(inst.header, inst.block)
+        self.context = inst.context
+        self.always_exits = always_exits(self.block)
+        self.cache_effects = [
+            effect for effect in self.inputs if isinstance(effect, parser.CacheEffect)
+        ]
+        self.cache_offset = sum(c.size for c in self.cache_effects)
+        self.input_effects = [
+            effect for effect in self.inputs if isinstance(effect, parser.StackEffect)
+        ]
+        self.output_effects = self.outputs  # For consistency/completeness
+
+    def write(
+        self, f: typing.TextIO, indent: str, dedent: int = 0
+    ) -> None:
+        """Write one instruction, sans prologue and epilogue."""
+        if dedent < 0:
+            indent += " " * -dedent  # DO WE NEED THIS?
+
+        # Get cache offset and maybe assert that it is correct
+        if family := self.family:
+            if self.name == family.members[0]:
+                if cache_size := family.size:
+                    f.write(
+                        f"{indent}    static_assert({cache_size} == "
+                        f'{self.cache_offset}, "incorrect cache size");\n'
+                    )
+
+        # Write cache effect variable declarations
+        cache_offset = 0
+        for ceffect in self.cache_effects:
+            if ceffect.name != "unused":
+                # TODO: if name is 'descr' use PyObject *descr = read_obj(...)
+                bits = ceffect.size * 16
+                f.write(f"{indent}    uint{bits}_t {ceffect.name} = ")
+                if ceffect.size == 1:
+                    f.write(f"*(next_instr + {cache_offset});\n")
+                else:
+                    f.write(f"read_u{bits}(next_instr + {cache_offset});\n")
+            cache_offset += ceffect.size
+        assert cache_offset == self.cache_offset
+
+        # Write input stack effect variable declarations and initializations
+        for i, seffect in enumerate(reversed(self.input_effects), 1):
+            if seffect.name != "unused":
+                f.write(f"{indent}    PyObject *{seffect.name} = PEEK({i});\n")
+
+        # Write output stack effect variable declarations
+        for seffect in self.output_effects:
+            if seffect.name != "unused":
+                f.write(f"{indent}    PyObject *{seffect.name};\n")
+
+        self.write_body(f, indent, dedent)
+
+        # Skip the rest if the block always exits
+        if always_exits(self.block):
+            return
+
+        # Write net stack growth/shrinkage
+        diff = len(self.output_effects) - len(self.input_effects)
+        if diff > 0:
+            f.write(f"{indent}    STACK_GROW({diff});\n")
+        elif diff < 0:
+            f.write(f"{indent}    STACK_SHRINK({-diff});\n")
+
+        # Write output stack effect assignments
+        input_names = [seffect.name for seffect in self.input_effects]
+        for i, output in enumerate(reversed(self.output_effects), 1):
+            if output.name not in input_names and output.name != "unused":
+                f.write(f"{indent}    POKE({i}, {output.name});\n")
+
+        # Write cache effect
+        if self.cache_offset:
+            f.write(f"{indent}    next_instr += {self.cache_offset};\n")
+
+    def write_body(
+        self, f: typing.TextIO, ndent: str, dedent: int
+    ) -> None:
+        """Write the instruction body."""
+
+        # Get lines of text with proper dedelt
+        blocklines = self.block.to_text(dedent=dedent).splitlines(True)
+
+        # Remove blank lines from both ends
+        while blocklines and not blocklines[0].strip():
+            blocklines.pop(0)
+        while blocklines and not blocklines[-1].strip():
+            blocklines.pop()
+
+        # Remove leading and trailing braces
+        assert blocklines and blocklines[0].strip() == "{"
+        assert blocklines and blocklines[-1].strip() == "}"
+        blocklines.pop()
+        blocklines.pop(0)
+
+        # Remove trailing blank lines
+        while blocklines and not blocklines[-1].strip():
+            blocklines.pop()
+
+        # Write the body, substituting a goto for ERROR_IF()
+        for line in blocklines:
+            if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line):
+                space, cond, label = m.groups()
+                # ERROR_IF() must pop the inputs from the stack.
+                # The code block is responsible for DECREF()ing them.
+                # NOTE: If the label doesn't exist, just add it to ceval.c.
+                ninputs = len(self.input_effects)
+                if ninputs:
+                    f.write(f"{space}if ({cond}) goto pop_{ninputs}_{label};\n")
+                else:
+                    f.write(f"{space}if ({cond}) goto {label};\n")
+            else:
+                f.write(line)
+
+
+class Analyzer:
+    """Parse input, analyze it, and write to output."""
+
+    filename: str
+    src: str
+    errors: int = 0
+
+    def __init__(self, filename: str):
+        """Read the input file."""
+        self.filename = filename
+        with open(filename) as f:
+            self.src = f.read()
+
+    instrs: dict[str, Instruction]
+    supers: dict[str, parser.Super]
+    families: dict[str, parser.Family]
+
+    def parse(self) -> None:
+        """Parse the source text."""
+        psr = parser.Parser(self.src, filename=self.filename)
+
+        # Skip until begin marker
+        while tkn := psr.next(raw=True):
+            if tkn.text == BEGIN_MARKER:
+                break
         else:
-            raise psr.make_syntax_error(f"Unexpected token")
-    return instrs, supers, families
+            raise psr.make_syntax_error(f"Couldn't find {BEGIN_MARKER!r} in {psr.filename}")
+
+        # Parse until end marker
+        self.instrs = {}
+        self.supers = {}
+        self.families = {}
+        while (tkn := psr.peek(raw=True)) and tkn.text != END_MARKER:
+            if inst := psr.inst_def():
+                self.instrs[inst.name] = instr = Instruction(inst)
+            elif super := psr.super_def():
+                self.supers[super.name] = super
+            elif family := psr.family_def():
+                self.families[family.name] = family
+            else:
+                raise psr.make_syntax_error(f"Unexpected token")
+
+        print(
+            f"Read {len(self.instrs)} instructions, "
+            f"{len(self.supers)} supers, "
+            f"and {len(self.families)} families from {self.filename}",
+            file=sys.stderr,
+        )
+
+    def analyze(self) -> None:
+        """Analyze the inputs.
+
+        Raises SystemExit if there is an error.
+        """
+        self.find_predictions()
+        self.map_families()
+        self.check_families()
+
+    def find_predictions(self) -> None:
+        """Find the instructions that need PREDICTED() labels."""
+        for instr in self.instrs.values():
+            for target in re.findall(RE_PREDICTED, instr.block.text):
+                if target_instr := self.instrs.get(target):
+                    target_instr.predicted = True
+                else:
+                    print(
+                        f"Unknown instruction {target!r} predicted in {instr.name!r}",
+                        file=sys.stderr,
+                    )
+                    self.errors += 1
+
+    def map_families(self) -> None:
+        """Make instruction names back to their family, if they have one."""
+        for family in self.families.values():
+            for member in family.members:
+                if member_instr := self.instrs.get(member):
+                    member_instr.family = family
+                else:
+                    print(
+                        f"Unknown instruction {member!r} referenced in family {family.name!r}",
+                        file=sys.stderr,
+                    )
+                    self.errors += 1
+
+    def check_families(self) -> None:
+        """Check each family:
+
+        - Must have at least 2 members
+        - All members must be known instructions
+        - All members must have the same cache, input and output effects
+        """
+        for family in self.families.values():
+            if len(family.members) < 2:
+                print(f"Family {family.name!r} has insufficient members")
+                self.errors += 1
+            members = [member for member in family.members if member in self.instrs]
+            if members != family.members:
+                unknown = set(family.members) - set(members)
+                print(f"Family {family.name!r} has unknown members: {unknown}")
+                self.errors += 1
+            if len(members) < 2:
+                continue
+            head = self.instrs[members[0]]
+            cache = head.cache_offset
+            input = len(head.input_effects)
+            output = len(head.output_effects)
+            for member in members[1:]:
+                instr = self.instrs[member]
+                c = instr.cache_offset
+                i = len(instr.input_effects)
+                o = len(instr.output_effects)
+                if (c, i, o) != (cache, input, output):
+                    self.errors += 1
+                    print(
+                        f"Family {family.name!r} has inconsistent "
+                        f"(cache, inputs, outputs) effects:",
+                        file=sys.stderr,
+                    )
+                    print(
+                        f"  {family.members[0]} = {(cache, input, output)}; "
+                        f"{member} = {(c, i, o)}",
+                        file=sys.stderr,
+                    )
+                    self.errors += 1
+
+    def write_instructions(self, filename: str) -> None:
+        """Write instructions to output file."""
+        indent = " " * 8
+        with open(filename, "w") as f:
+            # Write provenance header
+            f.write(f"// This file is generated by {os.path.relpath(__file__)}\n")
+            f.write(f"// from {os.path.relpath(self.filename)}\n")
+            f.write(f"// Do not edit!\n")
+
+            # Write regular instructions
+            for name, instr in self.instrs.items():
+                f.write(f"\n{indent}TARGET({name}) {{\n")
+                if instr.predicted:
+                    f.write(f"{indent}    PREDICTED({name});\n")
+                instr.write(f, indent)
+                if not always_exits(instr.block):
+                    f.write(f"{indent}    DISPATCH();\n")
+                f.write(f"{indent}}}\n")
+
+            # Write super-instructions
+            for name, sup in self.supers.items():
+                components = [self.instrs[name] for name in sup.ops]
+                f.write(f"\n{indent}TARGET({sup.name}) {{\n")
+                for i, instr in enumerate(components):
+                    if i > 0:
+                        f.write(f"{indent}    NEXTOPARG();\n")
+                        f.write(f"{indent}    next_instr++;\n")
+                    f.write(f"{indent}    {{\n")
+                    instr.write(f, indent, dedent=-4)
+                    f.write(f"    {indent}}}\n")
+                f.write(f"{indent}    DISPATCH();\n")
+                f.write(f"{indent}}}\n")
+
+            print(
+                f"Wrote {len(self.instrs)} instructions and "
+                f"{len(self.supers)} super-instructions to {filename}",
+                file=sys.stderr,
+            )
 
 
 def always_exits(block: parser.Block) -> bool:
+    """Determine whether a block always ends in a return/goto/etc."""
     text = block.text
     lines = text.splitlines()
     while lines and not lines[-1].strip():
@@ -61,181 +332,24 @@ def always_exits(block: parser.Block) -> bool:
         return False
     line = lines.pop().rstrip()
     # Indent must match exactly (TODO: Do something better)
-    if line[:12] != " "*12:
+    if line[:12] != " " * 12:
         return False
     line = line[12:]
-    return line.startswith(("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()"))
-
-
-def find_cache_size(instr: InstDef, families: list[parser.Family]) -> str | None:
-    for family in families:
-        if instr.name == family.members[0]:
-            return family.size
-
-
-def write_instr(
-    instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0, cache_size: str | None = None
-) -> int:
-    # Returns cache offset
-    if dedent < 0:
-        indent += " " * -dedent
-    # Separate stack inputs from cache inputs
-    input_names: set[str] = set()
-    stack: list[parser.StackEffect] = []
-    cache: list[parser.CacheEffect] = []
-    for input in instr.inputs:
-        if isinstance(input, parser.StackEffect):
-            stack.append(input)
-            input_names.add(input.name)
-        else:
-            assert isinstance(input, parser.CacheEffect), input
-            cache.append(input)
-    outputs = instr.outputs
-    cache_offset = 0
-    for ceffect in cache:
-        if ceffect.name != "unused":
-            bits = ceffect.size * 16
-            f.write(f"{indent}    PyObject *{ceffect.name} = read{bits}(next_instr + {cache_offset});\n")
-        cache_offset += ceffect.size
-    if cache_size:
-        f.write(f"{indent}    assert({cache_size} == {cache_offset});\n")
-    # TODO: Is it better to count forward or backward?
-    for i, effect in enumerate(reversed(stack), 1):
-        if effect.name != "unused":
-            f.write(f"{indent}    PyObject *{effect.name} = PEEK({i});\n")
-    for output in instr.outputs:
-        if output.name not in input_names and output.name != "unused":
-            f.write(f"{indent}    PyObject *{output.name};\n")
-    blocklines = instr.block.to_text(dedent=dedent).splitlines(True)
-    # Remove blank lines from ends
-    while blocklines and not blocklines[0].strip():
-        blocklines.pop(0)
-    while blocklines and not blocklines[-1].strip():
-        blocklines.pop()
-    # Remove leading '{' and trailing '}'
-    assert blocklines and blocklines[0].strip() == "{"
-    assert blocklines and blocklines[-1].strip() == "}"
-    blocklines.pop()
-    blocklines.pop(0)
-    # Remove trailing blank lines
-    while blocklines and not blocklines[-1].strip():
-        blocklines.pop()
-    # Write the body
-    ninputs = len(stack)
-    for line in blocklines:
-        if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line):
-            space, cond, label = m.groups()
-            # ERROR_IF() must remove the inputs from the stack.
-            # The code block is responsible for DECREF()ing them.
-            if ninputs:
-                f.write(f"{space}if ({cond}) goto pop_{ninputs}_{label};\n")
-            else:
-                f.write(f"{space}if ({cond}) goto {label};\n")
-        else:
-            f.write(line)
-    if always_exits(instr.block):
-        # None of the rest matters
-        return cache_offset
-    # Stack effect
-    noutputs = len(outputs)
-    diff = noutputs - ninputs
-    if diff > 0:
-        f.write(f"{indent}    STACK_GROW({diff});\n")
-    elif diff < 0:
-        f.write(f"{indent}    STACK_SHRINK({-diff});\n")
-    for i, output in enumerate(reversed(outputs), 1):
-        if output.name not in input_names and output.name != "unused":
-            f.write(f"{indent}    POKE({i}, {output.name});\n")
-    # Cache effect
-    if cache_offset:
-        f.write(f"{indent}    next_instr += {cache_offset};\n")
-    return cache_offset
-
-
-def write_cases(
-    f: TextIO, instrs: list[InstDef], supers: list[parser.Super], families: list[parser.Family]
-) -> dict[str, tuple[int, int, int]]:
-    predictions: set[str] = set()
-    for instr in instrs:
-        for target in re.findall(RE_PREDICTED, instr.block.text):
-            predictions.add(target)
-    indent = "        "
-    f.write(f"// This file is generated by {os.path.relpath(__file__)}\n")
-    f.write(f"// Do not edit!\n")
-    instr_index: dict[str, InstDef] = {}
-    effects_table: dict[str, tuple[int, int, int]] = {}  # name -> (ninputs, noutputs, cache_offset)
-    for instr in instrs:
-        instr_index[instr.name] = instr
-        f.write(f"\n{indent}TARGET({instr.name}) {{\n")
-        if instr.name in predictions:
-            f.write(f"{indent}    PREDICTED({instr.name});\n")
-        cache_offset = write_instr(
-            instr, predictions, indent, f,
-            cache_size=find_cache_size(instr, families)
-        )
-        effects_table[instr.name] = len(instr.inputs), len(instr.outputs), cache_offset
-        if not always_exits(instr.block):
-            f.write(f"{indent}    DISPATCH();\n")
-        # Write trailing '}'
-        f.write(f"{indent}}}\n")
-
-    for sup in supers:
-        components = [instr_index[name] for name in sup.ops]
-        f.write(f"\n{indent}TARGET({sup.name}) {{\n")
-        for i, instr in enumerate(components):
-            if i > 0:
-                f.write(f"{indent}    NEXTOPARG();\n")
-                f.write(f"{indent}    next_instr++;\n")
-            f.write(f"{indent}    {{\n")
-            write_instr(instr, predictions, indent, f, dedent=-4)
-            f.write(f"    {indent}}}\n")
-        f.write(f"{indent}    DISPATCH();\n")
-        f.write(f"{indent}}}\n")
-
-    return effects_table
+    return line.startswith(
+        ("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()")
+    )
 
 
 def main():
-    args = arg_parser.parse_args()
-    with eopen(args.input) as f:
-        srclines = f.read().splitlines()
-    begin = srclines.index("// BEGIN BYTECODES //")
-    end = srclines.index("// END BYTECODES //")
-    src = "\n".join(srclines[begin+1 : end])
-    instrs, supers, families = parse_cases(src, filename=args.input)
-    ninstrs = nsupers = nfamilies = 0
-    if not args.quiet:
-        ninstrs = len(instrs)
-        nsupers = len(supers)
-        nfamilies = len(families)
-        print(
-            f"Read {ninstrs} instructions, {nsupers} supers, "
-            f"and {nfamilies} families from {args.input}",
-            file=sys.stderr,
-        )
-    with eopen(args.output, "w") as f:
-        effects_table = write_cases(f, instrs, supers, families)
-    if not args.quiet:
-        print(
-            f"Wrote {ninstrs + nsupers} instructions to {args.output}",
-            file=sys.stderr,
-        )
-    # Check that families have consistent effects
-    errors = 0
-    for family in families:
-        head = effects_table[family.members[0]]
-        for member in family.members:
-            if effects_table[member] != head:
-                errors += 1
-                print(
-                    f"Family {family.name!r} has inconsistent effects (inputs, outputs, cache units):",
-                    file=sys.stderr,
-                )
-                print(
-                    f"  {family.members[0]} = {head}; {member} = {effects_table[member]}",
-                )
-    if errors:
-        sys.exit(1)
+    """Parse command line, parse input, analyze, write output."""
+    args = arg_parser.parse_args()  # Prints message and sys.exit(2) on error
+    a = Analyzer(args.input)  # Raises OSError if file not found
+    a.parse()  # Raises SyntaxError on failure
+    a.analyze()  # Prints messages and raises SystemExit on failure
+    if a.errors:
+        sys.exit(f"Found {a.errors} errors")
+
+    a.write_instructions(args.output)  # Raises OSError if file can't be written
 
 
 if __name__ == "__main__":
diff --git a/Tools/cases_generator/lexer.py b/Tools/cases_generator/lexer.py
index c5320c03d546..493a32e38166 100644
--- a/Tools/cases_generator/lexer.py
+++ b/Tools/cases_generator/lexer.py
@@ -115,7 +115,7 @@ def choice(*opts):
 matcher = re.compile(choice(id_re, number_re, str_re, char, newline, macro, comment_re, *operators.values()))
 letter = re.compile(r'[a-zA-Z_]')
 
-keywords = (
+kwds = (
     'AUTO', 'BREAK', 'CASE', 'CHAR', 'CONST',
     'CONTINUE', 'DEFAULT', 'DO', 'DOUBLE', 'ELSE', 'ENUM', 'EXTERN',
     'FLOAT', 'FOR', 'GOTO', 'IF', 'INLINE', 'INT', 'LONG',
@@ -124,9 +124,9 @@ def choice(*opts):
     'SWITCH', 'TYPEDEF', 'UNION', 'UNSIGNED', 'VOID',
     'VOLATILE', 'WHILE'
 )
-for name in keywords:
+for name in kwds:
     globals()[name] = name
-keywords = { name.lower() : name for name in keywords }
+keywords = { name.lower() : name for name in kwds }
 
 
 def make_syntax_error(
diff --git a/Tools/cases_generator/parser.py b/Tools/cases_generator/parser.py
index 1f855312aeba..c511607fdf70 100644
--- a/Tools/cases_generator/parser.py
+++ b/Tools/cases_generator/parser.py
@@ -57,27 +57,26 @@ class Block(Node):
 
 
 @dataclass
-class Effect(Node):
-    pass
-
-
- at dataclass
-class StackEffect(Effect):
+class StackEffect(Node):
     name: str
     # TODO: type, condition
 
 
 @dataclass
-class CacheEffect(Effect):
+class CacheEffect(Node):
     name: str
     size: int
 
 
+InputEffect = StackEffect | CacheEffect
+OutputEffect = StackEffect
+
+
 @dataclass
 class InstHeader(Node):
     name: str
-    inputs: list[Effect]
-    outputs: list[Effect]
+    inputs: list[InputEffect]
+    outputs: list[OutputEffect]
 
 
 @dataclass
@@ -90,13 +89,12 @@ def name(self) -> str:
         return self.header.name
 
     @property
-    def inputs(self) -> list[Effect]:
+    def inputs(self) -> list[InputEffect]:
         return self.header.inputs
 
     @property
     def outputs(self) -> list[StackEffect]:
-        # This is always true
-        return [x for x in self.header.outputs if isinstance(x, StackEffect)]
+        return self.header.outputs
 
 
 @dataclass
@@ -126,7 +124,7 @@ def inst_def(self) -> InstDef | None:
     def inst_header(self) -> InstHeader | None:
         # inst(NAME) | inst(NAME, (inputs -- outputs))
         # TODO: Error out when there is something unexpected.
-        # TODO: Make INST a keyword in the lexer.``
+        # TODO: Make INST a keyword in the lexer.
         if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "inst":
             if (self.expect(lx.LPAREN)
                     and (tkn := self.expect(lx.IDENTIFIER))):
@@ -136,32 +134,22 @@ def inst_header(self) -> InstHeader | None:
                     if self.expect(lx.RPAREN):
                         if ((tkn := self.peek())
                                 and tkn.kind == lx.LBRACE):
-                            self.check_overlaps(inp, outp)
                             return InstHeader(name, inp, outp)
                 elif self.expect(lx.RPAREN):
                     return InstHeader(name, [], [])
         return None
 
-    def check_overlaps(self, inp: list[Effect], outp: list[Effect]):
-        for i, name in enumerate(inp):
-            for j, name2 in enumerate(outp):
-                if name == name2:
-                    if i != j:
-                        raise self.make_syntax_error(
-                            f"Input {name!r} at pos {i} repeated in output at different pos {j}")
-                    break
-
-    def stack_effect(self) -> tuple[list[Effect], list[Effect]]:
+    def stack_effect(self) -> tuple[list[InputEffect], list[OutputEffect]]:
         # '(' [inputs] '--' [outputs] ')'
         if self.expect(lx.LPAREN):
-            inp = self.inputs() or []
+            inputs = self.inputs() or []
             if self.expect(lx.MINUSMINUS):
-                outp = self.outputs() or []
+                outputs = self.outputs() or []
                 if self.expect(lx.RPAREN):
-                    return inp, outp
+                    return inputs, outputs
         raise self.make_syntax_error("Expected stack effect")
 
-    def inputs(self) -> list[Effect] | None:
+    def inputs(self) -> list[InputEffect] | None:
         # input (',' input)*
         here = self.getpos()
         if inp := self.input():
@@ -175,7 +163,7 @@ def inputs(self) -> list[Effect] | None:
         return None
 
     @contextual
-    def input(self) -> Effect | None:
+    def input(self) -> InputEffect | None:
         # IDENTIFIER '/' INTEGER (CacheEffect)
         # IDENTIFIER (StackEffect)
         if (tkn := self.expect(lx.IDENTIFIER)):
@@ -192,7 +180,7 @@ def input(self) -> Effect | None:
             else:
                 return StackEffect(tkn.text)
 
-    def outputs(self) -> list[Effect] | None:
+    def outputs(self) -> list[OutputEffect] | None:
         # output (, output)*
         here = self.getpos()
         if outp := self.output():
@@ -206,7 +194,7 @@ def outputs(self) -> list[Effect] | None:
         return None
 
     @contextual
-    def output(self) -> Effect | None:
+    def output(self) -> OutputEffect | None:
         if (tkn := self.expect(lx.IDENTIFIER)):
             return StackEffect(tkn.text)
 
diff --git a/Tools/cases_generator/plexer.py b/Tools/cases_generator/plexer.py
index 107d608152ce..a73254ed5b1d 100644
--- a/Tools/cases_generator/plexer.py
+++ b/Tools/cases_generator/plexer.py
@@ -3,7 +3,7 @@
 
 
 class PLexer:
-    def __init__(self, src: str, filename: str|None = None):
+    def __init__(self, src: str, filename: str):
         self.src = src
         self.filename = filename
         self.tokens = list(lx.tokenize(self.src, filename=filename))
@@ -89,16 +89,17 @@ def make_syntax_error(self, message: str, tkn: Token|None = None) -> SyntaxError
         filename = sys.argv[1]
         if filename == "-c" and sys.argv[2:]:
             src = sys.argv[2]
-            filename = None
+            filename = "<string>"
         else:
             with open(filename) as f:
                 src = f.read()
     else:
-        filename = None
+        filename = "<default>"
         src = "if (x) { x.foo; // comment\n}"
     p = PLexer(src, filename)
     while not p.eof():
         tok = p.next(raw=True)
+        assert tok
         left = repr(tok)
         right = lx.to_text([tok]).rstrip()
         print(f"{left:40.40} {right}")



More information about the Python-checkins mailing list