[pypy-commit] pypy optinfo-into-bridges: propagate class information into bridges by storing extra bits into the resume

cfbolz pypy.commits at gmail.com
Wed Oct 12 11:47:41 EDT 2016


Author: Carl Friedrich Bolz <cfbolz at gmx.de>
Branch: optinfo-into-bridges
Changeset: r87729:64035cb4e678
Date: 2016-10-12 17:07 +0200
http://bitbucket.org/pypy/pypy/changeset/64035cb4e678/

Log:	propagate class information into bridges by storing extra bits into
	the resume bytecode

	(approach is open to discussion)

diff --git a/rpython/jit/metainterp/compile.py b/rpython/jit/metainterp/compile.py
--- a/rpython/jit/metainterp/compile.py
+++ b/rpython/jit/metainterp/compile.py
@@ -85,13 +85,14 @@
     """ This represents ops() with a jump at the end that goes to some
     loop, we need to deal with virtual state and inlining of short preamble
     """
-    def __init__(self, trace, runtime_boxes, call_pure_results=None,
+    def __init__(self, trace, runtime_boxes, key, call_pure_results=None,
                  enable_opts=None, inline_short_preamble=False):
         self.trace = trace
         self.runtime_boxes = runtime_boxes
         self.call_pure_results = call_pure_results
         self.enable_opts = enable_opts
         self.inline_short_preamble = inline_short_preamble
+        self.resumestorage = key
 
     def optimize(self, metainterp_sd, jitdriver_sd, optimizations, unroll):
         from rpython.jit.metainterp.optimizeopt.unroll import UnrollOptimizer
@@ -100,7 +101,8 @@
         return opt.optimize_bridge(self.trace, self.runtime_boxes,
                                    self.call_pure_results,
                                    self.inline_short_preamble,
-                                   self.box_names_memo)
+                                   self.box_names_memo,
+                                   self.resumestorage.rd_numb)
 
 class UnrolledLoopData(CompileData):
     """ This represents label() ops jump with extra info that's from the
@@ -1068,7 +1070,11 @@
     call_pure_results = metainterp.call_pure_results
 
     if metainterp.history.ends_with_jump:
-        data = BridgeCompileData(trace, runtime_boxes,
+        if isinstance(resumekey, ResumeGuardCopiedDescr):
+            key = resumekey.prev
+        else:
+            key = resumekey
+        data = BridgeCompileData(trace, runtime_boxes, key,
                                  call_pure_results=call_pure_results,
                                  enable_opts=enable_opts,
                                  inline_short_preamble=inline_short_preamble)
diff --git a/rpython/jit/metainterp/opencoder.py b/rpython/jit/metainterp/opencoder.py
--- a/rpython/jit/metainterp/opencoder.py
+++ b/rpython/jit/metainterp/opencoder.py
@@ -62,7 +62,7 @@
         assert isinstance(snapshot, TopSnapshot)
         self.vable_array = snapshot.vable_array
         self.vref_array = snapshot.vref_array
-        self.size = len(self.vable_array) + len(self.vref_array) + 2
+        self.size = len(self.vable_array) + len(self.vref_array) + 3
         jc_index, pc = unpack_uint(snapshot.packed_jitcode_pc)
         self.framestack = []
         if jc_index == 2**16-1:
diff --git a/rpython/jit/metainterp/optimizeopt/bridgeopt.py b/rpython/jit/metainterp/optimizeopt/bridgeopt.py
new file mode 100644
--- /dev/null
+++ b/rpython/jit/metainterp/optimizeopt/bridgeopt.py
@@ -0,0 +1,59 @@
+""" Code to feed information from the optimizer via the resume code into the
+optimizer of the bridge attached to a guard. """
+
+from rpython.jit.metainterp.resumecode import numb_next_item, numb_next_n_items, unpack_numbering
+
+# XXX at the moment this is all quite ad-hoc. Could be delegated to the
+# different optimization passes
+
+# adds the following sections at the end of the resume code:
+#
+# ---- known classes
+# <bitfield> size is the number of reference boxes in the liveboxes
+#            1 klass known
+#            0 klass unknown
+#            (the class is found by actually looking at the runtime value)
+#            the bits are bunched in bunches of 7
+#
+# ----
+
+def serialize_optimizer_knowledge(optimizer, numb_state, liveboxes, memo):
+    numb_state.grow(len(liveboxes)) # bit too much
+    # class knowledge
+    bitfield = 0
+    shifts = 0
+    for box in liveboxes:
+        if box.type != "r":
+            continue
+        info = optimizer.getptrinfo(box)
+        known_class = info is not None and info.get_known_class(optimizer.cpu) is not None
+        bitfield <<= 1
+        bitfield |= known_class
+        shifts += 1
+        if shifts == 7:
+            numb_state.append_int(bitfield)
+            bitfield = shifts = 0
+    if shifts:
+        numb_state.append_int(bitfield << (7 - shifts))
+
+def deserialize_optimizer_knowledge(optimizer, numb, runtime_boxes, liveboxes):
+    # skip resume section
+    index = skip_resume_section(numb, optimizer)
+    # class knowledge
+    bitfield = 0
+    mask = 0
+    for i, box in enumerate(liveboxes):
+        if box.type != "r":
+            continue
+        if not mask:
+            bitfield, index = numb_next_item(numb, index)
+            mask = 0b1000000
+        class_known = bitfield & mask
+        mask >>= 1
+        if class_known:
+            cls = optimizer.cpu.ts.cls_of_box(runtime_boxes[i])
+            optimizer.make_constant_class(box, cls)
+
+def skip_resume_section(numb, optimizer):
+    startcount, index = numb_next_item(numb, 0)
+    return numb_next_n_items(numb, startcount, 0)
diff --git a/rpython/jit/metainterp/optimizeopt/optimizer.py b/rpython/jit/metainterp/optimizeopt/optimizer.py
--- a/rpython/jit/metainterp/optimizeopt/optimizer.py
+++ b/rpython/jit/metainterp/optimizeopt/optimizer.py
@@ -285,6 +285,7 @@
         self.optrewrite = None
         self.optearlyforce = None
         self.optunroll = None
+        self._really_emitted_operation = None
 
         self._last_guard_op = None
 
diff --git a/rpython/jit/metainterp/optimizeopt/unroll.py b/rpython/jit/metainterp/optimizeopt/unroll.py
--- a/rpython/jit/metainterp/optimizeopt/unroll.py
+++ b/rpython/jit/metainterp/optimizeopt/unroll.py
@@ -237,9 +237,13 @@
         return label_vs
 
     def optimize_bridge(self, trace, runtime_boxes, call_pure_results,
-                        inline_short_preamble, box_names_memo):
+                        inline_short_preamble, box_names_memo, numb):
+        from rpython.jit.metainterp.optimizeopt.bridgeopt import deserialize_optimizer_knowledge
         trace = trace.get_iter()
         self._check_no_forwarding([trace.inputargs])
+        deserialize_optimizer_knowledge(self.optimizer,
+                                        numb, runtime_boxes,
+                                        trace.inputargs)
         info, ops = self.optimizer.propagate_all_forward(trace,
             call_pure_results, False)
         jump_op = info.jump_op
diff --git a/rpython/jit/metainterp/resume.py b/rpython/jit/metainterp/resume.py
--- a/rpython/jit/metainterp/resume.py
+++ b/rpython/jit/metainterp/resume.py
@@ -165,7 +165,8 @@
 class NumberingState(object):
     def __init__(self, size):
         self.liveboxes = {}
-        self.current = [rffi.cast(rffi.SHORT, 0)] * size
+        self.current = []
+        self.grow(size)
         self._pos = 0
         self.num_boxes = 0
         self.num_virtuals = 0
@@ -180,7 +181,16 @@
         return self.append_short(short)
 
     def create_numbering(self):
-        return resumecode.create_numbering(self.current)
+        return resumecode.create_numbering(self.current[:self._pos])
+
+    def grow(self, size):
+        self.current.extend([rffi.cast(rffi.SHORT, 0)] * size)
+
+    def patch_current_size(self, index):
+        item = self._pos
+        short = rffi.cast(rffi.SHORT, item)
+        assert rffi.cast(lltype.Signed, short) == item
+        self.current[index] = item
 
 class ResumeDataLoopMemo(object):
 
@@ -268,6 +278,7 @@
     def number(self, optimizer, position, trace):
         snapshot_iter = trace.get_snapshot_iter(position)
         numb_state = NumberingState(snapshot_iter.size)
+        numb_state.append_int(-1) # patch later
 
         arr = snapshot_iter.vable_array
 
@@ -287,6 +298,7 @@
             numb_state.append_int(pc)
             self._number_boxes(
                     snapshot_iter, snapshot.box_array, optimizer, numb_state)
+        numb_state.patch_current_size(0)
 
         return numb_state
 
@@ -471,6 +483,7 @@
         self._number_virtuals(liveboxes, optimizer, num_virtuals)
         self._add_pending_fields(optimizer, pending_setfields)
 
+        self._add_optimizer_sections(numb_state, liveboxes)
         storage.rd_numb = numb_state.create_numbering()
         storage.rd_consts = self.memo.consts
         return liveboxes[:]
@@ -590,6 +603,11 @@
                 return self.liveboxes_from_env[box]
             return self.liveboxes[box]
 
+    def _add_optimizer_sections(self, numb_state, liveboxes):
+        # add extra information about things the optimizer learned
+        from rpython.jit.metainterp.optimizeopt.bridgeopt import serialize_optimizer_knowledge
+        serialize_optimizer_knowledge(self.optimizer, numb_state, liveboxes, self.memo)
+
 class AbstractVirtualInfo(object):
     kind = REF
     is_about_raw = False
@@ -932,7 +950,11 @@
     def _init(self, cpu, storage):
         self.cpu = cpu
         self.numb = storage.rd_numb
-        self.cur_index = 0
+        count, self.cur_index = resumecode.numb_next_item(
+            self.numb, 0)
+        # XXX inefficient
+        self.size_resume_section = resumecode.numb_next_n_items(
+            self.numb, count, 0)
         self.count = storage.rd_count
         self.consts = storage.rd_consts
 
@@ -948,7 +970,7 @@
         return jitcode_pos, pc
 
     def done_reading(self):
-        return self.cur_index >= len(self.numb.code)
+        return self.cur_index >= self.size_resume_section
 
     def getvirtual_ptr(self, index):
         # Returns the index'th virtual, building it lazily if needed.
@@ -1057,6 +1079,7 @@
     boxes = resumereader.consume_vref_and_vable_boxes(virtualizable_info,
                                                       greenfield_info)
     virtualizable_boxes, virtualref_boxes = boxes
+
     while not resumereader.done_reading():
         jitcode_pos, pc = resumereader.read_jitcode_pos_pc()
         jitcode = metainterp.staticdata.jitcodes[jitcode_pos]
@@ -1110,7 +1133,7 @@
         return lst, index
 
     def consume_vref_and_vable_boxes(self, vinfo, ginfo):
-        vable_size, index = resumecode.numb_next_item(self.numb, 0)
+        vable_size, index = resumecode.numb_next_item(self.numb, self.cur_index)
         if vinfo is not None:
             virtualizable_boxes, index = self.consume_virtualizable_boxes(vinfo,
                                                                           index)
@@ -1443,7 +1466,7 @@
     load_value_of_type._annspecialcase_ = 'specialize:arg(1)'
 
     def consume_vref_and_vable(self, vrefinfo, vinfo, ginfo):
-        vable_size, index = resumecode.numb_next_item(self.numb, 0)
+        vable_size, index = resumecode.numb_next_item(self.numb, self.cur_index)
         if self.resume_after_guard_not_forced != 2:
             if vinfo is not None:
                 index = self.consume_vable_info(vinfo, index)
diff --git a/rpython/jit/metainterp/resumecode.py b/rpython/jit/metainterp/resumecode.py
--- a/rpython/jit/metainterp/resumecode.py
+++ b/rpython/jit/metainterp/resumecode.py
@@ -1,6 +1,8 @@
 
 """ Resume bytecode. It goes as following:
 
+  # ----- resume section
+  [total size of resume section, unencoded]
   [<length> <virtualizable object> <numb> <numb> <numb>]    if vinfo is not None
    -OR-
   [1 <ginfo object>]                                        if ginfo is not None
@@ -13,7 +15,10 @@
   [<pc> <jitcode> <numb> <numb>]
   ...
 
-  until the length of the array.
+  until the size of the resume section
+
+  # ----- optimization section
+  <more code>                                      further sections according to bridgeopt.py
 """
 
 from rpython.rtyper.lltypesystem import rffi, lltype
diff --git a/rpython/jit/metainterp/test/test_bridgeopt.py b/rpython/jit/metainterp/test/test_bridgeopt.py
new file mode 100644
--- /dev/null
+++ b/rpython/jit/metainterp/test/test_bridgeopt.py
@@ -0,0 +1,98 @@
+# tests that check that information is fed from the optimizer into the bridges
+
+from rpython.rlib import jit
+from rpython.jit.metainterp.test.support import LLJitMixin
+from rpython.jit.metainterp.optimizeopt.bridgeopt import serialize_optimizer_knowledge
+from rpython.jit.metainterp.optimizeopt.bridgeopt import deserialize_optimizer_knowledge
+from rpython.jit.metainterp.resoperation import InputArgRef, InputArgInt
+from rpython.jit.metainterp.resume import NumberingState
+from rpython.jit.metainterp.optimizeopt.info import InstancePtrInfo
+
+class FakeTS(object):
+    def __init__(self, dct):
+        self.dct = dct
+
+    def cls_of_box(self, box):
+        return self.dct[box]
+
+
+class FakeCPU(object):
+    def __init__(self, dct):
+        self.ts = FakeTS(dct)
+
+class FakeOptimizer(object):
+    metainterp_sd = None
+
+    def __init__(self, dct={}, cpu=None):
+        self.dct = dct
+        self.constant_classes = {}
+        self.cpu = cpu
+
+    def getptrinfo(self, arg):
+        return self.dct.get(arg, None)
+
+    def make_constant_class(self, arg, cls):
+        self.constant_classes[arg] = cls
+
+class FakeClass(object):
+    pass
+
+def test_simple():
+    box1 = InputArgRef()
+    box2 = InputArgRef()
+    box3 = InputArgRef()
+
+    cls = FakeClass()
+    dct = {box1: InstancePtrInfo(known_class=cls)}
+    optimizer = FakeOptimizer(dct)
+
+    numb_state = NumberingState(4)
+    numb_state.append_int(1) # vinfo
+    liveboxes = [InputArgInt(), box2, box1, box3]
+
+    serialize_optimizer_knowledge(optimizer, numb_state, liveboxes, None)
+
+    assert numb_state.current[:numb_state._pos] == [1, 0b0100000]
+
+    rbox1 = InputArgRef()
+    rbox2 = InputArgRef()
+    rbox3 = InputArgRef()
+    after_optimizer = FakeOptimizer(cpu=FakeCPU({rbox1: cls}))
+    deserialize_optimizer_knowledge(
+        after_optimizer, numb_state.create_numbering(),
+        [InputArgInt(), rbox2, rbox1, rbox3], liveboxes)
+    assert box1 in after_optimizer.constant_classes
+    assert box2 not in after_optimizer.constant_classes
+    assert box3 not in after_optimizer.constant_classes
+
+
+class TestOptBridge(LLJitMixin):
+    # integration tests
+    def test_bridge(self):
+        myjitdriver = jit.JitDriver(greens=[], reds=['y', 'res', 'n', 'a'])
+        class A(object):
+            def f(self):
+                return 1
+        class B(A):
+            def f(self):
+                return 2
+        def f(x, y, n):
+            if x:
+                a = A()
+            else:
+                a = B()
+            a.x = 0
+            res = 0
+            while y > 0:
+                myjitdriver.jit_merge_point(y=y, n=n, res=res, a=a)
+                res += a.f()
+                a.x += 1
+                if y > n:
+                    res += 1
+                res += a.f()
+                y -= 1
+            return res
+        res = self.meta_interp(f, [6, 32, 16])
+        assert res == f(6, 32, 16)
+        self.check_trace_count(3)
+        self.check_resops(guard_class=1)
diff --git a/rpython/jit/metainterp/test/test_resume.py b/rpython/jit/metainterp/test/test_resume.py
--- a/rpython/jit/metainterp/test/test_resume.py
+++ b/rpython/jit/metainterp/test/test_resume.py
@@ -289,7 +289,8 @@
     assert bh.written_f == expected_f
 
 
-Numbering = create_numbering
+def Numbering(l):
+    return create_numbering([len(l)] + l) # prefix index to the end of thing
 
 def tagconst(i):
     return tag(i + TAG_CONST_OFFSET, TAGCONST)
@@ -362,7 +363,7 @@
             return s
     class FakeStorage(object):
         rd_virtuals = [FakeVinfo(), None]
-        rd_numb = []
+        rd_numb = Numbering([])
         rd_consts = []
         rd_pendingfields = None
         rd_count = 0
@@ -858,7 +859,7 @@
     base = [0, 0, tag(0, TAGBOX), tag(1, TAGINT),
             tag(1, TAGBOX), tag(0, TAGBOX), tag(2, TAGINT)]
 
-    assert unpack_numbering(numb) == [0, 0] + base + [0, 2, tag(3, TAGINT), tag(2, TAGBOX),
+    assert unpack_numbering(numb) == [16, 0, 0] + base + [0, 2, tag(3, TAGINT), tag(2, TAGBOX),
                                       tag(0, TAGBOX), tag(1, TAGINT)]
     t.append(0)
     snap2 = t.create_top_snapshot(FakeJitCode("jitcode", 0), 2, Frame(env2),
@@ -872,7 +873,7 @@
     assert numb_state2.liveboxes == {b1: tag(0, TAGBOX), b2: tag(1, TAGBOX),
                                      b3: tag(2, TAGBOX)}
     assert numb_state2.liveboxes is not numb_state.liveboxes
-    assert unpack_numbering(numb2) == [0, 0] + base + [0, 2, tag(3, TAGINT), tag(2, TAGBOX),
+    assert unpack_numbering(numb2) == [16, 0, 0] + base + [0, 2, tag(3, TAGINT), tag(2, TAGBOX),
                                        tag(0, TAGBOX), tag(3, TAGINT)]
 
     t.append(0)
@@ -894,7 +895,7 @@
     assert numb_state3.num_virtuals == 0
     
     assert numb_state3.liveboxes == {b1: tag(0, TAGBOX), b2: tag(1, TAGBOX)}
-    assert unpack_numbering(numb3) == ([0, 2, tag(3, TAGINT), tag(4, TAGINT),
+    assert unpack_numbering(numb3) == ([16, 0, 2, tag(3, TAGINT), tag(4, TAGINT),
                                        tag(0, TAGBOX), tag(3, TAGINT)] +
                                        base + [0, 2])
 
@@ -911,7 +912,7 @@
     
     assert numb_state4.liveboxes == {b1: tag(0, TAGBOX), b2: tag(1, TAGBOX),
                                      b4: tag(0, TAGVIRTUAL)}
-    assert unpack_numbering(numb4) == [0, 2, tag(3, TAGINT), tag(0, TAGVIRTUAL),
+    assert unpack_numbering(numb4) == [16, 0, 2, tag(3, TAGINT), tag(0, TAGVIRTUAL),
                                        tag(0, TAGBOX), tag(3, TAGINT)] + base + [0, 2]
 
     t.append(0)
@@ -930,7 +931,7 @@
 
     assert numb_state5.liveboxes == {b1: tag(0, TAGBOX), b2: tag(1, TAGBOX),
                                      b4: tag(0, TAGVIRTUAL), b5: tag(1, TAGVIRTUAL)}
-    assert unpack_numbering(numb5) == [
+    assert unpack_numbering(numb5) == [21,
         3, tag(0, TAGBOX), tag(0, TAGVIRTUAL), tag(1, TAGVIRTUAL),
         0] + base + [
         2, 1, tag(3, TAGINT), tag(0, TAGVIRTUAL), tag(0, TAGBOX), tag(3, TAGINT)
@@ -949,15 +950,16 @@
     numb_state = memo.number(FakeOptimizer(), 0, i)
     numb = numb_state.create_numbering()
     l = unpack_numbering(numb)
-    assert l[0] == 0
+    assert l[0] == len(l)
     assert l[1] == 0
     assert l[2] == 0
     assert l[3] == 0
+    assert l[4] == 0
     mapping = dict(zip(inpargs, i.inputargs))
     for i, item in enumerate(lst):
-        v, tag = untag(l[i + 4])
+        v, tag = untag(l[i + 5])
         if tag == TAGBOX:
-            assert l[i + 4] == numb_state.liveboxes[mapping[item]]
+            assert l[i + 5] == numb_state.liveboxes[mapping[item]]
         elif tag == TAGCONST:
             assert memo.consts[v].getint() == item.getint()
         elif tag == TAGINT:
@@ -1228,7 +1230,7 @@
     liveboxes = []
     modifier._number_virtuals(liveboxes, FakeOptimizer(), 0)
     storage.rd_consts = memo.consts[:]
-    storage.rd_numb = None
+    storage.rd_numb = Numbering([])
     # resume
     b3t, b5t = [IntFrontendOp(0), RefFrontendOp(0)]
     b5t.setref_base(demo55o)
@@ -1299,7 +1301,7 @@
     modifier._number_virtuals(liveboxes, FakeOptimizer(), 0)
     dump_storage(storage, liveboxes)
     storage.rd_consts = memo.consts[:]
-    storage.rd_numb = None
+    storage.rd_numb = Numbering([])
     # resume
     b1t, b3t, b4t = [IntFrontendOp(0), IntFrontendOp(0), IntFrontendOp(0)]
     b1t.setint(11)
@@ -1352,7 +1354,7 @@
     modifier._number_virtuals(liveboxes, FakeOptimizer(), 0)
     dump_storage(storage, liveboxes)
     storage.rd_consts = memo.consts[:]
-    storage.rd_numb = None
+    storage.rd_numb = Numbering([])
     b4t = RefFrontendOp(0)
     newboxes = _resume_remap(liveboxes, [#b2s -- virtual
                                          b4s], b4t)
@@ -1398,7 +1400,7 @@
     modifier._add_pending_fields(FakeOptimizer(), [
         ResOperation(rop.SETFIELD_GC, [b2s, b4s], descr=LLtypeMixin.nextdescr)])
     storage.rd_consts = memo.consts[:]
-    storage.rd_numb = None
+    storage.rd_numb = Numbering([])
     # resume
     demo55.next = lltype.nullptr(LLtypeMixin.NODE)
     b2t = RefFrontendOp(0)


More information about the pypy-commit mailing list