--- a/_pytest/assertion/__init__.py	Tue Jun 28 21:11:56 2011 -0500
+++ b/_pytest/assertion/__init__.py	Tue Jun 28 21:13:12 2011 -0500
@@ -2,20 +2,12 @@
 support for presenting detailed information in failing assertions.
 import py
-import imp
-import marshal
-import struct
 import sys
 import pytest
 from _pytest.monkeypatch import monkeypatch
-from _pytest.assertion import reinterpret, util
+from _pytest.assertion import util
-    from _pytest.assertion.rewrite import rewrite_asserts
-except ImportError:
-    rewrite_asserts = None
-    import ast
+REWRITING_AVAILABLE = "_ast" in sys.builtin_module_names
 def pytest_addoption(parser):
     group = parser.getgroup("debugconfig")
@@ -38,9 +30,9 @@
     def __init__(self, config, mode):
         self.mode = mode
         self.trace = config.trace.root.get("assertion")
+        self.pycs = []
 def pytest_configure(config):
-    warn_about_missing_assertion()
     mode = config.getvalue("assertmode")
     if config.getvalue("noassert") or config.getvalue("nomagic"):
         if mode not in ("off", "default"):
@@ -48,7 +40,10 @@
         mode = "off"
     elif mode == "default":
         mode = "on"
+    if mode == "on" and not REWRITING_AVAILABLE:
+        mode = "old"
     if mode != "off":
+        _load_modules(mode)
         def callbinrepr(op, left, right):
             hook_result = config.hook.pytest_assertrepr_compare(
                 config=config, op=op, left=left, right=right)
@@ -60,69 +55,55 @@
         m.setattr(py.builtin.builtins, 'AssertionError',
         m.setattr(util, '_reprcompare', callbinrepr)
-    if mode == "on" and rewrite_asserts is None:
-        mode = "old"
+    hook = None
+    if mode == "on":
+        hook = rewrite.AssertionRewritingHook()
+        sys.meta_path.append(hook)
+    warn_about_missing_assertion(mode)
     config._assertstate = AssertionState(config, mode)
+    config._assertstate.hook = hook
     config._assertstate.trace("configured with mode set to %r" % (mode,))
-def _write_pyc(co, source_path):
-    if hasattr(imp, "cache_from_source"):
-        # Handle PEP 3147 pycs.
-        pyc = py.path.local(imp.cache_from_source(str(source_path)))
-        pyc.ensure()
-    else:
-        pyc = source_path + "c"
-    mtime = int(source_path.mtime())
-    fp = pyc.open("wb")
-    try:
-        fp.write(imp.get_magic())
-        fp.write(struct.pack("<l", mtime))
-        marshal.dump(co, fp)
-    finally:
-        fp.close()
-    return pyc
+def pytest_unconfigure(config):
+    if config._assertstate.mode == "on":
+        rewrite._drain_pycs(config._assertstate)
+    hook = config._assertstate.hook
+    if hook is not None:
+        sys.meta_path.remove(hook)
-def before_module_import(mod):
-    if mod.config._assertstate.mode != "on":
-        return
-    # Some deep magic: load the source, rewrite the asserts, and write a
-    # fake pyc, so that it'll be loaded when the module is imported.
-    source = mod.fspath.read()
-    try:
-        tree = ast.parse(source)
-    except SyntaxError:
-        # Let this pop up again in the real import.
-        mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,))
-        return
-    rewrite_asserts(tree)
-    try:
-        co = compile(tree, str(mod.fspath), "exec")
-    except SyntaxError:
-        # It's possible that this error is from some bug in the assertion
-        # rewriting, but I don't know of a fast way to tell.
-        mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,))
-        return
-    mod._pyc = _write_pyc(co, mod.fspath)
-    mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,))
+def pytest_sessionstart(session):
+    hook = session.config._assertstate.hook
+    if hook is not None:
+        hook.set_session(session)
-def after_module_import(mod):
-    if not hasattr(mod, "_pyc"):
-        return
-    state = mod.config._assertstate
-    try:
-        mod._pyc.remove()
-    except py.error.ENOENT:
-        state.trace("couldn't find pyc: %r" % (mod._pyc,))
-    else:
-        state.trace("removed pyc: %r" % (mod._pyc,))
+def pytest_sessionfinish(session):
+    if session.config._assertstate.mode == "on":
+        rewrite._drain_pycs(session.config._assertstate)
+    hook = session.config._assertstate.hook
+    if hook is not None:
+        hook.session = None
-def warn_about_missing_assertion():
+def _load_modules(mode):
+    """Lazily import assertion related code."""
+    global rewrite, reinterpret
+    from _pytest.assertion import reinterpret
+    if mode == "on":
+        from _pytest.assertion import rewrite
+def warn_about_missing_assertion(mode):
         assert False
     except AssertionError:
-        sys.stderr.write("WARNING: failing tests may report as passing because "
-        "assertions are turned off!  (are you using python -O?)\n")
+        if mode == "on":
+            specifically = ("assertions which are not in test modules "
+                            "will be ignored")
+        else:
+            specifically = "failing tests may report as passing"
+        sys.stderr.write("WARNING: " + specifically +
+                        " because assertions are turned off "
+                        "(are you using python -O?)\n")
 pytest_assertrepr_compare = util.assertrepr_compare

--- a/_pytest/assertion/rewrite.py	Tue Jun 28 21:11:56 2011 -0500
+++ b/_pytest/assertion/rewrite.py	Tue Jun 28 21:13:12 2011 -0500
@@ -3,12 +3,173 @@
 import ast
 import collections
 import itertools
+import imp
+import marshal
+import os
+import struct
 import sys
 import py
 from _pytest.assertion import util
+# py.test caches rewritten pycs in __pycache__.
+if hasattr(imp, "get_tag"):
+    PYTEST_TAG = imp.get_tag() + "-PYTEST"
+    ver = sys.version_info
+    PYTEST_TAG = "cpython-" + str(ver[0]) + str(ver[1]) + "-PYTEST"
+    del ver
+class AssertionRewritingHook(object):
+    """Import hook which rewrites asserts.
+    Note this hook doesn't load modules itself. It uses find_module to write a
+    fake pyc, so the normal import system will find it.
+    """
+    def __init__(self):
+        self.session = None
+    def set_session(self, session):
+        self.fnpats = session.config.getini("python_files")
+        self.session = session
+    def find_module(self, name, path=None):
+        if self.session is None:
+            return None
+        sess = self.session
+        state = sess.config._assertstate
+        names = name.rsplit(".", 1)
+        lastname = names[-1]
+        pth = None
+        if path is not None and len(path) == 1:
+            pth = path[0]
+        if pth is None:
+            try:
+                fd, fn, desc = imp.find_module(lastname, path)
+            except ImportError:
+                return None
+            if fd is not None:
+                fd.close()
+            tp = desc[2]
+            if tp == imp.PY_COMPILED:
+                if hasattr(imp, "source_from_cache"):
+                    fn = imp.source_from_cache(fn)
+                else:
+                    fn = fn[:-1]
+            elif tp != imp.PY_SOURCE:
+                # Don't know what this is.
+                return None
+        else:
+            fn = os.path.join(pth, name + ".py")
+        fn_pypath = py.path.local(fn)
+        # Is this a test file?
+        if not sess.isinitpath(fn):
+            # We have to be very careful here because imports in this code can
+            # trigger a cycle.
+            self.session = None
+            try:
+                for pat in self.fnpats:
+                    if fn_pypath.fnmatch(pat):
+                        break
+                else:
+                    return None
+            finally:
+                self.session = sess
+        # This looks like a test file, so rewrite it. This is the most magical
+        # part of the process: load the source, rewrite the asserts, and write a
+        # fake pyc, so that it'll be loaded when the module is imported. This is
+        # complicated by the fact we cache rewritten pycs.
+        pyc = _compute_pyc_location(fn_pypath)
+        state.pycs.append(pyc)
+        cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
+        cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn)
+        if _use_cached_pyc(fn_pypath, cache):
+            state.trace("found cached rewritten pyc for %r" % (fn,))
+            cache.copy(pyc)
+        else:
+            state.trace("rewriting %r" % (fn,))
+            _make_rewritten_pyc(state, fn_pypath, pyc)
+            # Try cache it in the __pycache__ directory.
+            _cache_pyc(state, pyc, cache)
+        return None
+def _drain_pycs(state):
+    for pyc in state.pycs:
+        try:
+            pyc.remove()
+        except py.error.ENOENT:
+            state.trace("couldn't find pyc: %r" % (pyc,))
+        else:
+            state.trace("removed pyc: %r" % (pyc,))
+def _write_pyc(co, source_path, pyc):
+    mtime = int(source_path.mtime())
+    fp = pyc.open("wb")
+    try:
+        fp.write(imp.get_magic())
+        fp.write(struct.pack("<l", mtime))
+        marshal.dump(co, fp)
+    finally:
+        fp.close()
+def _make_rewritten_pyc(state, fn, pyc):
+    try:
+        source = fn.read("rb")
+    except EnvironmentError:
+        return None
+    try:
+        tree = ast.parse(source)
+    except SyntaxError:
+        # Let this pop up again in the real import.
+        state.trace("failed to parse: %r" % (fn,))
+        return None
+    rewrite_asserts(tree)
+    try:
+        co = compile(tree, fn.strpath, "exec")
+    except SyntaxError:
+        # It's possible that this error is from some bug in the
+        # assertion rewriting, but I don't know of a fast way to tell.
+        state.trace("failed to compile: %r" % (fn,))
+        return None
+    _write_pyc(co, fn, pyc)
+def _compute_pyc_location(source_path):
+    if hasattr(imp, "cache_from_source"):
+        # Handle PEP 3147 pycs.
+        pyc = py.path.local(imp.cache_from_source(str(source_path)))
+        pyc.ensure()
+    else:
+        pyc = source_path + "c"
+    return pyc
+def _use_cached_pyc(source, cache):
+    try:
+        mtime = source.mtime()
+        fp = cache.open("rb")
+        try:
+            data = fp.read(8)
+        finally:
+            fp.close()
+    except EnvironmentError:
+        return False
+    if (len(data) != 8 or
+        data[:4] != imp.get_magic() or
+        struct.unpack("<l", data[4:])[0] != mtime):
+        # Invalid or out of date.
+        return False
+    # The cached pyc exists and is up to date.
+    return True
+def _cache_pyc(state, pyc, cache):
+    try:
+        cache.dirpath().ensure(dir=True)
+        pyc.copy(cache)
+    except EnvironmentError:
+        state.trace("failed to cache %r as %r" % (pyc, cache))
 def rewrite_asserts(mod):
     """Rewrite the assert statements in mod."""

--- a/_pytest/python.py	Tue Jun 28 21:11:56 2011 -0500
+++ b/_pytest/python.py	Tue Jun 28 21:13:12 2011 -0500
@@ -226,13 +226,8 @@
     def _importtestmodule(self):
         # we assume we are only called once per module
-        from _pytest import assertion
-        assertion.before_module_import(self)
-            try:
-                mod = self.fspath.pyimport(ensuresyspath=True)
-            finally:
-                assertion.after_module_import(self)
+            mod = self.fspath.pyimport(ensuresyspath=True)
         except SyntaxError:
             excinfo = py.code.ExceptionInfo()
             raise self.CollectError(excinfo.getrepr(style="short"))

--- a/testing/test_assertion.py	Tue Jun 28 21:11:56 2011 -0500
+++ b/testing/test_assertion.py	Tue Jun 28 21:13:12 2011 -0500
@@ -250,8 +250,9 @@
 def test_load_fake_pyc(testdir):
-    path = testdir.makepyfile("x = 'hello'")
+    rewrite = pytest.importorskip("_pytest.assertion.rewrite")
+    path = testdir.makepyfile(a_random_module="x = 'hello'")
     co = compile("x = 'bye'", str(path), "exec")
-    plugin._write_pyc(co, path)
+    rewrite._write_pyc(co, path, rewrite._compute_pyc_location(path))
     mod = path.pyimport()
     assert mod.x == "bye"

