[pypy-commit] pypy apptest-file: Backport new files from py3tests branch

rlamy pypy.commits at gmail.com
Fri Mar 23 10:19:38 EDT 2018


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: apptest-file
Changeset: r94119:ba2f4ba5e8c4
Date: 2018-03-23 14:37 +0100
http://bitbucket.org/pypy/pypy/changeset/ba2f4ba5e8c4/

Log:	Backport new files from py3tests branch

diff --git a/pypy/tool/pytest/app_rewrite.py b/pypy/tool/pytest/app_rewrite.py
new file mode 100644
--- /dev/null
+++ b/pypy/tool/pytest/app_rewrite.py
@@ -0,0 +1,39 @@
+import re
+
+ASCII_IS_DEFAULT_ENCODING = False
+
+cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
+BOM_UTF8 = '\xef\xbb\xbf'
+
+def _prepare_source(fn):
+    """Read the source code for re-writing."""
+    try:
+        stat = fn.stat()
+        source = fn.read("rb")
+    except EnvironmentError:
+        return None, None
+    if ASCII_IS_DEFAULT_ENCODING:
+        # ASCII is the default encoding in Python 2. Without a coding
+        # declaration, Python 2 will complain about any bytes in the file
+        # outside the ASCII range. Sadly, this behavior does not extend to
+        # compile() or ast.parse(), which prefer to interpret the bytes as
+        # latin-1. (At least they properly handle explicit coding cookies.) To
+        # preserve this error behavior, we could force ast.parse() to use ASCII
+        # as the encoding by inserting a coding cookie. Unfortunately, that
+        # messes up line numbers. Thus, we have to check ourselves if anything
+        # is outside the ASCII range in the case no encoding is explicitly
+        # declared. For more context, see issue #269. Yay for Python 3 which
+        # gets this right.
+        end1 = source.find("\n")
+        end2 = source.find("\n", end1 + 1)
+        if (not source.startswith(BOM_UTF8) and
+            cookie_re.match(source[0:end1]) is None and
+            cookie_re.match(source[end1 + 1:end2]) is None):
+            try:
+                source.decode("ascii")
+            except UnicodeDecodeError:
+                # Let it fail in real import.
+                return None, None
+    # On Python versions which are not 2.7 and less than or equal to 3.1, the
+    # parser expects *nix newlines.
+    return stat, source
diff --git a/pypy/tool/pytest/apptest2.py b/pypy/tool/pytest/apptest2.py
new file mode 100644
--- /dev/null
+++ b/pypy/tool/pytest/apptest2.py
@@ -0,0 +1,119 @@
+import sys
+import os
+
+import pytest
+from pypy import pypydir
+import pypy.interpreter.function
+from pypy.tool.pytest import app_rewrite
+from pypy.interpreter.error import OperationError
+from pypy.interpreter.module import Module
+from pypy.tool.pytest import objspace
+from pypy.tool.pytest import appsupport
+
+
+class AppTestModule(pytest.Module):
+    def __init__(self, path, parent, rewrite_asserts=False):
+        super(AppTestModule, self).__init__(path, parent)
+        self.rewrite_asserts = rewrite_asserts
+
+    def collect(self):
+        _, source = app_rewrite._prepare_source(self.fspath)
+        space = objspace.gettestobjspace()
+        w_rootdir = space.newtext(
+            os.path.join(pypydir, 'tool', 'pytest', 'ast-rewriter'))
+        w_source = space.newtext(source)
+        fname = str(self.fspath)
+        w_fname = space.newtext(fname)
+        if self.rewrite_asserts:
+            w_mod = space.appexec([w_rootdir, w_source, w_fname],
+                                """(rootdir, source, fname):
+                import sys
+                sys.path.insert(0, rootdir)
+                from ast_rewrite import rewrite_asserts, create_module
+
+                co = rewrite_asserts(source, fname)
+                mod = create_module(fname, co)
+                return mod
+            """)
+        else:
+            w_mod = create_module(space, w_fname, fname, source)
+        mod_dict = w_mod.getdict(space).unwrap(space)
+        items = []
+        for name, w_obj in mod_dict.items():
+            if not name.startswith('test_'):
+                continue
+            if not isinstance(w_obj, pypy.interpreter.function.Function):
+                continue
+            items.append(AppTestFunction(name, self, w_obj))
+        return items
+
+    def setup(self):
+        pass
+
+def create_module(space, w_name, filename, source):
+    w_mod = Module(space, w_name)
+    w_dict = w_mod.getdict(space)
+    space.setitem(w_dict, space.newtext('__file__'), space.newtext(filename))
+    space.exec_(source, w_dict, w_dict, filename=filename)
+    return w_mod
+
+
+class AppError(Exception):
+
+    def __init__(self, excinfo):
+        self.excinfo = excinfo
+
+
+class AppTestFunction(pytest.Item):
+
+    def __init__(self, name, parent, w_obj):
+        super(AppTestFunction, self).__init__(name, parent)
+        self.w_obj = w_obj
+
+    def runtest(self):
+        target = self.w_obj
+        space = target.space
+        self.execute_appex(space, target)
+
+    def repr_failure(self, excinfo):
+        if excinfo.errisinstance(AppError):
+            excinfo = excinfo.value.excinfo
+        return super(AppTestFunction, self).repr_failure(excinfo)
+
+    def execute_appex(self, space, w_func):
+        space.getexecutioncontext().set_sys_exc_info(None)
+        sig = w_func.code._signature
+        if sig.varargname or sig.kwargname or sig.kwonlyargnames:
+            raise ValueError(
+                'Test functions may not use *args, **kwargs or '
+                'keyword-only args')
+        args_w = self.get_fixtures(space, sig.argnames)
+        try:
+            space.call_function(w_func, *args_w)
+        except OperationError as e:
+            if self.config.option.raise_operr:
+                raise
+            tb = sys.exc_info()[2]
+            if e.match(space, space.w_KeyboardInterrupt):
+                raise KeyboardInterrupt, KeyboardInterrupt(), tb
+            appexcinfo = appsupport.AppExceptionInfo(space, e)
+            if appexcinfo.traceback:
+                raise AppError, AppError(appexcinfo), tb
+            raise
+
+    def reportinfo(self):
+        """Must return a triple (fspath, lineno, test_name)"""
+        lineno = self.w_obj.code.co_firstlineno
+        return self.parent.fspath, lineno, self.w_obj.name
+
+    def get_fixtures(self, space, fixtures):
+        if not fixtures:
+            return []
+        import imp
+        fixtures_mod = imp.load_source(
+            'fixtures', str(self.parent.fspath.new(basename='fixtures.py')))
+        result = []
+        for name in fixtures:
+            arg = getattr(fixtures_mod, name)(space, self.parent.config)
+            result.append(arg)
+        return result
diff --git a/pypy/tool/pytest/ast-rewriter/ast_rewrite.py b/pypy/tool/pytest/ast-rewriter/ast_rewrite.py
new file mode 100644
--- /dev/null
+++ b/pypy/tool/pytest/ast-rewriter/ast_rewrite.py
@@ -0,0 +1,648 @@
+"""Rewrite assertion AST to produce nice error messages"""
+from __future__ import absolute_import, division, print_function
+import ast
+import itertools
+import marshal
+import struct
+import sys
+
+from ast_util import assertrepr_compare, format_explanation as _format_explanation
+
+
+# pytest caches rewritten pycs in __pycache__.
+PYTEST_TAG = sys.implementation.cache_tag + "-PYTEST"
+
+PYC_EXT = ".py" + (__debug__ and "c" or "o")
+PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
+
+
+if sys.version_info >= (3, 5):
+    ast_Call = ast.Call
+else:
+    def ast_Call(a, b, c):
+        return ast.Call(a, b, c, None, None)
+
+
+def _write_pyc(state, co, source_stat, pyc):
+    # Technically, we don't have to have the same pyc format as
+    # (C)Python, since these "pycs" should never be seen by builtin
+    # import. However, there's little reason deviate, and I hope
+    # sometime to be able to use imp.load_compiled to load them. (See
+    # the comment in load_module above.)
+    try:
+        fp = open(pyc, "wb")
+    except IOError:
+        err = sys.exc_info()[1].errno
+        state.trace("error writing pyc file at %s: errno=%s" % (pyc, err))
+        # we ignore any failure to write the cache file
+        # there are many reasons, permission-denied, __pycache__ being a
+        # file etc.
+        return False
+    try:
+        fp.write(imp.get_magic())
+        mtime = int(source_stat.mtime)
+        size = source_stat.size & 0xFFFFFFFF
+        fp.write(struct.pack("<ll", mtime, size))
+        marshal.dump(co, fp)
+    finally:
+        fp.close()
+    return True
+
+
+def rewrite_asserts(source, filename):
+    """Parse the source code and rewrite asserts statements
+
+    Returns a module object.
+    """
+    tree = ast.parse(source)
+    AssertionRewriter(filename).run(tree)
+    co = compile(tree, filename, 'exec')
+    return co
+
+
+def create_module(filename, co, pyc=None):
+    """Create a module from a code object created by rewrite_asserts()"""
+    mod = type(sys)(filename)
+    mod.__file__ = co.co_filename
+    if pyc is not None:
+        mod.__cached__ = pyc
+    mod.__loader__ = None
+    exec(co, mod.__dict__)
+    return mod
+
+
+def _make_rewritten_pyc(state, source_stat, pyc, co):
+    """Try to dump rewritten code to *pyc*."""
+    import os
+    if sys.platform.startswith("win"):
+        # Windows grants exclusive access to open files and doesn't have atomic
+        # rename, so just write into the final file.
+        _write_pyc(state, co, source_stat, pyc)
+    else:
+        # When not on windows, assume rename is atomic. Dump the code object
+        # into a file specific to this process and atomically replace it.
+        proc_pyc = pyc + "." + str(os.getpid())
+        if _write_pyc(state, co, source_stat, proc_pyc):
+            os.rename(proc_pyc, pyc)
+
+
+def _read_pyc(source, pyc, trace=lambda x: None):
+    """Possibly read a pytest pyc containing rewritten code.
+
+    Return rewritten code if successful or None if not.
+    """
+    try:
+        fp = open(pyc, "rb")
+    except IOError:
+        return None
+    with fp:
+        try:
+            mtime = int(source.mtime())
+            size = source.size()
+            data = fp.read(12)
+        except EnvironmentError as e:
+            trace('_read_pyc(%s): EnvironmentError %s' % (source, e))
+            return None
+        # Check for invalid or out of date pyc file.
+        if (len(data) != 12 or data[:4] != imp.get_magic() or
+                struct.unpack("<ll", data[4:]) != (mtime, size)):
+            trace('_read_pyc(%s): invalid or out of date pyc' % source)
+            return None
+        try:
+            co = marshal.load(fp)
+        except Exception as e:
+            trace('_read_pyc(%s): marshal.load error %s' % (source, e))
+            return None
+        if not isinstance(co, types.CodeType):
+            trace('_read_pyc(%s): not a code object' % source)
+            return None
+        return co
+
+
+def _saferepr(obj):
+    """Get a safe repr of an object for assertion error messages.
+
+    The assertion formatting (util.format_explanation()) requires
+    newlines to be escaped since they are a special character for it.
+    Normally assertion.util.format_explanation() does this but for a
+    custom repr it is possible to contain one of the special escape
+    sequences, especially '\n{' and '\n}' are likely to be present in
+    JSON reprs.
+
+    """
+    return repr(obj).replace('\n', '\\n')
+
+
+def _format_assertmsg(obj):
+    """Format the custom assertion message given.
+
+    For strings this simply replaces newlines with '\n~' so that
+    util.format_explanation() will preserve them instead of escaping
+    newlines.  For other objects py.io.saferepr() is used first.
+
+    """
+    return obj.replace("\n", "\n~").replace("%", "%%")
+
+
+def _should_repr_global_name(obj):
+    return not hasattr(obj, "__name__") and not callable(obj)
+
+
+def _format_boolop(explanations, is_or):
+    explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
+    return explanation.replace('%', '%%')
+
+
+def _call_reprcompare(ops, results, expls, each_obj):
+    for i, res, expl in zip(range(len(ops)), results, expls):
+        try:
+            done = not res
+        except Exception:
+            done = True
+        if done:
+            break
+    custom = assertrepr_compare(ops[i], each_obj[i], each_obj[i + 1])
+    if custom is not None:
+        return custom
+    return expl
+
+
+unary_map = {
+    ast.Not: "not %s",
+    ast.Invert: "~%s",
+    ast.USub: "-%s",
+    ast.UAdd: "+%s"
+}
+
+binop_map = {
+    ast.BitOr: "|",
+    ast.BitXor: "^",
+    ast.BitAnd: "&",
+    ast.LShift: "<<",
+    ast.RShift: ">>",
+    ast.Add: "+",
+    ast.Sub: "-",
+    ast.Mult: "*",
+    ast.Div: "/",
+    ast.FloorDiv: "//",
+    ast.Mod: "%%",  # escaped for string formatting
+    ast.Eq: "==",
+    ast.NotEq: "!=",
+    ast.Lt: "<",
+    ast.LtE: "<=",
+    ast.Gt: ">",
+    ast.GtE: ">=",
+    ast.Pow: "**",
+    ast.Is: "is",
+    ast.IsNot: "is not",
+    ast.In: "in",
+    ast.NotIn: "not in"
+}
+# Python 3.5+ compatibility
+try:
+    binop_map[ast.MatMult] = "@"
+except AttributeError:
+    pass
+
+# Python 3.4+ compatibility
+if hasattr(ast, "NameConstant"):
+    _NameConstant = ast.NameConstant
+else:
+    def _NameConstant(c):
+        return ast.Name(str(c), ast.Load())
+
+
+def set_location(node, lineno, col_offset):
+    """Set node location information recursively."""
+    def _fix(node, lineno, col_offset):
+        if "lineno" in node._attributes:
+            node.lineno = lineno
+        if "col_offset" in node._attributes:
+            node.col_offset = col_offset
+        for child in ast.iter_child_nodes(node):
+            _fix(child, lineno, col_offset)
+    _fix(node, lineno, col_offset)
+    return node
+
+
+class AssertionRewriter(ast.NodeVisitor):
+    """Assertion rewriting implementation.
+
+    The main entrypoint is to call .run() with an ast.Module instance,
+    this will then find all the assert statements and rewrite them to
+    provide intermediate values and a detailed assertion error.  See
+    http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
+    for an overview of how this works.
+
+    The entry point here is .run() which will iterate over all the
+    statements in an ast.Module and for each ast.Assert statement it
+    finds call .visit() with it.  Then .visit_Assert() takes over and
+    is responsible for creating new ast statements to replace the
+    original assert statement: it rewrites the test of an assertion
+    to provide intermediate values and replace it with an if statement
+    which raises an assertion error with a detailed explanation in
+    case the expression is false.
+
+    For this .visit_Assert() uses the visitor pattern to visit all the
+    AST nodes of the ast.Assert.test field, each visit call returning
+    an AST node and the corresponding explanation string.  During this
+    state is kept in several instance attributes:
+
+    :statements: All the AST statements which will replace the assert
+       statement.
+
+    :variables: This is populated by .variable() with each variable
+       used by the statements so that they can all be set to None at
+       the end of the statements.
+
+    :variable_counter: Counter to create new unique variables needed
+       by statements.  Variables are created using .variable() and
+       have the form of "@py_assert0".
+
+    :on_failure: The AST statements which will be executed if the
+       assertion test fails.  This is the code which will construct
+       the failure message and raises the AssertionError.
+
+    :explanation_specifiers: A dict filled by .explanation_param()
+       with %-formatting placeholders and their corresponding
+       expressions to use in the building of an assertion message.
+       This is used by .pop_format_context() to build a message.
+
+    :stack: A stack of the explanation_specifiers dicts maintained by
+       .push_format_context() and .pop_format_context() which allows
+       to build another %-formatted string while already building one.
+
+    This state is reset on every new assert statement visited and used
+    by the other visitors.
+
+    """
+
+    def __init__(self, module_path):
+        super(AssertionRewriter, self).__init__()
+        self.module_path = module_path
+
+    def run(self, mod):
+        """Find all assert statements in *mod* and rewrite them."""
+        if not mod.body:
+            # Nothing to do.
+            return
+        # Insert some special imports at the top of the module but after any
+        # docstrings and __future__ imports.
+        if sys.version_info[0] >= 3:
+            builtin_name = 'builtins'
+        else:
+            builtin_name = '__builtin__'
+        aliases = [ast.alias(builtin_name, "@py_builtins"),
+                   ast.alias("ast_rewrite", "@pytest_ar")]
+        doc = getattr(mod, "docstring", None)
+        expect_docstring = doc is None
+        if doc is not None and self.is_rewrite_disabled(doc):
+            return
+        pos = 0
+        lineno = 1
+        for item in mod.body:
+            if (expect_docstring and isinstance(item, ast.Expr) and
+                    isinstance(item.value, ast.Str)):
+                doc = item.value.s
+                if self.is_rewrite_disabled(doc):
+                    return
+                expect_docstring = False
+            elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
+                  item.module != "__future__"):
+                lineno = item.lineno
+                break
+            pos += 1
+        else:
+            lineno = item.lineno
+        imports = [ast.Import([alias], lineno=lineno, col_offset=0)
+                   for alias in aliases]
+        mod.body[pos:pos] = imports
+        # Collect asserts.
+        nodes = [mod]
+        while nodes:
+            node = nodes.pop()
+            for name, field in ast.iter_fields(node):
+                if isinstance(field, list):
+                    new = []
+                    for i, child in enumerate(field):
+                        if isinstance(child, ast.Assert):
+                            # Transform assert.
+                            new.extend(self.visit(child))
+                        else:
+                            new.append(child)
+                            if isinstance(child, ast.AST):
+                                nodes.append(child)
+                    setattr(node, name, new)
+                elif (isinstance(field, ast.AST) and
+                      # Don't recurse into expressions as they can't contain
+                      # asserts.
+                      not isinstance(field, ast.expr)):
+                    nodes.append(field)
+
+    @staticmethod
+    def is_rewrite_disabled(docstring):
+        return "PYTEST_DONT_REWRITE" in docstring
+
+    def variable(self):
+        """Get a new variable."""
+        # Use a character invalid in python identifiers to avoid clashing.
+        name = "@py_assert" + str(next(self.variable_counter))
+        self.variables.append(name)
+        return name
+
+    def assign(self, expr):
+        """Give *expr* a name."""
+        name = self.variable()
+        self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
+        return ast.Name(name, ast.Load())
+
+    def display(self, expr):
+        """Call py.io.saferepr on the expression."""
+        return self.helper("saferepr", expr)
+
+    def helper(self, name, *args):
+        """Call a helper in this module."""
+        py_name = ast.Name("@pytest_ar", ast.Load())
+        attr = ast.Attribute(py_name, "_" + name, ast.Load())
+        return ast_Call(attr, list(args), [])
+
+    def builtin(self, name):
+        """Return the builtin called *name*."""
+        builtin_name = ast.Name("@py_builtins", ast.Load())
+        return ast.Attribute(builtin_name, name, ast.Load())
+
+    def explanation_param(self, expr):
+        """Return a new named %-formatting placeholder for expr.
+
+        This creates a %-formatting placeholder for expr in the
+        current formatting context, e.g. ``%(py0)s``.  The placeholder
+        and expr are placed in the current format context so that it
+        can be used on the next call to .pop_format_context().
+
+        """
+        specifier = "py" + str(next(self.variable_counter))
+        self.explanation_specifiers[specifier] = expr
+        return "%(" + specifier + ")s"
+
+    def push_format_context(self):
+        """Create a new formatting context.
+
+        The format context is used for when an explanation wants to
+        have a variable value formatted in the assertion message.  In
+        this case the value required can be added using
+        .explanation_param().  Finally .pop_format_context() is used
+        to format a string of %-formatted values as added by
+        .explanation_param().
+
+        """
+        self.explanation_specifiers = {}
+        self.stack.append(self.explanation_specifiers)
+
+    def pop_format_context(self, expl_expr):
+        """Format the %-formatted string with current format context.
+
+        The expl_expr should be an ast.Str instance constructed from
+        the %-placeholders created by .explanation_param().  This will
+        add the required code to format said string to .on_failure and
+        return the ast.Name instance of the formatted string.
+
+        """
+        current = self.stack.pop()
+        if self.stack:
+            self.explanation_specifiers = self.stack[-1]
+        keys = [ast.Str(key) for key in current.keys()]
+        format_dict = ast.Dict(keys, list(current.values()))
+        form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
+        name = "@py_format" + str(next(self.variable_counter))
+        self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
+        return ast.Name(name, ast.Load())
+
+    def generic_visit(self, node):
+        """Handle expressions we don't have custom code for."""
+        assert isinstance(node, ast.expr)
+        res = self.assign(node)
+        return res, self.explanation_param(self.display(res))
+
+    def visit_Assert(self, assert_):
+        """Return the AST statements to replace the ast.Assert instance.
+
+        This rewrites the test of an assertion to provide
+        intermediate values and replace it with an if statement which
+        raises an assertion error with a detailed explanation in case
+        the expression is false.
+
+        """
+        self.statements = []
+        self.variables = []
+        self.variable_counter = itertools.count()
+        self.stack = []
+        self.on_failure = []
+        self.push_format_context()
+        # Rewrite assert into a bunch of statements.
+        top_condition, explanation = self.visit(assert_.test)
+        # Create failure message.
+        body = self.on_failure
+        negation = ast.UnaryOp(ast.Not(), top_condition)
+        self.statements.append(ast.If(negation, body, []))
+        if assert_.msg:
+            assertmsg = self.helper('format_assertmsg', assert_.msg)
+            explanation = "\n>assert " + explanation
+        else:
+            assertmsg = ast.Str("")
+            explanation = "assert " + explanation
+        template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
+        msg = self.pop_format_context(template)
+        fmt = self.helper("format_explanation", msg)
+        err_name = ast.Name("AssertionError", ast.Load())
+        exc = ast_Call(err_name, [fmt], [])
+        if sys.version_info[0] >= 3:
+            raise_ = ast.Raise(exc, None)
+        else:
+            raise_ = ast.Raise(exc, None, None)
+        body.append(raise_)
+        # Clear temporary variables by setting them to None.
+        if self.variables:
+            variables = [ast.Name(name, ast.Store())
+                         for name in self.variables]
+            clear = ast.Assign(variables, _NameConstant(None))
+            self.statements.append(clear)
+        # Fix line numbers.
+        for stmt in self.statements:
+            set_location(stmt, assert_.lineno, assert_.col_offset)
+        return self.statements
+
+    def visit_Name(self, name):
+        # Display the repr of the name if it's a local variable or
+        # _should_repr_global_name() thinks it's acceptable.
+        locs = ast_Call(self.builtin("locals"), [], [])
+        inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
+        dorepr = self.helper("should_repr_global_name", name)
+        test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
+        expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
+        return name, self.explanation_param(expr)
+
+    def visit_BoolOp(self, boolop):
+        res_var = self.variable()
+        expl_list = self.assign(ast.List([], ast.Load()))
+        app = ast.Attribute(expl_list, "append", ast.Load())
+        is_or = int(isinstance(boolop.op, ast.Or))
+        body = save = self.statements
+        fail_save = self.on_failure
+        levels = len(boolop.values) - 1
+        self.push_format_context()
+        # Process each operand, short-circuting if needed.
+        for i, v in enumerate(boolop.values):
+            if i:
+                fail_inner = []
+                # cond is set in a prior loop iteration below
+                self.on_failure.append(ast.If(cond, fail_inner, []))  # noqa
+                self.on_failure = fail_inner
+            self.push_format_context()
+            res, expl = self.visit(v)
+            body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
+            expl_format = self.pop_format_context(ast.Str(expl))
+            call = ast_Call(app, [expl_format], [])
+            self.on_failure.append(ast.Expr(call))
+            if i < levels:
+                cond = res
+                if is_or:
+                    cond = ast.UnaryOp(ast.Not(), cond)
+                inner = []
+                self.statements.append(ast.If(cond, inner, []))
+                self.statements = body = inner
+        self.statements = save
+        self.on_failure = fail_save
+        expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
+        expl = self.pop_format_context(expl_template)
+        return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
+
+    def visit_UnaryOp(self, unary):
+        pattern = unary_map[unary.op.__class__]
+        operand_res, operand_expl = self.visit(unary.operand)
+        res = self.assign(ast.UnaryOp(unary.op, operand_res))
+        return res, pattern % (operand_expl,)
+
+    def visit_BinOp(self, binop):
+        symbol = binop_map[binop.op.__class__]
+        left_expr, left_expl = self.visit(binop.left)
+        right_expr, right_expl = self.visit(binop.right)
+        explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
+        res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
+        return res, explanation
+
+    def visit_Call_35(self, call):
+        """
+        visit `ast.Call` nodes on Python3.5 and after
+        """
+        new_func, func_expl = self.visit(call.func)
+        arg_expls = []
+        new_args = []
+        new_kwargs = []
+        for arg in call.args:
+            res, expl = self.visit(arg)
+            arg_expls.append(expl)
+            new_args.append(res)
+        for keyword in call.keywords:
+            res, expl = self.visit(keyword.value)
+            new_kwargs.append(ast.keyword(keyword.arg, res))
+            if keyword.arg:
+                arg_expls.append(keyword.arg + "=" + expl)
+            else:  # **args have `arg` keywords with an .arg of None
+                arg_expls.append("**" + expl)
+
+        expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
+        new_call = ast.Call(new_func, new_args, new_kwargs)
+        res = self.assign(new_call)
+        res_expl = self.explanation_param(self.display(res))
+        outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
+        return res, outer_expl
+
+    def visit_Starred(self, starred):
+        # From Python 3.5, a Starred node can appear in a function call
+        res, expl = self.visit(starred.value)
+        return starred, '*' + expl
+
+    def visit_Call_legacy(self, call):
+        """
+        visit `ast.Call nodes on 3.4 and below`
+        """
+        new_func, func_expl = self.visit(call.func)
+        arg_expls = []
+        new_args = []
+        new_kwargs = []
+        new_star = new_kwarg = None
+        for arg in call.args:
+            res, expl = self.visit(arg)
+            new_args.append(res)
+            arg_expls.append(expl)
+        for keyword in call.keywords:
+            res, expl = self.visit(keyword.value)
+            new_kwargs.append(ast.keyword(keyword.arg, res))
+            arg_expls.append(keyword.arg + "=" + expl)
+        if call.starargs:
+            new_star, expl = self.visit(call.starargs)
+            arg_expls.append("*" + expl)
+        if call.kwargs:
+            new_kwarg, expl = self.visit(call.kwargs)
+            arg_expls.append("**" + expl)
+        expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
+        new_call = ast.Call(new_func, new_args, new_kwargs,
+                            new_star, new_kwarg)
+        res = self.assign(new_call)
+        res_expl = self.explanation_param(self.display(res))
+        outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
+        return res, outer_expl
+
+    # ast.Call signature changed on 3.5,
+    # conditionally change  which methods is named
+    # visit_Call depending on Python version
+    if sys.version_info >= (3, 5):
+        visit_Call = visit_Call_35
+    else:
+        visit_Call = visit_Call_legacy
+
+    def visit_Attribute(self, attr):
+        if not isinstance(attr.ctx, ast.Load):
+            return self.generic_visit(attr)
+        value, value_expl = self.visit(attr.value)
+        res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
+        res_expl = self.explanation_param(self.display(res))
+        pat = "%s\n{%s = %s.%s\n}"
+        expl = pat % (res_expl, res_expl, value_expl, attr.attr)
+        return res, expl
+
+    def visit_Compare(self, comp):
+        self.push_format_context()
+        left_res, left_expl = self.visit(comp.left)
+        if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
+            left_expl = "({0})".format(left_expl)
+        res_variables = [self.variable() for i in range(len(comp.ops))]
+        load_names = [ast.Name(v, ast.Load()) for v in res_variables]
+        store_names = [ast.Name(v, ast.Store()) for v in res_variables]
+        it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
+        expls = []
+        syms = []
+        results = [left_res]
+        for i, op, next_operand in it:
+            next_res, next_expl = self.visit(next_operand)
+            if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
+                next_expl = "({0})".format(next_expl)
+            results.append(next_res)
+            sym = binop_map[op.__class__]
+            syms.append(ast.Str(sym))
+            expl = "%s %s %s" % (left_expl, sym, next_expl)
+            expls.append(ast.Str(expl))
+            res_expr = ast.Compare(left_res, [op], [next_res])
+            self.statements.append(ast.Assign([store_names[i]], res_expr))
+            left_res, left_expl = next_res, next_expl
+        # Use pytest.assertion.util._reprcompare if that's available.
+        expl_call = self.helper("call_reprcompare",
+                                ast.Tuple(syms, ast.Load()),
+                                ast.Tuple(load_names, ast.Load()),
+                                ast.Tuple(expls, ast.Load()),
+                                ast.Tuple(results, ast.Load()))
+        if len(comp.ops) > 1:
+            res = ast.BoolOp(ast.And(), load_names)
+        else:
+            res = load_names[0]
+        return res, self.explanation_param(self.pop_format_context(expl_call))
diff --git a/pypy/tool/pytest/ast-rewriter/ast_util.py b/pypy/tool/pytest/ast-rewriter/ast_util.py
new file mode 100644
--- /dev/null
+++ b/pypy/tool/pytest/ast-rewriter/ast_util.py
@@ -0,0 +1,289 @@
+"""Utilities for assertion debugging"""
+from __future__ import absolute_import, division, print_function
+
+u = str
+
+# The _reprcompare attribute on the util module is used by the new assertion
+# interpretation code and assertion rewriter to detect this plugin was
+# loaded and in turn call the hooks defined here as part of the
+# DebugInterpreter.
+_reprcompare = None
+
+
+def format_explanation(explanation):
+    """This formats an explanation
+
+    Normally all embedded newlines are escaped, however there are
+    three exceptions: \n{, \n} and \n~.  The first two are intended
+    cover nested explanations, see function and attribute explanations
+    for examples (.visit_Call(), visit_Attribute()).  The last one is
+    for when one explanation needs to span multiple lines, e.g. when
+    displaying diffs.
+    """
+    lines = _split_explanation(explanation)
+    result = _format_lines(lines)
+    return '\n'.join(result)
+
+
+def _split_explanation(explanation):
+    """Return a list of individual lines in the explanation
+
+    This will return a list of lines split on '\n{', '\n}' and '\n~'.
+    Any other newlines will be escaped and appear in the line as the
+    literal '\n' characters.
+    """
+    raw_lines = (explanation or '').split('\n')
+    lines = [raw_lines[0]]
+    for values in raw_lines[1:]:
+        if values and values[0] in ['{', '}', '~', '>']:
+            lines.append(values)
+        else:
+            lines[-1] += '\\n' + values
+    return lines
+
+
+def _format_lines(lines):
+    """Format the individual lines
+
+    This will replace the '{', '}' and '~' characters of our mini
+    formatting language with the proper 'where ...', 'and ...' and ' +
+    ...' text, taking care of indentation along the way.
+
+    Return a list of formatted lines.
+    """
+    result = lines[:1]
+    stack = [0]
+    stackcnt = [0]
+    for line in lines[1:]:
+        if line.startswith('{'):
+            if stackcnt[-1]:
+                s = 'and   '
+            else:
+                s = 'where '
+            stack.append(len(result))
+            stackcnt[-1] += 1
+            stackcnt.append(0)
+            result.append(' +' + '  ' * (len(stack) - 1) + s + line[1:])
+        elif line.startswith('}'):
+            stack.pop()
+            stackcnt.pop()
+            result[stack[-1]] += line[1:]
+        else:
+            assert line[0] in ['~', '>']
+            stack[-1] += 1
+            indent = len(stack) if line.startswith('~') else len(stack) - 1
+            result.append('  ' * indent + line[1:])
+    assert len(stack) == 1
+    return result
+
+
+# Provide basestring in python3
+try:
+    basestring = basestring
+except NameError:
+    basestring = str
+
+
+def saferepr(obj, maxsize=None):
+    s = repr(obj)
+    if maxsize is not None:
+        s = s[:maxsize]
+    return s
+
+
+def assertrepr_compare(op, left, right, verbose=False):
+    """Return specialised explanations for some operators/operands"""
+    width = 80 - 15 - len(op) - 2  # 15 chars indentation, 1 space around op
+    left_repr = saferepr(left, maxsize=int(width // 2))
+    right_repr = saferepr(right, maxsize=width - len(left_repr))
+
+    summary = u('%s %s %s') % (left_repr, op, right_repr)
+
+    def issequence(x):
+        return hasattr(x, '__iter__') and not isinstance(x, basestring)
+
+    def istext(x):
+        return isinstance(x, basestring)
+
+    def isdict(x):
+        return isinstance(x, dict)
+
+    def isset(x):
+        return isinstance(x, (set, frozenset))
+
+    def isiterable(obj):
+        try:
+            iter(obj)
+            return not istext(obj)
+        except TypeError:
+            return False
+
+    explanation = None
+    try:
+        if op == '==':
+            if istext(left) and istext(right):
+                explanation = _diff_text(left, right, verbose)
+            else:
+                if issequence(left) and issequence(right):
+                    explanation = _compare_eq_sequence(left, right, verbose)
+                elif isset(left) and isset(right):
+                    explanation = _compare_eq_set(left, right, verbose)
+                elif isdict(left) and isdict(right):
+                    explanation = _compare_eq_dict(left, right, verbose)
+                if isiterable(left) and isiterable(right):
+                    expl = _compare_eq_iterable(left, right, verbose)
+                    if explanation is not None:
+                        explanation.extend(expl)
+                    else:
+                        explanation = expl
+        elif op == 'not in':
+            if istext(left) and istext(right):
+                explanation = _notin_text(left, right, verbose)
+    except Exception as err:
+        explanation = [
+            '(pytest assertion: representation of details failed.  '
+            'Probably an object has a faulty __repr__.)', str(err)]
+
+    if not explanation:
+        return None
+
+    return [summary] + explanation
+
+
+def _diff_text(left, right, verbose=False):
+    """Return the explanation for the diff between text or bytes
+
+    Unless --verbose is used this will skip leading and trailing
+    characters which are identical to keep the diff minimal.
+
+    If the input are bytes they will be safely converted to text.
+    """
+    from difflib import ndiff
+    explanation = []
+    if not verbose:
+        i = 0  # just in case left or right has zero length
+        for i in range(min(len(left), len(right))):
+            if left[i] != right[i]:
+                break
+        if i > 42:
+            i -= 10                 # Provide some context
+            explanation = ['Skipping %s identical leading '
+                           'characters in diff, use -v to show' % i]
+            left = left[i:]
+            right = right[i:]
+        if len(left) == len(right):
+            for i in range(len(left)):
+                if left[-i] != right[-i]:
+                    break
+            if i > 42:
+                i -= 10     # Provide some context
+                explanation += ['Skipping %s identical trailing '
+                                'characters in diff, use -v to show' % i]
+                left = left[:-i]
+                right = right[:-i]
+    keepends = True
+    explanation += [line.strip('\n')
+                    for line in ndiff(left.splitlines(keepends),
+                                      right.splitlines(keepends))]
+    return explanation
+
+
+def _compare_eq_iterable(left, right, verbose=False):
+    if not verbose:
+        return ['Use -v to get the full diff']
+    # dynamic import to speedup pytest
+    import difflib, pprint
+    try:
+        left_formatting = pprint.pformat(left).splitlines()
+        right_formatting = pprint.pformat(right).splitlines()
+        explanation = ['Full diff:']
+    except Exception:
+        # hack: PrettyPrinter.pformat() in python 2 fails when formatting items that can't be sorted(), ie, calling
+        # sorted() on a list would raise. See issue #718.
+        # As a workaround, the full diff is generated by using the repr() string of each item of each container.
+        left_formatting = sorted(repr(x) for x in left)
+        right_formatting = sorted(repr(x) for x in right)
+        explanation = ['Full diff (fallback to calling repr on each item):']
+    explanation.extend(line.strip() for line in difflib.ndiff(left_formatting, right_formatting))
+    return explanation
+
+
+def _compare_eq_sequence(left, right, verbose=False):
+    explanation = []
+    for i in range(min(len(left), len(right))):
+        if left[i] != right[i]:
+            explanation += [u('At index %s diff: %r != %r')
+                            % (i, left[i], right[i])]
+            break
+    if len(left) > len(right):
+        explanation += [u('Left contains more items, first extra item: %s')
+                        % saferepr(left[len(right)],)]
+    elif len(left) < len(right):
+        explanation += [
+            u('Right contains more items, first extra item: %s') %
+            saferepr(right[len(left)],)]
+    return explanation
+
+
+def _compare_eq_set(left, right, verbose=False):
+    explanation = []
+    diff_left = left - right
+    diff_right = right - left
+    if diff_left:
+        explanation.append(u('Extra items in the left set:'))
+        for item in diff_left:
+            explanation.append(saferepr(item))
+    if diff_right:
+        explanation.append(u('Extra items in the right set:'))
+        for item in diff_right:
+            explanation.append(saferepr(item))
+    return explanation
+
+
+def _compare_eq_dict(left, right, verbose=False):
+    import pprint
+    explanation = []
+    common = set(left).intersection(set(right))
+    same = dict((k, left[k]) for k in common if left[k] == right[k])
+    if same and verbose < 2:
+        explanation += [u('Omitting %s identical items, use -vv to show') %
+                        len(same)]
+    elif same:
+        explanation += [u('Common items:')]
+        explanation += pprint.pformat(same).splitlines()
+    diff = set(k for k in common if left[k] != right[k])
+    if diff:
+        explanation += [u('Differing items:')]
+        for k in diff:
+            explanation += [saferepr({k: left[k]}) + ' != ' +
+                            saferepr({k: right[k]})]
+    extra_left = set(left) - set(right)
+    if extra_left:
+        explanation.append(u('Left contains more items:'))
+        explanation.extend(pprint.pformat(
+            dict((k, left[k]) for k in extra_left)).splitlines())
+    extra_right = set(right) - set(left)
+    if extra_right:
+        explanation.append(u('Right contains more items:'))
+        explanation.extend(pprint.pformat(
+            dict((k, right[k]) for k in extra_right)).splitlines())
+    return explanation
+
+
+def _notin_text(term, text, verbose=False):
+    index = text.find(term)
+    head = text[:index]
+    tail = text[index + len(term):]
+    correct_text = head + tail
+    diff = _diff_text(correct_text, text, verbose)
+    newdiff = [u('%s is contained here:') % saferepr(term, maxsize=42)]
+    for line in diff:
+        if line.startswith(u('Skipping')):
+            continue
+        if line.startswith(u('- ')):
+            continue
+        if line.startswith(u('+ ')):
+            newdiff.append(u('  ') + line[2:])
+        else:
+            newdiff.append(line)
+    return newdiff
diff --git a/pypy/tool/pytest/fake_pytest/__init__.py b/pypy/tool/pytest/fake_pytest/__init__.py
new file mode 100644
--- /dev/null
+++ b/pypy/tool/pytest/fake_pytest/__init__.py
@@ -0,0 +1,12 @@
+from pypy.interpreter.mixedmodule import MixedModule
+
+class Module(MixedModule):
+    applevel_name = 'pytest'
+    interpleveldefs = {
+        'raises': 'interp_pytest.pypyraises',
+        'skip': 'interp_pytest.pypyskip',
+        'fixture': 'interp_pytest.fake_fixture',
+    }
+    appleveldefs = {
+        'importorskip': 'app_pytest.importorskip',
+    }
diff --git a/pypy/tool/pytest/fake_pytest/app_pytest.py b/pypy/tool/pytest/fake_pytest/app_pytest.py
new file mode 100644
--- /dev/null
+++ b/pypy/tool/pytest/fake_pytest/app_pytest.py
@@ -0,0 +1,8 @@
+import pytest
+
+
+def importorskip(name):
+    try:
+        return __import__(name)
+    except ImportError:
+        pytest.skip('Module %s not available' % name)
diff --git a/pypy/tool/pytest/fake_pytest/interp_pytest.py b/pypy/tool/pytest/fake_pytest/interp_pytest.py
new file mode 100644
--- /dev/null
+++ b/pypy/tool/pytest/fake_pytest/interp_pytest.py
@@ -0,0 +1,4 @@
+from pypy.tool.pytest.appsupport import pypyraises, pypyskip
+
+def fake_fixture(space, w_arg):
+    return w_arg


More information about the pypy-commit mailing list