[Python-checkins] r64291 - in python/branches/tlee-ast-optimize: Lib/test/test_optimizer.py Python/optimize.c

thomas.lee python-checkins at python.org
Sun Jun 15 14:00:59 CEST 2008


Author: thomas.lee
Date: Sun Jun 15 14:00:58 2008
New Revision: 64291

Log:
Simplify jumps to returns from if statements.

Modified:
   python/branches/tlee-ast-optimize/Lib/test/test_optimizer.py
   python/branches/tlee-ast-optimize/Python/optimize.c

Modified: python/branches/tlee-ast-optimize/Lib/test/test_optimizer.py
==============================================================================
--- python/branches/tlee-ast-optimize/Lib/test/test_optimizer.py	(original)
+++ python/branches/tlee-ast-optimize/Lib/test/test_optimizer.py	Sun Jun 15 14:00:58 2008
@@ -222,7 +222,24 @@
             self.assertEqual(tuple, ast.body[0].value.value.__class__)
             self.assertEqual(obj, ast.body[0].value.value)
 
-    def test_folding_of_constant_list_in_for_loop(self):
+    def test_skip_unreachable_for_loop(self):
+        code = """
+for i in []:
+    print i
+"""
+        ast = self.compileast(code)
+        self.assertEqual(_ast.Pass, ast.body[0].__class__)
+
+    def test_skip_unreachable_while_loop(self):
+        code = """
+while 0:
+    print 'foo'
+"""
+
+        ast = self.compileast(code)
+        self.assertEqual(_ast.Pass, ast.body[0].__class__)
+
+    def test_fold_constant_list_in_for_loop(self):
         code = """
 for i in [1, 2, 3]:
     print i
@@ -242,6 +259,20 @@
             self.assertEqual(_ast.Const, ast.body[0].value.__class__)
             self.assertEqual(obj, ast.body[0].value.value)
 
+    def test_jumps_to_returns_are_simplified(self):
+        code = """
+def foo(x):
+    if x:
+        y = 1
+    else:
+        y = 0
+    return y
+"""
+
+        ast = self.compileast(code)
+        self.assertEqual(_ast.Return, ast.body[0].body[0].body[1].__class__)
+        self.assertEqual('y', ast.body[0].body[0].body[1].value.id)
+
 def test_main():
     test_support.run_unittest(AstOptimizerTest)
 

Modified: python/branches/tlee-ast-optimize/Python/optimize.c
==============================================================================
--- python/branches/tlee-ast-optimize/Python/optimize.c	(original)
+++ python/branches/tlee-ast-optimize/Python/optimize.c	Sun Jun 15 14:00:58 2008
@@ -152,6 +152,34 @@
 }
 
 /**
+ */
+static asdl_seq*
+_asdl_seq_append(asdl_seq* seq1, int n1, asdl_seq* seq2, int n2,
+                    PyArena* arena)
+{
+    asdl_seq* new;
+    int newlen, i;
+    int len1, len2;
+
+    /* XXX: check this calculation */
+    len1 = asdl_seq_LEN(seq1) - n1;
+    len2 = asdl_seq_LEN(seq2) - n2;
+    newlen = len1 + len2;
+
+    new = asdl_seq_new(newlen, arena);
+    if (new == NULL)
+        return NULL;
+
+    for (i = 0; i < len1; i++)
+        asdl_seq_SET(new, i, asdl_seq_GET(seq1, n1 + i));
+
+    for (i = 0; i < len2; i++)
+        asdl_seq_SET(new, len1 + i, asdl_seq_GET(seq2, n2 + i));
+
+    return new;
+}
+
+/**
  * Replace an AST node at position `n' with the node(s) in `replacement'.
  */
 static asdl_seq*
@@ -203,50 +231,150 @@
     return seq;
 }
 
+#define LAST_IN_SEQ(seq) (asdl_seq_LEN((seq)) - 1)
+
 /**
- * Optimize a sequence of statements.
+ * Eliminate code that we can determine will never be executed.
  */
 static int
-optimize_stmt_seq(asdl_seq** seq_ptr, PySTEntryObject* ste, PyArena* arena)
+_eliminate_unreachable_code(asdl_seq** seq_ptr, int n, PySTEntryObject* ste,
+                                PyArena* arena)
 {
-    int n;
     asdl_seq* seq = *seq_ptr;
-    for (n = 0; n < asdl_seq_LEN(seq); n++) {
-        stmt_ty stmt = asdl_seq_GET(seq, n);
-        if (!optimize_stmt((stmt_ty*)&asdl_seq_GET(seq, n), ste, arena))
-            return 0;
+    stmt_ty stmt = asdl_seq_GET(seq, n);
 
-        if (stmt->kind == If_kind) {
-            PyObject* test = _expr_constant_value(stmt->v.If.test);
-            /* eliminate branches that can never be reached */
-            if (test != NULL) {
-                if (PyObject_IsTrue(test))
-                    seq = _asdl_seq_replace(seq, n, stmt->v.If.body, arena);
+    /* eliminate unreachable branches in an "if" statement? */
+    if (stmt->kind == If_kind) {
+        PyObject* test = _expr_constant_value(stmt->v.If.test);
+        if (test != NULL) {
+            if (PyObject_IsTrue(test))
+                seq = _asdl_seq_replace(seq, n, stmt->v.If.body, arena);
+            else {
+                if (stmt->v.If.orelse == NULL) {
+                    /* no "else:" body: use a Pass() */
+                    seq = _asdl_seq_replace_with_pass(seq, n, stmt->lineno,
+                            stmt->col_offset, arena);
+                }
                 else {
-                    if (stmt->v.If.orelse == NULL) {
-                        /* no "else:" body: use a Pass() */
-                        seq = _asdl_seq_replace_with_pass(seq, n,
-                                stmt->lineno, stmt->col_offset, arena);
-                    }
-                    else {
-                        seq = _asdl_seq_replace(seq, n, stmt->v.If.orelse,
-                                arena);
-                    }
+                    seq = _asdl_seq_replace(seq, n, stmt->v.If.orelse, arena);
                 }
-                if (seq == NULL)
-                    return 0;
-                *seq_ptr = seq;
             }
+            if (seq == NULL)
+                return 0;
+            *seq_ptr = seq;
+        }
+    }
+    /* eliminate unreachable while loops? */
+    else if (stmt->kind == While_kind) {
+        PyObject* test = _expr_constant_value(stmt->v.While.test);
+        if (test != NULL) {
+            if (!PyObject_IsTrue(test)) {
+                /* XXX: what about orelse? */
+                seq = _asdl_seq_replace_with_pass(seq, n, stmt->lineno,
+                        stmt->col_offset, arena);
+            }
+            if (seq == NULL)
+                return 0;
+            *seq_ptr = seq;
         }
-        else if (stmt->kind == Return_kind && n < (asdl_seq_LEN(seq) - 1)) {
-            /* eliminate all nodes after a return */
-            seq = _asdl_seq_replace_with_pass(seq, n + 1,
-                    stmt->lineno, stmt->col_offset, arena);
+    }
+    /* eliminate unreachable for loops? */
+    else if (stmt->kind == For_kind) {
+        PyObject* iter = _expr_constant_value(stmt->v.For.iter);
+        if (iter != NULL) {
+            if (PyObject_Size(iter) == 0) {
+                /* XXX: what about orelse? */
+                seq = _asdl_seq_replace_with_pass(seq, n, stmt->lineno,
+                        stmt->col_offset, arena);
+            }
             if (seq == NULL)
                 return 0;
             *seq_ptr = seq;
         }
     }
+    /* eliminate all code after a "return" statement */
+    else if (stmt->kind == Return_kind && n < LAST_IN_SEQ(seq)) {
+        /* eliminate all nodes after a return */
+        seq = _asdl_seq_replace_with_pass(seq, n + 1,
+                stmt->lineno, stmt->col_offset, arena);
+        if (seq == NULL)
+            return 0;
+        *seq_ptr = seq;
+    }
+
+    return 1;
+}
+
+static asdl_seq*
+_asdl_seq_append_return(asdl_seq* seq, expr_ty value, PyArena* arena)
+{
+    stmt_ty ret;
+    stmt_ty last;
+    asdl_seq* retseq = asdl_seq_new(1, arena);
+    if (retseq == NULL)
+        return NULL;
+    last = asdl_seq_GET(seq, asdl_seq_LEN(seq)-1);
+    ret = Return(value, last->lineno, last->col_offset, arena);
+    if (ret == NULL)
+        return NULL;
+    asdl_seq_SET(retseq, 0, ret);
+    return _asdl_seq_append(seq, 0, retseq, 0, arena);
+}
+
+/**
+ * Simplify any branches that converge on a "return" statement such that
+ * they immediately return rather than jump.
+ */
+static int
+_simplify_jumps_to_return(asdl_seq* seq, PySTEntryObject* ste,
+                            PyArena* arena)
+{
+    int n, len;
+
+    len = asdl_seq_LEN(seq);
+
+    for (n = 0; n < len - 1; n++) {
+        stmt_ty stmt = asdl_seq_GET(seq, n);
+        stmt_ty next = asdl_seq_GET(seq, n+1);
+        
+        if (next->kind == Return_kind) {
+            /* if the else body is not present, there will be no jump */
+            if (stmt->kind == If_kind && stmt->v.If.orelse != NULL) {
+                stmt_ty inner = asdl_seq_GET(stmt->v.If.body,
+                                            asdl_seq_LEN(stmt->v.If.body)-1);
+                
+                if (inner->kind != Return_kind) {
+                    stmt->v.If.body =
+                        _asdl_seq_append_return(stmt->v.If.body,
+                                                next->v.Return.value, arena);
+
+                    if (stmt->v.If.body == NULL)
+                        return 0;
+                }
+            }
+        }
+    }
+
+    return 1;
+}
+
+/**
+ * Optimize a sequence of statements.
+ */
+static int
+optimize_stmt_seq(asdl_seq** seq_ptr, PySTEntryObject* ste, PyArena* arena)
+{
+    int n;
+    asdl_seq* seq = *seq_ptr;
+    for (n = 0; n < asdl_seq_LEN(seq); n++) {
+        if (!optimize_stmt((stmt_ty*)&asdl_seq_GET(seq, n), ste, arena))
+            return 0;
+        if (!_eliminate_unreachable_code(seq_ptr, n, ste, arena))
+            return 0;
+        if (ste->ste_type == FunctionBlock)
+            if (!_simplify_jumps_to_return(*seq_ptr, ste, arena))
+                return 0;
+    }
     return 1;
 }
 
@@ -972,11 +1100,9 @@
 }
 
 static int
-optimize_for(stmt_ty* stmt_ptr, PySTEntryObject* ste, PyArena* arena)
+_optimize_for_iter(stmt_ty* stmt_ptr, PySTEntryObject* ste, PyArena* arena)
 {
     stmt_ty stmt = *stmt_ptr;
-    if (!optimize_expr(&stmt->v.For.target, ste, arena))
-        return 0;
     if (!optimize_expr(&stmt->v.For.iter, ste, arena))
         return 0;
     /* if the object we're iterating over is a list of constants,
@@ -997,6 +1123,17 @@
                 return 0;
         }
     }
+    return 1;
+}
+
+static int
+optimize_for(stmt_ty* stmt_ptr, PySTEntryObject* ste, PyArena* arena)
+{
+    stmt_ty stmt = *stmt_ptr;
+    if (!optimize_expr(&stmt->v.For.target, ste, arena))
+        return 0;
+    if (!_optimize_for_iter(&stmt, ste, arena))
+        return 0;
     if (!optimize_stmt_seq(&stmt->v.For.body, ste, arena))
         return 0;
     if (!optimize_stmt_seq(&stmt->v.For.orelse, ste, arena))


More information about the Python-checkins mailing list