[pypy-commit] pypy default: hg merge conditional_call_value_3

arigo pypy.commits at gmail.com
Mon Sep 12 10:32:17 EDT 2016


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r87044:6305cfb3bad2
Date: 2016-09-12 16:31 +0200
http://bitbucket.org/pypy/pypy/changeset/6305cfb3bad2/

Log:	hg merge conditional_call_value_3

	JIT residual calls: if the called function starts with a fast-path
	like "if x.foo != 0: return x.foo", then inline the check before
	doing the CALL.

	Right now only implemented on the x86 backend. Other backends
	specify supports_cond_call_value = False.

diff --git a/pypy/module/pypyjit/test_pypy_c/test_containers.py b/pypy/module/pypyjit/test_pypy_c/test_containers.py
--- a/pypy/module/pypyjit/test_pypy_c/test_containers.py
+++ b/pypy/module/pypyjit/test_pypy_c/test_containers.py
@@ -67,7 +67,7 @@
             p10 = call_r(ConstClass(ll_str__IntegerR_SignedConst_Signed), i5, descr=<Callr . i EF=3>)
             guard_no_exception(descr=...)
             guard_nonnull(p10, descr=...)
-            i12 = call_i(ConstClass(ll_strhash), p10, descr=<Calli . r EF=0>)
+            i12 = call_i(ConstClass(_ll_strhash__rpy_stringPtr), p10, descr=<Calli . r EF=0>)
             p13 = new(descr=...)
             p15 = new_array_clear(16, descr=<ArrayU 1>)
             {{{
diff --git a/rpython/jit/backend/arm/regalloc.py b/rpython/jit/backend/arm/regalloc.py
--- a/rpython/jit/backend/arm/regalloc.py
+++ b/rpython/jit/backend/arm/regalloc.py
@@ -1002,6 +1002,9 @@
     prepare_op_cond_call_gc_wb_array = prepare_op_cond_call_gc_wb
 
     def prepare_op_cond_call(self, op, fcond):
+        # XXX don't force the arguments to be loaded in specific
+        # locations before knowing if we can take the fast path
+        # XXX add cond_call_value support
         assert 2 <= op.numargs() <= 4 + 2
         tmpreg = self.get_scratch_reg(INT, selected_reg=r.r4)
         v = op.getarg(1)
diff --git a/rpython/jit/backend/llgraph/runner.py b/rpython/jit/backend/llgraph/runner.py
--- a/rpython/jit/backend/llgraph/runner.py
+++ b/rpython/jit/backend/llgraph/runner.py
@@ -325,6 +325,7 @@
     supports_longlong = r_uint is not r_ulonglong
     supports_singlefloats = True
     supports_guard_gc_type = True
+    supports_cond_call_value = True
     translate_support_code = False
     is_llgraph = True
     vector_extension = True
@@ -1334,6 +1335,16 @@
         # cond_call can't have a return value
         self.execute_call_n(calldescr, func, *args)
 
+    def execute_cond_call_value_i(self, calldescr, value, func, *args):
+        if not value:
+            value = self.execute_call_i(calldescr, func, *args)
+        return value
+
+    def execute_cond_call_value_r(self, calldescr, value, func, *args):
+        if not value:
+            value = self.execute_call_r(calldescr, func, *args)
+        return value
+
     def _execute_call(self, calldescr, func, *args):
         effectinfo = calldescr.get_extra_info()
         if effectinfo is not None and hasattr(effectinfo, 'oopspecindex'):
diff --git a/rpython/jit/backend/llsupport/regalloc.py b/rpython/jit/backend/llsupport/regalloc.py
--- a/rpython/jit/backend/llsupport/regalloc.py
+++ b/rpython/jit/backend/llsupport/regalloc.py
@@ -759,6 +759,8 @@
         if (opnum != rop.GUARD_TRUE and opnum != rop.GUARD_FALSE
                                     and opnum != rop.COND_CALL):
             return False
+        # NB: don't list COND_CALL_VALUE_I/R here, these two variants
+        # of COND_CALL don't accept a cc as input
         if next_op.getarg(0) is not op:
             return False
         if self.longevity[op][1] > i + 1:
diff --git a/rpython/jit/backend/llsupport/rewrite.py b/rpython/jit/backend/llsupport/rewrite.py
--- a/rpython/jit/backend/llsupport/rewrite.py
+++ b/rpython/jit/backend/llsupport/rewrite.py
@@ -11,7 +11,7 @@
 from rpython.jit.backend.llsupport.symbolic import (WORD,
         get_array_token)
 from rpython.jit.backend.llsupport.descr import SizeDescr, ArrayDescr,\
-     FLAG_POINTER
+     FLAG_POINTER, CallDescr
 from rpython.jit.metainterp.history import JitCellToken
 from rpython.jit.backend.llsupport.descr import (unpack_arraydescr,
         unpack_fielddescr, unpack_interiorfielddescr)
@@ -370,7 +370,9 @@
                     self.consider_setfield_gc(op)
                 elif op.getopnum() == rop.SETARRAYITEM_GC:
                     self.consider_setarrayitem_gc(op)
-            # ---------- call assembler -----------
+            # ---------- calls -----------
+            if OpHelpers.is_plain_call(op.getopnum()):
+                self.expand_call_shortcut(op)
             if OpHelpers.is_call_assembler(op.getopnum()):
                 self.handle_call_assembler(op)
                 continue
@@ -616,6 +618,30 @@
         self.emit_gc_store_or_indexed(None, ptr, ConstInt(0), value,
                                       size, 1, ofs)
 
+    def expand_call_shortcut(self, op):
+        if not self.cpu.supports_cond_call_value:
+            return
+        descr = op.getdescr()
+        if descr is None:
+            return
+        assert isinstance(descr, CallDescr)
+        effectinfo = descr.get_extra_info()
+        if effectinfo is None or effectinfo.call_shortcut is None:
+            return
+        if op.type == 'r':
+            cond_call_opnum = rop.COND_CALL_VALUE_R
+        elif op.type == 'i':
+            cond_call_opnum = rop.COND_CALL_VALUE_I
+        else:
+            return
+        cs = effectinfo.call_shortcut
+        ptr_box = op.getarg(1 + cs.argnum)
+        value_box = self.emit_getfield(ptr_box, descr=cs.fielddescr,
+                                       raw=(ptr_box.type == 'i'))
+        self.replace_op_with(op, ResOperation(cond_call_opnum,
+                                              [value_box] + op.getarglist(),
+                                              descr=descr))
+
     def handle_call_assembler(self, op):
         descrs = self.gc_ll_descr.getframedescrs(self.cpu)
         loop_token = op.getdescr()
diff --git a/rpython/jit/backend/llsupport/test/test_rewrite.py b/rpython/jit/backend/llsupport/test/test_rewrite.py
--- a/rpython/jit/backend/llsupport/test/test_rewrite.py
+++ b/rpython/jit/backend/llsupport/test/test_rewrite.py
@@ -1,7 +1,8 @@
 import py
 from rpython.jit.backend.llsupport.descr import get_size_descr,\
      get_field_descr, get_array_descr, ArrayDescr, FieldDescr,\
-     SizeDescr, get_interiorfield_descr
+     SizeDescr, get_interiorfield_descr, get_call_descr
+from rpython.jit.codewriter.effectinfo import EffectInfo, CallShortcut
 from rpython.jit.backend.llsupport.gc import GcLLDescr_boehm,\
      GcLLDescr_framework
 from rpython.jit.backend.llsupport import jitframe
@@ -80,6 +81,14 @@
                                      lltype.malloc(T, zero=True))
         self.myT = myT
         #
+        call_shortcut = CallShortcut(0, tzdescr)
+        effectinfo = EffectInfo(None, None, None, None, None, None,
+                                EffectInfo.EF_RANDOM_EFFECTS,
+                                call_shortcut=call_shortcut)
+        call_shortcut_descr = get_call_descr(self.gc_ll_descr,
+            [lltype.Ptr(T)], lltype.Signed,
+            effectinfo)
+        #
         A = lltype.GcArray(lltype.Signed)
         adescr = get_array_descr(self.gc_ll_descr, A)
         adescr.tid = 4321
@@ -200,6 +209,7 @@
 
     load_constant_offset = True
     load_supported_factors = (1,2,4,8)
+    supports_cond_call_value = True
 
     translate_support_code = None
 
@@ -1429,3 +1439,15 @@
             jump()
         """)
         assert len(self.gcrefs) == 2
+
+    def test_handle_call_shortcut(self):
+        self.check_rewrite("""
+            [p0]
+            i1 = call_i(123, p0, descr=call_shortcut_descr)
+            jump(i1)
+        """, """
+            [p0]
+            i2 = gc_load_i(p0, %(tzdescr.offset)s, %(tzdescr.field_size)s)
+            i1 = cond_call_value_i(i2, 123, p0, descr=call_shortcut_descr)
+            jump(i1)
+        """)
diff --git a/rpython/jit/backend/model.py b/rpython/jit/backend/model.py
--- a/rpython/jit/backend/model.py
+++ b/rpython/jit/backend/model.py
@@ -16,6 +16,7 @@
     # Boxes and Consts are BoxFloats and ConstFloats.
     supports_singlefloats = False
     supports_guard_gc_type = False
+    supports_cond_call_value = False
 
     propagate_exception_descr = None
 
diff --git a/rpython/jit/backend/test/runner_test.py b/rpython/jit/backend/test/runner_test.py
--- a/rpython/jit/backend/test/runner_test.py
+++ b/rpython/jit/backend/test/runner_test.py
@@ -2389,7 +2389,7 @@
             f2 = longlong.getfloatstorage(3.4)
             frame = self.cpu.execute_token(looptoken, 1, 0, 1, 2, 3, 4, 5, f1, f2)
             assert not called
-            for j in range(5):
+            for j in range(6):
                 assert self.cpu.get_int_value(frame, j) == j
             assert longlong.getrealfloat(self.cpu.get_float_value(frame, 6)) == 1.2
             assert longlong.getrealfloat(self.cpu.get_float_value(frame, 7)) == 3.4
@@ -2447,6 +2447,54 @@
                                            67, 89)
             assert called == [(67, 89)]
 
+    def test_cond_call_value(self):
+        if not self.cpu.supports_cond_call_value:
+            py.test.skip("missing supports_cond_call_value")
+
+        def func_int(*args):
+            called.append(args)
+            return len(args) * 100 + 1000
+
+        for i in range(5):
+            called = []
+
+            FUNC = self.FuncType([lltype.Signed] * i, lltype.Signed)
+            func_ptr = llhelper(lltype.Ptr(FUNC), func_int)
+            calldescr = self.cpu.calldescrof(FUNC, FUNC.ARGS, FUNC.RESULT,
+                                             EffectInfo.MOST_GENERAL)
+
+            ops = '''
+            [i0, i1, i2, i3, i4, i5, i6, f0, f1]
+            i15 = cond_call_value_i(i1, ConstClass(func_ptr), %s)
+            guard_false(i0, descr=faildescr) [i1,i2,i3,i4,i5,i6,i15, f0,f1]
+            finish(i15)
+            ''' % ', '.join(['i%d' % (j + 2) for j in range(i)] +
+                            ["descr=calldescr"])
+            loop = parse(ops, namespace={'faildescr': BasicFailDescr(),
+                                         'func_ptr': func_ptr,
+                                         'calldescr': calldescr})
+            looptoken = JitCellToken()
+            self.cpu.compile_loop(loop.inputargs, loop.operations, looptoken)
+            f1 = longlong.getfloatstorage(1.2)
+            f2 = longlong.getfloatstorage(3.4)
+            frame = self.cpu.execute_token(looptoken, 1, 50, 1, 2, 3, 4, 5,
+                                           f1, f2)
+            assert not called
+            assert [self.cpu.get_int_value(frame, j) for j in range(7)] == [
+                        50, 1, 2, 3, 4, 5, 50]
+            assert longlong.getrealfloat(
+                        self.cpu.get_float_value(frame, 7)) == 1.2
+            assert longlong.getrealfloat(
+                        self.cpu.get_float_value(frame, 8)) == 3.4
+            #
+            frame = self.cpu.execute_token(looptoken, 1, 0, 1, 2, 3, 4, 5,
+                                           f1, f2)
+            assert called == [(1, 2, 3, 4)[:i]]
+            assert [self.cpu.get_int_value(frame, j) for j in range(7)] == [
+                        0, 1, 2, 3, 4, 5, i * 100 + 1000]
+            assert longlong.getrealfloat(self.cpu.get_float_value(frame, 7)) == 1.2
+            assert longlong.getrealfloat(self.cpu.get_float_value(frame, 8)) == 3.4
+
     def test_force_operations_returning_void(self):
         values = []
         def maybe_force(token, flag):
diff --git a/rpython/jit/backend/test/test_ll_random.py b/rpython/jit/backend/test/test_ll_random.py
--- a/rpython/jit/backend/test/test_ll_random.py
+++ b/rpython/jit/backend/test/test_ll_random.py
@@ -594,7 +594,7 @@
         return subset, d['f'], vtableptr
 
     def getresulttype(self):
-        if self.opnum == rop.CALL_I:
+        if self.opnum == rop.CALL_I or self.opnum == rop.COND_CALL_VALUE_I:
             return lltype.Signed
         elif self.opnum == rop.CALL_F:
             return lltype.Float
@@ -712,7 +712,12 @@
 class CondCallOperation(BaseCallOperation):
     def produce_into(self, builder, r):
         fail_subset = builder.subset_of_intvars(r)
-        v_cond = builder.get_bool_var(r)
+        if self.opnum == rop.COND_CALL:
+            RESULT_TYPE = lltype.Void
+            v_cond = builder.get_bool_var(r)
+        else:
+            RESULT_TYPE = lltype.Signed
+            v_cond = r.choice(builder.intvars)
         subset = builder.subset_of_intvars(r)[:4]
         for i in range(len(subset)):
             if r.random() < 0.35:
@@ -724,8 +729,10 @@
                 seen.append(args)
             else:
                 assert seen[0] == args
+            if RESULT_TYPE is lltype.Signed:
+                return len(args) - 42000
         #
-        TP = lltype.FuncType([lltype.Signed] * len(subset), lltype.Void)
+        TP = lltype.FuncType([lltype.Signed] * len(subset), RESULT_TYPE)
         ptr = llhelper(lltype.Ptr(TP), call_me)
         c_addr = ConstAddr(llmemory.cast_ptr_to_adr(ptr), builder.cpu)
         args = [v_cond, c_addr] + subset
@@ -769,6 +776,7 @@
 for i in range(2):
     OPERATIONS.append(GuardClassOperation(rop.GUARD_CLASS))
     OPERATIONS.append(CondCallOperation(rop.COND_CALL))
+    OPERATIONS.append(CondCallOperation(rop.COND_CALL_VALUE_I))
     OPERATIONS.append(RaisingCallOperation(rop.CALL_N))
     OPERATIONS.append(RaisingCallOperationGuardNoException(rop.CALL_N))
     OPERATIONS.append(RaisingCallOperationWrongGuardException(rop.CALL_N))
diff --git a/rpython/jit/backend/x86/assembler.py b/rpython/jit/backend/x86/assembler.py
--- a/rpython/jit/backend/x86/assembler.py
+++ b/rpython/jit/backend/x86/assembler.py
@@ -174,8 +174,8 @@
         # copy registers to the frame, with the exception of the
         # 'cond_call_register_arguments' and eax, because these have already
         # been saved by the caller.  Note that this is not symmetrical:
-        # these 5 registers are saved by the caller but restored here at
-        # the end of this function.
+        # these 5 registers are saved by the caller but 4 of them are
+        # restored here at the end of this function.
         self._push_all_regs_to_frame(mc, cond_call_register_arguments + [eax],
                                      supports_floats, callee_only)
         # the caller already did push_gcmap(store=True)
@@ -198,7 +198,7 @@
             mc.ADD(esp, imm(WORD * 7))
         self.set_extra_stack_depth(mc, 0)
         self.pop_gcmap(mc)   # cancel the push_gcmap(store=True) in the caller
-        self._pop_all_regs_from_frame(mc, [], supports_floats, callee_only)
+        self._pop_all_regs_from_frame(mc, [eax], supports_floats, callee_only)
         mc.RET()
         return mc.materialize(self.cpu, [])
 
@@ -1703,7 +1703,8 @@
         self.implement_guard(guard_token)
         # If the previous operation was a COND_CALL, overwrite its conditional
         # jump to jump over this GUARD_NO_EXCEPTION as well, if we can
-        if self._find_nearby_operation(-1).getopnum() == rop.COND_CALL:
+        if self._find_nearby_operation(-1).getopnum() in (
+                rop.COND_CALL, rop.COND_CALL_VALUE_I, rop.COND_CALL_VALUE_R):
             jmp_adr = self.previous_cond_call_jcond
             offset = self.mc.get_relative_pos() - jmp_adr
             if offset <= 127:
@@ -2381,7 +2382,7 @@
     def label(self):
         self._check_frame_depth_debug(self.mc)
 
-    def cond_call(self, op, gcmap, imm_func, arglocs):
+    def cond_call(self, gcmap, imm_func, arglocs, resloc=None):
         assert self.guard_success_cc >= 0
         self.mc.J_il8(rx86.invert_condition(self.guard_success_cc), 0)
                                                             # patched later
@@ -2394,11 +2395,14 @@
         # plus the register 'eax'
         base_ofs = self.cpu.get_baseofs_of_frame_field()
         should_be_saved = self._regalloc.rm.reg_bindings.values()
+        restore_eax = False
         for gpr in cond_call_register_arguments + [eax]:
-            if gpr not in should_be_saved:
+            if gpr not in should_be_saved or gpr is resloc:
                 continue
             v = gpr_reg_mgr_cls.all_reg_indexes[gpr.value]
             self.mc.MOV_br(v * WORD + base_ofs, gpr.value)
+            if gpr is eax:
+                restore_eax = True
         #
         # load the 0-to-4 arguments into these registers
         from rpython.jit.backend.x86.jump import remap_frame_layout
@@ -2422,8 +2426,16 @@
                 floats = True
         cond_call_adr = self.cond_call_slowpath[floats * 2 + callee_only]
         self.mc.CALL(imm(follow_jump(cond_call_adr)))
+        # if this is a COND_CALL_VALUE, we need to move the result in place
+        if resloc is not None and resloc is not eax:
+            self.mc.MOV(resloc, eax)
         # restoring the registers saved above, and doing pop_gcmap(), is left
-        # to the cond_call_slowpath helper.  We never have any result value.
+        # to the cond_call_slowpath helper.  We must only restore eax, if
+        # needed.
+        if restore_eax:
+            v = gpr_reg_mgr_cls.all_reg_indexes[eax.value]
+            self.mc.MOV_rb(eax.value, v * WORD + base_ofs)
+        #
         offset = self.mc.get_relative_pos() - jmp_adr
         assert 0 < offset <= 127
         self.mc.overwrite(jmp_adr-1, chr(offset))
diff --git a/rpython/jit/backend/x86/regalloc.py b/rpython/jit/backend/x86/regalloc.py
--- a/rpython/jit/backend/x86/regalloc.py
+++ b/rpython/jit/backend/x86/regalloc.py
@@ -938,16 +938,45 @@
                     self.rm.force_spill_var(box)
                     assert box not in self.rm.reg_bindings
         #
-        assert op.type == 'v'
         args = op.getarglist()
         assert 2 <= len(args) <= 4 + 2     # maximum 4 arguments
-        v = args[1]
-        assert isinstance(v, Const)
-        imm_func = self.rm.convert_to_imm(v)
+        v_func = args[1]
+        assert isinstance(v_func, Const)
+        imm_func = self.rm.convert_to_imm(v_func)
+
+        # Delicate ordering here.  First get the argument's locations.
+        # If this also contains args[0], this returns the current
+        # location too.
         arglocs = [self.loc(args[i]) for i in range(2, len(args))]
         gcmap = self.get_gcmap()
-        self.load_condition_into_cc(op.getarg(0))
-        self.assembler.cond_call(op, gcmap, imm_func, arglocs)
+
+        if op.type == 'v':
+            # a plain COND_CALL.  Calls the function when args[0] is
+            # true.  Often used just after a comparison operation.
+            self.load_condition_into_cc(op.getarg(0))
+            resloc = None
+        else:
+            # COND_CALL_VALUE_I/R.  Calls the function when args[0]
+            # is equal to 0 or NULL.  Returns the result from the
+            # function call if done, or args[0] if it was not 0/NULL.
+            # Implemented by forcing the result to live in the same
+            # register as args[0], and overwriting it if we really do
+            # the call.
+
+            # Load the register for the result.  Possibly reuse 'args[0]'.
+            # But the old value of args[0], if it survives, is first
+            # spilled away.  We can't overwrite any of op.args[2:] here.
+            resloc = self.rm.force_result_in_reg(op, args[0],
+                                                 forbidden_vars=args[2:])
+
+            # Test the register for the result.
+            self.assembler.test_location(resloc)
+            self.assembler.guard_success_cc = rx86.Conditions['Z']
+
+        self.assembler.cond_call(gcmap, imm_func, arglocs, resloc)
+
+    consider_cond_call_value_i = consider_cond_call
+    consider_cond_call_value_r = consider_cond_call
 
     def consider_call_malloc_nursery(self, op):
         size_box = op.getarg(0)
diff --git a/rpython/jit/backend/x86/runner.py b/rpython/jit/backend/x86/runner.py
--- a/rpython/jit/backend/x86/runner.py
+++ b/rpython/jit/backend/x86/runner.py
@@ -15,6 +15,7 @@
     debug = True
     supports_floats = True
     supports_singlefloats = True
+    supports_cond_call_value = True
 
     dont_keepalive_stuff = False # for tests
     with_threads = False
diff --git a/rpython/jit/codewriter/call.py b/rpython/jit/codewriter/call.py
--- a/rpython/jit/codewriter/call.py
+++ b/rpython/jit/codewriter/call.py
@@ -7,9 +7,10 @@
 from rpython.jit.codewriter.jitcode import JitCode
 from rpython.jit.codewriter.effectinfo import (VirtualizableAnalyzer,
     QuasiImmutAnalyzer, RandomEffectsAnalyzer, effectinfo_from_writeanalyze,
-    EffectInfo, CallInfoCollection)
+    EffectInfo, CallInfoCollection, CallShortcut)
 from rpython.rtyper.lltypesystem import lltype, llmemory
 from rpython.rtyper.lltypesystem.lltype import getfunctionptr
+from rpython.flowspace.model import Constant, Variable
 from rpython.rlib import rposix
 from rpython.translator.backendopt.canraise import RaiseAnalyzer
 from rpython.translator.backendopt.writeanalyze import ReadWriteAnalyzer
@@ -214,6 +215,7 @@
         elidable = False
         loopinvariant = False
         call_release_gil_target = EffectInfo._NO_CALL_RELEASE_GIL_TARGET
+        call_shortcut = None
         if op.opname == "direct_call":
             funcobj = op.args[0].value._obj
             assert getattr(funcobj, 'calling_conv', 'c') == 'c', (
@@ -228,6 +230,12 @@
                 tgt_func, tgt_saveerr = func._call_aroundstate_target_
                 tgt_func = llmemory.cast_ptr_to_adr(tgt_func)
                 call_release_gil_target = (tgt_func, tgt_saveerr)
+            if hasattr(funcobj, 'graph'):
+                call_shortcut = self.find_call_shortcut(funcobj.graph)
+            if getattr(func, "_call_shortcut_", False):
+                assert call_shortcut is not None, (
+                    "%r: marked as @jit.call_shortcut but shortcut not found"
+                    % (func,))
         elif op.opname == 'indirect_call':
             # check that we're not trying to call indirectly some
             # function with the special flags
@@ -298,6 +306,7 @@
             self.readwrite_analyzer.analyze(op, self.seen_rw), self.cpu,
             extraeffect, oopspecindex, can_invalidate, call_release_gil_target,
             extradescr, self.collect_analyzer.analyze(op, self.seen_gc),
+            call_shortcut,
         )
         #
         assert effectinfo is not None
@@ -368,3 +377,65 @@
                 if GTYPE_fieldname in jd.greenfield_info.green_fields:
                     return True
         return False
+
+    def find_call_shortcut(self, graph):
+        """Identifies graphs that start like this:
+
+           def graph(x, y, z):         def graph(x, y, z):
+               if y.field:                 r = y.field
+                   return y.field          if r: return r
+        """
+        block = graph.startblock
+        if len(block.operations) == 0:
+            return
+        op = block.operations[0]
+        if op.opname != 'getfield':
+            return
+        [v_inst, c_fieldname] = op.args
+        if not isinstance(v_inst, Variable):
+            return
+        v_result = op.result
+        if v_result.concretetype != graph.getreturnvar().concretetype:
+            return
+        if v_result.concretetype == lltype.Void:
+            return
+        argnum = i = 0
+        while block.inputargs[i] is not v_inst:
+            if block.inputargs[i].concretetype != lltype.Void:
+                argnum += 1
+            i += 1
+        PSTRUCT = v_inst.concretetype
+        v_check = v_result
+        fastcase = True
+        for op in block.operations[1:]:
+            if (op.opname in ('int_is_true', 'ptr_nonzero', 'same_as')
+                    and v_check is op.args[0]):
+                v_check = op.result
+            elif op.opname == 'ptr_iszero' and v_check is op.args[0]:
+                v_check = op.result
+                fastcase = not fastcase
+            elif (op.opname in ('int_eq', 'int_ne')
+                    and v_check is op.args[0]
+                    and isinstance(op.args[1], Constant)
+                    and op.args[1].value == 0):
+                v_check = op.result
+                if op.opname == 'int_eq':
+                    fastcase = not fastcase
+            else:
+                return
+        if v_check.concretetype is not lltype.Bool:
+            return
+        if block.exitswitch is not v_check:
+            return
+
+        links = [link for link in block.exits if link.exitcase == fastcase]
+        if len(links) != 1:
+            return
+        [link] = links
+        if link.args != [v_result]:
+            return
+        if not link.target.is_final_block():
+            return
+
+        fielddescr = self.cpu.fielddescrof(PSTRUCT.TO, c_fieldname.value)
+        return CallShortcut(argnum, fielddescr)
diff --git a/rpython/jit/codewriter/effectinfo.py b/rpython/jit/codewriter/effectinfo.py
--- a/rpython/jit/codewriter/effectinfo.py
+++ b/rpython/jit/codewriter/effectinfo.py
@@ -117,7 +117,8 @@
                 can_invalidate=False,
                 call_release_gil_target=_NO_CALL_RELEASE_GIL_TARGET,
                 extradescrs=None,
-                can_collect=True):
+                can_collect=True,
+                call_shortcut=None):
         readonly_descrs_fields = frozenset_or_none(readonly_descrs_fields)
         readonly_descrs_arrays = frozenset_or_none(readonly_descrs_arrays)
         readonly_descrs_interiorfields = frozenset_or_none(
@@ -135,7 +136,8 @@
                extraeffect,
                oopspecindex,
                can_invalidate,
-               can_collect)
+               can_collect,
+               call_shortcut)
         tgt_func, tgt_saveerr = call_release_gil_target
         if tgt_func:
             key += (object(),)    # don't care about caching in this case
@@ -190,6 +192,7 @@
         result.oopspecindex = oopspecindex
         result.extradescrs = extradescrs
         result.call_release_gil_target = call_release_gil_target
+        result.call_shortcut = call_shortcut
         if result.check_can_raise(ignore_memoryerror=True):
             assert oopspecindex in cls._OS_CANRAISE
 
@@ -275,7 +278,8 @@
                                  call_release_gil_target=
                                      EffectInfo._NO_CALL_RELEASE_GIL_TARGET,
                                  extradescr=None,
-                                 can_collect=True):
+                                 can_collect=True,
+                                 call_shortcut=None):
     from rpython.translator.backendopt.writeanalyze import top_set
     if effects is top_set or extraeffect == EffectInfo.EF_RANDOM_EFFECTS:
         readonly_descrs_fields = None
@@ -364,7 +368,8 @@
                       can_invalidate,
                       call_release_gil_target,
                       extradescr,
-                      can_collect)
+                      can_collect,
+                      call_shortcut)
 
 def consider_struct(TYPE, fieldname):
     if fieldType(TYPE, fieldname) is lltype.Void:
@@ -387,6 +392,24 @@
 
 # ____________________________________________________________
 
+
+class CallShortcut(object):
+    def __init__(self, argnum, fielddescr):
+        self.argnum = argnum
+        self.fielddescr = fielddescr
+
+    def __eq__(self, other):
+        return (isinstance(other, CallShortcut) and
+                self.argnum == other.argnum and
+                self.fielddescr == other.fielddescr)
+    def __ne__(self, other):
+        return not (self == other)
+    def __hash__(self):
+        return hash((self.argnum, self.fielddescr))
+
+# ____________________________________________________________
+
+
 class VirtualizableAnalyzer(BoolGraphAnalyzer):
     def analyze_simple_operation(self, op, graphinfo):
         return op.opname in ('jit_force_virtualizable',
diff --git a/rpython/jit/codewriter/test/test_call.py b/rpython/jit/codewriter/test/test_call.py
--- a/rpython/jit/codewriter/test/test_call.py
+++ b/rpython/jit/codewriter/test/test_call.py
@@ -6,7 +6,7 @@
 from rpython.rlib import jit
 from rpython.jit.codewriter import support, call
 from rpython.jit.codewriter.call import CallControl
-from rpython.jit.codewriter.effectinfo import EffectInfo
+from rpython.jit.codewriter.effectinfo import EffectInfo, CallShortcut
 
 
 class FakePolicy:
@@ -368,3 +368,100 @@
         assert call_op.opname == 'direct_call'
         call_descr = cc.getcalldescr(call_op)
         assert call_descr.extrainfo.check_can_collect() == expected
+
+def test_find_call_shortcut():
+    class FakeCPU:
+        def fielddescrof(self, TYPE, fieldname):
+            if isinstance(TYPE, lltype.GcStruct):
+                if fieldname == 'inst_foobar':
+                    return 'foobardescr'
+                if fieldname == 'inst_fooref':
+                    return 'foorefdescr'
+            if TYPE == RAW and fieldname == 'x':
+                return 'xdescr'
+            assert False, (TYPE, fieldname)
+    cc = CallControl(FakeCPU())
+
+    class B(object):
+        foobar = 0
+        fooref = None
+
+    def f1(a, b, c):
+        if b.foobar:
+            return b.foobar
+        b.foobar = a + c
+        return b.foobar
+
+    def f2(x, y, z, b):
+        r = b.fooref
+        if r is not None:
+            return r
+        r = b.fooref = B()
+        return r
+
+    class Space(object):
+        def _freeze_(self):
+            return True
+    space = Space()
+
+    def f3(space, b):
+        r = b.foobar
+        if not r:
+            r = b.foobar = 123
+        return r
+
+    def f4(raw):
+        r = raw.x
+        if r != 0:
+            return r
+        raw.x = 123
+        return 123
+    RAW = lltype.Struct('RAW', ('x', lltype.Signed))
+
+    def f5(b):
+        r = b.foobar
+        if r == 0:
+            r = b.foobar = 123
+        return r
+
+    def f(a, c):
+        b = B()
+        f1(a, b, c)
+        f2(a, c, a, b)
+        f3(space, b)
+        r = lltype.malloc(RAW, flavor='raw')
+        f4(r)
+        f5(b)
+
+    rtyper = support.annotate(f, [10, 20])
+    f1_graph = rtyper.annotator.translator._graphof(f1)
+    assert cc.find_call_shortcut(f1_graph) == CallShortcut(1, "foobardescr")
+    f2_graph = rtyper.annotator.translator._graphof(f2)
+    assert cc.find_call_shortcut(f2_graph) == CallShortcut(3, "foorefdescr")
+    f3_graph = rtyper.annotator.translator._graphof(f3)
+    assert cc.find_call_shortcut(f3_graph) == CallShortcut(0, "foobardescr")
+    f4_graph = rtyper.annotator.translator._graphof(f4)
+    assert cc.find_call_shortcut(f4_graph) == CallShortcut(0, "xdescr")
+    f5_graph = rtyper.annotator.translator._graphof(f5)
+    assert cc.find_call_shortcut(f5_graph) == CallShortcut(0, "foobardescr")
+
+def test_cant_find_call_shortcut():
+    from rpython.jit.backend.llgraph.runner import LLGraphCPU
+
+    @jit.dont_look_inside
+    @jit.call_shortcut
+    def f1(n):
+        return n + 17   # no call shortcut found
+
+    def f(n):
+        return f1(n)
+
+    rtyper = support.annotate(f, [1])
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cc = CallControl(LLGraphCPU(rtyper), jitdrivers_sd=[jitdriver_sd])
+    res = cc.find_all_graphs(FakePolicy())
+    [f_graph] = [x for x in res if x.func is f]
+    call_op = f_graph.startblock.operations[0]
+    assert call_op.opname == 'direct_call'
+    e = py.test.raises(AssertionError, cc.getcalldescr, call_op)
+    assert "shortcut not found" in str(e.value)
diff --git a/rpython/jit/metainterp/executor.py b/rpython/jit/metainterp/executor.py
--- a/rpython/jit/metainterp/executor.py
+++ b/rpython/jit/metainterp/executor.py
@@ -101,6 +101,18 @@
     if condbox.getint():
         do_call_n(cpu, metainterp, argboxes[1:], descr)
 
+def do_cond_call_value_i(cpu, metainterp, argboxes, descr):
+    value = argboxes[0].getint()
+    if value == 0:
+        value = do_call_i(cpu, metainterp, argboxes[1:], descr)
+    return value
+
+def do_cond_call_value_r(cpu, metainterp, argboxes, descr):
+    value = argboxes[0].getref_base()
+    if not value:
+        value = do_call_r(cpu, metainterp, argboxes[1:], descr)
+    return value
+
 def do_getarrayitem_gc_i(cpu, _, arraybox, indexbox, arraydescr):
     array = arraybox.getref_base()
     index = indexbox.getint()
@@ -366,6 +378,8 @@
                          rop.CALL_ASSEMBLER_I,
                          rop.CALL_ASSEMBLER_N,
                          rop.INCREMENT_DEBUG_COUNTER,
+                         rop.COND_CALL_VALUE_R,
+                         rop.COND_CALL_VALUE_I,
                          rop.COND_CALL_GC_WB,
                          rop.COND_CALL_GC_WB_ARRAY,
                          rop.ZERO_ARRAY,
diff --git a/rpython/jit/metainterp/resoperation.py b/rpython/jit/metainterp/resoperation.py
--- a/rpython/jit/metainterp/resoperation.py
+++ b/rpython/jit/metainterp/resoperation.py
@@ -1149,8 +1149,8 @@
     '_CANRAISE_FIRST', # ----- start of can_raise operations -----
     '_CALL_FIRST',
     'CALL/*d/rfin',
-    'COND_CALL/*d/n',
-    # a conditional call, with first argument as a condition
+    'COND_CALL/*d/n',   # a conditional call, with first argument as a condition
+    'COND_CALL_VALUE/*d/ri',  # same but returns a result; emitted by rewrite
     'CALL_ASSEMBLER/*d/rfin',  # call already compiled assembler
     'CALL_MAY_FORCE/*d/rfin',
     'CALL_LOOPINVARIANT/*d/rfin',
diff --git a/rpython/jit/metainterp/test/test_dict.py b/rpython/jit/metainterp/test/test_dict.py
--- a/rpython/jit/metainterp/test/test_dict.py
+++ b/rpython/jit/metainterp/test/test_dict.py
@@ -195,7 +195,8 @@
                            'new_with_vtable': 2, 'getinteriorfield_gc_i': 2,
                            'setfield_gc': 14, 'int_gt': 2, 'int_sub': 2,
                            'call_i': 6, 'call_n': 2, 'call_r': 2, 'int_ge': 2,
-                           'guard_no_exception': 8, 'new': 2})
+                           'guard_no_exception': 8, 'new': 2,
+                           'guard_nonnull': 2})
 
     def test_unrolling_of_dict_iter(self):
         driver = JitDriver(greens = [], reds = ['n'])
diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -257,6 +257,26 @@
     func.oopspec = "jit.not_in_trace()"   # note that 'func' may take arguments
     return func
 
+def call_shortcut(func):
+    """A decorator to ensure that a function has a fast-path.
+    Only useful on functions that the JIT doesn't normally look inside.
+    It still replaces residual calls to that function with inline code
+    that checks for a fast path, and only does the call if not.  For
+    now, graphs made by the following kinds of functions are detected:
+
+           def func(x, y, z):         def func(x, y, z):
+               if y.field:                 r = y.field
+                   return y.field          if r is None:
+               ...                             ...
+                                           return r
+
+    Fast-path detection is always on, but this decorator makes the
+    codewriter complain if it cannot find the promized fast-path.
+    """
+    func._call_shortcut_ = True
+    return func
+
+
 @oopspec("jit.isconstant(value)")
 def isconstant(value):
     """
diff --git a/rpython/rtyper/lltypesystem/rstr.py b/rpython/rtyper/lltypesystem/rstr.py
--- a/rpython/rtyper/lltypesystem/rstr.py
+++ b/rpython/rtyper/lltypesystem/rstr.py
@@ -369,13 +369,19 @@
         return b
 
     @staticmethod
+    def ll_strhash(s):
+        if s:
+            return LLHelpers._ll_strhash(s)
+        else:
+            return 0
+
+    @staticmethod
     @jit.elidable
-    def ll_strhash(s):
+    @jit.call_shortcut
+    def _ll_strhash(s):
         # unlike CPython, there is no reason to avoid to return -1
         # but our malloc initializes the memory to zero, so we use zero as the
         # special non-computed-yet value.
-        if not s:
-            return 0
         x = s.hash
         if x == 0:
             x = _hash_string(s.chars)


More information about the pypy-commit mailing list