[py-svn] commit/pytest: gutworth: rewrite with proper short-circuting on boolean operators (fixes #57)

Bitbucket commits-noreply at bitbucket.org
Wed Jun 29 03:21:23 CEST 2011


1 new changeset in pytest:

http://bitbucket.org/hpk42/pytest/changeset/9dead0879575/
changeset:   9dead0879575
user:        gutworth
date:        2011-06-29 03:21:22
summary:     rewrite with proper short-circuting on boolean operators (fixes #57)
affected #:  2 files (1.2 KB)

--- a/_pytest/assertion/rewrite.py	Tue Jun 28 10:39:11 2011 -0500
+++ b/_pytest/assertion/rewrite.py	Tue Jun 28 20:21:22 2011 -0500
@@ -17,13 +17,8 @@
 _saferepr = py.io.saferepr
 from _pytest.assertion.util import format_explanation as _format_explanation
 
-def _format_boolop(operands, explanations, is_or):
-    show_explanations = []
-    for operand, expl in zip(operands, explanations):
-        show_explanations.append(expl)
-        if operand == is_or:
-            break
-    return "(" + (is_or and " or " or " and ").join(show_explanations) + ")"
+def _format_boolop(explanations, is_or):
+    return "(" + (is_or and " or " or " and ").join(explanations) + ")"
 
 def _call_reprcompare(ops, results, expls, each_obj):
     for i, res, expl in zip(range(len(ops)), results, expls):
@@ -143,7 +138,7 @@
         """Get a new variable."""
         # Use a character invalid in python identifiers to avoid clashing.
         name = "@py_assert" + str(next(self.variable_counter))
-        self.variables.add(name)
+        self.variables[self.cond_chain].add(name)
         return name
 
     def assign(self, expr):
@@ -172,6 +167,13 @@
         self.explanation_specifiers[specifier] = expr
         return "%(" + specifier + ")s"
 
+    def enter_cond(self, cond, body):
+        self.statements.append(ast.If(cond, body, []))
+        self.cond_chain += cond,
+
+    def leave_cond(self, n=1):
+        self.cond_chain = self.cond_chain[:-n]
+
     def push_format_context(self):
         self.explanation_specifiers = {}
         self.stack.append(self.explanation_specifiers)
@@ -198,7 +200,8 @@
             # There's already a message. Don't mess with it.
             return [assert_]
         self.statements = []
-        self.variables = set()
+        self.cond_chain = ()
+        self.variables = collections.defaultdict(set)
         self.variable_counter = itertools.count()
         self.stack = []
         self.on_failure = []
@@ -220,11 +223,22 @@
         else:
             raise_ = ast.Raise(exc, None, None)
         body.append(raise_)
-        # Delete temporary variables.
-        names = [ast.Name(name, ast.Del()) for name in self.variables]
-        if names:
-            delete = ast.Delete(names)
-            self.statements.append(delete)
+        # Delete temporary variables. This requires a bit cleverness about the
+        # order, so we don't delete variables that are themselves conditions for
+        # later variables.
+        for chain in sorted(self.variables, key=len, reverse=True):
+            if chain:
+                where = []
+                if len(chain) > 1:
+                    cond = ast.Boolop(ast.And(), chain)
+                else:
+                    cond = chain[0]
+                self.statements.append(ast.If(cond, where, []))
+            else:
+                where = self.statements
+            v = self.variables[chain]
+            names = [ast.Name(name, ast.Del()) for name in v]
+            where.append(ast.Delete(names))
         # Fix line numbers.
         for stmt in self.statements:
             set_location(stmt, assert_.lineno, assert_.col_offset)
@@ -240,21 +254,32 @@
         return name, self.explanation_param(expr)
 
     def visit_BoolOp(self, boolop):
-        operands = []
-        explanations = []
+        res_var = self.variable()
+        expl_list = self.assign(ast.List([], ast.Load()))
+        app = ast.Attribute(expl_list, "append", ast.Load())
+        is_or = isinstance(boolop.op, ast.Or)
+        body = save = self.statements
+        levels = len(boolop.values) - 1
         self.push_format_context()
-        for operand in boolop.values:
-            res, explanation = self.visit(operand)
-            operands.append(res)
-            explanations.append(explanation)
-        expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load())
-        is_or = ast.Num(isinstance(boolop.op, ast.Or))
-        expl_template = self.helper("format_boolop",
-                                    ast.Tuple(operands, ast.Load()), expls,
-                                    is_or)
+        # Process each operand, short-circuting if needed.
+        for i, v in enumerate(boolop.values):
+            res, expl = self.visit(v)
+            body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
+            call = ast.Call(app, [ast.Str(expl)], [], None, None)
+            body.append(ast.Expr(call))
+            if i < levels:
+                inner = []
+                cond = res
+                if is_or:
+                    cond = ast.UnaryOp(ast.Not(), cond)
+                self.enter_cond(cond, inner)
+                self.statements = body = inner
+        # Leave all conditions.
+        self.leave_cond(levels)
+        self.statements = save
+        expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
         expl = self.pop_format_context(expl_template)
-        res = self.assign(ast.BoolOp(boolop.op, operands))
-        return res, self.explanation_param(expl)
+        return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
 
     def visit_UnaryOp(self, unary):
         pattern = unary_map[unary.op.__class__]


--- a/testing/test_assertrewrite.py	Tue Jun 28 10:39:11 2011 -0500
+++ b/testing/test_assertrewrite.py	Tue Jun 28 20:21:22 2011 -0500
@@ -129,16 +129,23 @@
             assert f or g
         assert getmsg(f) == "assert (False or False)"
         def f():
+            f = g = False
+            assert not f and not g
+        getmsg(f, must_pass=True)
+        def f():
             f = True
             g = False
             assert f or g
         getmsg(f, must_pass=True)
 
     def test_short_circut_evaluation(self):
-        pytest.xfail("complicated fix; I'm not sure if it's important")
         def f():
             assert True or explode
         getmsg(f, must_pass=True)
+        def f():
+            x = 1
+            assert x == 1 or x == 2
+        getmsg(f, must_pass=True)
 
     def test_unary_op(self):
         def f():

Repository URL: https://bitbucket.org/hpk42/pytest/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the pytest-commit mailing list