[Python-checkins] (no subject)

Batuhan Taşkaya webhook-mailer at python.org
Sun Mar 1 15:12:21 EST 2020




To: python-checkins at python.org
Subject: bpo-38870: Implement a precedence algorithm in ast.unparse (GH-17377)
Content-Type: text/plain; charset="utf-8"
Content-Transfer-Encoding: quoted-printable
MIME-Version: 1.0

https://github.com/python/cpython/commit/397b96f6d7a89f778ebc0591e32216a8183f=
e667
commit: 397b96f6d7a89f778ebc0591e32216a8183fe667
branch: master
author: Batuhan Ta=C5=9Fkaya <47358913+isidentical at users.noreply.github.com>
committer: GitHub <noreply at github.com>
date: 2020-03-01T20:12:17Z
summary:

bpo-38870: Implement a precedence algorithm in ast.unparse (GH-17377)

Implement a simple precedence algorithm for ast.unparse in order to avoid red=
undant
parenthesis for nested structures in the final output.

files:
M Lib/ast.py
M Lib/test/test_ast.py
M Lib/test/test_unparse.py

diff --git a/Lib/ast.py b/Lib/ast.py
index 511f0956a00b0..4839201e2e234 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -27,6 +27,7 @@
 import sys
 from _ast import *
 from contextlib import contextmanager, nullcontext
+from enum import IntEnum, auto
=20
=20
 def parse(source, filename=3D'<unknown>', mode=3D'exec', *,
@@ -560,6 +561,35 @@ def __new__(cls, *args, **kwargs):
 # We unparse those infinities to INFSTR.
 _INFSTR =3D "1e" + repr(sys.float_info.max_10_exp + 1)
=20
+class _Precedence(IntEnum):
+    """Precedence table that originated from python grammar."""
+
+    TUPLE =3D auto()
+    YIELD =3D auto()           # 'yield', 'yield from'
+    TEST =3D auto()            # 'if'-'else', 'lambda'
+    OR =3D auto()              # 'or'
+    AND =3D auto()             # 'and'
+    NOT =3D auto()             # 'not'
+    CMP =3D auto()             # '<', '>', '=3D=3D', '>=3D', '<=3D', '!=3D',
+                             # 'in', 'not in', 'is', 'is not'
+    EXPR =3D auto()
+    BOR =3D EXPR               # '|'
+    BXOR =3D auto()            # '^'
+    BAND =3D auto()            # '&'
+    SHIFT =3D auto()           # '<<', '>>'
+    ARITH =3D auto()           # '+', '-'
+    TERM =3D auto()            # '*', '@', '/', '%', '//'
+    FACTOR =3D auto()          # unary '+', '-', '~'
+    POWER =3D auto()           # '**'
+    AWAIT =3D auto()           # 'await'
+    ATOM =3D auto()
+
+    def next(self):
+        try:
+            return self.__class__(self + 1)
+        except ValueError:
+            return self
+
 class _Unparser(NodeVisitor):
     """Methods in this class recursively traverse an AST and
     output source code for the abstract syntax; original formatting
@@ -568,6 +598,7 @@ class _Unparser(NodeVisitor):
     def __init__(self):
         self._source =3D []
         self._buffer =3D []
+        self._precedences =3D {}
         self._indent =3D 0
=20
     def interleave(self, inter, f, seq):
@@ -625,6 +656,17 @@ def delimit_if(self, start, end, condition):
         else:
             return nullcontext()
=20
+    def require_parens(self, precedence, node):
+        """Shortcut to adding precedence related parens"""
+        return self.delimit_if("(", ")", self.get_precedence(node) > precede=
nce)
+
+    def get_precedence(self, node):
+        return self._precedences.get(node, _Precedence.TEST)
+
+    def set_precedence(self, precedence, *nodes):
+        for node in nodes:
+            self._precedences[node] =3D precedence
+
     def traverse(self, node):
         if isinstance(node, list):
             for item in node:
@@ -645,10 +687,12 @@ def visit_Module(self, node):
=20
     def visit_Expr(self, node):
         self.fill()
+        self.set_precedence(_Precedence.YIELD, node.value)
         self.traverse(node.value)
=20
     def visit_NamedExpr(self, node):
-        with self.delimit("(", ")"):
+        with self.require_parens(_Precedence.TUPLE, node):
+            self.set_precedence(_Precedence.ATOM, node.target, node.value)
             self.traverse(node.target)
             self.write(" :=3D ")
             self.traverse(node.value)
@@ -723,24 +767,27 @@ def visit_Nonlocal(self, node):
         self.interleave(lambda: self.write(", "), self.write, node.names)
=20
     def visit_Await(self, node):
-        with self.delimit("(", ")"):
+        with self.require_parens(_Precedence.AWAIT, node):
             self.write("await")
             if node.value:
                 self.write(" ")
+                self.set_precedence(_Precedence.ATOM, node.value)
                 self.traverse(node.value)
=20
     def visit_Yield(self, node):
-        with self.delimit("(", ")"):
+        with self.require_parens(_Precedence.YIELD, node):
             self.write("yield")
             if node.value:
                 self.write(" ")
+                self.set_precedence(_Precedence.ATOM, node.value)
                 self.traverse(node.value)
=20
     def visit_YieldFrom(self, node):
-        with self.delimit("(", ")"):
+        with self.require_parens(_Precedence.YIELD, node):
             self.write("yield from ")
             if not node.value:
                 raise ValueError("Node can't be used without a value attribu=
te.")
+            self.set_precedence(_Precedence.ATOM, node.value)
             self.traverse(node.value)
=20
     def visit_Raise(self, node):
@@ -907,7 +954,9 @@ def _fstring_Constant(self, node, write):
=20
     def _fstring_FormattedValue(self, node, write):
         write("{")
-        expr =3D type(self)().visit(node.value).rstrip("\n")
+        unparser =3D type(self)()
+        unparser.set_precedence(_Precedence.TEST.next(), node.value)
+        expr =3D unparser.visit(node.value).rstrip("\n")
         if expr.startswith("{"):
             write(" ")  # Separate pair of opening brackets as "{ {"
         write(expr)
@@ -983,19 +1032,23 @@ def visit_comprehension(self, node):
             self.write(" async for ")
         else:
             self.write(" for ")
+        self.set_precedence(_Precedence.TUPLE, node.target)
         self.traverse(node.target)
         self.write(" in ")
+        self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
         self.traverse(node.iter)
         for if_clause in node.ifs:
             self.write(" if ")
             self.traverse(if_clause)
=20
     def visit_IfExp(self, node):
-        with self.delimit("(", ")"):
+        with self.require_parens(_Precedence.TEST, node):
+            self.set_precedence(_Precedence.TEST.next(), node.body, node.tes=
t)
             self.traverse(node.body)
             self.write(" if ")
             self.traverse(node.test)
             self.write(" else ")
+            self.set_precedence(_Precedence.TEST, node.orelse)
             self.traverse(node.orelse)
=20
     def visit_Set(self, node):
@@ -1016,6 +1069,7 @@ def write_item(item):
                 # for dictionary unpacking operator in dicts {**{'y': 2}}
                 # see PEP 448 for details
                 self.write("**")
+                self.set_precedence(_Precedence.EXPR, v)
                 self.traverse(v)
             else:
                 write_key_value_pair(k, v)
@@ -1035,11 +1089,20 @@ def visit_Tuple(self, node):
                 self.interleave(lambda: self.write(", "), self.traverse, nod=
e.elts)
=20
     unop =3D {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
+    unop_precedence =3D {
+        "~": _Precedence.FACTOR,
+        "not": _Precedence.NOT,
+        "+": _Precedence.FACTOR,
+        "-": _Precedence.FACTOR
+    }
=20
     def visit_UnaryOp(self, node):
-        with self.delimit("(", ")"):
-            self.write(self.unop[node.op.__class__.__name__])
+        operator =3D self.unop[node.op.__class__.__name__]
+        operator_precedence =3D self.unop_precedence[operator]
+        with self.require_parens(operator_precedence, node):
+            self.write(operator)
             self.write(" ")
+            self.set_precedence(operator_precedence, node.operand)
             self.traverse(node.operand)
=20
     binop =3D {
@@ -1058,10 +1121,38 @@ def visit_UnaryOp(self, node):
         "Pow": "**",
     }
=20
+    binop_precedence =3D {
+        "+": _Precedence.ARITH,
+        "-": _Precedence.ARITH,
+        "*": _Precedence.TERM,
+        "@": _Precedence.TERM,
+        "/": _Precedence.TERM,
+        "%": _Precedence.TERM,
+        "<<": _Precedence.SHIFT,
+        ">>": _Precedence.SHIFT,
+        "|": _Precedence.BOR,
+        "^": _Precedence.BXOR,
+        "&": _Precedence.BAND,
+        "//": _Precedence.TERM,
+        "**": _Precedence.POWER,
+    }
+
+    binop_rassoc =3D frozenset(("**",))
     def visit_BinOp(self, node):
-        with self.delimit("(", ")"):
+        operator =3D self.binop[node.op.__class__.__name__]
+        operator_precedence =3D self.binop_precedence[operator]
+        with self.require_parens(operator_precedence, node):
+            if operator in self.binop_rassoc:
+                left_precedence =3D operator_precedence.next()
+                right_precedence =3D operator_precedence
+            else:
+                left_precedence =3D operator_precedence
+                right_precedence =3D operator_precedence.next()
+
+            self.set_precedence(left_precedence, node.left)
             self.traverse(node.left)
-            self.write(" " + self.binop[node.op.__class__.__name__] + " ")
+            self.write(f" {operator} ")
+            self.set_precedence(right_precedence, node.right)
             self.traverse(node.right)
=20
     cmpops =3D {
@@ -1078,20 +1169,32 @@ def visit_BinOp(self, node):
     }
=20
     def visit_Compare(self, node):
-        with self.delimit("(", ")"):
+        with self.require_parens(_Precedence.CMP, node):
+            self.set_precedence(_Precedence.CMP.next(), node.left, *node.com=
parators)
             self.traverse(node.left)
             for o, e in zip(node.ops, node.comparators):
                 self.write(" " + self.cmpops[o.__class__.__name__] + " ")
                 self.traverse(e)
=20
     boolops =3D {"And": "and", "Or": "or"}
+    boolop_precedence =3D {"and": _Precedence.AND, "or": _Precedence.OR}
=20
     def visit_BoolOp(self, node):
-        with self.delimit("(", ")"):
-            s =3D " %s " % self.boolops[node.op.__class__.__name__]
-            self.interleave(lambda: self.write(s), self.traverse, node.value=
s)
+        operator =3D self.boolops[node.op.__class__.__name__]
+        operator_precedence =3D self.boolop_precedence[operator]
+
+        def increasing_level_traverse(node):
+            nonlocal operator_precedence
+            operator_precedence =3D operator_precedence.next()
+            self.set_precedence(operator_precedence, node)
+            self.traverse(node)
+
+        with self.require_parens(operator_precedence, node):
+            s =3D f" {operator} "
+            self.interleave(lambda: self.write(s), increasing_level_traverse=
, node.values)
=20
     def visit_Attribute(self, node):
+        self.set_precedence(_Precedence.ATOM, node.value)
         self.traverse(node.value)
         # Special case: 3.__abs__() is a syntax error, so if node.value
         # is an integer literal then we need to either parenthesize
@@ -1102,6 +1205,7 @@ def visit_Attribute(self, node):
         self.write(node.attr)
=20
     def visit_Call(self, node):
+        self.set_precedence(_Precedence.ATOM, node.func)
         self.traverse(node.func)
         with self.delimit("(", ")"):
             comma =3D False
@@ -1119,18 +1223,21 @@ def visit_Call(self, node):
                 self.traverse(e)
=20
     def visit_Subscript(self, node):
+        self.set_precedence(_Precedence.ATOM, node.value)
         self.traverse(node.value)
         with self.delimit("[", "]"):
             self.traverse(node.slice)
=20
     def visit_Starred(self, node):
         self.write("*")
+        self.set_precedence(_Precedence.EXPR, node.value)
         self.traverse(node.value)
=20
     def visit_Ellipsis(self, node):
         self.write("...")
=20
     def visit_Index(self, node):
+        self.set_precedence(_Precedence.TUPLE, node.value)
         self.traverse(node.value)
=20
     def visit_Slice(self, node):
@@ -1212,10 +1319,11 @@ def visit_keyword(self, node):
         self.traverse(node.value)
=20
     def visit_Lambda(self, node):
-        with self.delimit("(", ")"):
+        with self.require_parens(_Precedence.TEST, node):
             self.write("lambda ")
             self.traverse(node.args)
             self.write(": ")
+            self.set_precedence(_Precedence.TEST, node.body)
             self.traverse(node.body)
=20
     def visit_alias(self, node):
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 2ed4657822e54..e78848537d47a 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -247,6 +247,13 @@ def to_tuple(t):
=20
 class AST_Tests(unittest.TestCase):
=20
+    def _is_ast_node(self, name, node):
+        if not isinstance(node, type):
+            return False
+        if "ast" not in node.__module__:
+            return False
+        return name !=3D 'AST' and name[0].isupper()
+
     def _assertTrueorder(self, ast_node, parent_pos):
         if not isinstance(ast_node, ast.AST) or ast_node._fields is None:
             return
@@ -335,7 +342,7 @@ def test_base_classes(self):
=20
     def test_field_attr_existence(self):
         for name, item in ast.__dict__.items():
-            if isinstance(item, type) and name !=3D 'AST' and name[0].isuppe=
r():
+            if self._is_ast_node(name, item):
                 x =3D item()
                 if isinstance(x, ast.AST):
                     self.assertEqual(type(x._fields), tuple)
diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py
index e8b0d4b06f9e9..f7fcb2bffe891 100644
--- a/Lib/test/test_unparse.py
+++ b/Lib/test/test_unparse.py
@@ -125,6 +125,13 @@ def check_roundtrip(self, code1):
     def check_invalid(self, node, raises=3DValueError):
         self.assertRaises(raises, ast.unparse, node)
=20
+    def check_src_roundtrip(self, code1, code2=3DNone, strip=3DTrue):
+        code2 =3D code2 or code1
+        code1 =3D ast.unparse(ast.parse(code1))
+        if strip:
+            code1 =3D code1.strip()
+        self.assertEqual(code2, code1)
+
=20
 class UnparseTestCase(ASTTestCase):
     # Tests for specific bugs found in earlier versions of unparse
@@ -281,6 +288,40 @@ def test_invalid_set(self):
     def test_invalid_yield_from(self):
         self.check_invalid(ast.YieldFrom(value=3DNone))
=20
+class CosmeticTestCase(ASTTestCase):
+    """Test if there are cosmetic issues caused by unnecesary additions"""
+
+    def test_simple_expressions_parens(self):
+        self.check_src_roundtrip("(a :=3D b)")
+        self.check_src_roundtrip("await x")
+        self.check_src_roundtrip("x if x else y")
+        self.check_src_roundtrip("lambda x: x")
+        self.check_src_roundtrip("1 + 1")
+        self.check_src_roundtrip("1 + 2 / 3")
+        self.check_src_roundtrip("(1 + 2) / 3")
+        self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2)")
+        self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2) ** 2")
+        self.check_src_roundtrip("~ x")
+        self.check_src_roundtrip("x and y")
+        self.check_src_roundtrip("x and y and z")
+        self.check_src_roundtrip("x and (y and x)")
+        self.check_src_roundtrip("(x and y) and z")
+        self.check_src_roundtrip("(x ** y) ** z ** q")
+        self.check_src_roundtrip("x >> y")
+        self.check_src_roundtrip("x << y")
+        self.check_src_roundtrip("x >> y and x >> z")
+        self.check_src_roundtrip("x + y - z * q ^ t ** k")
+        self.check_src_roundtrip("P * V if P and V else n * R * T")
+        self.check_src_roundtrip("lambda P, V, n: P * V =3D=3D n * R * T")
+        self.check_src_roundtrip("flag & (other | foo)")
+        self.check_src_roundtrip("not x =3D=3D y")
+        self.check_src_roundtrip("x =3D=3D (not y)")
+        self.check_src_roundtrip("yield x")
+        self.check_src_roundtrip("yield from x")
+        self.check_src_roundtrip("call((yield x))")
+        self.check_src_roundtrip("return x + (yield x)")
+
+
 class DirectoryTestCase(ASTTestCase):
     """Test roundtrip behaviour on all files in Lib and Lib/test."""
=20



More information about the Python-checkins mailing list