[pypy-commit] pypy jit-counter: in-progress

arigo noreply at buildbot.pypy.org
Wed Oct 30 18:02:07 CET 2013


Author: Armin Rigo <arigo at tunes.org>
Branch: jit-counter
Changeset: r67742:cebfa3182b4d
Date: 2013-10-30 18:01 +0100
http://bitbucket.org/pypy/pypy/changeset/cebfa3182b4d/

Log:	in-progress

diff --git a/rpython/jit/metainterp/counter.py b/rpython/jit/metainterp/counter.py
--- a/rpython/jit/metainterp/counter.py
+++ b/rpython/jit/metainterp/counter.py
@@ -1,14 +1,23 @@
-from rpython.rlib.rarithmetic import r_singlefloat
+from rpython.rlib.rarithmetic import r_singlefloat, intmask
 from rpython.rtyper.lltypesystem import lltype, rffi
 from rpython.translator.tool.cbuild import ExternalCompilationInfo
 
 
+r_uint32 = rffi.r_uint
+assert r_uint32.BITS == 32
+UINT32MAX = 2 ** 32 - 1
+
+
 class JitCounter:
     DEFAULT_SIZE = 4096
 
     def __init__(self, size=DEFAULT_SIZE):
-        assert size >= 1 and (size & (size - 1)) == 0     # a power of two
-        self.mask = size - 1
+        "NOT_RPYTHON"
+        self.size = size
+        self.shift = 1
+        while (UINT32MAX >> self.shift) != size - 1:
+            self.shift += 1
+            assert self.shift < 999, "size is not a power of two <= 2**31"
         self.timetable = lltype.malloc(rffi.CArray(rffi.FLOAT), size,
                                        flavor='raw', zero=True,
                                        track_allocation=False)
@@ -22,30 +31,33 @@
             threshold = 2
         return 1.0 / threshold   # the number is at most 0.5
 
-    def tick(self, hash, increment):
-        hash &= self.mask
-        counter = float(self.timetable[hash]) + increment
+    def get_index(self, hash):
+        """Return the index (< self.size) from a hash value.  This keeps
+        the *high* bits of hash!  Be sure that hash is computed correctly."""
+        return intmask(r_uint32(hash) >> self.shift)
+    get_index._always_inline_ = True
+
+    def tick(self, index, increment):
+        counter = float(self.timetable[index]) + increment
         if counter < 1.0:
-            self.timetable[hash] = r_singlefloat(counter)
+            self.timetable[index] = r_singlefloat(counter)
             return False
         else:
             return True
     tick._always_inline_ = True
 
-    def reset(self, hash):
-        hash &= self.mask
-        self.timetable[hash] = r_singlefloat(0.0)
+    def reset(self, index):
+        self.timetable[index] = r_singlefloat(0.0)
 
-    def lookup_chain(self, hash):
-        hash &= self.mask
-        return self.celltable[hash]
+    def lookup_chain(self, index):
+        return self.celltable[index]
 
-    def cleanup_chain(self, hash):
-        self.install_new_cell(hash, None)
+    def cleanup_chain(self, index):
+        self.reset(index)
+        self.install_new_cell(index, None)
 
-    def install_new_cell(self, hash, newcell):
-        hash &= self.mask
-        cell = self.celltable[hash]
+    def install_new_cell(self, index, newcell):
+        cell = self.celltable[index]
         keep = newcell
         while cell is not None:
             remove_me = cell.should_remove_jitcell()
@@ -54,7 +66,7 @@
                 cell.next = keep
                 keep = cell
             cell = nextcell
-        self.celltable[hash] = keep
+        self.celltable[index] = keep
 
     def set_decay(self, decay):
         """Set the decay, from 0 (none) to 1000 (max)."""
@@ -75,7 +87,7 @@
         # important in corner cases where we would suddenly compile more
         # than one loop because all counters reach the bound at the same
         # time, but where compiling all but the first one is pointless.
-        size = self.mask + 1
+        size = self.size
         pypy__decay_jit_counters(self.timetable, self.decay_by_mult, size)
 
 
diff --git a/rpython/jit/metainterp/test/test_counter.py b/rpython/jit/metainterp/test/test_counter.py
--- a/rpython/jit/metainterp/test/test_counter.py
+++ b/rpython/jit/metainterp/test/test_counter.py
@@ -1,21 +1,28 @@
 from rpython.jit.metainterp.counter import JitCounter
 
 
+def test_get_index():
+    jc = JitCounter(size=128)    # 7 bits
+    for i in range(10):
+        hash = 400000001 * i
+        index = jc.get_index(hash)
+        assert index == (hash >> (32 - 7))
+
 def test_tick():
     jc = JitCounter()
     incr = jc.compute_threshold(4)
     for i in range(5):
-        r = jc.tick(1234567, incr)
+        r = jc.tick(104, incr)
         assert r is (i >= 3)
     for i in range(5):
-        r = jc.tick(1234568, incr)
-        s = jc.tick(1234569, incr)
+        r = jc.tick(108, incr)
+        s = jc.tick(109, incr)
         assert r is (i >= 3)
         assert s is (i >= 3)
-    jc.reset(1234568)
+    jc.reset(108)
     for i in range(5):
-        r = jc.tick(1234568, incr)
-        s = jc.tick(1234569, incr)
+        r = jc.tick(108, incr)
+        s = jc.tick(109, incr)
         assert r is (i >= 3)
         assert s is True
 
@@ -30,21 +37,21 @@
             return False
     #
     jc = JitCounter()
-    assert jc.lookup_chain(1234567) is None
-    d1 = Dead()
-    jc.install_new_cell(1234567, d1)
-    assert jc.lookup_chain(1234567) is d1
+    assert jc.lookup_chain(104) is None
+    d1 = Dead() 
+    jc.install_new_cell(104, d1)
+    assert jc.lookup_chain(104) is d1
     d2 = Dead()
-    jc.install_new_cell(1234567, d2)
-    assert jc.lookup_chain(1234567) is d2
+    jc.install_new_cell(104, d2)
+    assert jc.lookup_chain(104) is d2
     assert d2.next is None
     #
     d3 = Alive()
-    jc.install_new_cell(1234567, d3)
-    assert jc.lookup_chain(1234567) is d3
+    jc.install_new_cell(104, d3)
+    assert jc.lookup_chain(104) is d3
     assert d3.next is None
     d4 = Alive()
-    jc.install_new_cell(1234567, d4)
-    assert jc.lookup_chain(1234567) is d3
+    jc.install_new_cell(104, d4)
+    assert jc.lookup_chain(104) is d3
     assert d3.next is d4
     assert d4.next is None
diff --git a/rpython/jit/metainterp/warmspot.py b/rpython/jit/metainterp/warmspot.py
--- a/rpython/jit/metainterp/warmspot.py
+++ b/rpython/jit/metainterp/warmspot.py
@@ -522,11 +522,6 @@
         #
         annhelper = MixLevelHelperAnnotator(self.translator.rtyper)
         for jd in self.jitdrivers_sd:
-            jd._set_jitcell_at_ptr = self._make_hook_graph(jd,
-                annhelper, jd.jitdriver.set_jitcell_at, annmodel.s_None,
-                s_BaseJitCell_not_None)
-            jd._get_jitcell_at_ptr = self._make_hook_graph(jd,
-                annhelper, jd.jitdriver.get_jitcell_at, s_BaseJitCell_or_None)
             jd._get_printable_location_ptr = self._make_hook_graph(jd,
                 annhelper, jd.jitdriver.get_printable_location, s_Str)
             jd._confirm_enter_jit_ptr = self._make_hook_graph(jd,
diff --git a/rpython/jit/metainterp/warmstate.py b/rpython/jit/metainterp/warmstate.py
--- a/rpython/jit/metainterp/warmstate.py
+++ b/rpython/jit/metainterp/warmstate.py
@@ -4,7 +4,7 @@
 from rpython.jit.codewriter import support, heaptracker, longlong
 from rpython.jit.metainterp import history
 from rpython.rlib.debug import debug_start, debug_stop, debug_print
-from rpython.rlib.jit import PARAMETERS, BaseJitCell
+from rpython.rlib.jit import PARAMETERS
 from rpython.rlib.nonconst import NonConstant
 from rpython.rlib.objectmodel import specialize, we_are_translated, r_dict
 from rpython.rlib.rarithmetic import intmask
@@ -124,7 +124,7 @@
         return rffi.cast(lltype.Signed, x)
 
 
-class JitCell(BaseJitCell):
+class BaseJitCell(object):
     tracing = False
     dont_trace_here = chr(0)
     wref_procedure_token = None
@@ -267,7 +267,7 @@
         vinfo = jitdriver_sd.virtualizable_info
         index_of_virtualizable = jitdriver_sd.index_of_virtualizable
         num_green_args = jitdriver_sd.num_green_args
-        get_jitcell = self.make_jitcell_getter()
+        JitCell = self.make_jitcell_subclass()
         self.make_jitdriver_callbacks()
         confirm_enter_jit = self.confirm_enter_jit
         range_red_args = unrolling_iterable(
@@ -310,63 +310,63 @@
             #
             assert 0, "should have raised"
 
-        def bound_reached(cell, *args):
-            jitcounter.reset(
-            cell.counter = 0
+        def bound_reached(index, *args):
+            jitcounter.reset(index)
             if not confirm_enter_jit(*args):
                 return
             # start tracing
             from rpython.jit.metainterp.pyjitpl import MetaInterp
             metainterp = MetaInterp(metainterp_sd, jitdriver_sd)
-            cell.tracing = True
-            cell.reset_counter()
+            greenargs = args[:num_green_args]
+            newcell = JitCell(*greenargs)
+            newcell.tracing = True
+            jitcounter.install_new_cell(index, newcell)
             try:
                 metainterp.compile_and_run_once(jitdriver_sd, *args)
             finally:
-                cell.tracing = False
-                cell.reset_counter()
+                newcell.tracing = False
 
-        def maybe_compile_and_run(threshold, *args):
+        def maybe_compile_and_run(increment_threshold, *args):
             """Entry point to the JIT.  Called at the point with the
             can_enter_jit() hint.
             """
-            # look for the cell corresponding to the current greenargs
+            # Look for the cell corresponding to the current greenargs.
+            # Search for the JitCell that is of the correct subclass of
+            # BaseJitCell, and that stores a key that compares equal
             greenargs = args[:num_green_args]
-            cell = get_jitcell(True, *greenargs)
-            mode = cell.mode
+            index = JitCell.get_index(*greenargs)
+            cell = jitcounter.lookup_chain(index)
+            while cell is not None:
+                if isinstance(cell, JitCell) and cell.comparekey(*greenargs):
+                    break    # found
+            else:
+                # not found. increment the counter
+                if jitcounter.tick(index, increment_threshold):
+                    bound_reached(index, *args)
+                return
 
-            if mode == MODE_COUNTING:
-                # update the profiling counter
-                n = cell.counter + threshold
-                if n <= self.THRESHOLD_LIMIT:       # bound not reached
-                    cell.counter = n
-                    return
-                else:
-                    bound_reached(cell, *args)
-                    return
-
-            else:
-                if mode != MODE_HAVE_PROC:
-                    assert mode == MODE_TRACING
-                    # tracing already happening in some outer invocation of
-                    # this function. don't trace a second time.
-                    return
-                if not confirm_enter_jit(*args):
-                    return
-                # machine code was already compiled for these greenargs
-                procedure_token = cell.get_procedure_token()
-                if procedure_token is None:   # it was a weakref that has been freed
-                    cell.counter = 0
-                    cell.mode = MODE_COUNTING
-                    return
-                # extract and unspecialize the red arguments to pass to
-                # the assembler
-                execute_args = ()
-                for i in range_red_args:
-                    execute_args += (unspecialize_value(args[i]), )
-                # run it!  this executes until interrupted by an exception
-                execute_assembler(procedure_token, *execute_args)
+            # Here, we have found 'cell'.
             #
+            if cell.tracing:
+                # tracing already happening in some outer invocation of
+                # this function. don't trace a second time.
+                return
+            # machine code was already compiled for these greenargs
+            procedure_token = cell.get_procedure_token()
+            if procedure_token is None:
+                # it was an aborted compilation, or maybe a weakref that
+                # has been freed
+                jitcounter.cleanup_chain(index)
+                return
+            if not confirm_enter_jit(*args):
+                return
+            # extract and unspecialize the red arguments to pass to
+            # the assembler
+            execute_args = ()
+            for i in range_red_args:
+                execute_args += (unspecialize_value(args[i]), )
+            # run it!  this executes until interrupted by an exception
+            execute_assembler(procedure_token, *execute_args)
             assert 0, "should not reach this point"
 
         maybe_compile_and_run._dont_inline_ = True
@@ -401,144 +401,45 @@
 
     # ----------
 
-    def make_jitcell_getter(self):
+    def make_jitcell_subclass(self):
         "NOT_RPYTHON"
-        if hasattr(self, 'jit_getter'):
-            return self.jit_getter
+        if hasattr(self, 'JitCell'):
+            return self.JitCell
         #
-        if self.jitdriver_sd._get_jitcell_at_ptr is None:
-            jit_getter = self._make_jitcell_getter_default()
-        else:
-            jit_getter = self._make_jitcell_getter_custom()
+        jitcounter = self.warmrunnerdesc.jitcounter
+        jitdriver_sd = self.jitdriver_sd
+        green_args_spec = unrolling_iterable([('g%d' % i, TYPE)
+                     for i, TYPE in enumerate(jitdriver_sd._green_args_spec)])
         #
-        unwrap_greenkey = self.make_unwrap_greenkey()
+        class JitCell(BaseJitCell):
+            def __init__(self, *greenargs):
+                i = 0
+                for attrname, _ in green_args_spec:
+                    setattr(self, attrname, greenargs[i])
+                    i = i + 1
+
+            def comparekey(self, *greenargs2):
+                i = 0
+                for attrname, TYPE in green_args_spec:
+                    item1 = getattr(self, attrname)
+                    if not equal_whatever(TYPE, item1, greenargs2[i]):
+                        return False
+                    i = i + 1
+                return True
+
+            @staticmethod
+            def get_index(*greenargs):
+                x = 0
+                i = 0
+                for TYPE in green_args_spec:
+                    item = greenargs[i]
+                    y = hash_whatever(TYPE, item)
+                    x = intmask((x ^ y) * 1405695061)  # prime number, 2**30~31
+                    i = i + 1
+                return jitcounter.get_index(x)
         #
-        def jit_cell_at_key(greenkey):
-            greenargs = unwrap_greenkey(greenkey)
-            return jit_getter(True, *greenargs)
-        self.jit_cell_at_key = jit_cell_at_key
-        self.jit_getter = jit_getter
-        #
-        return jit_getter
-
-    def _make_jitcell_getter_default(self):
-        "NOT_RPYTHON"
-        jitdriver_sd = self.jitdriver_sd
-        green_args_spec = unrolling_iterable(jitdriver_sd._green_args_spec)
-        #
-        def comparekey(greenargs1, greenargs2):
-            i = 0
-            for TYPE in green_args_spec:
-                if not equal_whatever(TYPE, greenargs1[i], greenargs2[i]):
-                    return False
-                i = i + 1
-            return True
-        #
-        def hashkey(greenargs):
-            x = 0x345678
-            i = 0
-            for TYPE in green_args_spec:
-                item = greenargs[i]
-                y = hash_whatever(TYPE, item)
-                x = intmask((1000003 * x) ^ y)
-                i = i + 1
-            return x
-        #
-        jitcell_dict = r_dict(comparekey, hashkey)
-        try:
-            self.warmrunnerdesc.stats.jitcell_dicts.append(jitcell_dict)
-        except AttributeError:
-            pass
-        #
-        def _cleanup_dict():
-            minimum = self.THRESHOLD_LIMIT // 20     # minimum 5%
-            killme = []
-            for key, cell in jitcell_dict.iteritems():
-                if cell.mode == MODE_COUNTING:
-                    cell.counter = int(cell.counter * 0.92)
-                    if cell.counter < minimum:
-                        killme.append(key)
-                elif (cell.mode == MODE_HAVE_PROC
-                      and cell.get_procedure_token() is None):
-                    killme.append(key)
-            for key in killme:
-                del jitcell_dict[key]
-        #
-        def _maybe_cleanup_dict():
-            # Once in a while, rarely, when too many entries have
-            # been put in the jitdict_dict, we do a cleanup phase:
-            # we decay all counters and kill entries with a too
-            # low counter.
-            self._trigger_automatic_cleanup += 1
-            if self._trigger_automatic_cleanup > 20000:
-                self._trigger_automatic_cleanup = 0
-                _cleanup_dict()
-        #
-        self._trigger_automatic_cleanup = 0
-        self._jitcell_dict = jitcell_dict       # for tests
-        #
-        def get_jitcell(build, *greenargs):
-            try:
-                cell = jitcell_dict[greenargs]
-            except KeyError:
-                if not build:
-                    return None
-                _maybe_cleanup_dict()
-                cell = JitCell()
-                jitcell_dict[greenargs] = cell
-            return cell
-        return get_jitcell
-
-    def _make_jitcell_getter_custom(self):
-        "NOT_RPYTHON"
-        rtyper = self.warmrunnerdesc.rtyper
-        get_jitcell_at_ptr = self.jitdriver_sd._get_jitcell_at_ptr
-        set_jitcell_at_ptr = self.jitdriver_sd._set_jitcell_at_ptr
-        lltohlhack = {}
-        # note that there is no equivalent of _maybe_cleanup_dict()
-        # in the case of custom getters.  We assume that the interpreter
-        # stores the JitCells on some objects that can go away by GC,
-        # like the PyCode objects in PyPy.
-        #
-        def get_jitcell(build, *greenargs):
-            fn = support.maybe_on_top_of_llinterp(rtyper, get_jitcell_at_ptr)
-            cellref = fn(*greenargs)
-            # <hacks>
-            if we_are_translated():
-                BASEJITCELL = lltype.typeOf(cellref)
-                cell = cast_base_ptr_to_instance(JitCell, cellref)
-            else:
-                if isinstance(cellref, (BaseJitCell, type(None))):
-                    BASEJITCELL = None
-                    cell = cellref
-                else:
-                    BASEJITCELL = lltype.typeOf(cellref)
-                    if cellref:
-                        cell = lltohlhack[rtyper.type_system.deref(cellref)]
-                    else:
-                        cell = None
-            if not build:
-                return cell
-            if cell is None:
-                cell = JitCell()
-                # <hacks>
-                if we_are_translated():
-                    cellref = cast_object_to_ptr(BASEJITCELL, cell)
-                else:
-                    if BASEJITCELL is None:
-                        cellref = cell
-                    else:
-                        if isinstance(BASEJITCELL, lltype.Ptr):
-                            cellref = lltype.malloc(BASEJITCELL.TO)
-                        else:
-                            assert False, "no clue"
-                        lltohlhack[rtyper.type_system.deref(cellref)] = cell
-                # </hacks>
-                fn = support.maybe_on_top_of_llinterp(rtyper,
-                                                      set_jitcell_at_ptr)
-                fn(cellref, *greenargs)
-            return cell
-        return get_jitcell
+        self.JitCell = JitCell
+        return JitCell
 
     # ----------
 
diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -515,8 +515,8 @@
                                   if '.' not in name])
         self._heuristic_order = {}   # check if 'reds' and 'greens' are ordered
         self._make_extregistryentries()
-        self.get_jitcell_at = get_jitcell_at
-        self.set_jitcell_at = set_jitcell_at
+        assert get_jitcell_at is None, "get_jitcell_at no longer used"
+        assert set_jitcell_at is None, "set_jitcell_at no longer used"
         self.get_printable_location = get_printable_location
         self.confirm_enter_jit = confirm_enter_jit
         self.can_never_inline = can_never_inline
@@ -696,9 +696,6 @@
 #
 # Annotation and rtyping of some of the JitDriver methods
 
-class BaseJitCell(object):
-    __slots__ = ()
-
 
 class ExtEnterLeaveMarker(ExtRegistryEntry):
     # Replace a call to myjitdriver.jit_merge_point(**livevars)
@@ -746,10 +743,7 @@
 
     def annotate_hooks(self, **kwds_s):
         driver = self.instance.im_self
-        s_jitcell = self.bookkeeper.valueoftype(BaseJitCell)
         h = self.annotate_hook
-        h(driver.get_jitcell_at, driver.greens, **kwds_s)
-        h(driver.set_jitcell_at, driver.greens, [s_jitcell], **kwds_s)
         h(driver.get_printable_location, driver.greens, **kwds_s)
 
     def annotate_hook(self, func, variables, args_s=[], **kwds_s):


More information about the pypy-commit mailing list