[pypy-commit] pypy conditional_call_value_4: in-progress

arigo pypy.commits at gmail.com
Wed Nov 23 04:09:22 EST 2016


Author: Armin Rigo <arigo at tunes.org>
Branch: conditional_call_value_4
Changeset: r88568:2353e0a7c50b
Date: 2016-11-23 10:08 +0100
http://bitbucket.org/pypy/pypy/changeset/2353e0a7c50b/

Log:	in-progress

diff --git a/rpython/jit/codewriter/jtransform.py b/rpython/jit/codewriter/jtransform.py
--- a/rpython/jit/codewriter/jtransform.py
+++ b/rpython/jit/codewriter/jtransform.py
@@ -1569,7 +1569,7 @@
             return []
         return getattr(self, 'handle_jit_marker__%s' % key)(op, jitdriver)
 
-    def rewrite_op_jit_conditional_call(self, op):
+    def _rewrite_op_cond_call(self, op, rewritten_opname):
         have_floats = False
         for arg in op.args:
             if getkind(arg.concretetype) == 'float':
@@ -1580,13 +1580,18 @@
         callop = SpaceOperation('direct_call', op.args[1:], op.result)
         calldescr = self.callcontrol.getcalldescr(callop)
         assert not calldescr.get_extra_info().check_forces_virtual_or_virtualizable()
-        op1 = self.rewrite_call(op, 'conditional_call',
+        op1 = self.rewrite_call(op, rewritten_opname,
                                 op.args[:2], args=op.args[2:],
                                 calldescr=calldescr, force_ir=True)
         if self.callcontrol.calldescr_canraise(calldescr):
             op1 = [op1, SpaceOperation('-live-', [], None)]
         return op1
 
+    def rewrite_op_jit_conditional_call(self, op):
+        return self._rewrite_op_cond_call(op, 'conditional_call')
+    def rewrite_op_jit_conditional_call_value(self, op):
+        return self._rewrite_op_cond_call(op, 'conditional_call_value')
+
     def handle_jit_marker__jit_merge_point(self, op, jitdriver):
         assert self.portal_jd is not None, (
             "'jit_merge_point' in non-portal graph!")
diff --git a/rpython/jit/metainterp/blackhole.py b/rpython/jit/metainterp/blackhole.py
--- a/rpython/jit/metainterp/blackhole.py
+++ b/rpython/jit/metainterp/blackhole.py
@@ -1199,13 +1199,27 @@
     def bhimpl_residual_call_irf_v(cpu, func, args_i,args_r,args_f,calldescr):
         return cpu.bh_call_v(func, args_i, args_r, args_f, calldescr)
 
-    # conditional calls - note that they cannot return stuff
     @arguments("cpu", "i", "i", "I", "R", "d")
     def bhimpl_conditional_call_ir_v(cpu, condition, func, args_i, args_r,
                                      calldescr):
+        # conditional calls - condition is a flag, and they cannot return stuff
         if condition:
             cpu.bh_call_v(func, args_i, args_r, None, calldescr)
 
+    @arguments("cpu", "i", "i", "I", "R", "d")
+    def bhimpl_conditional_call_value_ir_i(cpu, value, func, args_i, args_r,
+                                           calldescr):
+        if value == 0:
+            value = cpu.bh_call_i(func, args_i, args_r, None, calldescr)
+        return value
+
+    @arguments("cpu", "r", "i", "I", "R", "d")
+    def bhimpl_conditional_call_value_ir_r(cpu, value, func, args_i, args_r,
+                                           calldescr):
+        if not value:
+            value = cpu.bh_call_r(func, args_i, args_r, None, calldescr)
+        return value
+
     @arguments("cpu", "j", "R", returns="i")
     def bhimpl_inline_call_r_i(cpu, jitcode, args_r):
         return cpu.bh_call_i(jitcode.get_fnaddr_as_int(),
diff --git a/rpython/jit/metainterp/pyjitpl.py b/rpython/jit/metainterp/pyjitpl.py
--- a/rpython/jit/metainterp/pyjitpl.py
+++ b/rpython/jit/metainterp/pyjitpl.py
@@ -1059,8 +1059,19 @@
     @arguments("box", "box", "boxes2", "descr", "orgpc")
     def opimpl_conditional_call_ir_v(self, condbox, funcbox, argboxes,
                                      calldescr, pc):
+        if isinstance(condbox, ConstInt) and condbox.value == 0:
+            return   # so that the heapcache can keep argboxes virtual
         self.do_conditional_call(condbox, funcbox, argboxes, calldescr, pc)
 
+    @arguments("box", "box", "boxes2", "descr", "orgpc")
+    def _opimpl_conditional_call_value(self, valuebox, funcbox, argboxes,
+                                       calldescr, pc):
+        return self.do_conditional_call(valuebox, funcbox, argboxes,
+                                        calldescr, pc, is_value=True)
+
+    opimpl_conditional_call_value_ir_i = _opimpl_conditional_call_value
+    opimpl_conditional_call_value_ir_r = _opimpl_conditional_call_value
+
     @arguments("int", "boxes3", "boxes3", "orgpc")
     def _opimpl_recursive_call(self, jdindex, greenboxes, redboxes, pc):
         targetjitdriver_sd = self.metainterp.staticdata.jitdrivers_sd[jdindex]
@@ -1538,7 +1549,7 @@
                                                             descr=descr)
         if pure and not self.metainterp.last_exc_value and op:
             op = self.metainterp.record_result_of_call_pure(op, argboxes, descr,
-                patch_pos)
+                patch_pos, opnum)
             exc = exc and not isinstance(op, Const)
         if exc:
             if op is not None:
@@ -1712,16 +1723,21 @@
             else:
                 assert False
 
-    def do_conditional_call(self, condbox, funcbox, argboxes, descr, pc):
-        if isinstance(condbox, ConstInt) and condbox.value == 0:
-            return   # so that the heapcache can keep argboxes virtual
+    def do_conditional_call(self, condbox, funcbox, argboxes, descr, pc,
+                            is_value=False):
         allboxes = self._build_allboxes(funcbox, argboxes, descr)
         effectinfo = descr.get_extra_info()
         assert not effectinfo.check_forces_virtual_or_virtualizable()
         exc = effectinfo.check_can_raise()
-        pure = effectinfo.check_is_elidable()
-        return self.execute_varargs(rop.COND_CALL, [condbox] + allboxes, descr,
-                                    exc, pure)
+        if not is_value:
+            opnum = rop.COND_CALL
+        else:
+            opnum = OpHelpers.cond_call_value_for_descr(descr)
+        # COND_CALL cannot be pure (=elidable): it has no result.
+        # On the other hand, COND_CALL_VALUE is always calling a pure
+        # function.
+        return self.execute_varargs(opnum, [condbox] + allboxes, descr,
+                                    exc, pure=is_value)
 
     def _do_jit_force_virtual(self, allboxes, descr, pc):
         assert len(allboxes) == 2
@@ -3061,11 +3077,16 @@
         debug_stop("jit-abort-longest-function")
         return max_jdsd, max_key
 
-    def record_result_of_call_pure(self, op, argboxes, descr, patch_pos):
+    def record_result_of_call_pure(self, op, argboxes, descr, patch_pos, opnum):
         """ Patch a CALL into a CALL_PURE.
         """
         resbox_as_const = executor.constant_from_op(op)
-        for argbox in argboxes:
+        is_cond_value = OpHelpers.is_cond_call_value(opnum)
+        if is_cond_value:
+            normargboxes = argboxes[1:]    # ingore the 'value' arg
+        else:
+            normargboxes = argboxes
+        for argbox in normargboxes:
             if not isinstance(argbox, Const):
                 break
         else:
@@ -3077,6 +3098,8 @@
         # be either removed later by optimizeopt or turned back into CALL.
         arg_consts = [executor.constant_from_op(a) for a in argboxes]
         self.call_pure_results[arg_consts] = resbox_as_const
+        if is_cond_value:
+            return op       # but COND_CALL_VALUE remains
         opnum = OpHelpers.call_pure_for_descr(descr)
         self.history.cut(patch_pos)
         newop = self.history.record_nospec(opnum, argboxes, descr)
diff --git a/rpython/jit/metainterp/resoperation.py b/rpython/jit/metainterp/resoperation.py
--- a/rpython/jit/metainterp/resoperation.py
+++ b/rpython/jit/metainterp/resoperation.py
@@ -1151,7 +1151,7 @@
     '_CALL_FIRST',
     'CALL/*d/rfin',
     'COND_CALL/*d/n',   # a conditional call, with first argument as a condition
-    'COND_CALL_VALUE/*d/ri',  # same but returns a result; emitted by rewrite
+    'COND_CALL_VALUE/*d/ri',  # "return a0 or a1(a2, ..)", a1 elidable
     'CALL_ASSEMBLER/*d/rfin',  # call already compiled assembler
     'CALL_MAY_FORCE/*d/rfin',
     'CALL_LOOPINVARIANT/*d/rfin',
@@ -1274,6 +1274,15 @@
         return rop.CALL_LOOPINVARIANT_N
 
     @staticmethod
+    def cond_call_value_for_descr(descr):
+        tp = descr.get_normalized_result_type()
+        if tp == 'i':
+            return rop.COND_CALL_VALUE_I
+        elif tp == 'r':
+            return rop.COND_CALL_VALUE_R
+        assert False, tp
+
+    @staticmethod
     def getfield_pure_for_descr(descr):
         if descr.is_pointer_field():
             return rop.GETFIELD_GC_PURE_R
@@ -1447,6 +1456,11 @@
                 opnum == rop.CALL_RELEASE_GIL_N)
 
     @staticmethod
+    def is_cond_call_value(opnum):
+        return (opnum == rop.COND_CALL_VALUE_I or
+                opnum == rop.COND_CALL_VALUE_R)
+
+    @staticmethod
     def is_ovf(opnum):
         return rop._OVF_FIRST <= opnum <= rop._OVF_LAST
 
diff --git a/rpython/jit/metainterp/test/test_ajit.py b/rpython/jit/metainterp/test/test_ajit.py
--- a/rpython/jit/metainterp/test/test_ajit.py
+++ b/rpython/jit/metainterp/test/test_ajit.py
@@ -4575,3 +4575,15 @@
 
         self.meta_interp(g, [5, 5, 5])
         self.check_resops(guard_true=10)   # 5 unrolled, plus 5 unrelated
+
+    def test_conditional_call_value(self):
+        from rpython.rlib.jit import conditional_call_value
+        @elidable
+        def g(j):
+            return j + 5
+        def f(i, j):
+            return conditional_call_value(i, g, j)
+        res = self.interp_operations(f, [-42, 200])
+        assert res == -42
+        res = self.interp_operations(f, [0, 200])
+        assert res == 205
diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -1246,6 +1246,9 @@
                                                  args_s[1], args_s[2:])
         if self.instance == _jit_conditional_call_value:
             from rpython.annotator import model as annmodel
+            func = args_s[1].const
+            # conditional_call_value(): function must be elidable
+            assert func._elidable_function_
             return annmodel.unionof(s_res, args_s[0])
 
     def specialize_call(self, hop):
diff --git a/rpython/rlib/test/test_jit.py b/rpython/rlib/test/test_jit.py
--- a/rpython/rlib/test/test_jit.py
+++ b/rpython/rlib/test/test_jit.py
@@ -303,6 +303,7 @@
         mix.finish()
 
     def test_conditional_call_value(self):
+        @elidable
         def g(x, y):
             return x - y + 5
         def f(n, x, y):
diff --git a/rpython/rtyper/lltypesystem/lloperation.py b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -431,6 +431,7 @@
     'jit_record_exact_class'  : LLOp(canrun=True),
     'jit_ffi_save_result':  LLOp(canrun=True),
     'jit_conditional_call': LLOp(),
+    'jit_conditional_call_value': LLOp(),
     'jit_enter_portal_frame': LLOp(canrun=True),
     'jit_leave_portal_frame': LLOp(canrun=True),
     'get_exception_addr':   LLOp(),
diff --git a/rpython/translator/c/funcgen.py b/rpython/translator/c/funcgen.py
--- a/rpython/translator/c/funcgen.py
+++ b/rpython/translator/c/funcgen.py
@@ -456,6 +456,9 @@
     def OP_JIT_CONDITIONAL_CALL(self, op):
         return 'abort();  /* jit_conditional_call */'
 
+    def OP_JIT_CONDITIONAL_CALL_VALUE(self, op):
+        return 'abort();  /* jit_conditional_call_value */'
+
     # low-level operations
     def generic_get(self, op, sourceexpr):
         T = self.lltypemap(op.result)


More information about the pypy-commit mailing list