[pypy-svn] r51861 - in pypy/branch/jit-refactoring/pypy/jit/rainbow: . test

cfbolz at codespeak.net cfbolz at codespeak.net
Mon Feb 25 21:24:46 CET 2008


Author: cfbolz
Date: Mon Feb 25 21:24:44 2008
New Revision: 51861

Modified:
   pypy/branch/jit-refactoring/pypy/jit/rainbow/codewriter.py
   pypy/branch/jit-refactoring/pypy/jit/rainbow/dump.py
   pypy/branch/jit-refactoring/pypy/jit/rainbow/interpreter.py
   pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py
Log:
first indirect_call test passes


Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/codewriter.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/codewriter.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/codewriter.py	Mon Feb 25 21:24:44 2008
@@ -2,7 +2,7 @@
 from pypy.rlib.objectmodel import we_are_translated
 from pypy.objspace.flow import model as flowmodel
 from pypy.rpython.annlowlevel import cachedtype
-from pypy.rpython.lltypesystem import lltype
+from pypy.rpython.lltypesystem import lltype, llmemory
 from pypy.jit.hintannotator.model import originalconcretetype
 from pypy.jit.hintannotator import model as hintmodel
 from pypy.jit.timeshifter import rtimeshift, rvalue, rcontainer, exception
@@ -54,6 +54,30 @@
     def _freeze_(self):
         return True
 
+class IndirectCallsetDesc(object):
+    __metaclass__ = cachedtype
+    
+    def __init__(self, graph2tsgraph, codewriter):
+
+        keys = []
+        values = []
+        common_args_r = None
+        for graph, tsgraph in graph2tsgraph:
+            fnptr    = codewriter.rtyper.getcallable(graph)
+            keys.append(llmemory.cast_ptr_to_adr(fnptr))
+            values.append(codewriter.get_jitcode(tsgraph))
+
+        def bytecode_for_address(fnaddress):
+            # XXX optimize
+            for i in range(len(keys)):
+                if keys[i] == fnaddress:
+                    return values[i]
+
+        self.bytecode_for_address = bytecode_for_address
+
+        self.graphs = [graph for (graph, tsgraph) in graph2tsgraph]
+        self.jitcodes = values
+
 
 class BytecodeWriter(object):
     def __init__(self, t, hannotator, RGenOp):
@@ -71,6 +95,7 @@
         self.all_graphs = {} # mapping graph to bytecode
         self.unfinished_graphs = []
         self.num_global_mergepoints = 0
+        self.ptr_to_jitcode = {}
 
     def can_raise(self, op):
         return self.raise_analyzer.analyze(op)
@@ -95,6 +120,7 @@
         self.num_local_mergepoints = 0
         self.graph_color = self.graph_calling_color(graph)
         self.calldescs = []
+        self.indirectcalldescs = []
         self.is_portal = is_portal
         # mapping constant -> index in constants
         self.const_positions = {}
@@ -126,6 +152,8 @@
         self.graph_positions = {}
         # mapping fnobjs to index
         self.calldesc_positions = {}
+        # mapping fnobjs to index
+        self.indirectcalldesc_positions = {}
 
         self.graph = graph
         self.mergepoint_set = {}
@@ -152,6 +180,7 @@
                           self.num_local_mergepoints,
                           self.graph_color,
                           self.calldescs,
+                          self.indirectcalldescs,
                           self.is_portal)
         bytecode._source = self.assembler
         bytecode._interpreter = self.interpreter
@@ -163,6 +192,14 @@
                 self.num_global_mergepoints)
             return bytecode
 
+    def get_jitcode(self, graph):
+        if graph in self.all_graphs:
+            return self.all_graphs[graph]
+        bytecode = JitCode.__new__(JitCode)
+        self.all_graphs[graph] = bytecode
+        self.unfinished_graphs.append(graph)
+        return bytecode
+
     def finish_all_graphs(self):
         while self.unfinished_graphs:
             graph = self.unfinished_graphs.pop()
@@ -487,12 +524,7 @@
     def graph_position(self, graph):
         if graph in self.graph_positions:
             return self.graph_positions[graph]
-        if graph in self.all_graphs:
-            bytecode = self.all_graphs[graph]
-        else:
-            bytecode = JitCode.__new__(JitCode)
-            self.all_graphs[graph] = bytecode
-            self.unfinished_graphs.append(graph)
+        bytecode = self.get_jitcode(graph)
         index = len(self.called_bytecodes)
         self.called_bytecodes.append(bytecode)
         self.graph_positions[graph] = index
@@ -508,6 +540,27 @@
         self.calldesc_positions[key] = result
         return result
 
+    def indirectcalldesc_position(self, graph2code):
+        key = graph2code.items()
+        key.sort()
+        key = tuple(key)
+        if key in self.indirectcalldesc_positions:
+            return self.indirectcalldesc_positions[key]
+        callset = IndirectCallsetDesc(key, self)
+        for i in range(len(key) + 1, 0, -1):
+            subkey = key[:i]
+            if subkey in self.indirectcalldesc_positions:
+                result = self.indirectcalldesc_positions[subkey]
+                self.indirectcalldescs[result] = callset
+                break
+        else:
+            result = len(self.indirectcalldescs)
+            self.indirectcalldescs.append(callset)
+        for i in range(len(key) + 1, 0, -1):
+            subkey = key[:i]
+            self.indirectcalldesc_positions[subkey] = result
+        return result
+
     def interiordesc(self, op, PTRTYPE, nb_offsets):
         path = []
         CONTAINER = PTRTYPE.TO
@@ -620,6 +673,41 @@
         print op, kind, withexc
         return handler(op, withexc)
 
+    def serialize_op_indirect_call(self, op):
+        kind, withexc = self.guess_call_kind(op)
+        if kind == "red":
+            XXX
+        if kind == "yellow":
+            targets = dict(self.graphs_from(op))
+            fnptrindex = self.serialize_oparg("red", op.args[0])
+            self.emit("goto_if_constant", fnptrindex, tlabel(("direct call", op)))
+            emitted_args = []
+            for v in op.args[1:-1]:
+                emitted_args.append(self.serialize_oparg("red", v))
+            self.emit("red_residual_call")
+            calldescindex = self.calldesc_position(op.args[0].concretetype)
+            self.emit(fnptrindex, calldescindex, withexc)
+            self.emit(len(emitted_args), *emitted_args)
+            self.emit(self.promotiondesc_position(lltype.Signed))
+            self.emit("goto", tlabel(("after indirect call", op)))
+
+            self.emit(label(("direct call", op)))
+            args = targets.values()[0].getargs()
+            emitted_args = self.args_of_call(op.args[1:-1], args)
+            self.emit("indirect_call_const")
+            self.emit(*emitted_args)
+            setdescindex = self.indirectcalldesc_position(targets)
+            self.emit(fnptrindex, setdescindex)
+            self.emit("yellow_after_direct_call")
+            self.emit("yellow_retrieve_result_as_red")
+            self.emit(self.type_position(op.result.concretetype))
+             
+
+            self.emit(label(("after indirect call", op)))
+            self.register_redvar(op.result)
+            return
+        XXX
+
     def handle_oopspec_call(self, op, withexc):
         from pypy.jit.timeshifter.oop import Index
         fnobj = op.args[0].value._obj
@@ -683,13 +771,8 @@
             emitted_args.append(self.serialize_oparg("red", v))
         self.emit("red_residual_direct_call")
         self.emit(func, pos, withexc, len(emitted_args), *emitted_args)
-        self.register_redvar(op.result)
-        pos = self.register_redvar(("residual_flags_red", op.args[0]))
-        self.emit("promote")
-        self.emit(pos)
         self.emit(self.promotiondesc_position(lltype.Signed))
-        self.register_greenvar(("residual_flags_green", op.args[0]), check=False)
-        self.emit("residual_fetch", True, pos)
+        self.register_redvar(op.result)
 
     def handle_rpyexc_raise_call(self, op, withexc):
         emitted_args = []
@@ -728,8 +811,6 @@
         self.emit("yellow_retrieve_result")
         self.register_greenvar(op.result)
 
-    def serialize_op_indirect_call(self, op):
-        XXX
 
     def serialize_op_malloc(self, op):
         index = self.structtypedesc_position(op.args[0].value)

Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/dump.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/dump.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/dump.py	Mon Feb 25 21:24:44 2008
@@ -129,6 +129,10 @@
                     index = src.load_2byte()
                     function = jitcode.calldescs[index]
                     args.append(function)
+                elif argspec == "indirectcalldesc":
+                    index = src.load_2byte()
+                    function = jitcode.indirectcalldescs[index]
+                    args.append(function)
                 elif argspec == "oopspec":
                     oopspecindex = src.load_2byte()
                     oopspec = jitcode.oopspecdescs[oopspecindex]

Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/interpreter.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/interpreter.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/interpreter.py	Mon Feb 25 21:24:44 2008
@@ -2,7 +2,7 @@
 from pypy.rlib.unroll import unrolling_iterable
 from pypy.jit.timeshifter import rtimeshift, rcontainer
 from pypy.jit.timeshifter.greenkey import empty_key, GreenKey, newgreendict
-from pypy.rpython.lltypesystem import lltype
+from pypy.rpython.lltypesystem import lltype, llmemory
 
 class JitCode(object):
     """
@@ -21,7 +21,7 @@
                  keydescs, structtypedescs, fielddescs, arrayfielddescs,
                  interiordescs, oopspecdescs, promotiondescs,
                  called_bytecodes, num_mergepoints,
-                 graph_color, calldescs, is_portal):
+                 graph_color, calldescs, indirectcalldescs, is_portal):
         self.name = name
         self.code = code
         self.constants = constants
@@ -38,6 +38,7 @@
         self.num_mergepoints = num_mergepoints
         self.graph_color = graph_color
         self.calldescs = calldescs
+        self.indirectcalldescs = indirectcalldescs
         self.is_portal = is_portal
 
     def _freeze_(self):
@@ -114,6 +115,10 @@
                     index = self.load_2byte()
                     function = self.frame.bytecode.calldescs[index]
                     args += (function, )
+                elif argspec == "indirectcalldesc":
+                    index = self.load_2byte()
+                    function = self.frame.bytecode.indirectcalldescs[index]
+                    args += (function, )
                 elif argspec == "oopspec":
                     oopspecindex = self.load_2byte()
                     oopspec = self.frame.bytecode.oopspecdescs[oopspecindex]
@@ -367,6 +372,11 @@
         if descision:
             self.frame.pc = target
 
+    @arguments("red", "jumptarget")
+    def opimpl_goto_if_constant(self, valuebox, target):
+        if valuebox.is_constant():
+            self.frame.pc = target
+
     @arguments("red", returns="red")
     def opimpl_red_ptr_nonzero(self, ptrbox):
         return rtimeshift.genptrnonzero(self.jitstate, ptrbox, False)
@@ -480,6 +490,15 @@
         # this frame will be resumed later in the next bytecode, which is
         # yellow_after_direct_call
 
+    @arguments("green_varargs", "red_varargs", "red", "indirectcalldesc")
+    def opimpl_indirect_call_const(self, greenargs, redargs,
+                                      funcptrbox, callset):
+        gv = funcptrbox.getgenvar(self.jitstate)
+        addr = gv.revealconst(llmemory.Address)
+        bytecode = callset.bytecode_for_address(addr)
+        self.run(self.jitstate, bytecode, greenargs, redargs,
+                 start_bytecode_loop=False)
+
     @arguments()
     def opimpl_yellow_after_direct_call(self):
         newjitstate = rtimeshift.collect_split(
@@ -492,6 +511,13 @@
         # XXX all this jitstate.greens business is a bit messy
         return self.jitstate.greens[0]
 
+    @arguments("2byte", returns="red")
+    def opimpl_yellow_retrieve_result_as_red(self, typeid):
+        # XXX all this jitstate.greens business is a bit messy
+        redboxcls = self.frame.bytecode.redboxclasses[typeid]
+        kind = self.frame.bytecode.typekinds[typeid]
+        return redboxcls(kind, self.jitstate.greens[0])
+
     @arguments("oopspec", "bool", returns="red")
     def opimpl_red_oopspec_call_0(self, oopspec, deepfrozen):
         return oopspec.ll_handler(self.jitstate, oopspec, deepfrozen)
@@ -508,8 +534,9 @@
     def opimpl_red_oopspec_call_3(self, oopspec, deepfrozen, arg1, arg2, arg3):
         return oopspec.ll_handler(self.jitstate, oopspec, deepfrozen, arg1, arg2, arg3)
 
-    @arguments("red", "calldesc", "bool", "red_varargs")
-    def opimpl_red_residual_direct_call(self, funcbox, calldesc, withexc, redargs):
+    @arguments("red", "calldesc", "bool", "red_varargs", "promotiondesc")
+    def opimpl_red_residual_call(self, funcbox, calldesc, withexc,
+                                        redargs, promotiondesc):
         result = rtimeshift.gen_residual_call(self.jitstate, calldesc,
                                               funcbox, redargs)
         self.red_result(result)
@@ -519,13 +546,13 @@
             exceptiondesc = None
         flagbox = rtimeshift.after_residual_call(self.jitstate,
                                                  exceptiondesc, True)
-        self.red_result(flagbox)
-
-    @arguments("bool", "red")
-    def opimpl_residual_fetch(self, check_forced, flagbox):
+        done = rtimeshift.promote(self.jitstate, flagbox, promotiondesc)
+        if done:
+            return self.dispatch()
+        gv_flag = flagbox.getgenvar(self.jitstate)
+        assert gv_flag.is_const
         rtimeshift.residual_fetch(self.jitstate, self.exceptiondesc,
-                                  check_forced, flagbox)
-
+                                  True, flagbox)
 
     # exceptions
 

Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py	Mon Feb 25 21:24:44 2008
@@ -1040,7 +1040,6 @@
         assert res == 212
 
     def test_simple_meth(self):
-        py.test.skip("needs promote")
         class Base(object):
             def m(self):
                 raise NotImplementedError



More information about the Pypy-commit mailing list