[pypy-commit] pypy default: (antocuni) merge autoreds, the automatic detection of red variables

fijal noreply at buildbot.pypy.org
Wed Nov 14 20:23:16 CET 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: 
Changeset: r58913:cf489f0cbc09
Date: 2012-11-14 14:11 +0100
http://bitbucket.org/pypy/pypy/changeset/cf489f0cbc09/

Log:	(antocuni) merge autoreds, the automatic detection of red variables

diff --git a/pypy/jit/codewriter/jtransform.py b/pypy/jit/codewriter/jtransform.py
--- a/pypy/jit/codewriter/jtransform.py
+++ b/pypy/jit/codewriter/jtransform.py
@@ -1317,7 +1317,7 @@
     def promote_greens(self, args, jitdriver):
         ops = []
         num_green_args = len(jitdriver.greens)
-        assert len(args) == num_green_args + len(jitdriver.reds)
+        assert len(args) == num_green_args + jitdriver.numreds
         for v in args[:num_green_args]:
             if isinstance(v, Variable) and v.concretetype is not lltype.Void:
                 kind = getkind(v.concretetype)
diff --git a/pypy/jit/codewriter/support.py b/pypy/jit/codewriter/support.py
--- a/pypy/jit/codewriter/support.py
+++ b/pypy/jit/codewriter/support.py
@@ -78,28 +78,32 @@
     link = split_block(None, portalblock, 0, greens_v + reds_v)
     return link.target
 
+def sort_vars(args_v):
+    from pypy.jit.metainterp.history import getkind
+    _kind2count = {'int': 1, 'ref': 2, 'float': 3}
+    return sorted(args_v, key=lambda v: _kind2count[getkind(v.concretetype)])
+
 def decode_hp_hint_args(op):
     # Returns (list-of-green-vars, list-of-red-vars) without Voids.
     # Both lists must be sorted: first INT, then REF, then FLOAT.
     assert op.opname == 'jit_marker'
     jitdriver = op.args[1].value
     numgreens = len(jitdriver.greens)
-    numreds = len(jitdriver.reds)
+    assert jitdriver.numreds is not None
+    numreds = jitdriver.numreds
     greens_v = op.args[2:2+numgreens]
     reds_v = op.args[2+numgreens:]
     assert len(reds_v) == numreds
     #
     def _sort(args_v, is_green):
-        from pypy.jit.metainterp.history import getkind
         lst = [v for v in args_v if v.concretetype is not lltype.Void]
         if is_green:
             assert len(lst) == len(args_v), (
                 "not supported so far: 'greens' variables contain Void")
-        _kind2count = {'int': 1, 'ref': 2, 'float': 3}
-        lst2 = sorted(lst, key=lambda v: _kind2count[getkind(v.concretetype)])
         # a crash here means that you have to reorder the variable named in
         # the JitDriver.  Indeed, greens and reds must both be sorted: first
         # all INTs, followed by all REFs, followed by all FLOATs.
+        lst2 = sort_vars(args_v)
         assert lst == lst2
         return lst
     #
diff --git a/pypy/jit/metainterp/test/test_ajit.py b/pypy/jit/metainterp/test/test_ajit.py
--- a/pypy/jit/metainterp/test/test_ajit.py
+++ b/pypy/jit/metainterp/test/test_ajit.py
@@ -3051,8 +3051,7 @@
                 i += 1
         res = self.meta_interp(f, [32])
         assert res == f(32)
-
-
+        
 class XXXDisabledTestOOtype(BasicTests, OOJitMixin):
 
     def test_oohash(self):
diff --git a/pypy/jit/metainterp/test/test_warmspot.py b/pypy/jit/metainterp/test/test_warmspot.py
--- a/pypy/jit/metainterp/test/test_warmspot.py
+++ b/pypy/jit/metainterp/test/test_warmspot.py
@@ -250,11 +250,11 @@
                            'int_sub': 2})
 
     def test_void_red_variable(self):
-        mydriver = JitDriver(greens=[], reds=['a', 'm'])
+        mydriver = JitDriver(greens=[], reds=['m'])
         def f1(m):
             a = None
             while m > 0:
-                mydriver.jit_merge_point(a=a, m=m)
+                mydriver.jit_merge_point(m=m)
                 m = m - 1
                 if m == 10:
                     pass   # other case
@@ -312,6 +312,117 @@
         self.meta_interp(f1, [18])
 
 
+    def test_loop_automatic_reds(self):
+        myjitdriver = JitDriver(greens = ['m'], reds = 'auto')
+        def f(n, m):
+            res = 0
+            # try to have lots of red vars, so that if there is an error in
+            # the ordering of reds, there are low chances that the test passes
+            # by chance
+            a = b = c = d = n
+            while n > 0:
+                myjitdriver.jit_merge_point(m=m)
+                n -= 1
+                a += 1 # dummy unused red
+                b += 2 # dummy unused red
+                c += 3 # dummy unused red
+                d += 4 # dummy unused red
+                res += m*2
+            return res
+        expected = f(21, 5)
+        res = self.meta_interp(f, [21, 5])
+        assert res == expected
+        self.check_resops(int_sub=2, int_mul=0, int_add=10)
+
+    def test_loop_automatic_reds_with_floats_and_refs(self):
+        myjitdriver = JitDriver(greens = ['m'], reds = 'auto')
+        class MyObj(object):
+            def __init__(self, val):
+                self.val = val
+        def f(n, m):
+            res = 0
+            # try to have lots of red vars, so that if there is an error in
+            # the ordering of reds, there are low chances that the test passes
+            # by chance
+            i1 = i2 = i3 = i4 = n
+            f1 = f2 = f3 = f4 = float(n)
+            r1 = r2 = r3 = r4 = MyObj(n)
+            while n > 0:
+                myjitdriver.jit_merge_point(m=m)
+                n -= 1
+                i1 += 1 # dummy unused red
+                i2 += 2 # dummy unused red
+                i3 += 3 # dummy unused red
+                i4 += 4 # dummy unused red
+                f1 += 1 # dummy unused red
+                f2 += 2 # dummy unused red
+                f3 += 3 # dummy unused red
+                f4 += 4 # dummy unused red
+                r1.val += 1 # dummy unused red
+                r2.val += 2 # dummy unused red
+                r3.val += 3 # dummy unused red
+                r4.val += 4 # dummy unused red
+                res += m*2
+            return res
+        expected = f(21, 5)
+        res = self.meta_interp(f, [21, 5])
+        assert res == expected
+        self.check_resops(int_sub=2, int_mul=0, int_add=18, float_add=8)
+
+    def test_loop_automatic_reds_livevars_before_jit_merge_point(self):
+        myjitdriver = JitDriver(greens = ['m'], reds = 'auto')
+        def f(n, m):
+            res = 0
+            while n > 0:
+                n -= 1
+                myjitdriver.jit_merge_point(m=m)
+                res += m*2
+            return res
+        expected = f(21, 5)
+        res = self.meta_interp(f, [21, 5])
+        assert res == expected
+        self.check_resops(int_sub=2, int_mul=0, int_add=2)
+
+    def test_inline_in_portal(self):
+        myjitdriver = JitDriver(greens = [], reds = 'auto')
+        class MyRange(object):
+            def __init__(self, n):
+                self.cur = 0
+                self.n = n
+
+            def __iter__(self):
+                return self
+
+            @myjitdriver.inline_in_portal
+            def next(self):
+                myjitdriver.jit_merge_point()
+                if self.cur == self.n:
+                    raise StopIteration
+                self.cur += 1
+                return self.cur
+
+        def one():
+            res = 0
+            for i in MyRange(10):
+                res += i
+            return res
+
+        def two():
+            res = 0
+            for i in MyRange(13):
+                res += i * 2
+            return res
+
+        def f(n, m):
+            res = one() * 100
+            res += two()
+            return res
+        expected = f(21, 5)
+        res = self.meta_interp(f, [21, 5])
+        assert res == expected
+        self.check_resops(int_eq=4, int_add=8)
+        self.check_trace_count(2)
+
 class TestLLWarmspot(WarmspotTests, LLJitMixin):
     CPUClass = runner.LLtypeCPU
     type_system = 'lltype'
diff --git a/pypy/jit/metainterp/warmspot.py b/pypy/jit/metainterp/warmspot.py
--- a/pypy/jit/metainterp/warmspot.py
+++ b/pypy/jit/metainterp/warmspot.py
@@ -186,6 +186,7 @@
         self.set_translator(translator)
         self.memory_manager = memmgr.MemoryManager()
         self.build_cpu(CPUClass, **kwds)
+        self.inline_inlineable_portals()
         self.find_portals()
         self.codewriter = codewriter.CodeWriter(self.cpu, self.jitdrivers_sd)
         if policy is None:
@@ -241,23 +242,106 @@
         self.rtyper = translator.rtyper
         self.gcdescr = gc.get_description(translator.config)
 
+    def inline_inlineable_portals(self):
+        """
+        Find all the graphs which have been decorated with
+        @jitdriver.inline_in_portal and inline them in the callers, making
+        them JIT portals. Then, create a fresh copy of the jitdriver for each
+        of those new portals, because they cannot share the same one.  See
+        test_ajit::test_inline_in_portal.
+        """
+        from pypy.translator.backendopt import inline
+        lltype_to_classdef = self.translator.rtyper.lltype_to_classdef_mapping()
+        raise_analyzer = inline.RaiseAnalyzer(self.translator)
+        callgraph = inline.inlinable_static_callers(self.translator.graphs)
+        new_portals = set()
+        for caller, callee in callgraph:
+            func = getattr(callee, 'func', None)
+            _inline_in_portal_ = getattr(func, '_inline_in_portal_', False)
+            if _inline_in_portal_:
+                count = inline.inline_function(self.translator, callee, caller,
+                                               lltype_to_classdef, raise_analyzer)
+                assert count > 0, ('The function has been decorated with '
+                                   '@inline_in_portal, but it is not possible '
+                                   'to inline it')
+                new_portals.add(caller)
+        self.clone_inlined_jit_merge_points(new_portals)
+
+    def clone_inlined_jit_merge_points(self, graphs):
+        """
+        Find all the jit_merge_points in the given graphs, and replace the
+        original JitDriver with a fresh clone.
+        """
+        if not graphs:
+            return
+        for graph, block, pos in find_jit_merge_points(graphs):
+            op = block.operations[pos]
+            v_driver = op.args[1]
+            new_driver = v_driver.value.clone()
+            c_new_driver = Constant(new_driver, v_driver.concretetype)
+            op.args[1] = c_new_driver
+
+
     def find_portals(self):
         self.jitdrivers_sd = []
         graphs = self.translator.graphs
-        for jit_merge_point_pos in find_jit_merge_points(graphs):
-            self.split_graph_and_record_jitdriver(*jit_merge_point_pos)
+        for graph, block, pos in find_jit_merge_points(graphs):
+            self.autodetect_jit_markers_redvars(graph)
+            self.split_graph_and_record_jitdriver(graph, block, pos)
         #
         assert (len(set([jd.jitdriver for jd in self.jitdrivers_sd])) ==
                 len(self.jitdrivers_sd)), \
                 "there are multiple jit_merge_points with the same jitdriver"
 
+    def autodetect_jit_markers_redvars(self, graph):
+        # the idea is to find all the jit_merge_point and can_enter_jit and
+        # add all the variables across the links to the reds.
+        for block, op in graph.iterblockops():
+            if op.opname == 'jit_marker':
+                jitdriver = op.args[1].value
+                if not jitdriver.autoreds:
+                    continue
+                # if we want to support also can_enter_jit, we should find a
+                # way to detect a consistent set of red vars to pass *both* to
+                # jit_merge_point and can_enter_jit. The current simple
+                # solution doesn't work because can_enter_jit might be in
+                # another block, so the set of alive_v will be different.
+                methname = op.args[0].value
+                assert methname == 'jit_merge_point', (
+                    "reds='auto' is supported only for jit drivers which " 
+                    "calls only jit_merge_point. Found a call to %s" % methname)
+                #
+                assert jitdriver.confirm_enter_jit is None
+                # compute the set of live variables before the jit_marker
+                alive_v = set(block.inputargs)
+                for op1 in block.operations:
+                    if op1 is op:
+                        break # stop when the meet the jit_marker
+                    if op1.result.concretetype != lltype.Void:
+                        alive_v.add(op1.result)
+                greens_v = op.args[2:]
+                reds_v = alive_v - set(greens_v)
+                reds_v = support.sort_vars(reds_v)
+                op.args.extend(reds_v)
+                if jitdriver.numreds is None:
+                    jitdriver.numreds = len(reds_v)
+                else:
+                    assert jitdriver.numreds == len(reds_v), 'inconsistent number of reds_v'
+
+
     def split_graph_and_record_jitdriver(self, graph, block, pos):
         op = block.operations[pos]
         jd = JitDriverStaticData()
         jd._jit_merge_point_in = graph
         args = op.args[2:]
         s_binding = self.translator.annotator.binding
-        jd._portal_args_s = [s_binding(v) for v in args]
+        if op.args[1].value.autoreds:
+            # _portal_args_s is used only by _make_hook_graph, but for now we
+            # declare the various set_jitcell_at, get_printable_location,
+            # etc. as incompatible with autoreds
+            jd._portal_args_s = None
+        else:
+            jd._portal_args_s = [s_binding(v) for v in args]
         graph = copygraph(graph)
         [jmpp] = find_jit_merge_points([graph])
         graph.startblock = support.split_before_jit_merge_point(*jmpp)
@@ -509,6 +593,7 @@
         if func is None:
             return None
         #
+        assert not jitdriver_sd.jitdriver.autoreds
         extra_args_s = []
         if s_first_arg is not None:
             extra_args_s.append(s_first_arg)
diff --git a/pypy/rlib/jit.py b/pypy/rlib/jit.py
--- a/pypy/rlib/jit.py
+++ b/pypy/rlib/jit.py
@@ -443,6 +443,7 @@
     active = True          # if set to False, this JitDriver is ignored
     virtualizables = []
     name = 'jitdriver'
+    inlined_in_portal = False
 
     def __init__(self, greens=None, reds=None, virtualizables=None,
                  get_jitcell_at=None, set_jitcell_at=None,
@@ -452,16 +453,27 @@
         if greens is not None:
             self.greens = greens
         self.name = name
-        if reds is not None:
+        if reds == 'auto':
+            self.autoreds = True
+            self.reds = []
+            self.numreds = None # see warmspot.autodetect_jit_markers_redvars
+            assert confirm_enter_jit is None, 'cannot use automatic reds if confirm_enter_jit is given'
+        elif reds is not None:
+            self.autoreds = False
             self.reds = reds
+            self.numreds = len(reds)
         if not hasattr(self, 'greens') or not hasattr(self, 'reds'):
             raise AttributeError("no 'greens' or 'reds' supplied")
         if virtualizables is not None:
             self.virtualizables = virtualizables
         for v in self.virtualizables:
             assert v in self.reds
-        self._alllivevars = dict.fromkeys(
-            [name for name in self.greens + self.reds if '.' not in name])
+        # if reds are automatic, they won't be passed to jit_merge_point, so
+        # _check_arguments will receive only the green ones (i.e., the ones
+        # which are listed explicitly). So, it is fine to just ignore reds
+        self._somelivevars = set([name for name in
+                                  self.greens + (self.reds or [])
+                                  if '.' not in name])
         self._heuristic_order = {}   # check if 'reds' and 'greens' are ordered
         self._make_extregistryentries()
         self.get_jitcell_at = get_jitcell_at
@@ -475,7 +487,7 @@
         return True
 
     def _check_arguments(self, livevars):
-        assert dict.fromkeys(livevars) == self._alllivevars
+        assert set(livevars) == self._somelivevars
         # check heuristically that 'reds' and 'greens' are ordered as
         # the JIT will need them to be: first INTs, then REFs, then
         # FLOATs.
@@ -527,6 +539,8 @@
         _self._check_arguments(livevars)
 
     def can_enter_jit(_self, **livevars):
+        if _self.autoreds:
+            raise TypeError, "Cannot call can_enter_jit on a driver with reds='auto'"
         # special-cased by ExtRegistryEntry
         _self._check_arguments(livevars)
 
@@ -534,6 +548,18 @@
         # special-cased by ExtRegistryEntry
         pass
 
+    def inline_in_portal(self, func):
+        assert self.autoreds, "inline_in_portal works only with reds='auto'"
+        func._inline_in_portal_ = True
+        self.inlined_in_portal = True
+        return func
+
+    def clone(self):
+        assert self.inlined_in_portal, 'JitDriver.clone works only after @inline_in_portal'
+        newdriver = object.__new__(self.__class__)
+        newdriver.__dict__ = self.__dict__.copy()
+        return newdriver
+
     def _make_extregistryentries(self):
         # workaround: we cannot declare ExtRegistryEntries for functions
         # used as methods of a frozen object, but we can attach the
diff --git a/pypy/rlib/test/test_jit.py b/pypy/rlib/test/test_jit.py
--- a/pypy/rlib/test/test_jit.py
+++ b/pypy/rlib/test/test_jit.py
@@ -15,6 +15,33 @@
         pass
     assert fn.oopspec == 'foobar'
 
+def test_jitdriver_autoreds():
+    driver = JitDriver(greens=['foo'], reds='auto')
+    assert driver.autoreds
+    assert driver.reds == []
+    assert driver.numreds is None
+    py.test.raises(TypeError, "driver.can_enter_jit(foo='something')")
+    #
+    py.test.raises(AssertionError, "JitDriver(greens=['foo'], reds='auto', confirm_enter_jit='something')")
+
+def test_jitdriver_clone():
+    def foo():
+        pass
+    driver = JitDriver(greens=[], reds=[])
+    py.test.raises(AssertionError, "driver.inline_in_portal(foo)")
+    #
+    driver = JitDriver(greens=[], reds='auto')
+    py.test.raises(AssertionError, "driver.clone()")
+    foo = driver.inline_in_portal(foo)
+    assert foo._inline_in_portal_ == True
+    #
+    driver.foo = 'bar'
+    driver2 = driver.clone()
+    assert driver is not driver2
+    assert driver2.foo == 'bar'
+    driver.foo = 'xxx'
+    assert driver2.foo == 'bar'
+    
 
 class BaseTestJIT(BaseRtypingTest):
     def test_hint(self):


More information about the pypy-commit mailing list