[pypy-svn] r68495 - in pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp: . test

pedronis at codespeak.net pedronis at codespeak.net
Thu Oct 15 16:28:02 CEST 2009


Author: pedronis
Date: Thu Oct 15 16:28:02 2009
New Revision: 68495

Removed:
   pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_policy.py
Modified:
   pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/codewriter.py
   pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/policy.py
   pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/pyjitpl.py
   pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_basic.py
   pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_codewriter.py
   pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_warmspot.py
   pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/warmspot.py
Log:
(cfbolz, pedronis) don't depend on look_inside_graph and the policy after the
first round of collecting graphs, other cleanups like move generate_bytecode to
the codewriter away from staticdata



Modified: pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/codewriter.py
==============================================================================
--- pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/codewriter.py	(original)
+++ pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/codewriter.py	Thu Oct 15 16:28:02 2009
@@ -51,22 +51,132 @@
 class CodeWriter(object):
     portal_graph = None
 
-    def __init__(self, metainterp_sd, policy):
+    def __init__(self, rtyper):
+        self.rtyper = rtyper
+        self.candidate_graphs = None
         self.all_prebuilt_values = dict_equal_consts()
         self.all_indirect_call_targets = {}
         self.all_graphs = {}
         self.all_methdescrs = {}
         self.all_listdescs = {}
         self.unfinished_graphs = []
-        self.metainterp_sd = metainterp_sd
-        self.rtyper = metainterp_sd.cpu.rtyper
-        self.cpu = metainterp_sd.cpu
-        self.policy = policy
         self.counter = 0
-        self.raise_analyzer = RaiseAnalyzer(self.rtyper.annotator.translator)
         self.class_sizes = []
         self._class_sizes_seen = {}
 
+        # set later with .start()
+        self.metainterp_sd = None
+        self.cpu = None
+        self.portal_runner_ptr = None
+        self.raise_analyzer = None
+
+    def find_all_graphs(self, portal_graph, leave_graph,
+                        policy, supports_floats):
+        from pypy.translator.simplify import get_graph
+        def is_candidate(graph):
+            return policy.look_inside_graph(graph, supports_floats)
+        
+        todo = [portal_graph]
+        if leave_graph is not None:
+            todo.append(leave_graph)        
+        self.candidate_graphs = seen = set(todo)
+        while todo:
+            top_graph = todo.pop()
+            for _, op in top_graph.iterblockops():
+                if op.opname not in ("direct_call", "indirect_call", "oosend"):
+                    continue
+                kind = self.guess_call_kind(op, is_candidate)
+                if kind != "regular":
+                    continue
+                for graph in self.graphs_from(op, is_candidate):
+                    if graph in seen:
+                        continue
+                    assert is_candidate(graph)
+                    todo.append(graph)
+                    seen.add(graph)
+        return self.candidate_graphs
+
+    def graphs_from(self, op, is_candidate=None):
+        if is_candidate is None:
+            is_candidate = self.is_candidate
+        if op.opname == 'direct_call':
+            funcobj = get_funcobj(op.args[0].value)
+            graph = funcobj.graph
+            if is_candidate(graph):
+                return [graph]     # common case: look inside this graph
+        else:
+            assert op.opname in ('indirect_call', 'oosend')
+            if op.opname == 'indirect_call':
+                graphs = op.args[-1].value
+            else:
+                v_obj = op.args[1].concretetype
+                graphs = v_obj._lookup_graphs(op.args[0].value)
+            if graphs is not None:
+                result = []
+                for graph in graphs:
+                    if is_candidate(graph):
+                        result.append(graph)
+                if result:
+                    return result  # common case: look inside these graphs,
+                                   # and ignore the others if there are any
+            else:
+                # special case: handle the indirect call that goes to
+                # the 'instantiate' methods.  This check is a bit imprecise
+                # but it's not too bad if we mistake a random indirect call
+                # for the one to 'instantiate'.
+                CALLTYPE = op.args[0].concretetype
+                if (op.opname == 'indirect_call' and len(op.args) == 2 and
+                    CALLTYPE == rclass.OBJECT_VTABLE.instantiate):
+                    return list(self._graphs_of_all_instantiate())
+        # residual call case: we don't need to look into any graph
+        return None
+
+    def _graphs_of_all_instantiate(self        ):
+        for vtable in self.rtyper.lltype2vtable.values():
+            if vtable.instantiate:
+                yield vtable.instantiate._obj.graph
+                
+    def guess_call_kind(self, op, is_candidate=None):
+        if op.opname == 'direct_call':
+            funcptr = op.args[0].value
+            if funcptr is self.portal_runner_ptr:
+                return 'recursive'
+            funcobj = get_funcobj(funcptr)
+            if getattr(funcobj, 'graph', None) is None:
+                return 'residual'
+            targetgraph = funcobj.graph
+            if (hasattr(targetgraph, 'func') and
+                hasattr(targetgraph.func, 'oopspec')):
+                return 'builtin'
+        elif op.opname == 'oosend':
+            SELFTYPE, methname, opargs = support.decompose_oosend(op)
+            if SELFTYPE.oopspec_name is not None:
+                return 'builtin'
+        if self.graphs_from(op, is_candidate) is None:
+            return 'residual'
+        return 'regular'
+
+    def is_candidate(self, graph):
+        return graph in self.candidate_graphs
+
+
+    def generate_bytecode(self, metainterp_sd, portal_graph, leave_graph,
+                          portal_runner_ptr):
+        self._start(metainterp_sd, portal_runner_ptr)
+        leave_code = None
+        if leave_graph:
+            leave_code = self.make_one_bytecode((leave_graph, None), False)
+        portal_code = self.make_portal_bytecode(portal_graph)
+
+        self.metainterp_sd.info_from_codewriter(portal_code, leave_code,
+                                                self.class_sizes)
+
+    def _start(self, metainterp_sd, portal_runner_ptr):
+        self.metainterp_sd = metainterp_sd
+        self.cpu = metainterp_sd.cpu
+        self.portal_runner_ptr = portal_runner_ptr
+        self.raise_analyzer = RaiseAnalyzer(self.rtyper.annotator.translator)
+
     def make_portal_bytecode(self, graph):
         log.info("making JitCodes...")
         self.portal_graph = graph
@@ -148,8 +258,7 @@
         return (cfnptr, calldescr)
 
     def register_indirect_call_targets(self, op):
-        targets = self.policy.graphs_from(op, self.rtyper,
-                                          self.cpu.supports_floats)
+        targets = self.graphs_from(op)
         assert targets is not None
         for graph in targets:
             if graph in self.all_indirect_call_targets:
@@ -183,8 +292,7 @@
             _, meth = T._lookup(methname)
             if not getattr(meth, 'abstract', False):
                 assert meth.graph
-                if self.policy.look_inside_graph(meth.graph,
-                                                 self.cpu.supports_floats):
+                if self.is_candidate(meth.graph):
                     jitcode = self.get_jitcode(meth.graph,
                                                oosend_methdescr=methdescr)
                 else:
@@ -234,8 +342,7 @@
         graph, oosend_methdescr = graph_key
         self.bytecode = self.codewriter.get_jitcode(graph,
                                              oosend_methdescr=oosend_methdescr)
-        assert codewriter.policy.look_inside_graph(graph,
-                                                   self.cpu.supports_floats)
+        assert codewriter.is_candidate(graph)
         self.graph = graph
 
     def assemble(self):
@@ -1007,28 +1114,20 @@
             self.emit('can_enter_jit')
 
     def serialize_op_direct_call(self, op):
-        kind = self.codewriter.policy.guess_call_kind(op,
-                                           self.codewriter.rtyper,
-                                           self.codewriter.cpu.supports_floats)
+        kind = self.codewriter.guess_call_kind(op)
         return getattr(self, 'handle_%s_call' % kind)(op)
 
     def serialize_op_indirect_call(self, op):
-        kind = self.codewriter.policy.guess_call_kind(op,
-                                           self.codewriter.rtyper,
-                                           self.codewriter.cpu.supports_floats)
+        kind = self.codewriter.guess_call_kind(op)
         return getattr(self, 'handle_%s_indirect_call' % kind)(op)
 
     def serialize_op_oosend(self, op):
-        kind = self.codewriter.policy.guess_call_kind(op,
-                                           self.codewriter.rtyper,
-                                           self.codewriter.cpu.supports_floats)
+        kind = self.codewriter.guess_call_kind(op)
         return getattr(self, 'handle_%s_oosend' % kind)(op)
 
     def handle_regular_call(self, op, oosend_methdescr=None):
         self.minimize_variables()
-        [targetgraph] = self.codewriter.policy.graphs_from(op,
-                                           self.codewriter.rtyper,
-                                           self.codewriter.cpu.supports_floats)
+        [targetgraph] = self.codewriter.graphs_from(op)
         jitbox = self.codewriter.get_jitcode(targetgraph, self.graph,
                                              oosend_methdescr=oosend_methdescr)
         if oosend_methdescr:

Modified: pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/policy.py
==============================================================================
--- pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/policy.py	(original)
+++ pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/policy.py	Thu Oct 15 16:28:02 2009
@@ -15,8 +15,6 @@
             print >> f, graph
         f.close()
 
-    portal_runner_ptr = None # set by WarmRunnerDesc.rewrite_jit_merge_point
-
     def look_inside_function(self, func):
         if hasattr(func, '_jit_look_inside_'):
             return func._jit_look_inside_
@@ -60,64 +58,6 @@
             self.unsafe_loopy_graphs.add(graph)
         return res and not contains_loop
 
-    def graphs_from(self, op, rtyper, supports_floats):
-        if op.opname == 'direct_call':
-            funcobj = get_funcobj(op.args[0].value)
-            graph = funcobj.graph
-            if self.look_inside_graph(graph, supports_floats):
-                return [graph]     # common case: look inside this graph
-        else:
-            assert op.opname in ('indirect_call', 'oosend')
-            if op.opname == 'indirect_call':
-                graphs = op.args[-1].value
-            else:
-                v_obj = op.args[1].concretetype
-                graphs = v_obj._lookup_graphs(op.args[0].value)
-            if graphs is not None:
-                result = []
-                for graph in graphs:
-                    if self.look_inside_graph(graph, supports_floats):
-                        result.append(graph)
-                if result:
-                    return result  # common case: look inside these graphs,
-                                   # and ignore the others if there are any
-            else:
-                # special case: handle the indirect call that goes to
-                # the 'instantiate' methods.  This check is a bit imprecise
-                # but it's not too bad if we mistake a random indirect call
-                # for the one to 'instantiate'.
-                CALLTYPE = op.args[0].concretetype
-                if (op.opname == 'indirect_call' and len(op.args) == 2 and
-                    CALLTYPE == rclass.OBJECT_VTABLE.instantiate):
-                    return list(self._graphs_of_all_instantiate(rtyper))
-        # residual call case: we don't need to look into any graph
-        return None
-
-    def _graphs_of_all_instantiate(self, rtyper):
-        for vtable in rtyper.lltype2vtable.values():
-            if vtable.instantiate:
-                yield vtable.instantiate._obj.graph
-
-    def guess_call_kind(self, op, rtyper, supports_floats):
-        if op.opname == 'direct_call':
-            funcptr = op.args[0].value
-            funcobj = get_funcobj(funcptr)
-            if funcptr is self.portal_runner_ptr:
-                return 'recursive'
-            if getattr(funcobj, 'graph', None) is None:
-                return 'residual'
-            targetgraph = funcobj.graph
-            if (hasattr(targetgraph, 'func') and
-                hasattr(targetgraph.func, 'oopspec')):
-                return 'builtin'
-        elif op.opname == 'oosend':
-            SELFTYPE, methname, opargs = support.decompose_oosend(op)
-            if SELFTYPE.oopspec_name is not None:
-                return 'builtin'
-        if self.graphs_from(op, rtyper, supports_floats) is None:
-            return 'residual'
-        return 'regular'
-
 def contains_unsupported_variable_type(graph, supports_floats):
     getkind = history.getkind
     try:

Modified: pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/pyjitpl.py
==============================================================================
--- pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/pyjitpl.py	(original)
+++ pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/pyjitpl.py	Thu Oct 15 16:28:02 2009
@@ -967,10 +967,8 @@
     logger_noopt = None
     logger_ops = None
 
-    def __init__(self, portal_graph, graphs, cpu, stats, options,
-                 ProfilerClass=EmptyProfiler, warmrunnerdesc=None,
-                 leave_graph=None):
-        self.portal_graph = portal_graph
+    def __init__(self, portal_graph, cpu, stats, options,
+                 ProfilerClass=EmptyProfiler, warmrunnerdesc=None):
         self.cpu = cpu
         self.stats = stats
         self.options = options
@@ -995,11 +993,19 @@
         backendmodule = self.cpu.__module__
         backendmodule = backendmodule.split('.')[-2]
         self.jit_starting_line = 'JIT starting (%s)' % backendmodule
-        self.leave_graph = leave_graph
+
+        self.portal_code = None
+        self.leave_code = None
+        self._class_sizes = None        
 
     def _freeze_(self):
         return True
 
+    def info_from_codewriter(self, portal_code, leave_code, class_sizes):
+        self.portal_code = portal_code
+        self.leave_code = leave_code
+        self._class_sizes = class_sizes
+
     def finish_setup(self, optimizer=None):
         warmrunnerdesc = self.warmrunnerdesc
         if warmrunnerdesc is not None:
@@ -1032,17 +1038,6 @@
             class_sizes[vtable] = sizedescr
         self.cpu.set_class_sizes(class_sizes)
 
-    def generate_bytecode(self, policy):
-        self._codewriter = codewriter.CodeWriter(self, policy)
-        self.leave_code = None
-        if self.leave_graph:
-            self.leave_code = self._codewriter.make_one_bytecode(
-                                                    (self.leave_graph, None),
-                                                    False)
-        self.portal_code = self._codewriter.make_portal_bytecode(
-            self.portal_graph)
-        self._class_sizes = self._codewriter.class_sizes
-
     def bytecode_for_address(self, fnaddress):
         if we_are_translated():
             d = self.globaldata.indirectcall_dict
@@ -1065,14 +1060,6 @@
 
     # ---------- construction-time interface ----------
 
-    def _register_opcode(self, opname):
-        assert len(self.opcode_implementations) < 256, \
-               "too many implementations of opcodes!"
-        name = "opimpl_" + opname
-        self.opname_to_index[opname] = len(self.opcode_implementations)
-        self.opcode_names.append(opname)
-        self.opcode_implementations.append(getattr(MIFrame, name).im_func)
-
     def _register_indirect_call_target(self, fnaddress, jitcode):
         self.indirectcall_keys.append(fnaddress)
         self.indirectcall_values.append(jitcode)
@@ -1084,6 +1071,14 @@
             self._register_opcode(name)
             return self.opname_to_index[name]
 
+    def _register_opcode(self, opname):
+        assert len(self.opcode_implementations) < 256, \
+               "too many implementations of opcodes!"
+        name = "opimpl_" + opname
+        self.opname_to_index[opname] = len(self.opcode_implementations)
+        self.opcode_names.append(opname)
+        self.opcode_implementations.append(getattr(MIFrame, name).im_func)
+
     # ---------------- logging ------------------------
 
     def log(self, msg, event_kind='info'):

Modified: pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_basic.py
==============================================================================
--- pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_basic.py	(original)
+++ pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_basic.py	Thu Oct 15 16:28:02 2009
@@ -12,8 +12,8 @@
 from pypy.rpython.lltypesystem import lltype
 from pypy.rpython.ootypesystem import ootype
 
-def get_metainterp(func, values, CPUClass, type_system, policy,
-                   listops=False, optimizer=OPTIMIZER_FULL):
+def _get_bare_metainterp(func, values, CPUClass, type_system,
+                         listops=False):
     from pypy.annotation.policy import AnnotatorPolicy
     from pypy.annotation.model import lltype_to_annotation
     from pypy.rpython.test.test_llinterp import gengraph
@@ -22,10 +22,10 @@
 
     stats = history.Stats()
     cpu = CPUClass(rtyper, stats, False)
-    graph = rtyper.annotator.translator.graphs[0]
+    graphs = rtyper.annotator.translator.graphs
     opt = history.Options(listops=listops)
-    metainterp_sd = pyjitpl.MetaInterpStaticData(graph, [], cpu, stats, opt)
-    metainterp_sd.finish_setup(optimizer=optimizer)
+    metainterp_sd = pyjitpl.MetaInterpStaticData(graphs[0], cpu, stats, opt)
+    metainterp_sd.finish_setup(optimizer="bogus")
     metainterp = pyjitpl.MetaInterp(metainterp_sd)
     return metainterp, rtyper
 
@@ -60,7 +60,7 @@
             kwds["backendopt"] = False
         return ll_meta_interp(*args, **kwds)
 
-    def interp_operations(self, f, args, policy=None, **kwds):
+    def interp_operations(self, f, args, **kwds):
         from pypy.jit.metainterp import simple_optimize
 
         class DoneWithThisFrame(Exception):
@@ -81,17 +81,17 @@
             trace_limit = sys.maxint
             debug_level = 2
         
-        if policy is None:
-            policy = JitPolicy()
-        metainterp, rtyper = get_metainterp(f, args, self.CPUClass,
-                                            self.type_system, policy=policy,
-                                            optimizer="bogus",
-                                            **kwds)
-        cw = codewriter.CodeWriter(metainterp.staticdata, policy)
-        graph = rtyper.annotator.translator.graphs[0]
-        graph.func._jit_unroll_safe_ = True
-        graph_key = (graph, None)
-        maingraph = cw.make_one_bytecode(graph_key, False)
+        metainterp, rtyper = _get_bare_metainterp(f, args, self.CPUClass,
+                                                  self.type_system,
+                                                  **kwds)
+        portal_graph = rtyper.annotator.translator.graphs[0]
+        cw = codewriter.CodeWriter(rtyper)
+        
+        graphs = cw.find_all_graphs(portal_graph, None, JitPolicy(),
+                                    self.CPUClass.supports_floats)
+        cw._start(metainterp.staticdata, None)
+        portal_graph.func._jit_unroll_safe_ = True
+        maingraph = cw.make_one_bytecode((portal_graph, None), False)
         cw.finish_making_bytecodes()
         metainterp.staticdata.portal_code = maingraph
         metainterp.staticdata._class_sizes = cw.class_sizes
@@ -297,11 +297,12 @@
         assert res == ord(u"?")
 
     def test_residual_call(self):
+        @dont_look_inside
         def externfn(x, y):
             return x * y
         def f(n):
             return externfn(n, n+1)
-        res = self.interp_operations(f, [6], policy=StopAtXPolicy(externfn))
+        res = self.interp_operations(f, [6])
         assert res == 42
         self.check_history_(int_add=1, int_mul=0, call=1, guard_no_exception=0)
 

Modified: pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_codewriter.py
==============================================================================
--- pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_codewriter.py	(original)
+++ pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_codewriter.py	Thu Oct 15 16:28:02 2009
@@ -7,6 +7,60 @@
 from pypy.translator.translator import graphof
 from pypy.rpython.lltypesystem.rbuiltin import ll_instantiate
 
+def test_find_all_graphs():
+    def f(x):
+        if x < 0:
+            return f(-x)
+        return x + 1
+    @jit.purefunction
+    def g(x):
+        return x + 2
+    @jit.dont_look_inside
+    def h(x):
+        return x + 3
+    def i(x):
+        return f(x) * g(x) * h(x)
+
+    rtyper = support.annotate(i, [7])
+    cw = CodeWriter(rtyper)
+    jitpolicy = JitPolicy()
+    res = cw.find_all_graphs(rtyper.annotator.translator.graphs[0], None,
+                             jitpolicy, True)
+    translator = rtyper.annotator.translator
+
+    funcs = set([graph.func for graph in res])
+    assert funcs == set([i, f])
+
+def test_find_all_graphs_without_floats():
+    def g(x):
+        return int(x * 12.5)
+    def f(x):
+        return g(x) + 1
+    rtyper = support.annotate(f, [7])
+    cw = CodeWriter(rtyper)
+    jitpolicy = JitPolicy()
+    translator = rtyper.annotator.translator
+    res = cw.find_all_graphs(translator.graphs[0], None, jitpolicy,
+                             supports_floats=True)
+    funcs = set([graph.func for graph in res])
+    assert funcs == set([f, g])
+
+    cw = CodeWriter(rtyper)        
+    res = cw.find_all_graphs(translator.graphs[0], None, jitpolicy,
+                             supports_floats=False)
+    funcs = [graph.func for graph in res]
+    assert funcs == [f]
+
+def test_find_all_graphs_str_join():
+    def i(x, y):
+        return "hello".join([str(x), str(y), "bye"])
+
+    rtyper = support.annotate(i, [7, 100])
+    cw = CodeWriter(rtyper)
+    jitpolicy = JitPolicy()
+    translator = rtyper.annotator.translator
+    # does not explode
+    cw.find_all_graphs(translator.graphs[0], None, jitpolicy, True)
 
 class SomeLabel(object):
     def __eq__(self, other):
@@ -15,7 +69,7 @@
 
 class TestCodeWriter:
 
-    def make_graph(self, func, values, type_system='lltype'):
+    def make_graphs(self, func, values, type_system='lltype'):
         class FakeMetaInterpSd:
             virtualizable_info = None
             def find_opcode(self, name):
@@ -56,9 +110,9 @@
         self.metainterp_sd.indirectcalls = []
         self.metainterp_sd.cpu = FakeCPU()
 
-        rtyper = support.annotate(func, values, type_system=type_system)
-        self.metainterp_sd.cpu.rtyper = rtyper
-        return rtyper.annotator.translator.graphs[0]
+        self.rtyper = support.annotate(func, values, type_system=type_system)
+        self.metainterp_sd.cpu.rtyper = self.rtyper
+        return self.rtyper.annotator.translator.graphs
 
     def graphof(self, func):
         rtyper = self.metainterp_sd.cpu.rtyper
@@ -67,15 +121,86 @@
     def test_basic(self):
         def f(n):
             return n + 10
-        graph = self.make_graph(f, [5])
-        cw = CodeWriter(self.metainterp_sd, JitPolicy())
-        jitcode = cw.make_one_bytecode((graph, None), False)
+        graphs = self.make_graphs(f, [5])
+        cw = CodeWriter(self.rtyper)
+        cw.candidate_graphs = graphs
+        cw._start(self.metainterp_sd, None)
+        jitcode = cw.make_one_bytecode((graphs[0], None), False)
         assert jitcode._source == [
             SomeLabel(),
             'int_add', 0, 1, '# => r1',
             'make_new_vars_1', 2,
             'return']
 
+    def test_guess_call_kind_and_calls_from_graphs(self):
+        from pypy.objspace.flow.model import SpaceOperation, Constant, Variable
+
+        portal_runner_ptr = object()
+        g = object()
+        g1 = object()
+        cw = CodeWriter(None)
+        cw.candidate_graphs = [g, g1]
+        cw.portal_runner_ptr = portal_runner_ptr
+
+        op = SpaceOperation('direct_call', [Constant(portal_runner_ptr)],
+                            Variable())
+        assert cw.guess_call_kind(op) == 'recursive'
+
+        op = SpaceOperation('direct_call', [Constant(object())],
+                            Variable())
+        assert cw.guess_call_kind(op) == 'residual'        
+
+        class funcptr:
+            class graph:
+                class func:
+                    oopspec = "spec"
+        op = SpaceOperation('direct_call', [Constant(funcptr)],
+                            Variable())
+        assert cw.guess_call_kind(op) == 'builtin'
+        
+        class funcptr:
+            graph = g
+        op = SpaceOperation('direct_call', [Constant(funcptr)],
+                            Variable())
+        res = cw.graphs_from(op)
+        assert res == [g]        
+        assert cw.guess_call_kind(op) == 'regular'
+
+        class funcptr:
+            graph = object()
+        op = SpaceOperation('direct_call', [Constant(funcptr)],
+                            Variable())
+        res = cw.graphs_from(op)
+        assert res is None        
+        assert cw.guess_call_kind(op) == 'residual'
+
+        h = object()
+        op = SpaceOperation('indirect_call', [Variable(),
+                                              Constant([g, g1, h])],
+                            Variable())
+        res = cw.graphs_from(op)
+        assert res == [g, g1]
+        assert cw.guess_call_kind(op) == 'regular'
+
+        op = SpaceOperation('indirect_call', [Variable(),
+                                              Constant([h])],
+                            Variable())
+        res = cw.graphs_from(op)
+        assert res is None
+        assert cw.guess_call_kind(op) == 'residual'        
+        
+    def test_direct_call(self):
+        def g(m):
+            return 123
+        def f(n):
+            return g(n+1)
+        graphs = self.make_graphs(f, [5])
+        cw = CodeWriter(self.rtyper)
+        cw.candidate_graphs = graphs
+        cw._start(self.metainterp_sd, None)
+        jitcode = cw.make_one_bytecode((graphs[0], None), False)
+        assert len(cw.all_graphs) == 2        
+
     def test_indirect_call_target(self):
         def g(m):
             return 123
@@ -87,9 +212,11 @@
             else:
                 call = h
             return call(n+1) + call(n+2)
-        graph = self.make_graph(f, [5])
-        cw = CodeWriter(self.metainterp_sd, JitPolicy())
-        jitcode = cw.make_one_bytecode((graph, None), False)
+        graphs = self.make_graphs(f, [5])
+        cw = CodeWriter(self.rtyper)
+        cw.candidate_graphs = graphs
+        cw._start(self.metainterp_sd, None)        
+        jitcode = cw.make_one_bytecode((graphs[0], None), False)
         assert len(self.metainterp_sd.indirectcalls) == 2
         names = [jitcode.name for (fnaddress, jitcode)
                                in self.metainterp_sd.indirectcalls]
@@ -107,9 +234,13 @@
             else:
                 call = h
             return call(n+1) + call(n+2)
-        graph = self.make_graph(f, [5])
-        cw = CodeWriter(self.metainterp_sd, JitPolicy())
-        jitcode = cw.make_one_bytecode((graph, None), False)
+        graphs = self.make_graphs(f, [5])
+        graphs = [g for g in graphs if getattr(g.func, '_jit_look_inside_',
+                                               True)]
+        cw = CodeWriter(self.rtyper)
+        cw.candidate_graphs = graphs
+        cw._start(self.metainterp_sd, None)        
+        jitcode = cw.make_one_bytecode((graphs[0], None), False)
         assert len(self.metainterp_sd.indirectcalls) == 1
         names = [jitcode.name for (fnaddress, jitcode)
                                in self.metainterp_sd.indirectcalls]
@@ -131,9 +262,13 @@
             else:
                 x = C()
             return x.g() + x.g()
-        graph = self.make_graph(f, [5], type_system='ootype')
-        cw = CodeWriter(self.metainterp_sd, JitPolicy())
-        jitcode = cw.make_one_bytecode((graph, None), False)
+        graphs = self.make_graphs(f, [5], type_system='ootype')
+        graphs = [g for g in graphs if getattr(g.func, '_jit_look_inside_',
+                                               True)]
+        cw = CodeWriter(self.rtyper)
+        cw.candidate_graphs = graphs
+        cw._start(self.metainterp_sd, None)        
+        jitcode = cw.make_one_bytecode((graphs[0], None), False)
         assert len(self.methdescrs) == 1
         assert self.methdescrs[0].CLASS._name.endswith('.A')
         assert self.methdescrs[0].methname == 'og'
@@ -157,9 +292,11 @@
                 x, y = A2, B2
             n += 1
             return x().id + y().id + n
-        graph = self.make_graph(f, [5])
-        cw = CodeWriter(self.metainterp_sd, JitPolicy())
-        cw.make_one_bytecode((graph, None), False)
+        graphs = self.make_graphs(f, [5])
+        cw = CodeWriter(self.rtyper)
+        cw.candidate_graphs = graphs
+        cw._start(self.metainterp_sd, None)        
+        cw.make_one_bytecode((graphs[0], None), False)
         graph2 = self.graphof(ll_instantiate)
         jitcode = cw.make_one_bytecode((graph2, None), False)
         assert 'residual_call' not in jitcode._source

Modified: pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_warmspot.py
==============================================================================
--- pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_warmspot.py	(original)
+++ pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/test/test_warmspot.py	Thu Oct 15 16:28:02 2009
@@ -270,6 +270,32 @@
         self.meta_interp(f, [50])
         self.check_enter_count_at_most(2)
 
+    def test_wanted_unrolling_and_preinlining(self):
+        mydriver = JitDriver(reds = ['n', 'm'], greens = [])
+
+        @unroll_safe
+        def loop2(n):
+            # the jit looks here, due to the decorator
+            for i in range(5):
+                n += 1
+            return n
+        loop2._always_inline_ = True
+
+        def g(n):
+            return loop2(n)
+        g._dont_inline_ = True
+
+        def f(m):
+            n = 0
+            while n < m:
+                mydriver.can_enter_jit(n=n, m=m)
+                mydriver.jit_merge_point(n=n, m=m)
+                n = g(n)
+            return n
+        self.meta_interp(f, [50], backendopt=True)
+        self.check_enter_count_at_most(2)
+        self.check_loops(call=0)
+
 
 class TestLLWarmspot(WarmspotTests, LLJitMixin):
     CPUClass = runner.LLtypeCPU

Modified: pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/warmspot.py
==============================================================================
--- pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/warmspot.py	(original)
+++ pypy/branch/prevent-silly-unrolling/pypy/jit/metainterp/warmspot.py	Thu Oct 15 16:28:02 2009
@@ -17,6 +17,7 @@
 from pypy.translator.simplify import get_funcobj, get_functype
 from pypy.translator.unsimplify import call_final_function
 
+from pypy.jit.metainterp import codewriter
 from pypy.jit.metainterp import support, history, pyjitpl, gc
 from pypy.jit.metainterp.pyjitpl import MetaInterpStaticData, MetaInterp
 from pypy.jit.metainterp.policy import JitPolicy
@@ -153,8 +154,11 @@
         self.set_translator(translator)
         self.find_portal()
         self.make_leave_jit_graph()
-        graphs = find_all_graphs(self.portal_graph, policy, self.translator,
-                                 CPUClass.supports_floats)
+        self.codewriter = codewriter.CodeWriter(self.rtyper)
+        graphs = self.codewriter.find_all_graphs(self.portal_graph,
+                                                 self.leave_graph,
+                                                 policy,
+                                                 CPUClass.supports_floats)
         policy.dump_unsafe_loops()
         self.check_access_directly_sanity(graphs)
         if backendopt:
@@ -167,7 +171,12 @@
         if self.jitdriver.virtualizables:
             from pypy.jit.metainterp.virtualizable import VirtualizableInfo
             self.metainterp_sd.virtualizable_info = VirtualizableInfo(self)
-        self.metainterp_sd.generate_bytecode(policy)
+                
+        self.codewriter.generate_bytecode(self.metainterp_sd,
+                                          self.portal_graph,
+                                          self.leave_graph,
+                                          self.portal_runner_ptr
+                                          )
         self.make_enter_function()
         self.rewrite_can_enter_jit()
         self.rewrite_set_param()
@@ -186,6 +195,7 @@
 
     def set_translator(self, translator):
         self.translator = translator
+        self.rtyper = translator.rtyper
         self.gcdescr = gc.get_description(translator.config)
 
     def find_portal(self):
@@ -252,13 +262,12 @@
         cpu = CPUClass(self.translator.rtyper, self.stats,
                        translate_support_code, gcdescr=self.gcdescr)
         self.cpu = cpu
-        self.metainterp_sd = MetaInterpStaticData(self.portal_graph,
-                                                  self.translator.graphs, cpu,
+        self.metainterp_sd = MetaInterpStaticData(self.portal_graph, # xxx
+                                                  cpu,
                                                   self.stats, opt,
                                                   ProfilerClass=ProfilerClass,
-                                                  warmrunnerdesc=self,
-                                                  leave_graph=self.leave_graph)
-
+                                                  warmrunnerdesc=self)
+        
     def make_enter_function(self):
         WarmEnterState = make_state_class(self)
         state = WarmEnterState()
@@ -511,9 +520,8 @@
                         value = cast_base_ptr_to_instance(Exception, value)
                         raise Exception, value
         
-        portal_runner_ptr = self.helper_func(self.PTR_PORTAL_FUNCTYPE,
-                                             ll_portal_runner)
-        policy.portal_runner_ptr = portal_runner_ptr
+        self.portal_runner_ptr = self.helper_func(self.PTR_PORTAL_FUNCTYPE,
+                                                  ll_portal_runner)
 
         # ____________________________________________________________
         # Now mutate origportalgraph to end with a call to portal_runner_ptr
@@ -523,7 +531,7 @@
         assert op.opname == 'jit_marker'
         assert op.args[0].value == 'jit_merge_point'
         greens_v, reds_v = decode_hp_hint_args(op)
-        vlist = [Constant(portal_runner_ptr, self.PTR_PORTAL_FUNCTYPE)]
+        vlist = [Constant(self.portal_runner_ptr, self.PTR_PORTAL_FUNCTYPE)]
         vlist += greens_v
         vlist += reds_v
         v_result = Variable()
@@ -566,30 +574,6 @@
             op.args[:3] = [closures[funcname]]
 
 
-def find_all_graphs(portal, policy, translator, supports_floats):
-    from pypy.translator.simplify import get_graph
-    rtyper = translator.rtyper
-    all_graphs = [portal]
-    seen = set([portal])
-    todo = [portal]
-    while todo:
-        top_graph = todo.pop()
-        for _, op in top_graph.iterblockops():
-            if op.opname not in ("direct_call", "indirect_call", "oosend"):
-                continue
-            kind = policy.guess_call_kind(op, rtyper, supports_floats)
-            if kind != "regular":
-                continue
-            for graph in policy.graphs_from(op, rtyper, supports_floats):
-                if graph in seen:
-                    continue
-                if policy.look_inside_graph(graph, supports_floats):
-                    todo.append(graph)
-                    all_graphs.append(graph)
-                    seen.add(graph)
-    return all_graphs
-
-
 def decode_hp_hint_args(op):
     # Returns (list-of-green-vars, list-of-red-vars) without Voids.
     assert op.opname == 'jit_marker'



More information about the Pypy-commit mailing list