[pypy-commit] pypy default: Merge recursion_and_inlining.

ltratt noreply at buildbot.pypy.org
Tue Dec 9 17:29:57 CET 2014


Author: Laurence Tratt <laurie at tratt.net>
Branch: 
Changeset: r74869:70d88f23b9bb
Date: 2014-12-09 16:29 +0000
http://bitbucket.org/pypy/pypy/changeset/70d88f23b9bb/

Log:	Merge recursion_and_inlining.

	This branch stops inlining in recursive function calls after N
	levels of (posiblly indirect) recursion in a function (where N is
	configurable; what the best possible value of N might be is still a
	little unclear, and ideally requires testing on a wider range of
	benchmarks). This stops us abusing abort as a way of stopping
	inlining in recursion, and improves performance.

diff --git a/rpython/jit/metainterp/pyjitpl.py b/rpython/jit/metainterp/pyjitpl.py
--- a/rpython/jit/metainterp/pyjitpl.py
+++ b/rpython/jit/metainterp/pyjitpl.py
@@ -964,9 +964,40 @@
         assembler_call = False
         if warmrunnerstate.inlining:
             if warmrunnerstate.can_inline_callable(greenboxes):
+                # We've found a potentially inlinable function; now we need to
+                # see if it's already on the stack. In other words: are we about
+                # to enter recursion? If so, we don't want to inline the
+                # recursion, which would be equivalent to unrolling a while
+                # loop.
                 portal_code = targetjitdriver_sd.mainjitcode
-                return self.metainterp.perform_call(portal_code, allboxes,
-                                                    greenkey=greenboxes)
+                count = 0
+                for f in self.metainterp.framestack:
+                    if f.jitcode is not portal_code:
+                        continue
+                    gk = f.greenkey
+                    if gk is None:
+                        continue
+                    assert len(gk) == len(greenboxes)
+                    i = 0
+                    for i in range(len(gk)):
+                        if not gk[i].same_constant(greenboxes[i]):
+                            break
+                    else:
+                        count += 1
+                memmgr = self.metainterp.staticdata.warmrunnerdesc.memory_manager
+                if count >= memmgr.max_unroll_recursion:
+                    # This function is recursive and has exceeded the
+                    # maximum number of unrollings we allow. We want to stop
+                    # inlining it further and to make sure that, if it
+                    # hasn't happened already, the function is traced
+                    # separately as soon as possible.
+                    if have_debug_prints():
+                        loc = targetjitdriver_sd.warmstate.get_location_str(greenboxes)
+                        debug_print("recursive function (not inlined):", loc)
+                    warmrunnerstate.dont_trace_here(greenboxes)
+                else:
+                    return self.metainterp.perform_call(portal_code, allboxes,
+                                greenkey=greenboxes)
             assembler_call = True
             # verify that we have all green args, needed to make sure
             # that assembler that we call is still correct
diff --git a/rpython/jit/metainterp/test/test_recursive.py b/rpython/jit/metainterp/test/test_recursive.py
--- a/rpython/jit/metainterp/test/test_recursive.py
+++ b/rpython/jit/metainterp/test/test_recursive.py
@@ -1112,6 +1112,37 @@
         assert res == 2095
         self.check_resops(call_assembler=12)
 
+    def test_inline_recursion_limit(self):
+        driver = JitDriver(greens = ["threshold", "loop"], reds=["i"])
+        @dont_look_inside
+        def f():
+            set_param(driver, "max_unroll_recursion", 10)
+        def portal(threshold, loop, i):
+            f()
+            if i > threshold:
+                return i
+            while True:
+                driver.jit_merge_point(threshold=threshold, loop=loop, i=i)
+                if loop:
+                    portal(threshold, False, 0)
+                else:
+                    portal(threshold, False, i + 1)
+                    return i
+                if i > 10:
+                    return 1
+                i += 1
+                driver.can_enter_jit(threshold=threshold, loop=loop, i=i)
+
+        res1 = portal(10, True, 0)
+        res2 = self.meta_interp(portal, [10, True, 0], inline=True)
+        assert res1 == res2
+        self.check_resops(call_assembler=2)
+
+        res1 = portal(9, True, 0)
+        res2 = self.meta_interp(portal, [9, True, 0], inline=True)
+        assert res1 == res2
+        self.check_resops(call_assembler=0)
+
     def test_handle_jitexception_in_portal(self):
         # a test for _handle_jitexception_in_portal in blackhole.py
         driver = JitDriver(greens = ['codeno'], reds = ['i', 'str'],
diff --git a/rpython/jit/metainterp/warmspot.py b/rpython/jit/metainterp/warmspot.py
--- a/rpython/jit/metainterp/warmspot.py
+++ b/rpython/jit/metainterp/warmspot.py
@@ -69,7 +69,8 @@
                     backendopt=False, trace_limit=sys.maxint,
                     inline=False, loop_longevity=0, retrace_limit=5,
                     function_threshold=4,
-                    enable_opts=ALL_OPTS_NAMES, max_retrace_guards=15, **kwds):
+                    enable_opts=ALL_OPTS_NAMES, max_retrace_guards=15, 
+                    max_unroll_recursion=7, **kwds):
     from rpython.config.config import ConfigError
     translator = interp.typer.annotator.translator
     try:
@@ -91,6 +92,7 @@
         jd.warmstate.set_param_retrace_limit(retrace_limit)
         jd.warmstate.set_param_max_retrace_guards(max_retrace_guards)
         jd.warmstate.set_param_enable_opts(enable_opts)
+        jd.warmstate.set_param_max_unroll_recursion(max_unroll_recursion)
     warmrunnerdesc.finish()
     if graph_and_interp_only:
         return interp, graph
diff --git a/rpython/jit/metainterp/warmstate.py b/rpython/jit/metainterp/warmstate.py
--- a/rpython/jit/metainterp/warmstate.py
+++ b/rpython/jit/metainterp/warmstate.py
@@ -291,6 +291,11 @@
             if self.warmrunnerdesc.memory_manager:
                 self.warmrunnerdesc.memory_manager.max_unroll_loops = value
 
+    def set_param_max_unroll_recursion(self, value):
+        if self.warmrunnerdesc:
+            if self.warmrunnerdesc.memory_manager:
+                self.warmrunnerdesc.memory_manager.max_unroll_recursion = value
+
     def disable_noninlinable_function(self, greenkey):
         cell = self.JitCell.ensure_jit_cell_at_key(greenkey)
         cell.flags |= JC_DONT_TRACE_HERE
@@ -567,19 +572,26 @@
         jd = self.jitdriver_sd
         cpu = self.cpu
 
-        def can_inline_greenargs(*greenargs):
+        def can_inline_callable(greenkey):
+            greenargs = unwrap_greenkey(greenkey)
             if can_never_inline(*greenargs):
                 return False
             cell = JitCell.get_jitcell(*greenargs)
             if cell is not None and (cell.flags & JC_DONT_TRACE_HERE) != 0:
                 return False
             return True
-        def can_inline_callable(greenkey):
-            greenargs = unwrap_greenkey(greenkey)
-            return can_inline_greenargs(*greenargs)
-        self.can_inline_greenargs = can_inline_greenargs
         self.can_inline_callable = can_inline_callable
 
+        def dont_trace_here(greenkey):
+            # Set greenkey as somewhere that tracing should not occur into;
+            # notice that, as per the description of JC_DONT_TRACE_HERE earlier,
+            # if greenkey hasn't been traced separately, setting
+            # JC_DONT_TRACE_HERE will force tracing the next time the function
+            # is encountered.
+            cell = JitCell.ensure_jit_cell_at_key(greenkey)
+            cell.flags |= JC_DONT_TRACE_HERE
+        self.dont_trace_here = dont_trace_here
+
         if jd._should_unroll_one_iteration_ptr is None:
             def should_unroll_one_iteration(greenkey):
                 return False
diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -463,6 +463,7 @@
     'max_unroll_loops': 'number of extra unrollings a loop can cause',
     'enable_opts': 'INTERNAL USE ONLY (MAY NOT WORK OR LEAD TO CRASHES): '
                    'optimizations to enable, or all = %s' % ENABLE_ALL_OPTS,
+    'max_unroll_recursion': 'how many levels deep to unroll a recursive function'
     }
 
 PARAMETERS = {'threshold': 1039, # just above 1024, prime
@@ -476,6 +477,7 @@
               'max_retrace_guards': 15,
               'max_unroll_loops': 0,
               'enable_opts': 'all',
+              'max_unroll_recursion': 7,
               }
 unroll_parameters = unrolling_iterable(PARAMETERS.items())
 


More information about the pypy-commit mailing list