[py-svn] commit/pytest: gutworth: rewrite with proper short-circuting on boolean operators (fixes #57)
Bitbucket
commits-noreply at bitbucket.org
Wed Jun 29 03:21:23 CEST 2011
1 new changeset in pytest:
http://bitbucket.org/hpk42/pytest/changeset/9dead0879575/
changeset: 9dead0879575
user: gutworth
date: 2011-06-29 03:21:22
summary: rewrite with proper short-circuting on boolean operators (fixes #57)
affected #: 2 files (1.2 KB)
--- a/_pytest/assertion/rewrite.py Tue Jun 28 10:39:11 2011 -0500
+++ b/_pytest/assertion/rewrite.py Tue Jun 28 20:21:22 2011 -0500
@@ -17,13 +17,8 @@
_saferepr = py.io.saferepr
from _pytest.assertion.util import format_explanation as _format_explanation
-def _format_boolop(operands, explanations, is_or):
- show_explanations = []
- for operand, expl in zip(operands, explanations):
- show_explanations.append(expl)
- if operand == is_or:
- break
- return "(" + (is_or and " or " or " and ").join(show_explanations) + ")"
+def _format_boolop(explanations, is_or):
+ return "(" + (is_or and " or " or " and ").join(explanations) + ")"
def _call_reprcompare(ops, results, expls, each_obj):
for i, res, expl in zip(range(len(ops)), results, expls):
@@ -143,7 +138,7 @@
"""Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter))
- self.variables.add(name)
+ self.variables[self.cond_chain].add(name)
return name
def assign(self, expr):
@@ -172,6 +167,13 @@
self.explanation_specifiers[specifier] = expr
return "%(" + specifier + ")s"
+ def enter_cond(self, cond, body):
+ self.statements.append(ast.If(cond, body, []))
+ self.cond_chain += cond,
+
+ def leave_cond(self, n=1):
+ self.cond_chain = self.cond_chain[:-n]
+
def push_format_context(self):
self.explanation_specifiers = {}
self.stack.append(self.explanation_specifiers)
@@ -198,7 +200,8 @@
# There's already a message. Don't mess with it.
return [assert_]
self.statements = []
- self.variables = set()
+ self.cond_chain = ()
+ self.variables = collections.defaultdict(set)
self.variable_counter = itertools.count()
self.stack = []
self.on_failure = []
@@ -220,11 +223,22 @@
else:
raise_ = ast.Raise(exc, None, None)
body.append(raise_)
- # Delete temporary variables.
- names = [ast.Name(name, ast.Del()) for name in self.variables]
- if names:
- delete = ast.Delete(names)
- self.statements.append(delete)
+ # Delete temporary variables. This requires a bit cleverness about the
+ # order, so we don't delete variables that are themselves conditions for
+ # later variables.
+ for chain in sorted(self.variables, key=len, reverse=True):
+ if chain:
+ where = []
+ if len(chain) > 1:
+ cond = ast.Boolop(ast.And(), chain)
+ else:
+ cond = chain[0]
+ self.statements.append(ast.If(cond, where, []))
+ else:
+ where = self.statements
+ v = self.variables[chain]
+ names = [ast.Name(name, ast.Del()) for name in v]
+ where.append(ast.Delete(names))
# Fix line numbers.
for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset)
@@ -240,21 +254,32 @@
return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop):
- operands = []
- explanations = []
+ res_var = self.variable()
+ expl_list = self.assign(ast.List([], ast.Load()))
+ app = ast.Attribute(expl_list, "append", ast.Load())
+ is_or = isinstance(boolop.op, ast.Or)
+ body = save = self.statements
+ levels = len(boolop.values) - 1
self.push_format_context()
- for operand in boolop.values:
- res, explanation = self.visit(operand)
- operands.append(res)
- explanations.append(explanation)
- expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load())
- is_or = ast.Num(isinstance(boolop.op, ast.Or))
- expl_template = self.helper("format_boolop",
- ast.Tuple(operands, ast.Load()), expls,
- is_or)
+ # Process each operand, short-circuting if needed.
+ for i, v in enumerate(boolop.values):
+ res, expl = self.visit(v)
+ body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
+ call = ast.Call(app, [ast.Str(expl)], [], None, None)
+ body.append(ast.Expr(call))
+ if i < levels:
+ inner = []
+ cond = res
+ if is_or:
+ cond = ast.UnaryOp(ast.Not(), cond)
+ self.enter_cond(cond, inner)
+ self.statements = body = inner
+ # Leave all conditions.
+ self.leave_cond(levels)
+ self.statements = save
+ expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template)
- res = self.assign(ast.BoolOp(boolop.op, operands))
- return res, self.explanation_param(expl)
+ return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
def visit_UnaryOp(self, unary):
pattern = unary_map[unary.op.__class__]
--- a/testing/test_assertrewrite.py Tue Jun 28 10:39:11 2011 -0500
+++ b/testing/test_assertrewrite.py Tue Jun 28 20:21:22 2011 -0500
@@ -129,16 +129,23 @@
assert f or g
assert getmsg(f) == "assert (False or False)"
def f():
+ f = g = False
+ assert not f and not g
+ getmsg(f, must_pass=True)
+ def f():
f = True
g = False
assert f or g
getmsg(f, must_pass=True)
def test_short_circut_evaluation(self):
- pytest.xfail("complicated fix; I'm not sure if it's important")
def f():
assert True or explode
getmsg(f, must_pass=True)
+ def f():
+ x = 1
+ assert x == 1 or x == 2
+ getmsg(f, must_pass=True)
def test_unary_op(self):
def f():
Repository URL: https://bitbucket.org/hpk42/pytest/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the pytest-commit
mailing list