[Python-checkins] gh-100518: Add tests for `ast.NodeTransformer` (#100521)

pablogsal webhook-mailer at python.org
Sat Jan 21 16:44:49 EST 2023


https://github.com/python/cpython/commit/c1c5882359a2899b74c1685a0d4e61d6e232161f
commit: c1c5882359a2899b74c1685a0d4e61d6e232161f
branch: main
author: Nikita Sobolev <mail at sobolevn.me>
committer: pablogsal <Pablogsal at gmail.com>
date: 2023-01-21T21:44:41Z
summary:

gh-100518: Add tests for `ast.NodeTransformer` (#100521)

files:
A Lib/test/support/ast_helper.py
M Lib/test/test_ast.py
M Lib/test/test_unparse.py

diff --git a/Lib/test/support/ast_helper.py b/Lib/test/support/ast_helper.py
new file mode 100644
index 000000000000..8a0415b6aae3
--- /dev/null
+++ b/Lib/test/support/ast_helper.py
@@ -0,0 +1,43 @@
+import ast
+
+class ASTTestMixin:
+    """Test mixing to have basic assertions for AST nodes."""
+
+    def assertASTEqual(self, ast1, ast2):
+        # Ensure the comparisons start at an AST node
+        self.assertIsInstance(ast1, ast.AST)
+        self.assertIsInstance(ast2, ast.AST)
+
+        # An AST comparison routine modeled after ast.dump(), but
+        # instead of string building, it traverses the two trees
+        # in lock-step.
+        def traverse_compare(a, b, missing=object()):
+            if type(a) is not type(b):
+                self.fail(f"{type(a)!r} is not {type(b)!r}")
+            if isinstance(a, ast.AST):
+                for field in a._fields:
+                    value1 = getattr(a, field, missing)
+                    value2 = getattr(b, field, missing)
+                    # Singletons are equal by definition, so further
+                    # testing can be skipped.
+                    if value1 is not value2:
+                        traverse_compare(value1, value2)
+            elif isinstance(a, list):
+                try:
+                    for node1, node2 in zip(a, b, strict=True):
+                        traverse_compare(node1, node2)
+                except ValueError:
+                    # Attempt a "pretty" error ala assertSequenceEqual()
+                    len1 = len(a)
+                    len2 = len(b)
+                    if len1 > len2:
+                        what = "First"
+                        diff = len1 - len2
+                    else:
+                        what = "Second"
+                        diff = len2 - len1
+                    msg = f"{what} list contains {diff} additional elements."
+                    raise self.failureException(msg) from None
+            elif a != b:
+                self.fail(f"{a!r} != {b!r}")
+        traverse_compare(ast1, ast2)
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 53a6418329e5..c728d2b55e42 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -11,6 +11,7 @@
 from textwrap import dedent
 
 from test import support
+from test.support.ast_helper import ASTTestMixin
 
 def to_tuple(t):
     if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
@@ -2290,9 +2291,10 @@ def test_source_segment_missing_info(self):
         self.assertIsNone(ast.get_source_segment(s, x))
         self.assertIsNone(ast.get_source_segment(s, y))
 
-class NodeVisitorTests(unittest.TestCase):
+class BaseNodeVisitorCases:
+    # Both `NodeVisitor` and `NodeTranformer` must raise these warnings:
     def test_old_constant_nodes(self):
-        class Visitor(ast.NodeVisitor):
+        class Visitor(self.visitor_class):
             def visit_Num(self, node):
                 log.append((node.lineno, 'Num', node.n))
             def visit_Str(self, node):
@@ -2340,6 +2342,128 @@ def visit_Ellipsis(self, node):
         ])
 
 
+class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase):
+    visitor_class = ast.NodeVisitor
+
+
+class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase):
+    visitor_class = ast.NodeTransformer
+
+    def assertASTTransformation(self, tranformer_class,
+                                initial_code, expected_code):
+        initial_ast = ast.parse(dedent(initial_code))
+        expected_ast = ast.parse(dedent(expected_code))
+
+        tranformer = tranformer_class()
+        result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast))
+
+        self.assertASTEqual(result_ast, expected_ast)
+
+    def test_node_remove_single(self):
+        code = 'def func(arg) -> SomeType: ...'
+        expected = 'def func(arg): ...'
+
+        # Since `FunctionDef.returns` is defined as a single value, we test
+        # the `if isinstance(old_value, AST):` branch here.
+        class SomeTypeRemover(ast.NodeTransformer):
+            def visit_Name(self, node: ast.Name):
+                self.generic_visit(node)
+                if node.id == 'SomeType':
+                    return None
+                return node
+
+        self.assertASTTransformation(SomeTypeRemover, code, expected)
+
+    def test_node_remove_from_list(self):
+        code = """
+        def func(arg):
+            print(arg)
+            yield arg
+        """
+        expected = """
+        def func(arg):
+            print(arg)
+        """
+
+        # Since `FunctionDef.body` is defined as a list, we test
+        # the `if isinstance(old_value, list):` branch here.
+        class YieldRemover(ast.NodeTransformer):
+            def visit_Expr(self, node: ast.Expr):
+                self.generic_visit(node)
+                if isinstance(node.value, ast.Yield):
+                    return None  # Remove `yield` from a function
+                return node
+
+        self.assertASTTransformation(YieldRemover, code, expected)
+
+    def test_node_return_list(self):
+        code = """
+        class DSL(Base, kw1=True): ...
+        """
+        expected = """
+        class DSL(Base, kw1=True, kw2=True, kw3=False): ...
+        """
+
+        class ExtendKeywords(ast.NodeTransformer):
+            def visit_keyword(self, node: ast.keyword):
+                self.generic_visit(node)
+                if node.arg == 'kw1':
+                    return [
+                        node,
+                        ast.keyword('kw2', ast.Constant(True)),
+                        ast.keyword('kw3', ast.Constant(False)),
+                    ]
+                return node
+
+        self.assertASTTransformation(ExtendKeywords, code, expected)
+
+    def test_node_mutate(self):
+        code = """
+        def func(arg):
+            print(arg)
+        """
+        expected = """
+        def func(arg):
+            log(arg)
+        """
+
+        class PrintToLog(ast.NodeTransformer):
+            def visit_Call(self, node: ast.Call):
+                self.generic_visit(node)
+                if isinstance(node.func, ast.Name) and node.func.id == 'print':
+                    node.func.id = 'log'
+                return node
+
+        self.assertASTTransformation(PrintToLog, code, expected)
+
+    def test_node_replace(self):
+        code = """
+        def func(arg):
+            print(arg)
+        """
+        expected = """
+        def func(arg):
+            logger.log(arg, debug=True)
+        """
+
+        class PrintToLog(ast.NodeTransformer):
+            def visit_Call(self, node: ast.Call):
+                self.generic_visit(node)
+                if isinstance(node.func, ast.Name) and node.func.id == 'print':
+                    return ast.Call(
+                        func=ast.Attribute(
+                            ast.Name('logger', ctx=ast.Load()),
+                            attr='log',
+                            ctx=ast.Load(),
+                        ),
+                        args=node.args,
+                        keywords=[ast.keyword('debug', ast.Constant(True))],
+                    )
+                return node
+
+        self.assertASTTransformation(PrintToLog, code, expected)
+
+
 @support.cpython_only
 class ModuleStateTests(unittest.TestCase):
     # bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.
diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py
index f1f1dd5dc26b..88c7c3a0af87 100644
--- a/Lib/test/test_unparse.py
+++ b/Lib/test/test_unparse.py
@@ -6,6 +6,7 @@
 import random
 import tokenize
 import ast
+from test.support.ast_helper import ASTTestMixin
 
 
 def read_pyfile(filename):
@@ -128,46 +129,7 @@ class Foo: pass
     "async def foo():\n    ",
 )
 
-class ASTTestCase(unittest.TestCase):
-    def assertASTEqual(self, ast1, ast2):
-        # Ensure the comparisons start at an AST node
-        self.assertIsInstance(ast1, ast.AST)
-        self.assertIsInstance(ast2, ast.AST)
-
-        # An AST comparison routine modeled after ast.dump(), but
-        # instead of string building, it traverses the two trees
-        # in lock-step.
-        def traverse_compare(a, b, missing=object()):
-            if type(a) is not type(b):
-                self.fail(f"{type(a)!r} is not {type(b)!r}")
-            if isinstance(a, ast.AST):
-                for field in a._fields:
-                    value1 = getattr(a, field, missing)
-                    value2 = getattr(b, field, missing)
-                    # Singletons are equal by definition, so further
-                    # testing can be skipped.
-                    if value1 is not value2:
-                        traverse_compare(value1, value2)
-            elif isinstance(a, list):
-                try:
-                    for node1, node2 in zip(a, b, strict=True):
-                        traverse_compare(node1, node2)
-                except ValueError:
-                    # Attempt a "pretty" error ala assertSequenceEqual()
-                    len1 = len(a)
-                    len2 = len(b)
-                    if len1 > len2:
-                        what = "First"
-                        diff = len1 - len2
-                    else:
-                        what = "Second"
-                        diff = len2 - len1
-                    msg = f"{what} list contains {diff} additional elements."
-                    raise self.failureException(msg) from None
-            elif a != b:
-                self.fail(f"{a!r} != {b!r}")
-        traverse_compare(ast1, ast2)
-
+class ASTTestCase(ASTTestMixin, unittest.TestCase):
     def check_ast_roundtrip(self, code1, **kwargs):
         with self.subTest(code1=code1, ast_parse_kwargs=kwargs):
             ast1 = ast.parse(code1, **kwargs)



More information about the Python-checkins mailing list