[pypy-svn] r76066 - in pypy/branch/kill-caninline/pypy: jit/codewriter jit/metainterp jit/metainterp/test rlib

arigo at codespeak.net arigo at codespeak.net
Fri Jul 9 12:38:06 CEST 2010


Author: arigo
Date: Fri Jul  9 12:38:04 2010
New Revision: 76066

Modified:
   pypy/branch/kill-caninline/pypy/jit/codewriter/jtransform.py
   pypy/branch/kill-caninline/pypy/jit/metainterp/history.py
   pypy/branch/kill-caninline/pypy/jit/metainterp/pyjitpl.py
   pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_basic.py
   pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_jitdriver.py
   pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_recursive.py
   pypy/branch/kill-caninline/pypy/jit/metainterp/warmspot.py
   pypy/branch/kill-caninline/pypy/rlib/jit.py
Log:
Progress.


Modified: pypy/branch/kill-caninline/pypy/jit/codewriter/jtransform.py
==============================================================================
--- pypy/branch/kill-caninline/pypy/jit/codewriter/jtransform.py	(original)
+++ pypy/branch/kill-caninline/pypy/jit/codewriter/jtransform.py	Fri Jul  9 12:38:04 2010
@@ -804,7 +804,9 @@
                 self.make_three_lists(op.args[2:2+num_green_args]) +
                 self.make_three_lists(op.args[2+num_green_args:]))
         op1 = SpaceOperation('jit_merge_point', args, None)
-        return ops + [op1]
+        op2 = SpaceOperation('-live-', [], None)
+        # ^^^ we need a -live- for the case of do_recursive_call()
+        return ops + [op1, op2]
 
     def handle_jit_marker__can_enter_jit(self, op, jitdriver):
         jd = self.callcontrol.jitdriver_sd_from_jitdriver(jitdriver)

Modified: pypy/branch/kill-caninline/pypy/jit/metainterp/history.py
==============================================================================
--- pypy/branch/kill-caninline/pypy/jit/metainterp/history.py	(original)
+++ pypy/branch/kill-caninline/pypy/jit/metainterp/history.py	Fri Jul  9 12:38:04 2010
@@ -919,11 +919,12 @@
                 "found %d %r, expected %d" % (found, insn, expected_count))
         return insns
 
-    def check_loops(self, expected=None, **check):
+    def check_loops(self, expected=None, everywhere=False, **check):
         insns = {}
         for loop in self.loops:
-            if getattr(loop, '_ignore_during_counting', False):
-                continue
+            if not everywhere:
+                if getattr(loop, '_ignore_during_counting', False):
+                    continue
             insns = loop.summary(adding_insns=insns)
         if expected is not None:
             insns.pop('debug_merge_point', None)

Modified: pypy/branch/kill-caninline/pypy/jit/metainterp/pyjitpl.py
==============================================================================
--- pypy/branch/kill-caninline/pypy/jit/metainterp/pyjitpl.py	(original)
+++ pypy/branch/kill-caninline/pypy/jit/metainterp/pyjitpl.py	Fri Jul  9 12:38:04 2010
@@ -149,8 +149,6 @@
             assert oldbox not in registers[count:]
 
     def make_result_of_lastop(self, resultbox):
-        if resultbox is None:
-            return
         target_index = ord(self.bytecode[self.pc-1])
         if resultbox.type == history.INT:
             self.registers_i[target_index] = resultbox
@@ -685,11 +683,11 @@
     def _opimpl_recursive_call(self, jdindex, greenboxes, redboxes):
         targetjitdriver_sd = self.metainterp.staticdata.jitdrivers_sd[jdindex]
         allboxes = greenboxes + redboxes
-        portal_code = targetjitdriver_sd.mainjitcode
         warmrunnerstate = targetjitdriver_sd.warmstate
         token = None
         if warmrunnerstate.inlining:
             if warmrunnerstate.can_inline_callable(greenboxes):
+                portal_code = targetjitdriver_sd.mainjitcode
                 return self.metainterp.perform_call(portal_code, allboxes,
                                                     greenkey=greenboxes)
             token = warmrunnerstate.get_assembler_token(greenboxes)
@@ -697,6 +695,10 @@
             # that assembler that we call is still correct
             self.verify_green_args(targetjitdriver_sd, greenboxes)
         #
+        return self.do_recursive_call(targetjitdriver_sd, allboxes, token)
+
+    def do_recursive_call(self, targetjitdriver_sd, allboxes, token=None):
+        portal_code = targetjitdriver_sd.mainjitcode
         k = targetjitdriver_sd.portal_runner_adr
         funcbox = ConstInt(heaptracker.adr2int(k))
         return self.do_residual_call(funcbox, portal_code.calldescr,
@@ -787,12 +789,7 @@
 
     @arguments("int")
     def opimpl_can_enter_jit(self, jdindex):
-        if self.metainterp.in_recursion:
-            from pypy.jit.metainterp.warmspot import CannotInlineCanEnterJit
-            raise CannotInlineCanEnterJit()
-        assert jdindex == self.metainterp.jitdriver_sd.index, (
-            "found a can_enter_jit that does not match the current jitdriver")
-        self.metainterp.seen_can_enter_jit = True
+        self.metainterp.seen_can_enter_jit_for_jdindex = jdindex
 
     def verify_green_args(self, jitdriver_sd, varargs):
         num_green_args = jitdriver_sd.num_green_args
@@ -806,13 +803,15 @@
         self.verify_green_args(jitdriver_sd, greenboxes)
         # xxx we may disable the following line in some context later
         self.debug_merge_point(jitdriver_sd, greenboxes)
-        if self.metainterp.seen_can_enter_jit:
-            self.metainterp.seen_can_enter_jit = False
-            # Assert that it's impossible to arrive here with in_recursion
-            # set to a non-zero value: seen_can_enter_jit can only be set
-            # to True by opimpl_can_enter_jit, which should be executed
-            # just before opimpl_jit_merge_point (no recursion inbetween).
-            assert not self.metainterp.in_recursion
+        if self.metainterp.seen_can_enter_jit_for_jdindex < 0:
+            return
+        #
+        assert self.metainterp.seen_can_enter_jit_for_jdindex == jdindex, (
+            "found a can_enter_jit for a JitDriver that does not match "
+            "the following jit_merge_point's")
+        self.metainterp.seen_can_enter_jit_for_jdindex = -1
+        #
+        if not self.metainterp.in_recursion:
             assert jitdriver_sd is self.metainterp.jitdriver_sd
             # Set self.pc to point to jit_merge_point instead of just after:
             # if reached_can_enter_jit() raises SwitchToBlackhole, then the
@@ -822,6 +821,15 @@
             self.pc = orgpc
             self.metainterp.reached_can_enter_jit(greenboxes, redboxes)
             self.pc = saved_pc
+        else:
+            warmrunnerstate = jitdriver_sd.warmstate
+            token = warmrunnerstate.get_assembler_token(greenboxes)
+            resbox = self.do_recursive_call(jitdriver_sd,
+                                            greenboxes + redboxes,
+                                            token)
+            # in case of exception, do_recursive_call() stops by raising
+            # the ChangeFrame exception already.
+            self.metainterp.finishframe(resbox)
 
     def debug_merge_point(self, jitdriver_sd, greenkey):
         # debugging: produce a DEBUG_MERGE_POINT operation
@@ -1018,9 +1026,10 @@
         self.metainterp.clear_exception()
         resbox = self.metainterp.execute_and_record_varargs(opnum, argboxes,
                                                             descr=descr)
-        self.make_result_of_lastop(resbox)
-        # ^^^ this is done before handle_possible_exception() because we need
-        # the box to show up in get_list_of_active_boxes()
+        if resbox is not None:
+            self.make_result_of_lastop(resbox)
+            # ^^^ this is done before handle_possible_exception() because we
+            # need the box to show up in get_list_of_active_boxes()
         if exc:
             self.metainterp.handle_possible_exception()
         else:
@@ -1323,7 +1332,8 @@
         self.last_exc_value_box = None
         self.popframe()
         if self.framestack:
-            self.framestack[-1].make_result_of_lastop(resultbox)
+            if resultbox is not None:
+                self.framestack[-1].make_result_of_lastop(resultbox)
             raise ChangeFrame
         else:
             try:
@@ -1552,7 +1562,7 @@
         redkey = original_boxes[num_green_args:]
         self.resumekey = compile.ResumeFromInterpDescr(original_greenkey,
                                                        redkey)
-        self.seen_can_enter_jit = False
+        self.seen_can_enter_jit_for_jdindex = -1
         try:
             self.interpret()
         except GenerateMergePoint, gmp:
@@ -1579,7 +1589,7 @@
         # because we cannot reconstruct the beginning of the proper loop
         self.current_merge_points = [(original_greenkey, -1)]
         self.resumekey = key
-        self.seen_can_enter_jit = False
+        self.seen_can_enter_jit_for_jdindex = -1
         try:
             self.prepare_resume_from_failure(key.guard_opnum)
             self.interpret()
@@ -2232,7 +2242,8 @@
         else:
             resultbox = unboundmethod(self, *args)
         #
-        self.make_result_of_lastop(resultbox)
+        if resultbox is not None:
+            self.make_result_of_lastop(resultbox)
     #
     unboundmethod = getattr(MIFrame, 'opimpl_' + name).im_func
     argtypes = unrolling_iterable(unboundmethod.argtypes)

Modified: pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_basic.py
==============================================================================
--- pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_basic.py	(original)
+++ pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_basic.py	Fri Jul  9 12:38:04 2010
@@ -116,8 +116,9 @@
 
 class JitMixin:
     basic = True
-    def check_loops(self, expected=None, **check):
-        get_stats().check_loops(expected=expected, **check)
+    def check_loops(self, expected=None, everywhere=False, **check):
+        get_stats().check_loops(expected=expected, everywhere=everywhere,
+                                **check)
     def check_loop_count(self, count):
         """NB. This is a hack; use check_tree_loop_count() or
         check_enter_count() for the real thing.

Modified: pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_jitdriver.py
==============================================================================
--- pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_jitdriver.py	(original)
+++ pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_jitdriver.py	Fri Jul  9 12:38:04 2010
@@ -14,7 +14,6 @@
 
     def test_simple(self):
         myjitdriver1 = JitDriver(greens=[], reds=['n', 'm'],
-                                 can_inline = lambda *args: False,
                                  get_printable_location = getloc1)
         myjitdriver2 = JitDriver(greens=['g'], reds=['r'],
                                  get_printable_location = getloc2)

Modified: pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_recursive.py
==============================================================================
--- pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_recursive.py	(original)
+++ pypy/branch/kill-caninline/pypy/jit/metainterp/test/test_recursive.py	Fri Jul  9 12:38:04 2010
@@ -5,7 +5,7 @@
 from pypy.jit.metainterp.test.test_basic import LLJitMixin, OOJitMixin
 from pypy.jit.codewriter.policy import StopAtXPolicy
 from pypy.rpython.annlowlevel import hlstr
-from pypy.jit.metainterp.warmspot import CannotInlineCanEnterJit, get_stats
+from pypy.jit.metainterp.warmspot import get_stats
 
 class RecursiveTests:
 
@@ -968,6 +968,94 @@
         assert res == portal(2, 0)
         self.check_loops(call_assembler=2)
 
+    def test_inline_without_hitting_the_loop(self):
+        driver = JitDriver(greens = ['codeno'], reds = ['i'],
+                           get_printable_location = lambda codeno : str(codeno))
+
+        def portal(codeno):
+            i = 0
+            while True:
+                driver.jit_merge_point(codeno=codeno, i=i)
+                if codeno < 10:
+                    i += portal(20)
+                    codeno += 1
+                elif codeno == 10:
+                    if i > 63:
+                        return i
+                    codeno = 0
+                    driver.can_enter_jit(codeno=codeno, i=i)
+                else:
+                    return 1
+
+        assert portal(0) == 70
+        res = self.meta_interp(portal, [0], inline=True)
+        assert res == 70
+        self.check_loops(call_assembler=0)
+
+    def test_inline_with_hitting_the_loop_sometimes(self):
+        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
+                           get_printable_location = lambda codeno : str(codeno))
+
+        def portal(codeno, k):
+            if k > 2:
+                return 1
+            i = 0
+            while True:
+                driver.jit_merge_point(codeno=codeno, i=i, k=k)
+                if codeno < 10:
+                    i += portal(codeno + 5, k+1)
+                    codeno += 1
+                elif codeno == 10:
+                    if i > [-1, 2000, 63][k]:
+                        return i
+                    codeno = 0
+                    driver.can_enter_jit(codeno=codeno, i=i, k=k)
+                else:
+                    return 1
+
+        assert portal(0, 1) == 2095
+        res = self.meta_interp(portal, [0, 1], inline=True)
+        assert res == 2095
+        self.check_loops(call_assembler=6, everywhere=True)
+
+    def test_inline_with_hitting_the_loop_sometimes_exc(self):
+        driver = JitDriver(greens = ['codeno'], reds = ['i', 'k'],
+                           get_printable_location = lambda codeno : str(codeno))
+        class GotValue(Exception):
+            def __init__(self, result):
+                self.result = result
+
+        def portal(codeno, k):
+            if k > 2:
+                raise GotValue(1)
+            i = 0
+            while True:
+                driver.jit_merge_point(codeno=codeno, i=i, k=k)
+                if codeno < 10:
+                    try:
+                        portal(codeno + 5, k+1)
+                    except GotValue, e:
+                        i += e.result
+                    codeno += 1
+                elif codeno == 10:
+                    if i > [-1, 2000, 63][k]:
+                        raise GotValue(i)
+                    codeno = 0
+                    driver.can_enter_jit(codeno=codeno, i=i, k=k)
+                else:
+                    raise GotValue(1)
+
+        def main(codeno, k):
+            try:
+                portal(codeno, k)
+            except GotValue, e:
+                return e.result
+
+        assert main(0, 1) == 2095
+        res = self.meta_interp(main, [0, 1], inline=True)
+        assert res == 2095
+        self.check_loops(call_assembler=6, everywhere=True)
+
     # There is a test which I fail to write.
     #   * what happens if we call recursive_call while blackholing
     #     this seems to be completely corner case and not really happening

Modified: pypy/branch/kill-caninline/pypy/jit/metainterp/warmspot.py
==============================================================================
--- pypy/branch/kill-caninline/pypy/jit/metainterp/warmspot.py	(original)
+++ pypy/branch/kill-caninline/pypy/jit/metainterp/warmspot.py	Fri Jul  9 12:38:04 2010
@@ -136,9 +136,6 @@
 class ContinueRunningNormallyBase(JitException):
     pass
 
-class CannotInlineCanEnterJit(JitException):
-    pass
-
 # ____________________________________________________________
 
 class WarmRunnerDesc(object):

Modified: pypy/branch/kill-caninline/pypy/rlib/jit.py
==============================================================================
--- pypy/branch/kill-caninline/pypy/rlib/jit.py	(original)
+++ pypy/branch/kill-caninline/pypy/rlib/jit.py	Fri Jul  9 12:38:04 2010
@@ -253,8 +253,7 @@
     
     def __init__(self, greens=None, reds=None, virtualizables=None,
                  get_jitcell_at=None, set_jitcell_at=None,
-                 can_inline=None, get_printable_location=None,
-                 confirm_enter_jit=None):
+                 get_printable_location=None, confirm_enter_jit=None):
         if greens is not None:
             self.greens = greens
         if reds is not None:
@@ -270,7 +269,6 @@
         self.get_jitcell_at = get_jitcell_at
         self.set_jitcell_at = set_jitcell_at
         self.get_printable_location = get_printable_location
-        self.can_inline = can_inline
         self.confirm_enter_jit = confirm_enter_jit
 
     def _freeze_(self):
@@ -384,7 +382,6 @@
         self.annotate_hook(driver.get_jitcell_at, driver.greens, **kwds_s)
         self.annotate_hook(driver.set_jitcell_at, driver.greens, [s_jitcell],
                            **kwds_s)
-        self.annotate_hook(driver.can_inline, driver.greens, **kwds_s)
         self.annotate_hook(driver.get_printable_location, driver.greens, **kwds_s)
 
     def annotate_hook(self, func, variables, args_s=[], **kwds_s):



More information about the Pypy-commit mailing list