[pypy-svn] r76285 - in pypy/trunk/pypy/interpreter: astcompiler astcompiler/tools test

agaynor at codespeak.net agaynor at codespeak.net
Mon Jul 19 21:33:50 CEST 2010


Author: agaynor
Date: Mon Jul 19 21:33:47 2010
New Revision: 76285

Modified:
   pypy/trunk/pypy/interpreter/astcompiler/ast.py
   pypy/trunk/pypy/interpreter/astcompiler/codegen.py
   pypy/trunk/pypy/interpreter/astcompiler/symtable.py
   pypy/trunk/pypy/interpreter/astcompiler/tools/asdl_py.py
   pypy/trunk/pypy/interpreter/test/test_compiler.py
Log:
Ensure that things in the if clause of a list comprehension are optimized by the AST optimizer.


Modified: pypy/trunk/pypy/interpreter/astcompiler/ast.py
==============================================================================
--- pypy/trunk/pypy/interpreter/astcompiler/ast.py	(original)
+++ pypy/trunk/pypy/interpreter/astcompiler/ast.py	Mon Jul 19 21:33:47 2010
@@ -230,6 +230,7 @@
         visitor.visit_FunctionDef(self)
 
     def mutate_over(self, visitor):
+        self.args = self.args.mutate_over(visitor)
         if self.body:
             visitor._mutate_sequence(self.body)
         if self.decorators:
@@ -784,6 +785,8 @@
     def mutate_over(self, visitor):
         if self.body:
             visitor._mutate_sequence(self.body)
+        if self.handlers:
+            visitor._mutate_sequence(self.handlers)
         if self.orelse:
             visitor._mutate_sequence(self.orelse)
         return visitor.visit_TryExcept(self)
@@ -927,6 +930,8 @@
         visitor.visit_Import(self)
 
     def mutate_over(self, visitor):
+        if self.names:
+            visitor._mutate_sequence(self.names)
         return visitor.visit_Import(self)
 
     def sync_app_attrs(self, space):
@@ -965,6 +970,8 @@
         visitor.visit_ImportFrom(self)
 
     def mutate_over(self, visitor):
+        if self.names:
+            visitor._mutate_sequence(self.names)
         return visitor.visit_ImportFrom(self)
 
     def sync_app_attrs(self, space):
@@ -1282,6 +1289,7 @@
         visitor.visit_Lambda(self)
 
     def mutate_over(self, visitor):
+        self.args = self.args.mutate_over(visitor)
         self.body = self.body.mutate_over(visitor)
         return visitor.visit_Lambda(self)
 
@@ -1398,6 +1406,8 @@
 
     def mutate_over(self, visitor):
         self.elt = self.elt.mutate_over(visitor)
+        if self.generators:
+            visitor._mutate_sequence(self.generators)
         return visitor.visit_ListComp(self)
 
     def sync_app_attrs(self, space):
@@ -1437,6 +1447,8 @@
 
     def mutate_over(self, visitor):
         self.elt = self.elt.mutate_over(visitor)
+        if self.generators:
+            visitor._mutate_sequence(self.generators)
         return visitor.visit_GeneratorExp(self)
 
     def sync_app_attrs(self, space):
@@ -1562,6 +1574,8 @@
         self.func = self.func.mutate_over(visitor)
         if self.args:
             visitor._mutate_sequence(self.args)
+        if self.keywords:
+            visitor._mutate_sequence(self.keywords)
         if self.starargs:
             self.starargs = self.starargs.mutate_over(visitor)
         if self.kwargs:
@@ -2293,6 +2307,13 @@
         self.w_ifs = None
         self.initialization_state = 7
 
+    def mutate_over(self, visitor):
+        self.target = self.target.mutate_over(visitor)
+        self.iter = self.iter.mutate_over(visitor)
+        if self.ifs:
+            visitor._mutate_sequence(self.ifs)
+        return visitor.visit_comprehension(self)
+
     def walkabout(self, visitor):
         visitor.visit_comprehension(self)
 
@@ -2327,6 +2348,15 @@
         self.col_offset = col_offset
         self.initialization_state = 31
 
+    def mutate_over(self, visitor):
+        if self.type:
+            self.type = self.type.mutate_over(visitor)
+        if self.name:
+            self.name = self.name.mutate_over(visitor)
+        if self.body:
+            visitor._mutate_sequence(self.body)
+        return visitor.visit_excepthandler(self)
+
     def walkabout(self, visitor):
         visitor.visit_excepthandler(self)
 
@@ -2366,6 +2396,13 @@
         self.w_defaults = None
         self.initialization_state = 15
 
+    def mutate_over(self, visitor):
+        if self.args:
+            visitor._mutate_sequence(self.args)
+        if self.defaults:
+            visitor._mutate_sequence(self.defaults)
+        return visitor.visit_arguments(self)
+
     def walkabout(self, visitor):
         visitor.visit_arguments(self)
 
@@ -2407,6 +2444,10 @@
         self.value = value
         self.initialization_state = 3
 
+    def mutate_over(self, visitor):
+        self.value = self.value.mutate_over(visitor)
+        return visitor.visit_keyword(self)
+
     def walkabout(self, visitor):
         visitor.visit_keyword(self)
 
@@ -2426,6 +2467,9 @@
         self.asname = asname
         self.initialization_state = 3
 
+    def mutate_over(self, visitor):
+        return visitor.visit_alias(self)
+
     def walkabout(self, visitor):
         visitor.visit_alias(self)
 

Modified: pypy/trunk/pypy/interpreter/astcompiler/codegen.py
==============================================================================
--- pypy/trunk/pypy/interpreter/astcompiler/codegen.py	(original)
+++ pypy/trunk/pypy/interpreter/astcompiler/codegen.py	Mon Jul 19 21:33:47 2010
@@ -258,9 +258,11 @@
         # Load decorators first, but apply them after the function is created.
         if func.decorators:
             self.visit_sequence(func.decorators)
-        if func.args.defaults:
-            self.visit_sequence(func.args.defaults)
-            num_defaults = len(func.args.defaults)
+        args = func.args
+        assert isinstance(args, ast.arguments)
+        if args.defaults:
+            self.visit_sequence(args.defaults)
+            num_defaults = len(args.defaults)
         else:
             num_defaults = 0
         code = self.sub_scope(FunctionCodeGenerator, func.name, func,
@@ -274,9 +276,11 @@
 
     def visit_Lambda(self, lam):
         self.update_position(lam.lineno)
-        if lam.args.defaults:
-            self.visit_sequence(lam.args.defaults)
-            default_count = len(lam.args.defaults)
+        args = lam.args
+        assert isinstance(args, ast.arguments)
+        if args.defaults:
+            self.visit_sequence(args.defaults)
+            default_count = len(args.defaults)
         else:
             default_count = 0
         code = self.sub_scope(LambdaCodeGenerator, "<lambda>", lam, lam.lineno)
@@ -1275,9 +1279,11 @@
         else:
             self.add_const(self.space.w_None)
             start = 0
-        if func.args.args:
-            self._handle_nested_args(func.args.args)
-            self.argcount = len(func.args.args)
+        args = func.args
+        assert isinstance(args, ast.arguments)
+        if args.args:
+            self._handle_nested_args(args.args)
+            self.argcount = len(args.args)
         for i in range(start, len(func.body)):
             func.body[i].walkabout(self)
 
@@ -1286,9 +1292,11 @@
 
     def _compile(self, lam):
         assert isinstance(lam, ast.Lambda)
-        if lam.args.args:
-            self._handle_nested_args(lam.args.args)
-            self.argcount = len(lam.args.args)
+        args = lam.args
+        assert isinstance(args, ast.arguments)
+        if args.args:
+            self._handle_nested_args(args.args)
+            self.argcount = len(args.args)
         # Prevent a string from being the first constant and thus a docstring.
         self.add_const(self.space.w_None)
         lam.body.walkabout(self)

Modified: pypy/trunk/pypy/interpreter/astcompiler/symtable.py
==============================================================================
--- pypy/trunk/pypy/interpreter/astcompiler/symtable.py	(original)
+++ pypy/trunk/pypy/interpreter/astcompiler/symtable.py	Mon Jul 19 21:33:47 2010
@@ -353,8 +353,10 @@
     def visit_FunctionDef(self, func):
         self.note_symbol(func.name, SYM_ASSIGNED)
         # Function defaults and decorators happen in the outer scope.
-        if func.args.defaults:
-            self.visit_sequence(func.args.defaults)
+        args = func.args
+        assert isinstance(args, ast.arguments)
+        if args.defaults:
+            self.visit_sequence(args.defaults)
         if func.decorators:
             self.visit_sequence(func.decorators)
         new_scope = FunctionScope(func.name, func.lineno, func.col_offset)
@@ -420,8 +422,10 @@
             self.note_symbol(name, SYM_GLOBAL)
 
     def visit_Lambda(self, lamb):
-        if lamb.args.defaults:
-            self.visit_sequence(lamb.args.defaults)
+        args = lamb.args
+        assert isinstance(args, ast.arguments)
+        if args.defaults:
+            self.visit_sequence(args.defaults)
         new_scope = FunctionScope("lambda", lamb.lineno, lamb.col_offset)
         self.push_scope(new_scope, lamb)
         lamb.args.walkabout(self)

Modified: pypy/trunk/pypy/interpreter/astcompiler/tools/asdl_py.py
==============================================================================
--- pypy/trunk/pypy/interpreter/astcompiler/tools/asdl_py.py	(original)
+++ pypy/trunk/pypy/interpreter/astcompiler/tools/asdl_py.py	Mon Jul 19 21:33:47 2010
@@ -100,6 +100,7 @@
         self.emit("")
         self.make_constructor(product.fields, product)
         self.emit("")
+        self.make_mutate_over(product, name)
         self.emit("def walkabout(self, visitor):", 1)
         self.emit("visitor.visit_%s(self)" % (name,), 2)
         self.emit("")
@@ -183,6 +184,26 @@
         have_everything = self.data.required_masks[node] | \
             self.data.optional_masks[node]
         self.emit("self.initialization_state = %i" % (have_everything,), 2)
+    
+    def make_mutate_over(self, cons, name):
+        self.emit("def mutate_over(self, visitor):", 1)
+        for field in cons.fields:
+            if (field.type.value not in asdl.builtin_types and
+                field.type.value not in self.data.simple_types):
+                if field.opt or field.seq:
+                    level = 3
+                    self.emit("if self.%s:" % (field.name,), 2)
+                else:
+                    level = 2
+                if field.seq:
+                    sub = (field.name,)
+                    self.emit("visitor._mutate_sequence(self.%s)" % sub, level)
+                else:
+                    sub = (field.name, field.name)
+                    self.emit("self.%s = self.%s.mutate_over(visitor)" % sub,
+                              level)
+        self.emit("return visitor.visit_%s(self)" % (name,), 2)
+        self.emit("")
 
     def visitConstructor(self, cons, base, extra_attributes):
         self.emit("class %s(%s):" % (cons.name, base))
@@ -199,24 +220,7 @@
         self.emit("def walkabout(self, visitor):", 1)
         self.emit("visitor.visit_%s(self)" % (cons.name,), 2)
         self.emit("")
-        self.emit("def mutate_over(self, visitor):", 1)
-        for field in cons.fields:
-            if field.type.value not in asdl.builtin_types and \
-                    field.type.value not in self.data.prod_simple:
-                if field.opt or field.seq:
-                    level = 3
-                    self.emit("if self.%s:" % (field.name,), 2)
-                else:
-                    level = 2
-                if field.seq:
-                    sub = (field.name,)
-                    self.emit("visitor._mutate_sequence(self.%s)" % sub, level)
-                else:
-                    sub = (field.name, field.name)
-                    self.emit("self.%s = self.%s.mutate_over(visitor)" % sub,
-                              level)
-        self.emit("return visitor.visit_%s(self)" % (cons.name,), 2)
-        self.emit("")
+        self.make_mutate_over(cons, cons.name)
         self.make_var_syncer(cons.fields + self.data.cons_attributes[cons],
                              cons, cons.name)
 

Modified: pypy/trunk/pypy/interpreter/test/test_compiler.py
==============================================================================
--- pypy/trunk/pypy/interpreter/test/test_compiler.py	(original)
+++ pypy/trunk/pypy/interpreter/test/test_compiler.py	Mon Jul 19 21:33:47 2010
@@ -838,6 +838,23 @@
             sys.stdout = save_stdout
         output = s.getvalue()
         assert "STOP_CODE" not in output
+    
+    def test_optimize_list_comp(self):
+        source = """def _f(a):
+            return [x for x in a if None]
+        """
+        exec source
+        code = _f.func_code
+        
+        import StringIO, sys, dis
+        s = StringIO.StringIO()
+        sys.stdout = s
+        try:
+            dis.dis(code)
+        finally:
+            sys.stdout = sys.__stdout__
+        output = s.getvalue()
+        assert "LOAD_GLOBAL" not in output
 
 class AppTestExceptions:
     def test_indentation_error(self):



More information about the Pypy-commit mailing list