[pypy-svn] r51225 - in pypy/branch/jit-refactoring/pypy/jit: rainbow rainbow/test timeshifter timeshifter/test

cfbolz at codespeak.net cfbolz at codespeak.net
Sun Feb 3 14:39:32 CET 2008


Author: cfbolz
Date: Sun Feb  3 14:39:31 2008
New Revision: 51225

Modified:
   pypy/branch/jit-refactoring/pypy/jit/rainbow/bytecode.py
   pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py
   pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_serializegraph.py
   pypy/branch/jit-refactoring/pypy/jit/timeshifter/rtimeshift.py
   pypy/branch/jit-refactoring/pypy/jit/timeshifter/test/test_timeshift.py
Log:
some support for direct red calls in the jit (nothing fancy, no exceptions).


Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/bytecode.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/bytecode.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/bytecode.py	Sun Feb  3 14:39:31 2008
@@ -3,6 +3,7 @@
 from pypy.objspace.flow import model as flowmodel
 from pypy.rpython.lltypesystem import lltype
 from pypy.jit.hintannotator.model import originalconcretetype
+from pypy.jit.hintannotator import model as hintmodel
 from pypy.jit.timeshifter import rtimeshift, rvalue
 from pypy.jit.timeshifter.greenkey import KeyDesc, empty_key, GreenKey
 
@@ -19,14 +20,17 @@
     green consts are negative indexes
     """
 
-    def __init__(self, code, constants, typekinds, redboxclasses, keydescs,
-                 num_mergepoints):
+    def __init__(self, name, code, constants, typekinds, redboxclasses,
+                 keydescs, called_bytecodes, num_mergepoints, is_portal):
+        self.name = name
         self.code = code
         self.constants = constants
         self.typekinds = typekinds
         self.redboxclasses = redboxclasses
         self.keydescs = keydescs
+        self.called_bytecodes = called_bytecodes
         self.num_mergepoints = num_mergepoints
+        self.is_portal = is_portal
 
     def _freeze_(self):
         return True
@@ -44,11 +48,10 @@
         self.opname_to_index = {}
         self.jitstate = None
         self.queue = None
-        self.bytecode = None
-        self.pc = -1
         self._add_implemented_opcodes()
 
-    def run(self, jitstate, bytecode, greenargs, redargs):
+    def run(self, jitstate, bytecode, greenargs, redargs,
+            start_bytecode_loop=True):
         self.jitstate = jitstate
         self.queue = rtimeshift.DispatchQueue(bytecode.num_mergepoints)
         rtimeshift.enter_frame(self.jitstate, self.queue)
@@ -57,7 +60,8 @@
         self.frame.bytecode = bytecode
         self.frame.local_boxes = redargs
         self.frame.local_green = greenargs
-        self.bytecode_loop()
+        if start_bytecode_loop:
+            self.bytecode_loop()
         return self.jitstate
 
     def bytecode_loop(self):
@@ -70,14 +74,16 @@
                 assert result is None
 
     def dispatch(self):
+        is_portal = self.frame.bytecode.is_portal
         newjitstate = rtimeshift.dispatch_next(self.queue)
         resumepoint = rtimeshift.getresumepoint(newjitstate)
         self.newjitstate(newjitstate)
         if resumepoint == -1:
-            # XXX what about green returns?
-            newjitstate = rtimeshift.leave_graph_red(self.queue, is_portal=True)
+            newjitstate = rtimeshift.leave_graph_red(
+                    self.queue, is_portal)
             self.newjitstate(newjitstate)
-            return STOP
+            if newjitstate is None or is_portal:
+                return STOP
         else:
             self.frame.pc = resumepoint
 
@@ -192,6 +198,30 @@
         if done:
             return self.dispatch()
 
+    def opimpl_red_direct_call(self):
+        greenargs = []
+        num = self.load_2byte()
+        for i in range(num):
+            greenargs.append(self.get_greenarg())
+        redargs = []
+        num = self.load_2byte()
+        for i in range(num):
+            redargs.append(self.get_redarg())
+        bytecodenum = self.load_2byte()
+        targetbytecode = self.frame.bytecode.called_bytecodes[bytecodenum]
+        self.run(self.jitstate, targetbytecode, greenargs, redargs,
+                 start_bytecode_loop=False)
+        # this frame will be resumed later in the next bytecode, which is
+        # red_after_direct_call
+
+    def opimpl_red_after_direct_call(self):
+        newjitstate = rtimeshift.collect_split(
+            self.jitstate, self.frame.pc,
+            self.frame.local_green)
+        assert newjitstate is self.jitstate
+
+
+    # ____________________________________________________________
     # construction-time interface
 
     def _add_implemented_opcodes(self):
@@ -244,21 +274,32 @@
 
 
 class BytecodeWriter(object):
-    def __init__(self, t, hintannotator, RGenOp):
+    def __init__(self, t, hannotator, RGenOp):
         self.translator = t
         self.annotator = t.annotator
-        self.hannotator = hintannotator
+        self.hannotator = hannotator
         self.interpreter = JitInterpreter()
         self.RGenOp = RGenOp
         self.current_block = None
-
-    def make_bytecode(self, graph):
+        self.raise_analyzer = hannotator.exceptiontransformer.raise_analyzer
+        self.all_graphs = {} # mapping graph to bytecode
+        self.unfinished_graphs = []
+
+    def can_raise(self, op):
+        return self.raise_analyzer.analyze(op)
+
+    def make_bytecode(self, graph, is_portal=True):
+        if is_portal:
+            self.all_graphs[graph] = JitCode.__new__(JitCode)
         self.seen_blocks = {}
         self.assembler = []
         self.constants = []
         self.typekinds = []
         self.redboxclasses = []
         self.keydescs = []
+        self.called_bytecodes = []
+        self.num_mergepoints = 0
+        self.is_portal = is_portal
         # mapping constant -> index in constants
         self.const_positions = {}
         # mapping blocks to True
@@ -273,19 +314,31 @@
         self.type_positions = {}
         # mapping tuple of green TYPES to index
         self.keydesc_positions = {}
-
-        self.num_mergepoints = 0
+        # mapping graphs to index
+        self.graph_positions = {}
 
         self.graph = graph
         self.entrymap = flowmodel.mkentrymap(graph)
         self.make_bytecode_block(graph.startblock)
         assert self.current_block is None
-        return JitCode(assemble(self.interpreter, *self.assembler),
-                       self.constants,
-                       self.typekinds,
-                       self.redboxclasses,
-                       self.keydescs,
-                       self.num_mergepoints)
+        bytecode = self.all_graphs[graph]
+        bytecode.__init__(graph.name,
+                          assemble(self.interpreter, *self.assembler),
+                          self.constants,
+                          self.typekinds,
+                          self.redboxclasses,
+                          self.keydescs,
+                          self.called_bytecodes,
+                          self.num_mergepoints,
+                          self.is_portal)
+        if is_portal:
+            self.finish_all_graphs()
+            return bytecode
+
+    def finish_all_graphs(self):
+        while self.unfinished_graphs:
+            graph = self.unfinished_graphs.pop()
+            self.make_bytecode(graph, is_portal=False)
 
     def make_bytecode_block(self, block, insert_goto=False):
         if block in self.seen_blocks:
@@ -482,6 +535,18 @@
         result = len(self.type_positions)
         self.type_positions[TYPE] = result
         return result
+
+    def graph_position(self, graph):
+        if graph in self.graph_positions:
+            return self.graph_positions[graph]
+        bytecode = JitCode.__new__(JitCode)
+        index = len(self.called_bytecodes)
+        self.called_bytecodes.append(bytecode)
+        self.all_graphs[graph] = bytecode
+        self.graph_positions[graph] = index
+        self.unfinished_graphs.append(graph)
+        return index
+
         
     def emit(self, stuff):
         assert stuff is not None
@@ -501,6 +566,7 @@
                 reds.append(v)
         return reds, greens
 
+    # ____________________________________________________________
     # operation special cases
 
     def serialize_op_hint(self, op):
@@ -522,6 +588,108 @@
             return
         XXX
 
+    def serialize_op_direct_call(self, op):
+        targets = dict(self.graphs_from(op))
+        assert len(targets) == 1
+        targetgraph, = targets.values()
+        kind, exc = self.guess_call_kind(op)
+        if kind == "red":
+            graphindex = self.graph_position(targetgraph)
+            args = targetgraph.getargs()
+            reds, greens = self.sort_by_color(op.args[1:], args)
+            result = []
+            for color, args in [("green", greens), ("red", reds)]:
+                result.append(len(args))
+                for v in args:
+                    result.append(self.serialize_oparg(color, v))
+            self.emit("red_direct_call")
+            for index in result:
+                self.emit(index)
+            self.emit(graphindex)
+            self.register_redvar(op.result)
+            self.emit("red_after_direct_call")
+        else:
+            XXX
+
+    def serialize_op_indirect_call(self, op):
+        XXX
+
+    # call handling
+
+    def graphs_from(self, spaceop):
+        if spaceop.opname == 'direct_call':
+            c_func = spaceop.args[0]
+            fnobj = c_func.value._obj
+            graphs = [fnobj.graph]
+            args_v = spaceop.args[1:]
+        elif spaceop.opname == 'indirect_call':
+            graphs = spaceop.args[-1].value
+            if graphs is None:
+                return       # cannot follow at all
+            args_v = spaceop.args[1:-1]
+        else:
+            raise AssertionError(spaceop.opname)
+        # if the graph - or all the called graphs - are marked as "don't
+        # follow", directly return None as a special case.  (This is only
+        # an optimization for the indirect_call case.)
+        for graph in graphs:
+            if self.hannotator.policy.look_inside_graph(graph):
+                break
+        else:
+            return
+        for graph in graphs:
+            tsgraph = self.specialized_graph_of(graph, args_v, spaceop.result)
+            yield graph, tsgraph
+
+    def guess_call_kind(self, spaceop):
+        if spaceop.opname == 'direct_call':
+            c_func = spaceop.args[0]
+            fnobj = c_func.value._obj
+            if hasattr(fnobj, 'jitcallkind'):
+                return fnobj.jitcallkind, None
+            if (hasattr(fnobj._callable, 'oopspec') and
+                self.hannotator.policy.oopspec):
+                if fnobj._callable.oopspec.startswith('vable.'):
+                    return 'vable', None
+                hs_result = self.hannotator.binding(spaceop.result)
+                if (hs_result.is_green() and
+                    hs_result.concretetype is not lltype.Void):
+                    return 'green', self.can_raise(spaceop)
+                return 'oopspec', self.can_raise(spaceop)
+        if self.hannotator.bookkeeper.is_green_call(spaceop):
+            return 'green', None
+        withexc = self.can_raise(spaceop)
+        colors = {}
+        for graph, tsgraph in self.graphs_from(spaceop):
+            color = self.graph_calling_color(tsgraph)
+            colors[color] = tsgraph
+        if not colors: # cannot follow this call
+            return 'residual', withexc
+        assert len(colors) == 1, colors   # buggy normalization?
+        return color, withexc
+
+    def specialized_graph_of(self, graph, args_v, v_result):
+        bk = self.hannotator.bookkeeper
+        args_hs = [self.hannotator.binding(v) for v in args_v]
+        hs_result = self.hannotator.binding(v_result)
+        if isinstance(hs_result, hintmodel.SomeLLAbstractConstant):
+            fixed = hs_result.is_fixed()
+        else:
+            fixed = False
+        specialization_key = bk.specialization_key(fixed, args_hs)
+        special_graph = bk.get_graph_by_key(graph, specialization_key)
+        return special_graph
+
+    def graph_calling_color(self, graph):
+        hs_res = self.hannotator.binding(graph.getreturnvar())
+        if originalconcretetype(hs_res) is lltype.Void:
+            c = 'gray'
+        elif hs_res.is_green():
+            c = 'yellow'
+        else:
+            c = 'red'
+        return c
+
 
 class label(object):
     def __init__(self, name):

Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py	Sun Feb  3 14:39:31 2008
@@ -405,6 +405,91 @@
         assert res == 42
         self.check_insns({'int_add': 2, 'int_sub': 1})
 
+    def test_call_simple(self):
+        def ll_add_one(x):
+            return x + 1
+        def ll_function(y):
+            return ll_add_one(y)
+        res = self.interpret(ll_function, [5], [])
+        assert res == 6
+        self.check_insns({'int_add': 1})
+
+    def test_call_2(self):
+        def ll_add_one(x):
+            return x + 1
+        def ll_function(y):
+            return ll_add_one(y) + y
+        res = self.interpret(ll_function, [5], [])
+        assert res == 11
+        self.check_insns({'int_add': 2})
+
+    def test_call_3(self):
+        def ll_add_one(x):
+            return x + 1
+        def ll_two(x):
+            return ll_add_one(ll_add_one(x)) - x
+        def ll_function(y):
+            return ll_two(y) * y
+        res = self.interpret(ll_function, [5], [])
+        assert res == 10
+        self.check_insns({'int_add': 2, 'int_sub': 1, 'int_mul': 1})
+
+    def test_call_4(self):
+        def ll_two(x):
+            if x > 0:
+                return x + 5
+            else:
+                return x - 4
+        def ll_function(y):
+            return ll_two(y) * y
+
+        res = self.interpret(ll_function, [3], [])
+        assert res == 24
+        self.check_insns({'int_gt': 1, 'int_add': 1,
+                          'int_sub': 1, 'int_mul': 1})
+
+        res = self.interpret(ll_function, [-3], [])
+        assert res == 21
+        self.check_insns({'int_gt': 1, 'int_add': 1,
+                          'int_sub': 1, 'int_mul': 1})
+
+    def test_void_call(self):
+        py.test.skip("calls are WIP")
+        def ll_do_nothing(x):
+            pass
+        def ll_function(y):
+            ll_do_nothing(y)
+            return y
+
+        res = self.interpret(ll_function, [3], [])
+        assert res == 3
+
+    def test_green_call(self):
+        py.test.skip("calls are WIP")
+        def ll_add_one(x):
+            return x+1
+        def ll_function(y):
+            z = ll_add_one(y)
+            z = hint(z, concrete=True)
+            return hint(z, variable=True)
+
+        res = self.interpret(ll_function, [3], [0])
+        assert res == 4
+        self.check_insns({})
+
+    def test_split_on_green_return(self):
+        py.test.skip("calls are WIP")
+        def ll_two(x):
+            if x > 0:
+                return 17
+            else:
+                return 22
+        def ll_function(x):
+            n = ll_two(x)
+            return hint(n+1, variable=True)
+        res = self.interpret(ll_function, [-70], [])
+        assert res == 23
+        self.check_insns({'int_gt': 1})
 
 class TestLLType(SimpleTests):
     type_system = "lltype"

Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_serializegraph.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_serializegraph.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_serializegraph.py	Sun Feb  3 14:39:31 2008
@@ -64,6 +64,8 @@
                                         "red_return", 0)
         assert len(jitcode.constants) == 0
         assert len(jitcode.typekinds) == 0
+        assert jitcode.is_portal
+        assert len(jitcode.called_bytecodes) == 0
 
     def test_constant(self):
         def f(x):
@@ -78,6 +80,8 @@
         assert len(jitcode.constants) == 1
         assert len(jitcode.typekinds) == 1
         assert len(jitcode.redboxclasses) == 1
+        assert jitcode.is_portal
+        assert len(jitcode.called_bytecodes) == 0
  
     def test_green_switch(self):
         def f(x, y, z):
@@ -102,6 +106,8 @@
         assert jitcode.code == expected
         assert len(jitcode.constants) == 0
         assert len(jitcode.typekinds) == 0
+        assert jitcode.is_portal
+        assert len(jitcode.called_bytecodes) == 0
 
     def test_green_switch2(self):
         def f(x, y, z):
@@ -134,6 +140,8 @@
         assert jitcode.code == expected
         assert len(jitcode.constants) == 0
         assert len(jitcode.typekinds) == 0
+        assert jitcode.is_portal
+        assert len(jitcode.called_bytecodes) == 0
 
     def test_merge(self):
         def f(x, y, z):
@@ -169,6 +177,8 @@
         assert jitcode.code == expected
         assert len(jitcode.constants) == 1
         assert len(jitcode.typekinds) == 1
+        assert jitcode.is_portal
+        assert len(jitcode.called_bytecodes) == 0
 
     def test_loop(self):
         def f(x):
@@ -197,6 +207,37 @@
                             "red_int_sub", 0, 3,
                             "make_new_redvars", 2, 2, 4,
                             "goto", tlabel("while"))
+        assert jitcode.is_portal
+        assert len(jitcode.called_bytecodes) == 0
+
+    def test_call(self):
+        def g(x):
+            return x + 1
+        def f(x):
+            return g(x) * 2
+        writer, jitcode = self.serialize(f, [int])
+        assert jitcode.code == assemble(writer.interpreter,
+                                        "red_direct_call", 0, 1, 0, 0,
+                                        "red_after_direct_call",
+                                        "make_redbox", -1, 0,
+                                        "red_int_mul", 1, 2,
+                                        "make_new_redvars", 1, 3,
+                                        "make_new_greenvars", 0,
+                                        "red_return", 0)
+        assert jitcode.is_portal
+        assert len(jitcode.called_bytecodes) == 1
+        called_jitcode = jitcode.called_bytecodes[0]
+        assert called_jitcode.code == assemble(writer.interpreter,
+                                               "make_redbox", -1, 0,
+                                               "red_int_add", 0, 1,
+                                               "make_new_redvars", 1, 2,
+                                               "make_new_greenvars", 0,
+                                               "red_return", 0)
+        assert not called_jitcode.is_portal
+        assert len(called_jitcode.called_bytecodes) == 0
+
+
+
 
 
 class TestLLType(AbstractSerializationTest):

Modified: pypy/branch/jit-refactoring/pypy/jit/timeshifter/rtimeshift.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/timeshifter/rtimeshift.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/timeshifter/rtimeshift.py	Sun Feb  3 14:39:31 2008
@@ -453,11 +453,10 @@
     if gotexc:
         jitstate.residual_ll_exception(ll_evalue)
 
-def collect_split(jitstate_chain, resumepoint, *greens_gv):
+def collect_split(jitstate_chain, resumepoint, greens_gv):
     # YYY split to avoid over-specialization
     # assumes that the head of the jitstate_chain is ready for writing,
     # and all the other jitstates in the chain are paused
-    greens_gv = list(greens_gv)
     pending = jitstate_chain
     resuming = jitstate_chain.get_resuming()
     if resuming is not None and resuming.mergesleft == 0:
@@ -467,7 +466,7 @@
             pending = pending.next
         pending.greens.extend(greens_gv)
         if pending.returnbox is not None:
-            pending.frame.local_boxes.insert(0, getreturnbox(pending))
+            pending.frame.local_boxes.append(getreturnbox(pending))
         pending.next = None
         start_writing(pending, jitstate_chain)
         return pending
@@ -478,7 +477,7 @@
         pending = pending.next
         jitstate.greens.extend(greens_gv)   # item 0 is the return value
         if jitstate.returnbox is not None:
-            jitstate.frame.local_boxes.insert(0, getreturnbox(jitstate))
+            jitstate.frame.local_boxes.append(getreturnbox(jitstate))
         jitstate.resumepoint = resumepoint
         if resuming is None:
             node = jitstate.promotion_path

Modified: pypy/branch/jit-refactoring/pypy/jit/timeshifter/test/test_timeshift.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/timeshifter/test/test_timeshift.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/timeshifter/test/test_timeshift.py	Sun Feb  3 14:39:31 2008
@@ -762,88 +762,6 @@
         assert res == 1 + 2
         self.check_insns({'int_is_true': 1, 'int_add': 1})
 
-    def test_call_simple(self):
-        def ll_add_one(x):
-            return x + 1
-        def ll_function(y):
-            return ll_add_one(y)
-        res = self.timeshift(ll_function, [5], [], policy=P_NOVIRTUAL)
-        assert res == 6
-        self.check_insns({'int_add': 1})
-
-    def test_call_2(self):
-        def ll_add_one(x):
-            return x + 1
-        def ll_function(y):
-            return ll_add_one(y) + y
-        res = self.timeshift(ll_function, [5], [], policy=P_NOVIRTUAL)
-        assert res == 11
-        self.check_insns({'int_add': 2})
-
-    def test_call_3(self):
-        def ll_add_one(x):
-            return x + 1
-        def ll_two(x):
-            return ll_add_one(ll_add_one(x)) - x
-        def ll_function(y):
-            return ll_two(y) * y
-        res = self.timeshift(ll_function, [5], [], policy=P_NOVIRTUAL)
-        assert res == 10
-        self.check_insns({'int_add': 2, 'int_sub': 1, 'int_mul': 1})
-
-    def test_call_4(self):
-        def ll_two(x):
-            if x > 0:
-                return x + 5
-            else:
-                return x - 4
-        def ll_function(y):
-            return ll_two(y) * y
-
-        res = self.timeshift(ll_function, [3], [], policy=P_NOVIRTUAL)
-        assert res == 24
-        self.check_insns({'int_gt': 1, 'int_add': 1,
-                          'int_sub': 1, 'int_mul': 1})
-
-        res = self.timeshift(ll_function, [-3], [], policy=P_NOVIRTUAL)
-        assert res == 21
-        self.check_insns({'int_gt': 1, 'int_add': 1,
-                          'int_sub': 1, 'int_mul': 1})
-
-    def test_void_call(self):
-        def ll_do_nothing(x):
-            pass
-        def ll_function(y):
-            ll_do_nothing(y)
-            return y
-
-        res = self.timeshift(ll_function, [3], [], policy=P_NOVIRTUAL)
-        assert res == 3
-
-    def test_green_call(self):
-        def ll_add_one(x):
-            return x+1
-        def ll_function(y):
-            z = ll_add_one(y)
-            z = hint(z, concrete=True)
-            return hint(z, variable=True)
-
-        res = self.timeshift(ll_function, [3], [0], policy=P_NOVIRTUAL)
-        assert res == 4
-        self.check_insns({})
-
-    def test_split_on_green_return(self):
-        def ll_two(x):
-            if x > 0:
-                return 17
-            else:
-                return 22
-        def ll_function(x):
-            n = ll_two(x)
-            return hint(n+1, variable=True)
-        res = self.timeshift(ll_function, [-70], [])
-        assert res == 23
-        self.check_insns({'int_gt': 1})
 
     def test_green_with_side_effects(self):
         S = lltype.GcStruct('S', ('flag', lltype.Bool))



More information about the Pypy-commit mailing list