[pypy-svn] r52956 - in pypy/branch/jit-hotpath/pypy: jit/rainbow jit/rainbow/test module/pypyjit

arigo at codespeak.net arigo at codespeak.net
Wed Mar 26 12:14:11 CET 2008


Author: arigo
Date: Wed Mar 26 12:14:09 2008
New Revision: 52956

Modified:
   pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py
   pypy/branch/jit-hotpath/pypy/module/pypyjit/interp_jit.py
Log:
Experimental: remove the allocation of the greenkey needed to enter
machine code.

This uses a fixed-size hash table and it can do something reasonable
with collisions, but it's probably slightly too clever for its benefits
(and not well tested).  Checking in anyway, maybe as an intermediate
step to some clever but really good approach (cuckoo hashing?).



Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py	Wed Mar 26 12:14:09 2008
@@ -5,14 +5,12 @@
 from pypy.rpython.annlowlevel import llhelper
 from pypy.rpython.lltypesystem import lltype, lloperation
 from pypy.rpython.llinterp import LLInterpreter
-from pypy.rlib.objectmodel import we_are_translated
-from pypy.rlib.rarithmetic import intmask
+from pypy.rlib.objectmodel import we_are_translated, UnboxedValue
+from pypy.rlib.rarithmetic import r_uint
 from pypy.rlib.unroll import unrolling_iterable
 from pypy.jit.codegen.i386.rgenop import cast_whatever_to_int
 from pypy.jit.hintannotator.model import originalconcretetype
 from pypy.jit.timeshifter import rvalue
-from pypy.jit.timeshifter.greenkey import KeyDesc, empty_key
-from pypy.jit.timeshifter.greenkey import GreenKey, newgreendict
 from pypy.jit.timeshifter.oop import maybe_on_top_of_llinterp
 from pypy.jit.rainbow import rhotpath, fallback
 from pypy.jit.rainbow.portal import getjitenterargdesc
@@ -83,20 +81,7 @@
 
         def maybe_enter_jit(*args):
             greenargs = args[:num_green_args]
-            argshash = state.getkey(*greenargs)
-            counter = state.counters.get(argshash, 0)
-            if counter >= 0:
-                counter += 1
-                if counter < self.jitdrivercls.getcurrentthreshold():
-                    if self.verbose_level >= 3:
-                        interpreter.debug_trace("jit_not_entered", *args)
-                    state.counters[argshash] = counter
-                    return
-                interpreter.debug_trace("jit_compile", *args[:num_green_args])
-                mc = state.compile(argshash, *greenargs)
-            else:
-                greenkey = state.getgreenkey(*greenargs)
-                mc = state.machine_codes.get(greenkey, state.NULL_MC)
+            mc = state.maybe_compile(*greenargs)
             if not mc:
                 return
             if self.verbose_level >= 2:
@@ -104,9 +89,8 @@
             run = maybe_on_top_of_llinterp(exceptiondesc, mc)
             residualargs = state.make_residualargs(*args[num_green_args:])
             run(*residualargs)
-
-        HotEnterState.compile.im_func._dont_inline_ = True
         maybe_enter_jit._always_inline_ = True
+
         self.maybe_enter_jit_fn = maybe_enter_jit
 
     def make_descs(self):
@@ -301,50 +285,111 @@
     # very minimal, just to make the first test pass
     green_args_spec = unrolling_iterable(hotrunnerdesc.green_args_spec)
     red_args_spec = unrolling_iterable(hotrunnerdesc.red_args_spec)
+    green_args_names = unrolling_iterable(
+        ['g%d' % i for i in range(len(hotrunnerdesc.green_args_spec))])
+    green_args_range = unrolling_iterable(
+        range(len(hotrunnerdesc.green_args_spec)))
     if hotrunnerdesc.green_args_spec:
-        keydesc = KeyDesc(hotrunnerdesc.RGenOp, *hotrunnerdesc.green_args_spec)
+        HASH_TABLE_SIZE = 2 ** 14
     else:
-        keydesc = None
+        HASH_TABLE_SIZE = 1
+
+    class StateCell(object):
+        __slots__ = []
+
+    class Counter(StateCell, UnboxedValue):
+        __slots__ = 'counter'
+
+    class MachineCodeEntryPoint(StateCell):
+        def __init__(self, mc, *greenargs):
+            self.mc = mc
+            self.next = Counter(0)
+            i = 0
+            for name in green_args_names:
+                setattr(self, name, greenargs[i])
+                i += 1
+        def equalkey(self, *greenargs):
+            i = 0
+            for name in green_args_names:
+                if getattr(self, name) != greenargs[i]:
+                    return False
+                i += 1
+            return True
 
     class HotEnterState:
         NULL_MC = lltype.nullptr(hotrunnerdesc.RESIDUAL_FUNCTYPE)
 
         def __init__(self):
-            self.machine_codes = newgreendict()
-            self.counters = {}     # value of -1 means "compiled"
+            self.cells = [Counter(0)] * HASH_TABLE_SIZE
 
-            # Only use the hash of the arguments as the key.
+            # Only use the hash of the arguments as the profiling key.
             # Indeed, this is all a heuristic, so if things are designed
             # correctly, the occasional mistake due to hash collision is
             # not too bad.
-            
-            # Another idea would be to replace the 'counters' with some
-            # hand-written fixed-sized hash table.  The fixed-size-ness would
-            # also let old recorded counters gradually disappear as they get
-            # replaced by more recent ones.
 
-        def getkey(self, *greenargs):
-            result = 0x345678
+        def maybe_compile(self, *greenargs):
+            argshash = self.getkeyhash(*greenargs)
+            argshash &= (HASH_TABLE_SIZE - 1)
+            cell = self.cells[argshash]
+            if isinstance(cell, Counter):
+                # update the profiling counter
+                interp = hotrunnerdesc.interpreter
+                n = cell.counter + 1
+                if n < hotrunnerdesc.jitdrivercls.getcurrentthreshold():
+                    if hotrunnerdesc.verbose_level >= 3:
+                        interp.debug_trace("jit_not_entered", *greenargs)
+                    self.cells[argshash] = Counter(n)
+                    return self.NULL_MC
+                interp.debug_trace("jit_compile", *greenargs)
+                return self.compile(argshash, *greenargs)
+            else:
+                # machine code was already compiled for these greenargs
+                # (or we have a hash collision)
+                assert isinstance(cell, MachineCodeEntryPoint)
+                if cell.equalkey(*greenargs):
+                    return cell.mc
+                else:
+                    return self.handle_hash_collision(cell, argshash,
+                                                      *greenargs)
+        maybe_compile._dont_inline_ = True
+
+        def handle_hash_collision(self, cell, argshash, *greenargs):
+            next = cell.next
+            while not isinstance(next, Counter):
+                assert isinstance(next, MachineCodeEntryPoint)
+                if next.equalkey(*greenargs):
+                    # found, move to the front of the linked list
+                    cell.next = next.next
+                    next.next = self.cells[argshash]
+                    self.cells[argshash] = next
+                    return next.mc
+                cell = next
+                next = cell.next
+            # not found at all, do profiling
+            interp = hotrunnerdesc.interpreter
+            n = next.counter + 1
+            if n < hotrunnerdesc.jitdrivercls.getcurrentthreshold():
+                if hotrunnerdesc.verbose_level >= 3:
+                    interp.debug_trace("jit_not_entered", *greenargs)
+                cell.next = Counter(n)
+                return self.NULL_MC
+            interp.debug_trace("jit_compile", *greenargs)
+            return self.compile(argshash, *greenargs)
+        handle_hash_collision._dont_inline_ = True
+
+        def getkeyhash(self, *greenargs):
+            result = r_uint(0x345678)
             i = 0
-            mult = 1000003
+            mult = r_uint(1000003)
             for TYPE in green_args_spec:
+                if i > 0:
+                    result = result * mult
+                    mult = mult + 82520 + 2*len(greenargs)
                 item = greenargs[i]
-                result = intmask((result ^ cast_whatever_to_int(TYPE, item)) *
-                                 intmask(mult))
-                mult = mult + 82520 + 2*len(greenargs)
+                result = result ^ cast_whatever_to_int(TYPE, item)
                 i += 1
             return result
-
-        def getgreenkey(self, *greenvalues):
-            if keydesc is None:
-                return empty_key
-            rgenop = hotrunnerdesc.interpreter.rgenop
-            lst_gv = [None] * len(greenvalues)
-            i = 0
-            for _ in green_args_spec:
-                lst_gv[i] = rgenop.genconst(greenvalues[i])
-                i += 1
-            return GreenKey(lst_gv, keydesc)
+        getkeyhash._always_inline_ = True
 
         def compile(self, argshash, *greenargs):
             try:
@@ -370,9 +415,9 @@
                 red_i += make_arg_redbox.consumes
             redargs = list(redargs)
 
-            greenkey = self.getgreenkey(*greenargs)
-            greenargs = list(greenkey.values)
-            rhotpath.setup_jitstate(interp, jitstate, greenargs, redargs,
+            greenargs_gv = [rgenop.genconst(greenargs[i])
+                            for i in green_args_range]
+            rhotpath.setup_jitstate(interp, jitstate, greenargs_gv, redargs,
                                     hotrunnerdesc.entryjitcode,
                                     hotrunnerdesc.sigtoken)
             builder.start_writing()
@@ -381,8 +426,19 @@
 
             FUNCPTR = lltype.Ptr(hotrunnerdesc.RESIDUAL_FUNCTYPE)
             generated = gv_generated.revealconst(FUNCPTR)
-            self.machine_codes[greenkey] = generated
-            self.counters[argshash] = -1     # compiled
+
+            newcell = MachineCodeEntryPoint(generated, *greenargs)
+            cell = self.cells[argshash]
+            if not isinstance(cell, Counter):
+                while True:
+                    assert isinstance(cell, MachineCodeEntryPoint)
+                    next = cell.next
+                    if isinstance(next, Counter):
+                        cell.next = Counter(0)
+                        break
+                    cell = next
+                newcell.next = self.cells[argshash]
+            self.cells[argshash] = newcell
 
             if not we_are_translated():
                 hotrunnerdesc.residual_graph = generated._obj.graph  #for tests
@@ -395,5 +451,6 @@
                 residualargs = residualargs + collect_residual_args(redargs[i])
                 i += 1
             return residualargs
+        make_residualargs._always_inline_ = True
 
     return HotEnterState

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py	Wed Mar 26 12:14:09 2008
@@ -177,13 +177,13 @@
         self.check_traces([
             # running non-JITted leaves the initial profiling traces
             # recorded by jit_may_enter().  We see the values of n1 and total.
-            "jit_not_entered 19 20",
-            "jit_not_entered 18 39",
-            "jit_not_entered 17 57",
-            "jit_not_entered 16 74",
-            "jit_not_entered 15 90",
-            "jit_not_entered 14 105",
-            "jit_not_entered 13 119",
+            "jit_not_entered",  # 19 20
+            "jit_not_entered",  # 18 39
+            "jit_not_entered",  # 17 57
+            "jit_not_entered",  # 16 74
+            "jit_not_entered",  # 15 90
+            "jit_not_entered",  # 14 105
+            "jit_not_entered",  # 13 119
             # on the start of the next iteration, compile the 'total += n1'
             "jit_compile",
             "pause at hotsplit in ll_function",
@@ -278,8 +278,8 @@
         assert res == main(1, 10)
         self.check_traces([
             # start compiling the 3rd time we loop back
-                "jit_not_entered * struct rpy_string {...} 5 9 10 10",
-                "jit_not_entered * struct rpy_string {...} 5 8 90 10",
+                "jit_not_entered * struct rpy_string {...} 5",  # 9 10 10
+                "jit_not_entered * struct rpy_string {...} 5",  # 8 90 10
                 "jit_compile * struct rpy_string {...} 5",
             # stop compiling at the red split ending an extra iteration
                 "pause at hotsplit in ll_function",
@@ -331,8 +331,9 @@
 
         res = self.run(ll_function, [3], threshold=3, small=True)
         assert res == (3*4)/2
-        self.check_traces(['jit_not_entered 2 3',
-                           'jit_not_entered 1 5'])
+        self.check_traces(['jit_not_entered',  # 2 3
+                           'jit_not_entered',  # 1 5
+                           ])
 
         res = self.run(ll_function, [50], threshold=3, small=True)
         assert res == (50*51)/2
@@ -409,8 +410,8 @@
         res = self.run(main, [1, 71], threshold=3)
         assert res == 5041
         self.check_traces([
-            "jit_not_entered * stru...} 10 70 * array [ 70, 71, 71 ]",
-            "jit_not_entered * stru...} 10 69 * array [ 69, 71, 142 ]",
+            "jit_not_entered * stru...} 10",  # 70 * array [ 70, 71, 71 ]
+            "jit_not_entered * stru...} 10",  # 69 * array [ 69, 71, 142 ]
             "jit_compile * stru...} 10",
         # we first see the promotion of len(regs) in on_enter_jit()
             "pause at promote in TLRJitDriver.on_enter_jit_Hv",

Modified: pypy/branch/jit-hotpath/pypy/module/pypyjit/interp_jit.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/module/pypyjit/interp_jit.py	(original)
+++ pypy/branch/jit-hotpath/pypy/module/pypyjit/interp_jit.py	Wed Mar 26 12:14:09 2008
@@ -114,21 +114,25 @@
 #
 # Public interface
 
+MAX_THRESHOLD = sys.maxint // 2
+
 class PyPyJITConfig:
     def __init__(self):
-        self.cur_threshold = sys.maxint    # disabled until the space is ready
+        self.cur_threshold = MAX_THRESHOLD  # disabled until the space is ready
         self.configured_threshold = JitDriver.getcurrentthreshold()
 
     def isenabled(self):
-        return self.cur_threshold < sys.maxint
+        return self.cur_threshold < MAX_THRESHOLD
 
     def enable(self):
         self.cur_threshold = self.configured_threshold
 
     def disable(self):
-        self.cur_threshold = sys.maxint
+        self.cur_threshold = MAX_THRESHOLD
 
     def setthreshold(self, threshold):
+        if threshold >= MAX_THRESHOLD:
+            threshold = MAX_THRESHOLD - 1
         self.configured_threshold = threshold
         if self.isenabled():
             self.cur_threshold = threshold



More information about the Pypy-commit mailing list