[pypy-svn] r17169 - in pypy/dist/pypy/translator: . test

cfbolz at codespeak.net cfbolz at codespeak.net
Fri Sep 2 18:31:36 CEST 2005


Author: cfbolz
Date: Fri Sep  2 18:31:35 2005
New Revision: 17169

Added:
   pypy/dist/pypy/translator/test/test_unsimplify.py
Modified:
   pypy/dist/pypy/translator/backendoptimization.py
   pypy/dist/pypy/translator/test/test_backendoptimization.py
   pypy/dist/pypy/translator/unsimplify.py
Log:
implemented a simple version of inlining, that does not work for functions guarded by a try...except.
in addition I implemented a split_block that splits a block in two blocks, doing the right thing for passing variables and such.
this is used by inline_function: inline_function splits the block where the call occurs into two blocks
and then copies the blocks of the function to be inlined. the link between the splitted blocks is replaced by
these copies.


Modified: pypy/dist/pypy/translator/backendoptimization.py
==============================================================================
--- pypy/dist/pypy/translator/backendoptimization.py	(original)
+++ pypy/dist/pypy/translator/backendoptimization.py	Fri Sep  2 18:31:35 2005
@@ -1,8 +1,9 @@
 import autopath
 from pypy.translator.translator import Translator
-from pypy.translator.simplify import eliminate_empty_blocks
+from pypy.translator.simplify import eliminate_empty_blocks, join_blocks, remove_identical_vars
+from pypy.translator.unsimplify import copyvar, split_block
 from pypy.objspace.flow.model import Variable, Constant, Block, Link
-from pypy.objspace.flow.model import SpaceOperation
+from pypy.objspace.flow.model import SpaceOperation, last_exception
 from pypy.objspace.flow.model import traverse, mkentrymap, checkgraph
 from pypy.tool.unionfind import UnionFind
 from pypy.rpython.lltype import Void
@@ -147,6 +148,113 @@
             "variables called %s have mixed concretetypes: %r" % (vname, vct))
 
 
+def inline_function(translator, inline_func, graph):
+    callsites = []
+    def find_callsites(block):
+        if isinstance(block, Block):
+            for i, op in enumerate(block.operations):
+                if not (op.opname == "direct_call" and
+                    isinstance(op.args[0], Constant)):
+                    continue
+                if op.args[0].value._obj._callable is inline_func:
+                    callsites.append((block, i))
+    traverse(find_callsites, graph)
+    for block, index_operation in callsites:
+        _inline_function(translator, graph, block, index_operation)
+        checkgraph(graph)
+    
+def _inline_function(translator, graph, block, index_operation):
+    if block.exitswitch == Constant(last_exception):
+        assert index_operation != len(block.operations) - 1, (
+            "can't handle exceptions yet")
+    op = block.operations[index_operation]
+    graph_to_inline = translator.flowgraphs[op.args[0].value._obj._callable]
+    entrymap = mkentrymap(graph_to_inline)
+    beforeblock = block
+    afterblock = split_block(translator, graph, block, index_operation + 1)
+    assert beforeblock.operations[-1] is op
+    #vars that need to be passed through the blocks of the inlined function
+    #this excludes the var resulting of the direct_call
+    passon_vars = {beforeblock: [arg for arg in beforeblock.exits[0].args
+                                     if isinstance(arg, Variable) and
+                                         arg != op.result]}
+    copied_blocks = {}
+    varmap = {}
+    def get_new_name(var):
+        if var is None:
+            return None
+        if isinstance(var, Constant):
+            return var
+        if var not in varmap:
+            varmap[var] = copyvar(translator, var)
+        return varmap[var]
+    def get_new_passon_var_names(block):
+        result = [copyvar(translator, var) for var in passon_vars[beforeblock]]
+        passon_vars[block] = result
+        return result
+    def copy_operation(op):
+        args = [get_new_name(arg) for arg in op.args]
+        return SpaceOperation(op.opname, args, get_new_name(op.result))
+    def copy_block(block):
+        if block in copied_blocks:
+            "already there"
+            return copied_blocks[block]
+        args = ([get_new_name(var) for var in block.inputargs] +
+                get_new_passon_var_names(block))
+        newblock = Block(args)
+        copied_blocks[block] = newblock
+        newblock.operations = [copy_operation(op) for op in block.operations]
+        newblock.exits = [copy_link(link, block) for link in block.exits]
+        newblock.exitswitch = get_new_name(block.exitswitch)
+        newblock.exc_handler = block.exc_handler
+        return newblock
+    def copy_link(link, prevblock):
+        newargs = [get_new_name(a) for a in link.args] + passon_vars[prevblock]
+        newlink = Link(newargs, copy_block(link.target), link.exitcase)
+        newlink.prevblock = copy_block(link.prevblock)
+        newlink.last_exception = get_new_name(link.last_exception)
+        newlink.last_exc_value = get_new_name(link.last_exc_value)
+        if hasattr(link, 'llexitcase'):
+            newlink.llexitcase = link.llexitcase
+        return newlink
+    linktoinlined = beforeblock.exits[0]
+    assert linktoinlined.target is afterblock
+    copiedstartblock = copy_block(graph_to_inline.startblock)
+    copiedstartblock.isstartblock = False
+    copiedreturnblock = copied_blocks[graph_to_inline.returnblock]
+    passon_args = []
+    i = 0
+    for arg in linktoinlined.args:
+        if isinstance(arg, Constant):
+            passon_args.append(arg)
+        elif arg == op.result:
+            passon_args.append(copiedreturnblock.inputargs[0])
+        else:
+            passon_args.append(passon_vars[graph_to_inline.returnblock][i])
+            i += 1
+    linktoinlined.target = copiedstartblock
+    linktoinlined.args = op.args[1:] + passon_vars[beforeblock]
+    afterblock.inputargs = afterblock.inputargs
+    beforeblock.operations = beforeblock.operations[:-1]
+    linkfrominlined = Link(passon_args, afterblock)
+    linkfrominlined.prevblock = copiedreturnblock
+    copiedreturnblock.exitswitch = None
+    copiedreturnblock.exits = [linkfrominlined]
+    assert copiedreturnblock.exits[0].target == afterblock
+    #let links to exceptblock of the graph to inline go to graphs exceptblock
+    if graph_to_inline.exceptblock in entrymap:
+        copiedexceptblock = copied_blocks[graph_to_inline.exceptblock]
+        for link in entrymap[graph_to_inline.exceptblock]:
+            copiedblock = copied_blocks[link.prevblock]
+            assert len(copiedblock.exits) == 1
+            copiedblock.exits[0].args = copiedblock.exits[0].args[:2]
+            copiedblock.exits[0].target = graph.exceptblock
+    #cleaning up -- makes sense to be here, because I insert quite
+    #some empty blocks and blocks that can be joined
+    eliminate_empty_blocks(graph)
+    join_blocks(graph)
+    remove_identical_vars(graph)
+
 def backend_optimizations(graph):
     remove_same_as(graph)
     eliminate_empty_blocks(graph)

Modified: pypy/dist/pypy/translator/test/test_backendoptimization.py
==============================================================================
--- pypy/dist/pypy/translator/test/test_backendoptimization.py	(original)
+++ pypy/dist/pypy/translator/test/test_backendoptimization.py	Fri Sep  2 18:31:35 2005
@@ -1,9 +1,9 @@
-from pypy.translator.backendoptimization import remove_void
+from pypy.translator.backendoptimization import remove_void, inline_function
 from pypy.translator.translator import Translator
 from pypy.rpython.lltype import Void
 from pypy.rpython.llinterp import LLInterpreter
 from pypy.objspace.flow.model import checkgraph
-from pypy.translator.test.snippet import simple_method
+from pypy.translator.test.snippet import simple_method, is_perfect_number
 from pypy.translator.llvm.log import log
 
 import py
@@ -42,3 +42,67 @@
     #interp = LLInterpreter(t.flowgraphs, t.rtyper)
     #assert interp.eval_function(f, [0]) == 1 
 
+def test_inline_simple():
+    def f(x, y):
+        return (g(x, y) + 1) * x
+    def g(x, y):
+        if x > 0:
+            return x * y
+        else:
+            return -x * y
+    t = Translator(f)
+    a = t.annotate([int, int])
+    a.simplify()
+    t.specialize()
+    inline_function(t, g, t.flowgraphs[f])
+    interp = LLInterpreter(t.flowgraphs, t.rtyper)
+    result = interp.eval_function(f, [-1, 5])
+    assert result == f(-1, 5)
+    result = interp.eval_function(f, [2, 12])
+    assert result == f(2, 12)
+
+def test_inline_big():
+    def f(x):
+        result = []
+        for i in range(1, x+1):
+            if is_perfect_number(i):
+                result.append(i)
+        return result
+    t = Translator(f)
+    a = t.annotate([int])
+    a.simplify()
+    t.specialize()
+    inline_function(t, is_perfect_number, t.flowgraphs[f])
+    interp = LLInterpreter(t.flowgraphs, t.rtyper)
+    result = interp.eval_function(f, [10])
+    assert result.length == len(f(10))
+
+def test_inline_raising():
+    def f(x):
+        if x == 1:
+            raise ValueError
+        return x
+    def g(x):
+        a = f(x)
+        if x == 2:
+            raise KeyError
+    def h(x):
+        try:
+            g(x)
+        except ValueError:
+            return 1
+        except KeyError:
+            return 2
+        return x
+    t = Translator(h)
+    a = t.annotate([int])
+    a.simplify()
+    t.specialize()
+    inline_function(t, f, t.flowgraphs[g])
+    interp = LLInterpreter(t.flowgraphs, t.rtyper)
+    result = interp.eval_function(h, [0])
+    assert result == 0
+    result = interp.eval_function(h, [1])
+    assert result == 1
+    result = interp.eval_function(h, [2])
+    assert result == 2    

Added: pypy/dist/pypy/translator/test/test_unsimplify.py
==============================================================================
--- (empty file)
+++ pypy/dist/pypy/translator/test/test_unsimplify.py	Fri Sep  2 18:31:35 2005
@@ -0,0 +1,67 @@
+from pypy.rpython.llinterp import LLInterpreter
+from pypy.translator.translator import Translator
+from pypy.translator.unsimplify import split_block
+
+def test_split_blocks_simple():
+    for i in range(4):
+        def f(x, y):
+            z = x + y
+            w = x * y
+            return z + w
+        t = Translator(f)
+        a = t.annotate([int, int])
+        t.specialize()
+        graph = t.flowgraphs[f]
+        split_block(t, graph, graph.startblock, i)
+        interp = LLInterpreter(t.flowgraphs, t.rtyper)
+        result = interp.eval_function(f, [1, 2])
+        assert result == 5
+    
+def test_split_blocks_conditional():
+    for i in range(3):
+        def f(x, y):
+            if x + 12:
+                return y + 1
+            else:
+                return y + 2
+        t = Translator(f)
+        a = t.annotate([int, int])
+        t.specialize()
+        graph = t.flowgraphs[f]
+        split_block(t, graph, graph.startblock, i)
+        interp = LLInterpreter(t.flowgraphs, t.rtyper)
+        result = interp.eval_function(f, [-12, 2])
+        assert result == 4
+        result = interp.eval_function(f, [0, 2])
+        assert result == 3
+
+def test_split_block_exceptions():
+    for i in range(2):
+        def raises(x):
+            if x == 1:
+                raise ValueError
+            elif x == 2:
+                raise KeyError
+            return x
+        def catches(x):
+            try:
+                y = x + 1
+                raises(y)
+            except ValueError:
+                return 0
+            except KeyError:
+                return 1
+            return x
+        t = Translator(catches)
+        a = t.annotate([int])
+        t.specialize()
+        graph = t.flowgraphs[catches]
+        split_block(t, graph, graph.startblock, i)
+        interp = LLInterpreter(t.flowgraphs, t.rtyper)
+        result = interp.eval_function(catches, [0])
+        assert result == 0
+        result = interp.eval_function(catches, [1])
+        assert result == 1
+        result = interp.eval_function(catches, [2])
+        assert result == 2
+    

Modified: pypy/dist/pypy/translator/unsimplify.py
==============================================================================
--- pypy/dist/pypy/translator/unsimplify.py	(original)
+++ pypy/dist/pypy/translator/unsimplify.py	Fri Sep  2 18:31:35 2005
@@ -35,6 +35,56 @@
     link.target = newblock
     return newblock
 
+def split_block(translator, graph, block, index):
+    """split a block in two, inserting a proper link between the new blocks"""
+    assert 0 <= index <= len(block.operations)
+    if block.exitswitch == Constant(last_exception):
+        assert index < len(block.operations)
+    #varmap is the map between names in the new and the old block
+    #but only for variables that are produced in the old block and needed in
+    #the new one
+    varmap = {}
+    vars_produced_in_new_block = {}
+    def get_new_name(var):
+        if var is None:
+            return None
+        if isinstance(var, Constant):
+            return var
+        if var in vars_produced_in_new_block:
+            return var
+        if var not in varmap:
+            varmap[var] = copyvar(translator, var)
+        return varmap[var]
+    moved_operations = block.operations[index:]
+    for op in moved_operations:
+        for i, arg in enumerate(op.args):
+            op.args[i] = get_new_name(op.args[i])
+        vars_produced_in_new_block[op.result] = True
+    for link in block.exits:
+        for i, arg in enumerate(link.args):
+            #last_exception and last_exc_value are considered to be created
+            #when the link is entered
+            if link.args[i] not in [link.last_exception, link.last_exc_value]:
+                link.args[i] = get_new_name(link.args[i])
+    exitswitch = get_new_name(block.exitswitch)
+    #the new block gets all the attributes relevant to outgoing links
+    #from block the old block
+    newblock = Block(varmap.values())
+    newblock.operations = moved_operations
+    newblock.exits = block.exits
+    newblock.exitswitch = exitswitch
+    newblock.exc_handler = block.exc_handler
+    for link in newblock.exits:
+        link.prevblock = newblock
+    link = Link(varmap.keys(), newblock)
+    link.prevblock = block
+    block.operations = block.operations[:index]
+    block.exits = [link]
+    block.exitswitch = None
+    block.exc_handler = False
+    checkgraph(graph)
+    return newblock
+
 def remove_direct_loops(translator, graph):
     """This is useful for code generators: it ensures that no link has
     common input and output variables, which could occur if a block's exit



More information about the Pypy-commit mailing list