[pypy-svn] r52474 - in pypy/branch/jit-hotpath/pypy/jit/rainbow: . test

arigo at codespeak.net arigo at codespeak.net
Fri Mar 14 09:23:05 CET 2008


Author: arigo
Date: Fri Mar 14 09:23:04 2008
New Revision: 52474

Modified:
   pypy/branch/jit-hotpath/pypy/jit/rainbow/codewriter.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/fallback.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/interpreter.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/rhotpath.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hot_promotion.py
   pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py
Log:
* Support for promotion.
* Rename the rewriter to HotRunnerDesc and try to make it the
  central place that holds references to all other objects
  (interpreter, fallbackinterp, etc.)


Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/codewriter.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/codewriter.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/codewriter.py	Fri Mar 14 09:23:04 2008
@@ -686,7 +686,11 @@
         ERASED = self.RGenOp.erasedType(TYPE)
         if ERASED in self.promotiondesc_positions:
             return self.promotiondesc_positions[ERASED]
-        promotiondesc = rtimeshift.PromotionDesc(ERASED, self.interpreter)
+        if self.hannotator.policy.hotpath:
+            from pypy.jit.rainbow.rhotpath import HotPromotionDesc
+            promotiondesc = HotPromotionDesc(ERASED, self.RGenOp)
+        else:
+            promotiondesc = rtimeshift.PromotionDesc(ERASED, self.interpreter)
         result = len(self.promotiondescs)
         self.promotiondescs.append(promotiondesc)
         self.promotiondesc_positions[ERASED] = result
@@ -837,7 +841,10 @@
         if self.varcolor(arg) == "green":
             self.register_greenvar(result, self.green_position(arg))
             return
-        self.emit("promote")
+        if self.hannotator.policy.hotpath:
+            self.emit("hp_promote")
+        else:
+            self.emit("promote")
         self.emit(self.serialize_oparg("red", arg))
         self.emit(self.promotiondesc_position(arg.concretetype))
         self.register_greenvar(result)

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/fallback.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/fallback.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/fallback.py	Fri Mar 14 09:23:04 2008
@@ -12,31 +12,24 @@
     actual values for the live red vars, and interprets the jitcode
     normally until it reaches the 'jit_merge_point' or raises.
     """
-    def __init__(self, interpreter, exceptiondesc,
-                 DoneWithThisFrame, ContinueRunningNormally):
-        self.interpreter = interpreter
-        self.rgenop = interpreter.rgenop
-        self.exceptiondesc = exceptiondesc
-        self.DoneWithThisFrame = DoneWithThisFrame
-        self.ContinueRunningNormally = ContinueRunningNormally
-        self.register_opcode_impls(interpreter)
+    def __init__(self, hotrunnerdesc):
+        self.hotrunnerdesc = hotrunnerdesc
+        self.interpreter = hotrunnerdesc.interpreter
+        self.rgenop = self.interpreter.rgenop
+        self.exceptiondesc = hotrunnerdesc.exceptiondesc
+        self.register_opcode_impls(self.interpreter)
 
-    def run(self, fallback_point, framebase, pc):
+    def initialize_state(self, fallback_point, framebase):
         self.interpreter.debug_trace("fallback_interp")
-        self.fbp = fallback_point
-        self.framebase = framebase
-        self.initialize_state(pc)
-        self.bytecode_loop()
-
-    def initialize_state(self, pc):
-        jitstate = self.fbp.saved_jitstate
+        jitstate = fallback_point.saved_jitstate
         incoming_gv = jitstate.get_locals_gv()
+        self.framebase = framebase
+        self.frameinfo = fallback_point.frameinfo
         self.gv_to_index = {}
         for i in range(len(incoming_gv)):
             self.gv_to_index[incoming_gv[i]] = i
 
         self.initialize_from_frame(jitstate.frame)
-        self.pc = pc
         self.gv_exc_type  = self.getinitialboxgv(jitstate.exc_type_box)
         self.gv_exc_value = self.getinitialboxgv(jitstate.exc_value_box)
         self.seen_can_enter_jit = False
@@ -47,7 +40,7 @@
         if not gv.is_const:
             # fetch the value from the machine code stack
             gv = self.rgenop.genconst_from_frame_var(box.kind, self.framebase,
-                                                     self.fbp.frameinfo,
+                                                     self.frameinfo,
                                                      self.gv_to_index[gv])
         return gv
 
@@ -85,7 +78,9 @@
             self.interpreter.debug_trace("fb_raise", type_name(lltype))
             raise LLException(lltype, llvalue)
         else:
-            raise self.DoneWithThisFrame(gv_result)
+            self.interpreter.debug_trace("fb_return", gv_result)
+            DoneWithThisFrame = self.hotrunnerdesc.DoneWithThisFrame
+            raise DoneWithThisFrame(gv_result)
 
     # ____________________________________________________________
     # XXX Lots of copy and paste from interp.py!
@@ -254,8 +249,9 @@
     opimpl_make_new_greenvars.argspec = arguments("green_varargs")
 
     @arguments("green", "calldesc", "green_varargs")
-    def opimpl_green_call(self, fnptr_gv, calldesc, greenargs):
-        xxx
+    def opimpl_green_call(self, gv_fnptr, calldesc, greenargs):
+        gv_res = calldesc.perform_call(self.rgenop, gv_fnptr, greenargs)
+        self.green_result(gv_res)
 
     @arguments("green_varargs", "red_varargs", "red", "indirectcalldesc")
     def opimpl_indirect_call_const(self, greenargs, redargs,
@@ -410,8 +406,9 @@
 
     @arguments("greenkey")
     def opimpl_jit_merge_point(self, key):
-        raise self.ContinueRunningNormally(self.local_green + self.local_red,
-                                           self.seen_can_enter_jit)
+        ContinueRunningNormally = self.hotrunnerdesc.ContinueRunningNormally
+        raise ContinueRunningNormally(self.local_green + self.local_red,
+                                      self.seen_can_enter_jit)
 
     @arguments()
     def opimpl_can_enter_jit(self):
@@ -422,6 +419,10 @@
         if gv_switch.revealconst(lltype.Bool):
             self.pc = target
 
+    @arguments("red", "promotiondesc")
+    def opimpl_hp_promote(self, gv_promote, promotiondesc):
+        xxx
+
     @arguments("green_varargs", "red_varargs", "bytecode")
     def opimpl_hp_yellow_direct_call(self, greenargs, redargs, targetbytecode):
         gv_res = self.run_directly(greenargs, redargs, targetbytecode)

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/hotpath.py	Fri Mar 14 09:23:04 2008
@@ -14,7 +14,7 @@
 from pypy.jit.rainbow.codewriter import maybe_on_top_of_llinterp
 
 
-class EntryPointsRewriter:
+class HotRunnerDesc:
 
     def __init__(self, hintannotator, rtyper, entryjitcode, RGenOp,
                  codewriter, threshold, translate_support_code = True):
@@ -22,6 +22,7 @@
         self.entryjitcode = entryjitcode
         self.rtyper = rtyper
         self.RGenOp = RGenOp
+        self.exceptiondesc = codewriter.exceptiondesc
         self.interpreter = codewriter.interpreter
         self.codewriter = codewriter
         self.threshold = threshold
@@ -34,7 +35,9 @@
         self.make_args_specification()
         self.make_enter_function()
         self.rewrite_graphs()
-        self.update_interp()
+        self.make_descs()
+        self.fallbackinterp = FallbackInterpreter(self)
+        self.interpreter.hotrunnerdesc = self
 
     def make_args_specification(self):
         origportalgraph = self.hintannotator.portalgraph
@@ -62,7 +65,7 @@
     def make_enter_function(self):
         HotEnterState = make_state_class(self)
         state = HotEnterState()
-        exceptiondesc = self.codewriter.exceptiondesc
+        exceptiondesc = self.exceptiondesc
         interpreter = self.interpreter
         num_green_args = len(self.green_args_spec)
 
@@ -87,15 +90,10 @@
         maybe_enter_jit._always_inline_ = True
         self.maybe_enter_jit_fn = maybe_enter_jit
 
-    def update_interp(self):
-        self.fallbackinterp = FallbackInterpreter(
-            self.interpreter,
-            self.codewriter.exceptiondesc,
-            self.DoneWithThisFrame,
-            self.ContinueRunningNormally)
+    def make_descs(self):
         ERASED = self.RGenOp.erasedType(lltype.Bool)
-        self.interpreter.bool_hotpromotiondesc = rhotpath.HotPromotionDesc(
-            ERASED, self.interpreter, self.threshold, self.fallbackinterp)
+        self.bool_hotpromotiondesc = rhotpath.HotPromotionDesc(ERASED,
+                                                               self.RGenOp)
 
     def rewrite_graphs(self):
         for graph in self.hintannotator.base_translator.graphs:
@@ -209,6 +207,7 @@
                         ', '.join(map(str, self.args)),)
 
             self.DoneWithThisFrame = DoneWithThisFrame
+            self.DoneWithThisFrameARG = RES
             self.ContinueRunningNormally = ContinueRunningNormally
 
             def portal_runner(*args):
@@ -252,12 +251,12 @@
         return True
 
 
-def make_state_class(rewriter):
+def make_state_class(hotrunnerdesc):
     # very minimal, just to make the first test pass
-    green_args_spec = unrolling_iterable(rewriter.green_args_spec)
-    red_args_spec = unrolling_iterable(rewriter.red_args_spec)
-    if rewriter.green_args_spec:
-        keydesc = KeyDesc(rewriter.RGenOp, *rewriter.green_args_spec)
+    green_args_spec = unrolling_iterable(hotrunnerdesc.green_args_spec)
+    red_args_spec = unrolling_iterable(hotrunnerdesc.red_args_spec)
+    if hotrunnerdesc.green_args_spec:
+        keydesc = KeyDesc(hotrunnerdesc.RGenOp, *hotrunnerdesc.green_args_spec)
     else:
         keydesc = None
 
@@ -279,7 +278,7 @@
         def getkey(self, *greenvalues):
             if keydesc is None:
                 return empty_key
-            rgenop = rewriter.interpreter.rgenop
+            rgenop = hotrunnerdesc.interpreter.rgenop
             lst_gv = [None] * len(greenvalues)
             i = 0
             for _ in green_args_spec:
@@ -292,14 +291,15 @@
                 self._compile(greenkey)
                 return True
             except Exception, e:
-                rhotpath.report_compile_time_exception(rewriter.interpreter, e)
+                rhotpath.report_compile_time_exception(
+                    hotrunnerdesc.interpreter, e)
                 return False
 
         def _compile(self, greenkey):
-            interp = rewriter.interpreter
+            interp = hotrunnerdesc.interpreter
             rgenop = interp.rgenop
             builder, gv_generated, inputargs_gv = rgenop.newgraph(
-                rewriter.sigtoken, "residual")
+                hotrunnerdesc.sigtoken, "residual")
 
             greenargs = list(greenkey.values)
             redargs = ()
@@ -313,14 +313,19 @@
 
             jitstate = interp.fresh_jitstate(builder)
             rhotpath.setup_jitstate(interp, jitstate, greenargs, redargs,
-                                    rewriter.entryjitcode, rewriter.sigtoken)
+                                    hotrunnerdesc.entryjitcode,
+                                    hotrunnerdesc.sigtoken)
             builder = jitstate.curbuilder
             builder.start_writing()
             rhotpath.compile(interp)
             builder.end()
 
-            FUNCPTR = lltype.Ptr(rewriter.RESIDUAL_FUNCTYPE)
-            self.machine_codes[greenkey] = gv_generated.revealconst(FUNCPTR)
+            FUNCPTR = lltype.Ptr(hotrunnerdesc.RESIDUAL_FUNCTYPE)
+            generated = gv_generated.revealconst(FUNCPTR)
+            self.machine_codes[greenkey] = generated
             self.counters[greenkey] = -1     # compiled
 
+            if not we_are_translated():
+                hotrunnerdesc.residual_graph = generated._obj.graph  #for tests
+
     return HotEnterState

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/interpreter.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/interpreter.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/interpreter.py	Fri Mar 14 09:23:04 2008
@@ -857,7 +857,7 @@
         if done:
             self.debug_trace("done at jit_merge_point")
             self.newjitstate(None)
-            return STOP
+            raise rhotpath.FinishedCompiling
 
     @arguments()
     def opimpl_can_enter_jit(self):
@@ -865,10 +865,24 @@
 
     @arguments("red", "jumptarget")
     def opimpl_hp_red_goto_iftrue(self, switchbox, target):
-        self.debug_trace("pause at hotsplit in", self.frame.bytecode.name)
-        rhotpath.hotsplit(self.jitstate, self.bool_hotpromotiondesc,
-                          switchbox, self.frame.pc, target)
-        assert False, "unreachable"
+        if switchbox.is_constant():
+            if switchbox.getgenvar(self.jitstate).revealconst(lltype.Bool):
+                self.frame.pc = target
+        else:
+            self.debug_trace("pause at hotsplit in", self.frame.bytecode.name)
+            rhotpath.hotsplit(self.jitstate, self.hotrunnerdesc,
+                              switchbox, self.frame.pc, target)
+            assert False, "unreachable"
+
+    @arguments("red", "promotiondesc")
+    def opimpl_hp_promote(self, promotebox, promotiondesc):
+        if promotebox.is_constant():
+            self.green_result_from_red(promotebox)
+        else:
+            self.debug_trace("pause at promote in", self.frame.bytecode.name)
+            rhotpath.hp_promote(self.jitstate, self.hotrunnerdesc,
+                                promotebox, promotiondesc)
+            assert False, "unreachable"
 
     @arguments("green_varargs", "red_varargs", "bytecode")
     def opimpl_hp_yellow_direct_call(self, greenargs, redargs, targetbytecode):
@@ -885,7 +899,20 @@
 
     @arguments()
     def opimpl_hp_red_return(self):
-        xxx
+        gv_result = self.frame.local_boxes[0].getgenvar(self.jitstate)
+        # XXX not translatable (and slow if translated literally)
+        # XXX well, and hackish, clearly
+        def exit(llvalue):
+            DoneWithThisFrame = self.hotrunnerdesc.DoneWithThisFrame
+            raise DoneWithThisFrame(self.rgenop.genconst(llvalue))
+        FUNCTYPE = lltype.FuncType([self.hotrunnerdesc.DoneWithThisFrameARG],
+                                   lltype.Void)
+        exitfnptr = lltype.functionptr(FUNCTYPE, 'exit', _callable=exit)
+        gv_exitfnptr = self.rgenop.genconst(exitfnptr)
+        self.jitstate.curbuilder.genop_call(self.rgenop.sigToken(FUNCTYPE),
+                                            gv_exitfnptr,
+                                            [gv_result])
+        rhotpath.leave_graph(self)
 
     @arguments()
     def opimpl_hp_yellow_return(self):

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/rhotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/rhotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/rhotpath.py	Fri Mar 14 09:23:04 2008
@@ -3,7 +3,8 @@
 """
 
 from pypy.jit.timeshifter import rtimeshift, rvalue
-from pypy.rlib.objectmodel import we_are_translated
+from pypy.jit.timeshifter.greenkey import KeyDesc, GreenKey, newgreendict
+from pypy.rlib.objectmodel import we_are_translated, specialize
 from pypy.rpython.annlowlevel import cachedtype, base_ptr_lltype
 from pypy.rpython.annlowlevel import llhelper
 from pypy.rpython.lltypesystem import lltype, llmemory
@@ -22,8 +23,6 @@
 
 def leave_graph(interp):
     jitstate = interp.jitstate
-    if jitstate is None:
-        return
     exceptiondesc = interp.exceptiondesc
     builder = jitstate.curbuilder
     #for virtualizable_box in jitstate.virtualizables:
@@ -34,18 +33,16 @@
     exceptiondesc.store_global_excdata(jitstate)
     jitstate.curbuilder.finish_and_return(interp.graphsigtoken, None)
     jitstate.curbuilder = None
+    raise FinishedCompiling
 
 def compile(interp):
     jitstate = interp.jitstate
     builder = jitstate.curbuilder
     try:
         interp.bytecode_loop()
+        assert False, "unreachable"
     except FinishedCompiling:
         pass
-    except GenerateReturn:
-        leave_graph(interp)
-    else:
-        leave_graph(interp)
     builder.show_incremental_progress()
 
 def report_compile_time_exception(interp, e):
@@ -70,58 +67,43 @@
 class FinishedCompiling(Exception):
     pass
 
-class GenerateReturn(Exception):
-    pass
-
 class HotPromotionDesc:
     __metaclass__ = cachedtype
 
-    def __init__(self, ERASED, interpreter, threshold, fallbackinterp):
-        self.exceptiondesc = interpreter.exceptiondesc
-        self.gv_constant_one = interpreter.rgenop.constPrebuiltGlobal(1)
+    def __init__(self, ERASED, RGenOp):
+        self.RGenOp = RGenOp
+        self.greenkeydesc = KeyDesc(RGenOp, ERASED)
+        pathkind = "%s path" % (ERASED,)
 
         def ll_reach_fallback_point(fallback_point_ptr, value, framebase):
             try:
                 fbp = fallback_point_ptr     # XXX cast
-                assert lltype.typeOf(value) is lltype.Bool   # XXX for now
-                if value:
-                    counter = fbp.truepath_counter
-                    pc = fbp.truepath_pc
-                else:
-                    counter = fbp.falsepath_counter
-                    pc = fbp.falsepath_pc
-                assert counter >= 0, (
-                    "reaching a fallback point for an already-compiled path")
-                counter += 1
+                # check if we should compile for this value.
+                path_is_hot = fbp.check_should_compile(value)
 
-                if counter >= threshold:
+                if path_is_hot:
                     # this is a hot path, compile it
-                    interpreter.debug_trace("jit_resume", "bool_path", value,
+                    interpreter = fbp.hotrunnerdesc.interpreter
+                    interpreter.debug_trace("jit_resume", pathkind, value,
                         "in", fbp.saved_jitstate.frame.bytecode.name)
-                    gv_value = interpreter.rgenop.genconst(value)
-                    fbp.compile_hot_path(interpreter, gv_value, pc)
-                    if value:
-                        fbp.truepath_counter = -1    # means "compiled"
-                    else:
-                        fbp.falsepath_counter = -1   # means "compiled"
+                    fbp.compile_hot_path(value)
                     # Done.  We return without an exception set, which causes
                     # our caller (the machine code produced by hotsplit()) to
                     # loop back to the flexswitch and execute the
                     # newly-generated code.
                     interpreter.debug_trace("resume_machine_code")
                     return
-                else:
-                    # path is still cold
-                    if value:
-                        fbp.truepath_counter = counter
-                    else:
-                        fbp.falsepath_counter = counter
+                # else: path is still cold
 
             except Exception, e:
+                interpreter = fbp.hotrunnerdesc.interpreter
                 report_compile_time_exception(interpreter, e)
 
             # exceptions below at run-time exceptions, we let them propagate
-            fallbackinterp.run(fbp, framebase, pc)
+            fallbackinterp = fbp.hotrunnerdesc.fallbackinterp
+            fallbackinterp.initialize_state(fbp, framebase)
+            fbp.prepare_fallbackinterp(fallbackinterp, value)
+            fallbackinterp.bytecode_loop()
             # If the fallback interpreter reached the next jit_merge_point(),
             # it raised ContinueRunningNormally().  This exception is
             # caught by portal_runner() from hotpath.py in order to loop
@@ -135,7 +117,7 @@
                                     llmemory.Address], lltype.Void)
         FUNCPTRTYPE = lltype.Ptr(FUNCTYPE)
         self.FUNCPTRTYPE = FUNCPTRTYPE
-        self.sigtoken = interpreter.rgenop.sigToken(FUNCTYPE)
+        self.sigtoken = RGenOp.sigToken(FUNCTYPE)
 
         def get_gv_reach_fallback_point(builder):
             fnptr = llhelper(FUNCPTRTYPE, ll_reach_fallback_point)
@@ -149,43 +131,159 @@
 
 
 class FallbackPoint(object):
-    falsepath_counter = 0     # -1 after this path was compiled
-    truepath_counter = 0      # -1 after this path was compiled
 
-    def __init__(self, jitstate, flexswitch, frameinfo,
-                 falsepath_pc, truepath_pc):
+    def __init__(self, jitstate, hotrunnerdesc, promotebox):
         # XXX we should probably trim down the jitstate once our caller
         # is done with it, to avoid keeping too much stuff in memory
         self.saved_jitstate = jitstate
+        self.hotrunnerdesc = hotrunnerdesc
+        self.promotebox = promotebox
+
+    def set_machine_code_info(self, flexswitch, frameinfo):
         self.flexswitch = flexswitch
         self.frameinfo = frameinfo
         # ^^^ 'frameinfo' describes where the machine code stored all
         # its GenVars, so that we can fish these values to pass them
         # to the fallback interpreter
+
+    # hack for testing: make the llinterpreter believe this is a Ptr to base
+    # instance
+    _TYPE = base_ptr_lltype()
+
+
+class HotSplitFallbackPoint(FallbackPoint):
+    falsepath_counter = 0     # -1 after this path was compiled
+    truepath_counter = 0      # -1 after this path was compiled
+
+    def __init__(self, jitstate, hotrunnerdesc, promotebox,
+                 falsepath_pc, truepath_pc):
+        FallbackPoint. __init__(self, jitstate, hotrunnerdesc, promotebox)
         self.falsepath_pc = falsepath_pc
         self.truepath_pc = truepath_pc
 
-    def compile_hot_path(self, interpreter, gv_case, pc):
+    @specialize.arg(1)
+    def check_should_compile(self, value):
+        assert lltype.typeOf(value) is lltype.Bool
+        threshold = self.hotrunnerdesc.threshold
+        if value:
+            counter = self.truepath_counter + 1
+            assert counter > 0, (
+                "reaching a fallback point for an already-compiled path")
+            if counter >= threshold:
+                return True
+            self.truepath_counter = counter
+            return False
+        else:
+            counter = self.falsepath_counter + 1
+            assert counter > 0, (
+                "reaching a fallback point for an already-compiled path")
+            if counter >= threshold:
+                return True
+            self.falsepath_counter = counter
+            return False   # path is still cold
+
+    @specialize.arg(2)
+    def prepare_fallbackinterp(self, fallbackinterp, value):
+        if value:
+            fallbackinterp.pc = self.truepath_pc
+        else:
+            fallbackinterp.pc = self.falsepath_pc
+
+    @specialize.arg(1)
+    def compile_hot_path(self, value):
+        if value:
+            pc = self.truepath_pc
+        else:
+            pc = self.falsepath_pc
+        gv_value = self.hotrunnerdesc.interpreter.rgenop.genconst(value)
+        self._compile_hot_path(gv_value, pc)
+        if value:
+            self.truepath_counter = -1    # means "compiled"
+        else:
+            self.falsepath_counter = -1   # means "compiled"
+
+    def _compile_hot_path(self, gv_case, pc):
         if self.falsepath_counter == -1 or self.truepath_counter == -1:
             # the other path was already compiled, we can reuse the jitstate
             jitstate = self.saved_jitstate
             self.saved_jitstate = None
+            promotebox = self.promotebox
         else:
             # clone the jitstate
             memo = rvalue.copy_memo()
             jitstate = self.saved_jitstate.clone(memo)
+            promotebox = memo.boxes[self.promotebox]
+        promotebox.setgenvar(gv_case)
+        interpreter = self.hotrunnerdesc.interpreter
         interpreter.newjitstate(jitstate)
         interpreter.frame.pc = pc
         jitstate.curbuilder = self.flexswitch.add_case(gv_case)
         compile(interpreter)
 
-    # hack for testing: make the llinterpreter believe this is a Ptr to base
-    # instance
-    _TYPE = base_ptr_lltype()
 
+class PromoteFallbackPoint(FallbackPoint):
 
-def hotsplit(jitstate, hotpromotiondesc, switchbox, falsepath_pc, truepath_pc):
+    def __init__(self, jitstate, hotrunnerdesc, promotebox, hotpromotiondesc):
+        FallbackPoint. __init__(self, jitstate, hotrunnerdesc, promotebox)
+        self.hotpromotiondesc = hotpromotiondesc
+        self.counters = newgreendict()
+
+    @specialize.arg(1)
+    def check_should_compile(self, value):
+        # XXX incredibly heavy for a supposely lightweight profiling
+        gv_value = self.hotrunnerdesc.interpreter.rgenop.genconst(value)
+        greenkey = GreenKey([gv_value], self.hotpromotiondesc.greenkeydesc)
+        counter = self.counters.get(greenkey, 0) + 1
+        threshold = self.hotrunnerdesc.threshold
+        assert counter > 0, (
+            "reaching a fallback point for an already-compiled path")
+        if counter >= threshold:
+            return True
+        self.counters[greenkey] = counter
+        return False
+
+    @specialize.arg(2)
+    def prepare_fallbackinterp(self, fallbackinterp, value):
+        gv_value = self.hotrunnerdesc.interpreter.rgenop.genconst(value)
+        fallbackinterp.local_green.append(gv_value)
+
+    @specialize.arg(1)
+    def compile_hot_path(self, value):
+        gv_value = self.hotrunnerdesc.interpreter.rgenop.genconst(value)
+        self._compile_hot_path(gv_value)
+
+    def _compile_hot_path(self, gv_value):
+        # clone the jitstate
+        memo = rvalue.copy_memo()
+        jitstate = self.saved_jitstate.clone(memo)
+        promotebox = memo.boxes[self.promotebox]
+        promotebox.setgenvar(gv_value)
+        # compile from that state
+        interpreter = self.hotrunnerdesc.interpreter
+        interpreter.newjitstate(jitstate)
+        interpreter.green_result(gv_value)
+        jitstate.curbuilder = self.flexswitch.add_case(gv_value)
+        compile(interpreter)
+        # done
+        greenkey = GreenKey([gv_value], self.hotpromotiondesc.greenkeydesc)
+        self.counters[greenkey] = -1     # means "compiled"
+
+
+def hotsplit(jitstate, hotrunnerdesc, switchbox,
+             falsepath_pc, truepath_pc):
     # produce a Bool flexswitch for now
+    fbp = HotSplitFallbackPoint(jitstate, hotrunnerdesc, switchbox,
+                                falsepath_pc, truepath_pc)
+    desc = hotrunnerdesc.bool_hotpromotiondesc
+    generate_fallback_code(fbp, desc, switchbox)
+
+def hp_promote(jitstate, hotrunnerdesc, promotebox, hotpromotiondesc):
+    fbp = PromoteFallbackPoint(jitstate, hotrunnerdesc, promotebox,
+                               hotpromotiondesc)
+    generate_fallback_code(fbp, hotpromotiondesc, promotebox)
+
+def generate_fallback_code(fbp, hotpromotiondesc, switchbox):
+    jitstate = fbp.saved_jitstate
     incoming = jitstate.enter_block_sweep_virtualizables()
     switchblock = rtimeshift.enter_next_block(jitstate, incoming)
     gv_switchvar = switchbox.genvar
@@ -195,8 +293,7 @@
     jitstate.curbuilder = default_builder
     # default case of the switch:
     frameinfo = default_builder.get_frame_info(incoming_gv)
-    fbp = FallbackPoint(jitstate, flexswitch, frameinfo,
-                        falsepath_pc, truepath_pc)
+    fbp.set_machine_code_info(flexswitch, frameinfo)
     ll_fbp = fbp        # XXX doesn't translate
     gv_fbp = default_builder.rgenop.genconst(ll_fbp)
     gv_switchvar = switchbox.genvar
@@ -208,7 +305,7 @@
     # The call above may either return normally, meaning that more machine
     # code was compiled and we should loop back to 'switchblock' to enter it,
     # or it may have set an exception.
-    exceptiondesc = hotpromotiondesc.exceptiondesc
+    exceptiondesc = fbp.hotrunnerdesc.exceptiondesc
     gv_exc_type = exceptiondesc.genop_get_exc_type(default_builder)
     gv_noexc = default_builder.genop_ptr_iszero(
         exceptiondesc.exc_type_kind, gv_exc_type)
@@ -217,4 +314,4 @@
 
     jitstate.curbuilder = excpath_builder
     excpath_builder.start_writing()
-    raise GenerateReturn
+    leave_graph(fbp.hotrunnerdesc.interpreter)

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hot_promotion.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hot_promotion.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hot_promotion.py	Fri Mar 14 09:23:04 2008
@@ -3,34 +3,44 @@
 from pypy.rlib.jit import JitDriver, hint, JitHintError
 from pypy.jit.rainbow.test import test_hotpath
 
-py.test.skip("in-progress")
-
 
 class TestHotPromotion(test_hotpath.HotPathTest):
 
     def interpret(self, main, ll_values, opt_consts=[]):
-        if opt_consts:
-            # opt_consts lists the indices of arguments that should be
-            # passed in as constant red boxes.  To emulate this effect
-            # we make them green vars in a wrapper main() but pass
-            # them as red boxes to the original main().
-            miniglobals = {'original_main': main,
-                           'hint': hint,
-                           }
-            args = ', '.join(['a%d' % i for i in range(len(ll_values))])
-            lines = '\n'.join(['    a%d = hint(hint(a%d, concrete=True), '
-                                              'variable=True)' % (i, i)
-                               for i in opt_consts])
-            # cannot use unrolling_iterable because the main() cannot
-            # take *args...
-            src = py.code.Source("""\
-                def main(%(args)s):
-                %(lines)s
-                    return original_main(%(args)s)""" % locals())
-            exec src.compile() in miniglobals
-            main = miniglobals['main']
+        py.test.skip("fix this test")
+    def interpret_raises(self, Exception, main, ll_values, opt_consts=[]):
+        py.test.skip("fix this test")
+
+    def get_residual_graph(self):
+        return self.hotrunnerdesc.residual_graph
+
+    def check_insns_excluding_return(self, expected=None, **counts):
+        # the return is currently implemented by a direct_call(exitfnptr)
+        if expected is not None:
+            expected.setdefault('direct_call', 0)
+            expected['direct_call'] += 1
+        if 'direct_call' in counts:
+            counts['direct_call'] += 1
+        self.check_insns(expected, **counts)
 
-        return self.run(main, ll_values, threshold=1)
+    def test_easy_case(self):
+        class MyJitDriver(JitDriver):
+            greens = ['n']
+            reds = []
+        def ll_two(k):
+            return (k+1)*2
+        def ll_function(n):
+            MyJitDriver.jit_merge_point(n=n)
+            MyJitDriver.can_enter_jit(n=n)
+            hint(n, concrete=True)
+            n = hint(n, variable=True)     # n => constant red box
+            k = hint(n, promote=True)      # no-op
+            k = ll_two(k)
+            return hint(k, variable=True)
+
+        res = self.run(ll_function, [20], threshold=1)
+        assert res == 42
+        self.check_insns_excluding_return({})
 
     def test_simple_promotion(self):
         class MyJitDriver(JitDriver):
@@ -40,36 +50,54 @@
             return (k+1)*2
         def ll_function(n):
             MyJitDriver.jit_merge_point(n=n)
+            MyJitDriver.can_enter_jit(n=n)
             k = hint(n, promote=True)
             k = ll_two(k)
             return hint(k, variable=True)
 
-        # easy case: no promotion needed
-        res = self.interpret(ll_function, [20], [0])
-        assert res == 42
-        self.check_insns({})
-
-        # the real test: with promotion
-        res = self.interpret(ll_function, [20], [])
+        res = self.run(ll_function, [20], threshold=1)
         assert res == 42
         self.check_insns(int_add=0, int_mul=0)
 
     def test_many_promotions(self):
+        class MyJitDriver(JitDriver):
+            greens = []
+            reds = ['n', 'total']
         def ll_two(k):
             return k*k
         def ll_function(n, total):
             while n > 0:
-                hint(None, global_merge_point=True)
+                MyJitDriver.jit_merge_point(n=n, total=total)
+                MyJitDriver.can_enter_jit(n=n, total=total)
                 k = hint(n, promote=True)
                 k = ll_two(k)
                 total += hint(k, variable=True)
                 n -= 1
             return total
 
-        res = self.interpret(ll_function, [10, 0], [])
+        res = self.run(ll_function, [10, 0], threshold=1)
         assert res == ll_function(10, 0)
         self.check_insns(int_add=10, int_mul=0)
 
+        # the same using the fallback interp instead of compiling each case
+        res = self.run(ll_function, [10, 0], threshold=3)
+        assert res == ll_function(10, 0)
+        self.check_insns(int_add=0, int_mul=0)
+        self.check_traces([
+            "jit_not_entered 10 0",
+            "jit_not_entered 9 100",
+            "jit_compile",
+            "pause at promote in ll_function",
+            "run_machine_code 8 181", "fallback_interp", "fb_leave 7 245",
+            "run_machine_code 7 245", "fallback_interp", "fb_leave 6 294",
+            "run_machine_code 6 294", "fallback_interp", "fb_leave 5 330",
+            "run_machine_code 5 330", "fallback_interp", "fb_leave 4 355",
+            "run_machine_code 4 355", "fallback_interp", "fb_leave 3 371",
+            "run_machine_code 3 371", "fallback_interp", "fb_leave 2 380",
+            "run_machine_code 2 380", "fallback_interp", "fb_leave 1 384",
+            "run_machine_code 1 384", "fallback_interp", "fb_return (385)"
+            ])
+
     def test_promote_after_call(self):
         S = lltype.GcStruct('S', ('x', lltype.Signed))
         def ll_two(k, s):
@@ -85,7 +113,7 @@
             k *= 17
             return hint(k, variable=True) + s.x
 
-        res = self.interpret(ll_function, [4], [])
+        res = self.interpret(ll_function, [4])
         assert res == 4*17 + 10
         self.check_insns(int_mul=0, int_add=1)
 
@@ -107,7 +135,7 @@
             k += c
             return hint(k, variable=True)
 
-        res = self.interpret(ll_function, [4], [])
+        res = self.interpret(ll_function, [4])
         assert res == 49
         self.check_insns(int_add=0)
 
@@ -120,7 +148,7 @@
             hint(None, global_merge_point=True)
             return ll_two(n + 1) - 1
 
-        res = self.interpret(ll_function, [10], [])
+        res = self.interpret(ll_function, [10])
         assert res == 186
         self.check_insns(int_add=1, int_mul=0, int_sub=0)
 
@@ -137,15 +165,15 @@
                 return 42
             return ll_two(n + 1) - 1
 
-        res = self.interpret(ll_function, [10, 0], [])
+        res = self.interpret(ll_function, [10, 0])
         assert res == 186
         self.check_insns(int_add=1, int_mul=0, int_sub=0)
 
-        res = self.interpret(ll_function, [0, 0], [])
+        res = self.interpret(ll_function, [0, 0])
         assert res == -41
         self.check_insns(int_add=1, int_mul=0, int_sub=0)
 
-        res = self.interpret(ll_function, [1, 1], [])
+        res = self.interpret(ll_function, [1, 1])
         assert res == 42
         self.check_insns(int_add=1, int_mul=0, int_sub=0)
 
@@ -158,7 +186,7 @@
             s1 = n1 + m1
             return hint(s1, variable=True)
 
-        res = self.interpret(ll_function, [40, 2], [])
+        res = self.interpret(ll_function, [40, 2])
         assert res == 42
         self.check_insns(int_add=0)
 
@@ -177,7 +205,7 @@
             hint(None, global_merge_point=True)
             return ll_two(n)
 
-        res = self.interpret(ll_function, [3], [])
+        res = self.interpret(ll_function, [3])
         assert res == 340
         self.check_insns(int_lt=1, int_mul=0)
 
@@ -199,7 +227,7 @@
         self.check_insns({})
 
         # the real test: with promotion
-        res = self.interpret(ll_function, [20], [])
+        res = self.interpret(ll_function, [20])
         assert res == 62
         self.check_insns(int_add=0, int_mul=0)
 
@@ -236,7 +264,7 @@
                 i += j
             return s.x + s.y * 17
 
-        res = self.interpret(ll_function, [100, 2], [])
+        res = self.interpret(ll_function, [100, 2])
         assert res == ll_function(100, 2)
 
     def test_mixed_merges(self):
@@ -307,7 +335,7 @@
             return s
         ll_function.convert_arguments = [struct_S, int]
 
-        res = self.interpret(ll_function, ["20", 0], [])
+        res = self.interpret(ll_function, ["20", 0])
         assert res == 42
         self.check_flexswitches(1)
 
@@ -324,7 +352,7 @@
                 vl[i] = l[i]
                 i = i + 1
             return len(vl)
-        res = self.interpret(ll_function, [6, 5], [])
+        res = self.interpret(ll_function, [6, 5])
         assert res == 6
         self.check_oops(**{'newlist': 1, 'list.len': 1})
             
@@ -349,7 +377,7 @@
                 a += z
 
         assert ll_function(1, 5, 8) == 22
-        res = self.interpret(ll_function, [1, 5, 8], [])
+        res = self.interpret(ll_function, [1, 5, 8])
         assert res == 22
 
     def test_raise_result_mixup(self):
@@ -431,7 +459,7 @@
             if m == 0:
                 raise ValueError
             return n
-        self.interpret_raises(ValueError, ll_function, [1, 0], [])
+        self.interpret_raises(ValueError, ll_function, [1, 0])
 
     def test_promote_in_yellow_call(self):
         def ll_two(n):
@@ -443,7 +471,7 @@
             c = ll_two(n)
             return hint(c, variable=True)
 
-        res = self.interpret(ll_function, [4], [])
+        res = self.interpret(ll_function, [4])
         assert res == 6
         self.check_insns(int_add=0)
 
@@ -460,7 +488,7 @@
                 c = ll_two(n)
             return hint(c, variable=True)
 
-        res = self.interpret(ll_function, [4], [])
+        res = self.interpret(ll_function, [4])
         assert res == 6
         self.check_insns(int_add=0)
 
@@ -482,6 +510,6 @@
             c = ll_one(n, m)
             return c
 
-        res = self.interpret(ll_function, [4, 7], [])
+        res = self.interpret(ll_function, [4, 7])
         assert res == 11
         self.check_insns(int_add=0)

Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hotpath.py	Fri Mar 14 09:23:04 2008
@@ -2,7 +2,7 @@
 import re
 from pypy.rlib.jit import JitDriver, hint, JitHintError
 from pypy.jit.rainbow.test import test_interpreter
-from pypy.jit.rainbow.hotpath import EntryPointsRewriter
+from pypy.jit.rainbow.hotpath import HotRunnerDesc
 from pypy.jit.hintannotator.policy import HintAnnotatorPolicy
 from pypy.rpython.llinterp import LLInterpreter
 from pypy import conftest
@@ -25,11 +25,10 @@
         return self._run(main, main_args)
 
     def _rewrite(self, threshold, small):
-        rewriter = EntryPointsRewriter(self.hintannotator, self.rtyper,
+        self.hotrunnerdesc = HotRunnerDesc(self.hintannotator, self.rtyper,
                                        self.jitcode, self.RGenOp, self.writer,
                                        threshold, self.translate_support_code)
-        self.rewriter = rewriter
-        rewriter.rewrite_all()
+        self.hotrunnerdesc.rewrite_all()
         if small and conftest.option.view:
             self.rtyper.annotator.translator.view()
 
@@ -40,8 +39,11 @@
             self.rtyper, exc_data_ptr=self.writer.exceptiondesc.exc_data_ptr)
         return llinterp.eval_graph(graph, main_args)
 
+    def get_traces(self):
+        return self.hotrunnerdesc.interpreter.debug_traces
+
     def check_traces(self, expected):
-        traces = self.rewriter.interpreter.debug_traces
+        traces = self.get_traces()
         i = 0
         for trace, expect in zip(traces + ['--end of traces--'],
                                  expected + ['--end of traces--']):
@@ -108,7 +110,7 @@
             "run_machine_code 5 195",
             # now that we know which path is hot (i.e. "staying in the loop"),
             # it gets compiled
-            "jit_resume bool_path False in ll_function",
+            "jit_resume Bool path False in ll_function",
             "done at jit_merge_point",
             # execution continues purely in machine code, from the "n1 <= 1"
             # test which triggered the "jit_resume"
@@ -203,7 +205,7 @@
                 "run_machine_code * struct rpy_string {...} 5 5 30240 10",
             # the third time, compile the hot path, which closes the loop
             # in the generated machine code
-                "jit_resume bool_path True in ll_function",
+                "jit_resume Bool path True in ll_function",
                 "done at jit_merge_point",
             # continue running 100% in the machine code as long as necessary
                 "resume_machine_code",
@@ -216,7 +218,7 @@
 
         res = self.run(main, [2, 1291], threshold=3, small=True)
         assert res == 1
-        assert len(self.rewriter.interpreter.debug_traces) < 20
+        assert len(self.get_traces()) < 20
 
     def test_simple_return(self):
         class MyJitDriver(JitDriver):
@@ -246,7 +248,7 @@
 
         res = self.run(ll_function, [50], threshold=3, small=True)
         assert res == (50*51)/2
-        assert len(self.rewriter.interpreter.debug_traces) < 20
+        assert len(self.get_traces()) < 20
 
     def test_hint_errors(self):
         class MyJitDriver(JitDriver):
@@ -291,7 +293,7 @@
             "fallback_interp",
             "fb_leave * stru...} 10 64 * array [ 64, 71, 497 ]",
             "run_machine_code * stru...} 10 64 * array [ 64, 71, 497 ]",
-            "jit_resume bool_path True in hp_interpret",
+            "jit_resume Bool path True in hp_interpret",
             "done at jit_merge_point",
             "resume_machine_code",
             "fallback_interp",



More information about the Pypy-commit mailing list