[pypy-svn] r52479 - in pypy/branch/jit-hotpath/pypy: jit/hintannotator jit/rainbow jit/rainbow/test jit/tl rlib

arigo at codespeak.net arigo at codespeak.net
Fri Mar 14 11:22:29 CET 2008


Author: arigo
Date: Fri Mar 14 11:22:28 2008
New Revision: 52479

Modified:
   pypy/branch/jit-hotpath/pypy/jit/hintannotator/hotpath.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/codewriter.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_interpreter.py
   pypy/branch/jit-hotpath/pypy/jit/tl/tlr.py
   pypy/branch/jit-hotpath/pypy/rlib/jit.py
Log:
Support for on_enter_jit().  Not really nice, full of graph
transformations and at least three slightly different copies
of the portal graph :-/


Modified: pypy/branch/jit-hotpath/pypy/jit/hintannotator/hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/hintannotator/hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/hintannotator/hotpath.py	Fri Mar 14 11:22:28 2008
@@ -1,63 +1,47 @@
 from pypy.objspace.flow.model import checkgraph, copygraph
-from pypy.translator.unsimplify import split_block
+from pypy.objspace.flow.model import Block, Link, SpaceOperation, Variable
+from pypy.translator.unsimplify import split_block, varoftype
 from pypy.translator.simplify import join_blocks
 from pypy.jit.hintannotator.annotator import HintAnnotator
 from pypy.jit.hintannotator.model import SomeLLAbstractConstant, OriginFlags
+from pypy.annotation import model as annmodel
+from pypy.rpython.rtyper import LowLevelOpList
 from pypy.rlib.jit import JitHintError
 
 
 class HotPathHintAnnotator(HintAnnotator):
 
-    def find_jit_merge_point(self, graph):
-        found_at = []
-        for block in graph.iterblocks():
-            for op in block.operations:
-                if op.opname == 'jit_merge_point':
-                    found_at.append((graph, block, op))
-        if len(found_at) > 1:
-            raise JitHintError("multiple jit_merge_point() not supported")
-        if found_at:
-            return found_at[0]
-        else:
-            return None
-
     def build_hotpath_types(self):
         self.prepare_portal_graphs()
+        graph = self.portalgraph_with_on_enter_jit
         input_args_hs = [SomeLLAbstractConstant(v.concretetype,
                                                 {OriginFlags(): True})
-                         for v in self.portalgraph.getargs()]
-        return self.build_types(self.portalgraph, input_args_hs)
+                         for v in graph.getargs()]
+        return self.build_types(graph, input_args_hs)
 
     def prepare_portal_graphs(self):
         # find the graph with the jit_merge_point()
         found_at = []
         for graph in self.base_translator.graphs:
-            place = self.find_jit_merge_point(graph)
+            place = find_jit_merge_point(graph)
             if place is not None:
                 found_at.append(place)
         if len(found_at) != 1:
             raise JitHintError("found %d graphs with a jit_merge_point(),"
                                " expected 1 (for now)" % len(found_at))
-        origportalgraph, _, _ = found_at[0]
+        origportalgraph, _, origportalop = found_at[0]
+        drivercls = origportalop.args[0].value
         #
         # We make a copy of origportalgraph and mutate it to make it
         # the portal.  The portal really starts at the jit_merge_point()
         # without any block or operation before it.
         #
         portalgraph = copygraph(origportalgraph)
-        _, portalblock, portalop = self.find_jit_merge_point(portalgraph)
-        portalopindex = portalblock.operations.index(portalop)
-        # split the block just before the jit_merge_point()
-        link = split_block(None, portalblock, portalopindex)
-        # split again, this time enforcing the order of the live vars
-        # specified by the user in the jit_merge_point() call
-        _, portalblock, portalop = self.find_jit_merge_point(portalgraph)
-        assert portalop is portalblock.operations[0]
-        livevars = portalop.args[1:]
-        link = split_block(None, portalblock, 0, livevars)
+        block = split_before_jit_merge_point(None, portalgraph)
+        assert block is not None
         # rewire the graph to start at the global_merge_point
         portalgraph.startblock.isstartblock = False
-        portalgraph.startblock = link.target
+        portalgraph.startblock = block
         portalgraph.startblock.isstartblock = True
         self.portalgraph = portalgraph
         self.origportalgraph = origportalgraph
@@ -65,7 +49,93 @@
         # been listed in the jit_merge_point()
         # (XXX should give an explicit JitHintError explaining the problem)
         checkgraph(portalgraph)
-        join_blocks(portalgraph)
+        # insert the on_enter_jit() logic before the jit_merge_point()
+        # in a copy of the graph which will be the one that gets hint-annotated
+        # and turned into rainbow bytecode.  On the other hand, the
+        # 'self.portalgraph' is the copy that will run directly, in
+        # non-JITting mode, so it should not contain the on_enter_jit() call.
+        if hasattr(drivercls, 'on_enter_jit'):
+            anothercopy = copygraph(portalgraph)
+            anothercopy.tag = 'portal'
+            insert_on_enter_jit_handling(self.base_translator.rtyper,
+                                         anothercopy,
+                                         drivercls)
+            self.portalgraph_with_on_enter_jit = anothercopy
+        else:
+            self.portalgraph_with_on_enter_jit = portalgraph  # same is ok
         # put the new graph back in the base_translator
         portalgraph.tag = 'portal'
         self.base_translator.graphs.append(portalgraph)
+
+# ____________________________________________________________
+
+def find_jit_merge_point(graph):
+    found_at = []
+    for block in graph.iterblocks():
+        for op in block.operations:
+            if op.opname == 'jit_merge_point':
+                found_at.append((graph, block, op))
+    if len(found_at) > 1:
+        raise JitHintError("multiple jit_merge_point() not supported")
+    if found_at:
+        return found_at[0]
+    else:
+        return None
+
+def split_before_jit_merge_point(hannotator, graph):
+    """Find the block with 'jit_merge_point' and split just before,
+    making sure the input args are in the canonical order.  If
+    hannotator is not None, preserve the hint-annotations while doing so
+    (used by codewriter.py).
+    """
+    found_at = find_jit_merge_point(graph)
+    if found_at is not None:
+        _, portalblock, portalop = found_at
+        portalopindex = portalblock.operations.index(portalop)
+        # split the block just before the jit_merge_point()
+        if portalopindex > 0:
+            split_block(hannotator, portalblock, portalopindex)
+        # split again, this time enforcing the order of the live vars
+        # specified by the user in the jit_merge_point() call
+        _, portalblock, portalop = find_jit_merge_point(graph)
+        assert portalop is portalblock.operations[0]
+        livevars = portalop.args[1:]
+        link = split_block(hannotator, portalblock, 0, livevars)
+        return link.target
+    else:
+        return None
+
+def insert_on_enter_jit_handling(rtyper, graph, drivercls):
+    vars = [varoftype(v.concretetype, name=v) for v in graph.getargs()]
+    newblock = Block(vars)
+
+    llops = LowLevelOpList(rtyper)
+    # generate ops to make an instance of DriverCls
+    classdef = rtyper.annotator.bookkeeper.getuniqueclassdef(drivercls)
+    s_instance = annmodel.SomeInstance(classdef)
+    r_instance = rtyper.getrepr(s_instance)
+    v_self = r_instance.new_instance(llops)
+    # generate ops to store the 'reds' variables on 'self'
+    num_greens = len(drivercls.greens)
+    num_reds = len(drivercls.reds)
+    assert len(vars) == num_greens + num_reds
+    for name, v_value in zip(drivercls.reds, vars[num_greens:]):
+        r_instance.setfield(v_self, name, v_value, llops)
+    # generate a call to on_enter_jit(self)
+    on_enter_jit_func = drivercls.on_enter_jit.im_func
+    s_func = rtyper.annotator.bookkeeper.immutablevalue(on_enter_jit_func)
+    r_func = rtyper.getrepr(s_func)
+    c_func = r_func.get_unique_llfn()
+    llops.genop('direct_call', [c_func, v_self])
+    # generate ops to reload the 'reds' variables from 'self'
+    newvars = vars[:num_greens]
+    for name, v_value in zip(drivercls.reds, vars[num_greens:]):
+        v_value = r_instance.getfield(v_self, name, llops)
+        newvars.append(v_value)
+    # done, fill the block and link it to make it the startblock
+    newblock.operations[:] = llops
+    newblock.closeblock(Link(newvars, graph.startblock))
+    graph.startblock.isstartblock = False
+    graph.startblock = newblock
+    graph.startblock.isstartblock = True
+    checkgraph(graph)

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/codewriter.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/codewriter.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/codewriter.py	Fri Mar 14 11:22:28 2008
@@ -1437,8 +1437,10 @@
         return greens_v, reds_v
 
     def serialize_op_jit_merge_point(self, op):
-        # by construction, the graph should have exactly the vars listed
-        # in the op as live vars.  Check this.
+        # If jit_merge_point is the first operation of its block, and if
+        # the block's input variables are in the right order, the graph
+        # should have exactly the vars listed in the op as live vars.
+        # Check this.  It should be ensured by the GraphTransformer.
         greens_v, reds_v = self.check_hp_hint_args(op)
         key = ()
         for i, v in enumerate(greens_v):
@@ -1466,6 +1468,11 @@
         # we want native red switch support in the hotpath policy
         if not self.hannotator.policy.hotpath:
             self.insert_splits()
+        # we need jit_merge_point to be at the start of its block
+        if self.hannotator.policy.hotpath:
+            from pypy.jit.hintannotator.hotpath import \
+                                             split_before_jit_merge_point
+            split_before_jit_merge_point(self.hannotator, graph)
 
     def insert_splits(self):
         hannotator = self.hannotator

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py	Fri Mar 14 11:22:28 2008
@@ -167,7 +167,7 @@
         #
         portalgraph = self.hintannotator.portalgraph
         # ^^^ as computed by HotPathHintAnnotator.prepare_portal_graphs()
-        if origportalgraph is portalgraph:
+        if origportalgraph is not self.hintannotator.origportalgraph:
             return False      # only mutate the original portal graph,
                               # not its copy
 

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py	Fri Mar 14 11:22:28 2008
@@ -278,7 +278,6 @@
         py.test.raises(JitHintError, self.run, ll_function, [5], 3)
 
     def test_on_enter_jit(self):
-        py.test.skip("in-progress")
         class MyJitDriver(JitDriver):
             greens = []
             reds = ['n']
@@ -299,6 +298,7 @@
             ])
 
     def test_hp_tlr(self):
+        py.test.skip("in-progress")
         from pypy.jit.tl import tlr
 
         def main(code, n):
@@ -332,7 +332,6 @@
             "fallback_interp",
             "fb_leave * stru...} 27 0 * array [ 0, 71, 5041 ]",
             ])
-        py.test.skip("XXX currently the 'regs' list is not virtual.")
         # We expect only the direct_call from the red split fallback point.
         # If we get e.g. 7 of them instead it probably means that we see
         # direct_calls to the ll helpers for the 'regs' list.

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_interpreter.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_interpreter.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_interpreter.py	Fri Mar 14 11:22:28 2008
@@ -13,6 +13,7 @@
 from pypy.rpython.llinterp import LLInterpreter, LLException
 from pypy.rpython.module.support import LLSupport
 from pypy.annotation import model as annmodel
+from pypy.annotation.policy import AnnotatorPolicy
 from pypy.objspace.flow.model import summary, Variable
 from pypy.rlib.debug import ll_assert
 from pypy.rlib.jit import hint
@@ -48,7 +49,9 @@
               portal=None, type_system="lltype"):
     # build the normal ll graphs for ll_function
     t = TranslationContext()
-    a = t.buildannotator()
+    annpolicy = AnnotatorPolicy()
+    annpolicy.allow_someobjects = False
+    a = t.buildannotator(policy=annpolicy)
     argtypes = getargtypes(a, values)
     a.build_types(func, argtypes)
     rtyper = t.buildrtyper(type_system = type_system)

Modified: pypy/branch/jit-hotpath/pypy/jit/tl/tlr.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/tl/tlr.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/tl/tlr.py	Fri Mar 14 11:22:28 2008
@@ -54,12 +54,11 @@
     reds   = ['a', 'regs']
 
     def on_enter_jit(self):
-        xxx - "not called yet"
         # make a copy of the 'regs' list to make it a VirtualList for the JIT
         length = hint(len(self.regs), promote=True)
-        newregs = [None] * length
-        for i in range(length):
-            newregs[i] = self.regs[i]
+        newregs = []
+        for x in self.regs:
+            newregs.append(x)
         self.regs = newregs
 
 def hp_interpret(bytecode, a):

Modified: pypy/branch/jit-hotpath/pypy/rlib/jit.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/rlib/jit.py	(original)
+++ pypy/branch/jit-hotpath/pypy/rlib/jit.py	Fri Mar 14 11:22:28 2008
@@ -121,6 +121,7 @@
             raise JitHintError("%s.%s(): must give exactly the same keywords"
                                " as the 'greens' and 'reds'" % (
                 drivercls.__name__, self.instance.name))
+        drivercls._emulate_method_calls(self.bookkeeper, kwds_s)
         return annmodel.s_None
 
     def specialize_call(self, hop, **kwds_i):
@@ -174,3 +175,19 @@
                 raise JitHintError("%s: the 'greens' and 'reds' names should"
                                    " not start with an underscore" % (cls,))
     _check_class = classmethod(_check_class)
+
+    def _emulate_method_calls(cls, bookkeeper, livevars_s):
+        # annotate "cls.on_enter_jit()" if it is defined
+        from pypy.annotation import model as annmodel
+        if hasattr(cls, 'on_enter_jit'):
+            classdef = bookkeeper.getuniqueclassdef(cls)
+            s_arg = annmodel.SomeInstance(classdef)
+            for name, s_value in livevars_s.items():
+                assert name.startswith('s_')
+                name = name[2:]
+                s_arg.setattr(bookkeeper.immutablevalue(name), s_value)
+            key = "rlib.jit.JitDriver.on_enter_jit"
+            s_func = bookkeeper.immutablevalue(cls.on_enter_jit.im_func)
+            s_result = bookkeeper.emulate_pbc_call(key, s_func, [s_arg])
+            assert annmodel.s_None.contains(s_result)
+    _emulate_method_calls = classmethod(_emulate_method_calls)



More information about the Pypy-commit mailing list