[pypy-svn] r34944 - in pypy/dist/pypy/jit/timeshifter: . test

ac at codespeak.net ac at codespeak.net
Fri Nov 24 19:17:18 CET 2006


Author: ac
Date: Fri Nov 24 19:17:06 2006
New Revision: 34944

Modified:
   pypy/dist/pypy/jit/timeshifter/hrtyper.py
   pypy/dist/pypy/jit/timeshifter/rtimeshift.py
   pypy/dist/pypy/jit/timeshifter/rvalue.py
   pypy/dist/pypy/jit/timeshifter/test/test_portal.py
   pypy/dist/pypy/jit/timeshifter/transform.py
Log:
(pedronis, arre, arigo around)

Make check_insns() work for tests with portals != entrypoint.

Add support for less aggressive merging (not used yet). Support
red and grey calls returning more than one jitstate.



Modified: pypy/dist/pypy/jit/timeshifter/hrtyper.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/hrtyper.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/hrtyper.py	Fri Nov 24 19:17:06 2006
@@ -3,6 +3,7 @@
 from pypy.translator.unsimplify import varoftype
 from pypy.translator.backendopt.ssa import SSA_to_SSI
 from pypy.annotation import model as annmodel
+from pypy.annotation import listdef
 from pypy.annotation.pairtype import pair, pairtype
 from pypy.rpython.annlowlevel import PseudoHighLevelCallable
 from pypy.rlib.unroll import unrolling_iterable
@@ -152,11 +153,16 @@
         bk = self.annotator.bookkeeper
         bk.compute_after_normalization()
         entrygraph = self.annotator.translator.graphs[0]
+        if origportalgraph:
+            portalgraph = bk.get_graph_by_key(origportalgraph, None)
+        else:
+            portalgraph = None
         pending = [entrygraph]
         seen = {entrygraph: True}
         while pending:
             graph = pending.pop()
-            for nextgraph in self.transform_graph(graph):
+            for nextgraph in self.transform_graph(graph,
+                                is_portal=graph is portalgraph):
                 if nextgraph not in seen:
                     pending.append(nextgraph)
                     seen[nextgraph] = True
@@ -170,7 +176,6 @@
             self.timeshift_graph(graph)
 
         if origportalgraph:
-            portalgraph = bk.get_graph_by_key(origportalgraph, None)
             self.rewire_portal(origportalgraph, portalgraph)
         
     def rewire_portal(self, origportalgraph, portalgraph):
@@ -222,6 +227,9 @@
                 return cache[key]
             except KeyError:
                 return lltype.nullptr(FUNC)
+
+        def readallportals():
+            return state.cache.values()
         
         def portalentry(*args):
             i = 0
@@ -280,8 +288,13 @@
         portalentrygraph = annhelper.getgraph(portalentry, args_s, s_result)
         portalentrygraph.tag = "portal_entry"
 
+        s_funcptr = annmodel.SomePtr(lltype.Ptr(FUNC))
         self.readportalgraph = annhelper.getgraph(readportal, args_s,
-                                   annmodel.SomePtr(lltype.Ptr(FUNC)))
+                                   s_funcptr)
+
+        s_funcptrlist = annmodel.SomeList(listdef.ListDef(None, s_funcptr))
+        self.readallportalsgraph = annhelper.getgraph(readallportals, [],
+                                                      s_funcptrlist)
 
         annhelper.finish()
 
@@ -290,11 +303,12 @@
         origportalgraph.exceptblock = portalentrygraph.exceptblock
         # name, func?
 
-    def transform_graph(self, graph):
+    def transform_graph(self, graph, is_portal=False):
         # prepare the graphs by inserting all bookkeeping/dispatching logic
         # as special operations
         assert graph.startblock in self.annotator.annotated
-        transformer = HintGraphTransformer(self.annotator, graph)
+        transformer = HintGraphTransformer(self.annotator, graph,
+                                           is_portal=is_portal)
         transformer.transform()
         flowmodel.checkgraph(graph)    # for now
         return transformer.tsgraphs_seen
@@ -752,14 +766,18 @@
                                         [v_jitstate     , self.v_queue],
                                         annmodel.s_None)
 
-    def translate_op_leave_graph_red(self, hop):
+    def translate_op_leave_graph_red(self, hop, is_portal=False):
         v_jitstate = hop.llops.getjitstate()
+        c_is_portal = inputconst(lltype.Bool, is_portal)
         v_newjs = hop.llops.genmixlevelhelpercall(rtimeshift.leave_graph_red,
-                            [self.s_JITState, self.s_Queue],
-                            [v_jitstate     , self.v_queue],
+                            [self.s_JITState, self.s_Queue, annmodel.s_Bool],
+                            [v_jitstate     , self.v_queue, c_is_portal],
                             self.s_JITState)
         hop.llops.setjitstate(v_newjs)
 
+    def translate_op_leave_graph_portal(self, hop):
+        self.translate_op_leave_graph_red(hop, is_portal=True)
+
     def translate_op_leave_graph_gray(self, hop):
         v_jitstate = hop.llops.getjitstate()
         v_newjs = hop.llops.genmixlevelhelpercall(rtimeshift.leave_graph_gray,

Modified: pypy/dist/pypy/jit/timeshifter/rtimeshift.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/rtimeshift.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/rtimeshift.py	Fri Nov 24 19:17:06 2006
@@ -212,7 +212,7 @@
 def return_marker(jitstate):
     raise AssertionError("shouldn't get here")
 
-def start_new_block(states_dic, jitstate, key, global_resumer):
+def start_new_block(states_dic, jitstate, key, global_resumer, index=-1):
     memo = rvalue.freeze_memo()
     frozen = jitstate.freeze(memo)
     memo = rvalue.exactmatch_memo()
@@ -221,7 +221,11 @@
     assert res, "exactmatch() failed"
     cleanup_partial_data(memo.partialdatamatch)
     newblock = enter_next_block(jitstate, outgoingvarboxes)
-    states_dic[key] = frozen, newblock
+    if index == -1:
+        states_dic[key].append((frozen, newblock))
+    else:
+        states_dic[key][index] = (frozen, newblock)
+        
     if global_resumer is not None and global_resumer is not return_marker:
         greens_gv = jitstate.greens
         rgenop = jitstate.curbuilder.rgenop
@@ -233,35 +237,51 @@
         #debug_print(lltype.Void, "PROMOTION ROOT")
 start_new_block._annspecialcase_ = "specialize:arglltype(2)"
 
-def retrieve_jitstate_for_merge(states_dic, jitstate, key, global_resumer):
+class DontMerge(Exception):
+    pass
+
+def retrieve_jitstate_for_merge(states_dic, jitstate, key, global_resumer,
+                                force_merge=False):
     if key not in states_dic:
+        states_dic[key] = []
         start_new_block(states_dic, jitstate, key, global_resumer)
         return False   # continue
 
-    frozen, oldblock = states_dic[key]
-    memo = rvalue.exactmatch_memo()
-    outgoingvarboxes = []
-
-    if frozen.exactmatch(jitstate, outgoingvarboxes, memo):
-        linkargs = []
+    states = states_dic[key]
+    for i in range(len(states) -1, -1, -1):
+        frozen, oldblock =  states[i]
+        memo = rvalue.exactmatch_memo(force_merge)
+        outgoingvarboxes = []
+        
+        try:
+            match = frozen.exactmatch(jitstate, outgoingvarboxes, memo)
+        except DontMerge:
+            continue
+        if match:
+            linkargs = []
+            for box in outgoingvarboxes:
+                linkargs.append(box.getgenvar(jitstate.curbuilder))
+            jitstate.curbuilder.finish_and_goto(linkargs, oldblock)
+            return True    # finished
+        # A mergable blook found
+        # We need a more general block.  Do it by generalizing all the
+        # redboxes from outgoingvarboxes, by making them variables.
+        # Then we make a new block based on this new state.
+        cleanup_partial_data(memo.partialdatamatch)
+        replace_memo = rvalue.copy_memo()
         for box in outgoingvarboxes:
-            linkargs.append(box.getgenvar(jitstate.curbuilder))
-        jitstate.curbuilder.finish_and_goto(linkargs, oldblock)
-        return True    # finished
-
-    # We need a more general block.  Do it by generalizing all the
-    # redboxes from outgoingvarboxes, by making them variables.
-    # Then we make a new block based on this new state.
-    cleanup_partial_data(memo.partialdatamatch)
-    replace_memo = rvalue.copy_memo()
-    for box in outgoingvarboxes:
-        box.forcevar(jitstate.curbuilder, replace_memo)
-    if replace_memo.boxes:
-        jitstate.replace(replace_memo)
+            box.forcevar(jitstate.curbuilder, replace_memo)
+        if replace_memo.boxes:
+            jitstate.replace(replace_memo)
+        start_new_block(states_dic, jitstate, key, global_resumer, index=i)
+        if global_resumer is None:
+            merge_generalized(jitstate)
+        return False       # continue
+
+    # No mergable states found, make a new.
     start_new_block(states_dic, jitstate, key, global_resumer)
-    if global_resumer is None:
-        merge_generalized(jitstate)
-    return False       # continue
+    return False   
+
 retrieve_jitstate_for_merge._annspecialcase_ = "specialize:arglltype(2)"
 
 def cleanup_partial_data(partialdatamatch):
@@ -888,54 +908,64 @@
                 parent_mergesleft = MC_CALL_NOT_TAKEN
         dispatchqueue.mergecounter = parent_mergesleft
 
-def merge_returning_jitstates(jitstate, dispatchqueue):
+def merge_returning_jitstates(jitstate, dispatchqueue, force_merge=False):
     return_chain = dispatchqueue.return_chain
-    resuming = jitstate.resuming
     return_cache = {}
     still_pending = None
     while return_chain is not None:
         jitstate = return_chain
         return_chain = return_chain.next
         res = retrieve_jitstate_for_merge(return_cache, jitstate, (),
-                                          return_marker)
+                                          return_marker,
+                                          force_merge=force_merge)
         if res is False:    # not finished
             jitstate.next = still_pending
             still_pending = jitstate
-    most_general_jitstate = still_pending
-    # if there are more than one jitstate still left, merge them forcefully
-    if still_pending is not None:
-        still_pending = still_pending.next
-        while still_pending is not None:
-            jitstate = still_pending
-            still_pending = still_pending.next
+    
+    # Of the jitstates we have left some may be mergable to a later
+    # more general one.
+    return_chain = still_pending
+    if return_chain is not None:
+        return_cache = {}
+        still_pending = None
+        while return_chain is not None:
+            jitstate = return_chain
+            return_chain = return_chain.next
             res = retrieve_jitstate_for_merge(return_cache, jitstate, (),
-                                              return_marker)
-            assert res is True   # finished
+                                              return_marker,
+                                              force_merge=force_merge)
+            if res is False:    # not finished
+                jitstate.next = still_pending
+                still_pending = jitstate
+    return still_pending
 
+def leave_graph_red(jitstate, dispatchqueue, is_portal):
+    resuming = jitstate.resuming
+    return_chain = merge_returning_jitstates(jitstate, dispatchqueue,
+                                             force_merge=is_portal)
     if resuming is not None:
         resuming.leave_call(dispatchqueue)
-        
-    return most_general_jitstate
-
-def leave_graph_red(jitstate, dispatchqueue):
-    jitstate = merge_returning_jitstates(jitstate, dispatchqueue)
-    if jitstate is not None:
+    jitstate = return_chain
+    while jitstate is not None:
         myframe = jitstate.frame
         leave_frame(jitstate)
         jitstate.greens = []
-        jitstate.next = None
         jitstate.returnbox = myframe.local_boxes[0]
-        # ^^^ fetched by a 'fetch_return' operation
-    return jitstate
+        jitstate = jitstate.next
+    return return_chain
 
 def leave_graph_gray(jitstate, dispatchqueue):
-    jitstate = merge_returning_jitstates(jitstate, dispatchqueue)
-    if jitstate is not None:
+    resuming = jitstate.resuming
+    return_chain = merge_returning_jitstates(jitstate, dispatchqueue)
+    if resuming is not None:
+        resuming.leave_call(dispatchqueue)
+    jitstate = return_chain
+    while jitstate is not None:
         leave_frame(jitstate)
         jitstate.greens = []
-        jitstate.next = None        
         jitstate.returnbox = None
-    return jitstate
+        jitstate = jitstate.next
+    return return_chain
 
 def leave_frame(jitstate):
     myframe = jitstate.frame

Modified: pypy/dist/pypy/jit/timeshifter/rvalue.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/rvalue.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/rvalue.py	Fri Nov 24 19:17:06 2006
@@ -13,9 +13,10 @@
 def freeze_memo():
     return Memo()
 
-def exactmatch_memo():
+def exactmatch_memo(force_merge=False):
     memo = Memo()
     memo.partialdatamatch = {}
+    memo.force_merge=force_merge
     return memo
 
 def copy_memo():

Modified: pypy/dist/pypy/jit/timeshifter/test/test_portal.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/test/test_portal.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/test/test_portal.py	Fri Nov 24 19:17:06 2006
@@ -25,6 +25,7 @@
 
     def postprocess_timeshifting(self):
         self.readportalgraph = self.hrtyper.readportalgraph
+        self.readallportalsgraph = self.hrtyper.readallportalsgraph
         
     def _timeshift_from_portal(self, main, portal, main_args,
                               inline=None, policy=None,
@@ -82,6 +83,7 @@
                                                 inline=inline, policy=policy,
                                                 backendoptimize=backendoptimize)
         self.main_args = main_args
+        self.main_is_portal = main is portal
         llinterp = LLInterpreter(self.rtyper)
         res = llinterp.eval_graph(self.maingraph, main_args)
         return res
@@ -89,8 +91,14 @@
     def check_insns(self, expected=None, **counts):
         # XXX only works if the portal is the same as the main
         llinterp = LLInterpreter(self.rtyper)
-        residual_graph = llinterp.eval_graph(self.readportalgraph,
-                                             self.main_args)._obj.graph
+        if self.main_is_portal:
+            residual_graph = llinterp.eval_graph(self.readportalgraph,
+                                                 self.main_args)._obj.graph
+        else:
+            residual_graphs = llinterp.eval_graph(self.readallportalsgraph, [])
+            assert residual_graphs.ll_length() == 1
+            residual_graph = residual_graphs.ll_getitem_fast(0)._obj.graph
+            
         self.insns = summary(residual_graph)
         if expected is not None:
             assert self.insns == expected
@@ -207,6 +215,45 @@
             def get(self):
                 return ord(self.s[4])
 
+        def ll_main(n):
+            if n > 0:
+                o = Int(n)
+            else:
+                o = Str('123')
+            return ll_function(o)
+
+        def ll_function(o):
+            hint(None, global_merge_point=True)
+            hint(o.__class__, promote=True)
+            return o.double().get()
+
+        res = self.timeshift_from_portal(ll_main, ll_function, [5], policy=P_NOVIRTUAL)
+        assert res == 10
+        self.check_insns(indirect_call=0)
+
+        res = self.timeshift_from_portal(ll_main, ll_function, [0], policy=P_NOVIRTUAL)
+        assert res == ord('2')
+        self.check_insns(indirect_call=0)
+
+    def test_virt_obj_method_call_promote(self):
+        py.test.skip('WIP')
+        class Base(object):
+            pass
+        class Int(Base):
+            def __init__(self, n):
+                self.n = n
+            def double(self):
+                return Int(self.n * 2)
+            def get(self):
+                return self.n
+        class Str(Base):
+            def __init__(self, s):
+                self.s = s
+            def double(self):
+                return Str(self.s + self.s)
+            def get(self):
+                return ord(self.s[4])
+
         def ll_make(n):
             if n > 0:
                 return Int(n)
@@ -221,8 +268,8 @@
 
         res = self.timeshift_from_portal(ll_function, ll_function, [5], policy=P_NOVIRTUAL)
         assert res == 10
-        self.check_insns(indirect_call=0) #, malloc=0)
+        self.check_insns(indirect_call=0, malloc=0)
 
         res = self.timeshift_from_portal(ll_function, ll_function, [0], policy=P_NOVIRTUAL)
         assert res == ord('2')
-        self.check_insns(indirect_call=0) #, malloc=0)
+        self.check_insns(indirect_call=0, malloc=0)

Modified: pypy/dist/pypy/jit/timeshifter/transform.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/transform.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/transform.py	Fri Nov 24 19:17:06 2006
@@ -37,9 +37,10 @@
 class HintGraphTransformer(object):
     c_dummy = inputconst(lltype.Void, None)
 
-    def __init__(self, hannotator, graph):
+    def __init__(self, hannotator, graph, is_portal=False):
         self.hannotator = hannotator
         self.graph = graph
+        self.is_portal = is_portal
         self.graphcolor = self.graph_calling_color(graph)
         self.resumepoints = {}
         self.mergepoint_set = {}    # set of blocks
@@ -399,7 +400,6 @@
         elif self.graphcolor == 'yellow':
             self.genop(block, 'save_greens', [v_retbox])
         elif self.graphcolor == 'red':
-            self.leave_graph_opname = 'leave_graph_red'
             self.genop(block, 'save_locals', [v_retbox])
         else:
             raise AssertionError(self.graph, self.graphcolor)
@@ -419,7 +419,11 @@
 
     def insert_leave_graph(self):
         block = self.before_return_block()
-        self.genop(block, 'leave_graph_%s' % (self.graphcolor,), [])
+        if self.is_portal:
+            assert self.graphcolor == 'red'
+            self.genop(block, 'leave_graph_portal', [])
+        else:
+            self.genop(block, 'leave_graph_%s' % (self.graphcolor,), [])
 
     # __________ handling of the various kinds of calls __________
 



More information about the Pypy-commit mailing list