[pypy-svn] r53091 - in pypy/branch/jit-hotpath/pypy: jit/rainbow jit/rainbow/test rlib rlib/test

arigo at codespeak.net arigo at codespeak.net
Sat Mar 29 14:12:34 CET 2008


Author: arigo
Date: Sat Mar 29 14:12:33 2008
New Revision: 53091

Modified:
   pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/rhotpath.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py
   pypy/branch/jit-hotpath/pypy/rlib/jit.py
   pypy/branch/jit-hotpath/pypy/rlib/test/test_jit.py
Log:
A more flexible way to configure parameters of the JIT
from the running program.  A bit indirect implementation-wise.


Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py	Sat Mar 29 14:12:33 2008
@@ -78,6 +78,8 @@
     def make_enter_function(self):
         HotEnterState = make_state_class(self)
         state = HotEnterState()
+        self.state = state
+        self._set_param_fn_cache = {}
         exceptiondesc = self.exceptiondesc
         interpreter = self.interpreter
         num_green_args = len(self.green_args_spec)
@@ -113,12 +115,12 @@
             for op in block.operations:
                 if op.opname == 'jit_marker':
                     index = block.operations.index(op)
-                    meth = getattr(self, 'rewrite_' + op.args[0].value)
+                    meth = getattr(self, 'rewrite__' + op.args[0].value)
                     if meth(graph, block, index):
                         return True      # graph mutated, start over again
         return False  # done
 
-    def rewrite_can_enter_jit(self, graph, block, index):
+    def rewrite__can_enter_jit(self, graph, block, index):
         #
         # In the original graphs, replace the 'can_enter_jit' operations
         # with a call to the maybe_enter_jit() helper.
@@ -151,7 +153,7 @@
         block.operations[index] = newop
         return True
 
-    def rewrite_jit_merge_point(self, origportalgraph, origblock, origindex):
+    def rewrite__jit_merge_point(self, origportalgraph, origblock, origindex):
         #
         # Mutate the original portal graph from this:
         #
@@ -280,6 +282,45 @@
         checkgraph(origportalgraph)
         return True
 
+    def rewrite__set_param(self, graph, block, index):
+        # Replace a set_param marker with a call to a helper function
+        op = block.operations[index]
+        assert op.opname == 'jit_marker'
+        assert op.args[0].value == 'set_param'
+        param_name = op.args[2].value
+        v_param_value = op.args[3]
+        SETTERFUNC = lltype.FuncType([lltype.Signed], lltype.Void)
+
+        try:
+            setter_fnptr = self._set_param_fn_cache[param_name]
+        except KeyError:
+            meth = getattr(self.state, 'set_param_' + param_name, None)
+            if meth is None:
+                raise Exception("set_param(): no such parameter: %r" %
+                                (param_name,))
+
+            def ll_setter(value):
+                meth(value)
+            ll_setter.__name__ = 'set_' + param_name
+
+            if not self.translate_support_code:
+                setter_fnptr = llhelper(lltype.Ptr(SETTERFUNC), ll_setter)
+            else:
+                args_s = [annmodel.SomeInteger()]
+                s_result = annmodel.s_None
+                setter_fnptr = self.annhelper.delayedfunction(
+                    setter, args_s, s_result)
+            self._set_param_fn_cache[param_name] = setter_fnptr
+
+        vlist = [Constant(setter_fnptr, lltype.Ptr(SETTERFUNC)),
+                 v_param_value]
+        v_result = Variable()
+        v_result.concretetype = lltype.Void
+        newop = SpaceOperation('direct_call', vlist, v_result)
+        block.operations[index] = newop
+
+# ____________________________________________________________
+
 
 def make_state_class(hotrunnerdesc):
     # very minimal, just to make the first test pass
@@ -336,6 +377,7 @@
 
         def __init__(self):
             self.cells = [Counter(0)] * HASH_TABLE_SIZE
+            self.threshold = 10
 
             # Only use the hash of the arguments as the profiling key.
             # Indeed, this is all a heuristic, so if things are designed
@@ -351,7 +393,7 @@
                 # update the profiling counter
                 interp = hotrunnerdesc.interpreter
                 n = cell.counter + 1
-                if n < hotrunnerdesc.jitdrivercls.getcurrentthreshold():
+                if n < self.threshold:
                     if hotrunnerdesc.verbose_level >= 3:
                         interp.debug_trace("jit_not_entered", *args)
                     self.cells[argshash] = Counter(n)
@@ -384,7 +426,7 @@
             # not found at all, do profiling
             interp = hotrunnerdesc.interpreter
             n = next.counter + 1
-            if n < hotrunnerdesc.jitdrivercls.getcurrentthreshold():
+            if n < self.threshold:
                 if hotrunnerdesc.verbose_level >= 3:
                     interp.debug_trace("jit_not_entered", *args)
                 cell.next = Counter(n)
@@ -479,4 +521,7 @@
             return residualargs
         make_residualargs._always_inline_ = True
 
+        def set_param_threshold(self, threshold):
+            self.threshold = threshold
+
     return HotEnterState

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/rhotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/rhotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/rhotpath.py	Sat Mar 29 14:12:33 2008
@@ -186,7 +186,7 @@
         # 'value' should be a Bool, but depending on the backend
         # it could have been ERASED to about anything else
         value = bool(value)
-        threshold = self.hotrunnerdesc.jitdrivercls.getcurrentthreshold()
+        threshold = self.hotrunnerdesc.state.threshold
         if value:
             counter = self.truepath_counter + 1
             assert counter > 0, (
@@ -257,7 +257,7 @@
         # XXX unsafe with a moving GC
         hash = cast_whatever_to_int(lltype.typeOf(value), value)
         counter = self.counters.get(hash, 0) + 1
-        threshold = self.hotrunnerdesc.jitdrivercls.getcurrentthreshold()
+        threshold = self.hotrunnerdesc.state.threshold
         assert counter > 0, (
             "reaching a fallback point for an already-compiled path")
         if counter >= threshold:

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py	Sat Mar 29 14:12:33 2008
@@ -70,14 +70,14 @@
         return self._run(main, main_args)
 
     def _rewrite(self, threshold, small):
-        assert len(self.hintannotator.jitdriverclasses) == 1
-        jitdrivercls = self.hintannotator.jitdriverclasses.keys()[0]    # hack
-        jitdrivercls.getcurrentthreshold = staticmethod(lambda : threshold) #..
+        assert len(self.hintannotator.jitdriverclasses) == 1     # xxx for now
+        jitdrivercls = self.hintannotator.jitdriverclasses.keys()[0]
         self.hotrunnerdesc = HotRunnerDesc(self.hintannotator, self.rtyper,
                                        self.jitcode, self.RGenOp, self.writer,
                                        jitdrivercls,
                                        self.translate_support_code)
         self.hotrunnerdesc.rewrite_all()
+        self.hotrunnerdesc.state.set_param_threshold(threshold)
         if self.simplify_virtualizable_accesses:
             from pypy.jit.rainbow import graphopt
             graphopt.simplify_virtualizable_accesses(self.writer)
@@ -422,6 +422,36 @@
         res = self.run(ll_function, [2, 7], threshold=1, small=True)
         assert res == 72002
 
+    def test_set_threshold(self):
+        class MyJitDriver(JitDriver):
+            greens = []
+            reds = ['i', 'x']
+        def ll_function(x):
+            MyJitDriver.set_param(threshold=x)
+            i = 1024
+            while i > 0:
+                i >>= 1
+                x += i
+                MyJitDriver.jit_merge_point(i=i, x=x)
+                MyJitDriver.can_enter_jit(i=i, x=x)
+            return x
+        res = self.run(ll_function, [2], threshold=9)
+        assert res == 1025
+        self.check_traces([
+            "jit_not_entered 512 514",
+            "jit_compile",
+            "pause at hotsplit in ll_function",
+            "run_machine_code 256 770",
+            "fallback_interp",
+            "fb_leave 128 898",
+            "run_machine_code 128 898",
+            "jit_resume Bool path True in ll_function",
+            "done at jit_merge_point",
+            "resume_machine_code",
+            "fallback_interp",
+            "fb_return 1025",
+            ])
+
     def test_hp_tlr(self):
         from pypy.jit.tl import tlr
 

Modified: pypy/branch/jit-hotpath/pypy/rlib/jit.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/rlib/jit.py	(original)
+++ pypy/branch/jit-hotpath/pypy/rlib/jit.py	Sat Mar 29 14:12:33 2008
@@ -111,9 +111,18 @@
     _type_ = _JitBoundClassMethod
 
     def compute_result_annotation(self, **kwds_s):
-        from pypy.annotation import model as annmodel
         drivercls = self.instance.drivercls
         drivercls._check_class()
+        meth = getattr(self, 'annotate_%s' % self.instance.name)
+        return meth(drivercls, **kwds_s)
+
+    def specialize_call(self, hop, **kwds_i):
+        drivercls = self.instance.drivercls
+        meth = getattr(self, 'specialize_%s' % self.instance.name)
+        return meth(drivercls, hop, **kwds_i)
+
+    def annotate_jit_merge_point(self, drivercls, **kwds_s):
+        from pypy.annotation import model as annmodel
         keys = kwds_s.keys()
         keys.sort()
         expected = ['s_' + name for name in drivercls.greens + drivercls.reds]
@@ -125,14 +134,15 @@
         drivercls._emulate_method_calls(self.bookkeeper, kwds_s)
         return annmodel.s_None
 
-    def specialize_call(self, hop, **kwds_i):
+    annotate_can_enter_jit = annotate_jit_merge_point
+
+    def specialize_jit_merge_point(self, drivercls, hop, **kwds_i):
         # replace a call to MyDriverCls.hintname(**livevars)
         # with an operation 'hintname(MyDriverCls, livevars...)'
         # XXX to be complete, this could also check that the concretetype
         # of the variables are the same for each of the calls.
         from pypy.rpython.error import TyperError
         from pypy.rpython.lltypesystem import lltype
-        drivercls = self.instance.drivercls
         greens_v = []
         reds_v = []
         for name in drivercls.greens:
@@ -153,6 +163,28 @@
         return hop.genop('jit_marker', vlist,
                          resulttype=lltype.Void)
 
+    specialize_can_enter_jit = specialize_jit_merge_point
+
+    def annotate_set_param(self, drivercls, **kwds_s):
+        from pypy.annotation import model as annmodel
+        if len(kwds_s) != 1:
+            raise Exception("DriverCls.set_param(): must specify exactly "
+                            "one keyword argument")
+        return annmodel.s_None
+
+    def specialize_set_param(self, drivercls, hop, **kwds_i):
+        from pypy.rpython.lltypesystem import lltype
+        [(name, i)] = kwds_i.items()
+        assert name.startswith('i_')
+        name = name[2:]
+        v_value = hop.inputarg(lltype.Signed, arg=i)
+        vlist = [hop.inputconst(lltype.Void, "set_param"),
+                 hop.inputconst(lltype.Void, drivercls),
+                 hop.inputconst(lltype.Void, name),
+                 v_value]
+        return hop.genop('jit_marker', vlist,
+                         resulttype=lltype.Void)
+
 # ____________________________________________________________
 # User interface for the hotpath JIT policy
 
@@ -168,10 +200,7 @@
 
     jit_merge_point = _JitHintClassMethod("jit_merge_point")
     can_enter_jit = _JitHintClassMethod("can_enter_jit")
-
-    def getcurrentthreshold():
-        return 10
-    getcurrentthreshold = staticmethod(getcurrentthreshold)
+    set_param = _JitHintClassMethod("set_param")
 
     def compute_invariants(self, *greens):
         """This can compute a value or tuple that is passed as a green

Modified: pypy/branch/jit-hotpath/pypy/rlib/test/test_jit.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/rlib/test/test_jit.py	(original)
+++ pypy/branch/jit-hotpath/pypy/rlib/test/test_jit.py	Sat Mar 29 14:12:33 2008
@@ -1,5 +1,5 @@
 import py
-from pypy.rlib.jit import hint, _is_early_constant
+from pypy.rlib.jit import hint, _is_early_constant, JitDriver
 from pypy.translator.translator import TranslationContext, graphof
 from pypy.rpython.test.tool import BaseRtypingTest, LLRtypeMixin, OORtypeMixin
 
@@ -27,5 +27,12 @@
         res = self.interpret(g, [])
         assert res == 42
 
+    def test_set_param(self):
+        class MyJitDriver(JitDriver):
+            greens = reds = []
+        def f(x):
+            MyJitDriver.set_param(foo=x)
 
-
+        assert f(4) is None
+        res = self.interpret(f, [4])
+        assert res is None



More information about the Pypy-commit mailing list