[pypy-svn] r61953 - in pypy/branch/pyjitpl5/pypy/jit/metainterp: . test

fijal at codespeak.net fijal at codespeak.net
Mon Feb 16 19:24:25 CET 2009


Author: fijal
Date: Mon Feb 16 19:24:23 2009
New Revision: 61953

Modified:
   pypy/branch/pyjitpl5/pypy/jit/metainterp/optimize.py
   pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_tlc.py
   pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_vlist.py
Log:
Tests and fixes


Modified: pypy/branch/pyjitpl5/pypy/jit/metainterp/optimize.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/jit/metainterp/optimize.py	(original)
+++ pypy/branch/pyjitpl5/pypy/jit/metainterp/optimize.py	Mon Feb 16 19:24:23 2009
@@ -37,7 +37,7 @@
         self.setitems = []
 
     def deal_with_box(self, box, nodes, liveboxes, memo, ready):
-        if isinstance(box, Const):
+        if isinstance(box, Const) or box not in nodes:
             virtual = False
             virtualized = False
         else:
@@ -48,8 +48,6 @@
                 return memo[box]
             virtual = instnode.virtual
             virtualized = instnode.virtualized
-            if not virtual:
-                assert box in ready
         if virtual:
             if isinstance(instnode.cls.source, ListDescr):
                 ld = instnode.cls.source
@@ -112,14 +110,26 @@
         self.expanded_fields = {}
         self.known_length = -1
 
+    def set_known_length(self, val):
+        if val != -1:
+            print "Setting: %r to %d" % (self, val)
+        assert val >= -1
+        self._kl = val
+
+    def get_known_length(self):
+        return self._kl
+
+    known_length = property(get_known_length, set_known_length)
+
     def escape_if_startbox(self, memo):
         if self in memo:
             return
         memo[self] = None
         if self.startbox:
             self.escaped = True
-        for node in self.curfields.values():
-            node.escape_if_startbox(memo)
+        if not self.virtualized:
+            for node in self.curfields.values():
+                node.escape_if_startbox(memo)
 
     def add_to_dependency_graph(self, other, dep_graph):
         dep_graph.append((self, other))
@@ -149,11 +159,10 @@
                 XXX # XXX think
             fields = []
             if self is other:
-                d = other.curfields
+                d = other.curfields.copy()
                 d.update(self.origfields)
             else:
                 d = other.curfields
-            d = other.curfields
             lst = d.items()
             lst.sort()
             for ofs, node in lst:
@@ -284,7 +293,7 @@
             if ofs >= field:
                 instnode.curfields[ofs + 1] = node
         instnode.curfields[field] = fieldnode
-        instnode.known_length += 1
+        instnode.known_length = instnode.known_length + 1
         self.dependency_graph.append((instnode, fieldnode))
         
     def find_nodes(self):
@@ -308,20 +317,21 @@
                 instnode = InstanceNode(box, escaped=False)
                 self.nodes[box] = instnode
                 self.first_escaping_op = False
-                # XXX we don't support lists with variable length
-                assert isinstance(op.args[1], ConstInt)
-                instnode.known_length = op.args[1].getint()
-                # XXX following guard_builtin will set the
-                #     correct class, otherwise it's a mess
-                continue
+                if (isinstance(op.args[1], ConstInt) or
+                    self.nodes[op.args[1]].const):
+                    instnode.known_length = self.getsource(op.args[1]).getint()
+                    # XXX following guard_builtin will set the
+                    #     correct class, otherwise it's a mess
+                    continue
             elif opname == 'guard_builtin':
                 instnode = self.nodes[op.args[0]]
                 # all builtins have equal classes
                 instnode.cls = InstanceNode(op.args[1])
                 continue
             elif opname == 'guard_len':
-                instnode = self.nodes[op.args[0]]
-                instnode.known_length = op.args[1].getint()
+                if instnode.known_length == -1:
+                    instnode = self.nodes[op.args[0]]
+                    instnode.known_length = op.args[1].getint()
                 continue
             elif opname == 'setfield_gc':
                 instnode = self.getnode(op.args[0])
@@ -346,41 +356,47 @@
                     self.nodes[op.args[2]].const):
                     field = self.getsource(fieldbox).getint()
                     if field < 0:
-                        field = instnode.known_length - field
+                        field = instnode.known_length + field
                     box = op.results[0]
                     self.find_nodes_getfield(instnode, field, box)
+                    print instnode, instnode.curfields, instnode.known_length 
                     continue
                 else:
                     instnode.escaped = True
             elif opname == 'append':
                 instnode = self.getnode(op.args[1])
                 assert isinstance(instnode.cls.source, ListDescr)
-                assert instnode.known_length != -1
-                field = instnode.known_length
-                instnode.known_length += 1
-                self.find_nodes_setfield(instnode, field,
-                                         self.getnode(op.args[2]))
+                if instnode.known_length != -1:
+                    field = instnode.known_length
+                    instnode.known_length = instnode.known_length + 1
+                    self.find_nodes_setfield(instnode, field,
+                                             self.getnode(op.args[2]))
+                    print instnode, instnode.curfields, instnode.known_length
                 continue
             elif opname == 'insert':
                 instnode = self.getnode(op.args[1])
                 assert isinstance(instnode.cls.source, ListDescr)
-                assert instnode.known_length != -1
-                fieldbox = self.getsource(op.args[2])
-                assert isinstance(fieldbox, Const) or fieldbox.const
-                field = fieldbox.getint()
-                if field < 0:
-                    field = instnode.known_length - field
-                self.find_nodes_insert(instnode, field,
-                                       self.getnode(op.args[3]))
+                if instnode.known_length != -1:
+                    fieldbox = self.getsource(op.args[2])
+                    assert isinstance(fieldbox, Const) or fieldbox.const
+                    field = fieldbox.getint()
+                    if field < 0:
+                        field = instnode.known_length + field
+                    self.find_nodes_insert(instnode, field,
+                                           self.getnode(op.args[3]))
+                print instnode, instnode.curfields, instnode.known_length
                 continue
             elif opname == 'pop':
                 instnode = self.getnode(op.args[1])
                 assert isinstance(instnode.cls.source, ListDescr)
-                assert instnode.known_length != -1
-                instnode.known_length -= 1
-                field = instnode.known_length
-                self.find_nodes_getfield(instnode, field, op.results[0])
-                continue
+                if instnode.known_length != -1:
+                    instnode.known_length = instnode.known_length - 1
+                    field = instnode.known_length
+                    self.find_nodes_getfield(instnode, field, op.results[0])
+                    if field in instnode.curfields:
+                        del instnode.curfields[field]                
+                    print instnode, instnode.curfields, instnode.known_length
+                    continue
             elif opname == 'len' or opname == 'listnonzero':
                 instnode = self.getnode(op.args[1])
                 if not instnode.escaped:
@@ -395,10 +411,11 @@
                     or self.nodes[op.args[2]].const):
                     field = self.getsource(fieldbox).getint()
                     if field < 0:
-                        field = instnode.known_length - field
+                        field = instnode.known_length + field
                     assert field < instnode.known_length
                     self.find_nodes_setfield(instnode, field,
                                              self.getnode(op.args[3]))
+                    print instnode, instnode.curfields, instnode.known_length
                     continue
                 else:
                     instnode.escaped = True
@@ -544,7 +561,7 @@
     def optimize_getfield(self, instnode, ofs, box):
         if instnode.virtual or instnode.virtualized:
             if ofs < 0:
-                ofs = instnode.known_length - ofs
+                ofs = instnode.known_length + ofs
             assert ofs in instnode.curfields
             return True # this means field is never actually
         elif ofs in instnode.cleanfields:
@@ -556,6 +573,8 @@
 
     def optimize_setfield(self, instnode, ofs, valuenode, valuebox):
         if instnode.virtual or instnode.virtualized:
+            if ofs < 0:
+                ofs = instnode.known_length + ofs
             instnode.curfields[ofs] = valuenode
         else:
             assert not valuenode.virtual
@@ -569,12 +588,13 @@
             if ofs >= field:
                 instnode.curfields[ofs + 1] = node
         instnode.curfields[field] = valuenode
-        instnode.known_length += 1
+        instnode.known_length = instnode.known_length + 1
 
     def optimize_loop(self):
         self.ready_results = {}
         newoperations = []
         exception_might_have_happened = False
+        ops_so_far = []
         mp = self.loop.operations[0]
         if mp.opname == 'merge_point':
             assert len(mp.args) == len(self.specnodes)
@@ -588,6 +608,8 @@
                 assert not self.nodes[box].virtual
 
         for op in self.loop.operations:
+            ops_so_far.append(op)
+
             if newoperations and newoperations[-1].results:
                 self.ready_results[newoperations[-1].results[0]] = None
             opname = op.opname
@@ -623,13 +645,13 @@
             elif opname == 'guard_len':
                 # it should be completely gone, because if it escapes
                 # we don't virtualize it anymore
-                if not instnode.escaped:
+                if not instnode.escaped and instnode.known_length == -1:
                     instnode = self.nodes[op.args[0]]
                     instnode.known_length = op.args[1].getint()
                 continue
             elif opname == 'guard_nonvirtualized':
                 instnode = self.nodes[op.args[0]]
-                if instnode.virtualized:
+                if instnode.virtualized or instnode.virtual:
                     continue
                 op = self.optimize_guard(op)
                 newoperations.append(op)
@@ -680,40 +702,40 @@
                 if not instnode.escaped:
                     instnode.virtual = True
                     valuesource = self.getsource(op.args[2])
-                    assert isinstance(valuesource, Const)
-                    maxlength = max(instnode.curfields.keys() +
-                                    instnode.origfields.keys())
-                    for i in range(maxlength + 1):
-                        instnode.curfields[i] = InstanceNode(valuesource,
-                                                             const=True)
-                    for ofs, item in instnode.origfields.items():
-                        self.nodes[item.source] = instnode.curfields[ofs]
                     instnode.known_length = op.args[1].getint()
+                    curfields = {}
+                    for i in range(instnode.known_length):
+                        curfields[i] = InstanceNode(valuesource,
+                                                    const=True)
+                    instnode.curfields = curfields
                     continue
             elif opname == 'append':
                 instnode = self.nodes[op.args[1]]
-                valuenode = self.nodes[op.args[2]]
-                ofs = instnode.known_length
-                instnode.known_length += 1
-                assert ofs != -1
-                self.optimize_setfield(instnode, ofs, valuenode, op.args[2])
-                continue
+                valuenode = self.getnode(op.args[2])
+                if instnode.virtual:
+                    ofs = instnode.known_length
+                    instnode.known_length = instnode.known_length + 1
+                    self.optimize_setfield(instnode, ofs, valuenode, op.args[2])
+                    continue
             elif opname == 'insert':
                 instnode = self.nodes[op.args[1]]
-                ofs = self.getsource(op.args[2]).getint()
-                valuenode = self.nodes[op.args[3]]
-                self.optimize_insert(instnode, ofs, valuenode, op.args[3])
-                continue
+                if instnode.virtual:
+                    ofs = self.getsource(op.args[2]).getint()
+                    valuenode = self.nodes[op.args[3]]
+                    self.optimize_insert(instnode, ofs, valuenode, op.args[3])
+                    continue
             elif opname == 'pop':
                 instnode = self.nodes[op.args[1]]
-                instnode.known_length -= 1
-                ofs = instnode.known_length
-                if self.optimize_getfield(instnode, ofs, op.results[0]):
+                if instnode.virtual:
+                    instnode.known_length = instnode.known_length - 1
+                    ofs = instnode.known_length
+                    if self.optimize_getfield(instnode, ofs, op.results[0]):
+                        del instnode.curfields[ofs]
                     continue
             elif opname == 'len' or opname == 'listnonzero':
                 instnode = self.nodes[op.args[1]]
-                assert instnode.known_length
-                continue
+                if instnode.virtual:
+                    continue
             elif opname == 'setfield_gc':
                 instnode = self.nodes[op.args[0]]
                 valuenode = self.nodes[op.args[2]]

Modified: pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_tlc.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_tlc.py	(original)
+++ pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_tlc.py	Mon Feb 16 19:24:23 2009
@@ -44,7 +44,7 @@
         assert res == 42
 
     def test_accumulator(self):
-        py.test.skip("X")
+        py.test.skip("x")
         path = py.path.local(tlc.__file__).dirpath('accumulator.tlc.src')
         code = path.read()
         res = self.exec_code(code, 20)

Modified: pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_vlist.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_vlist.py	(original)
+++ pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_vlist.py	Mon Feb 16 19:24:23 2009
@@ -69,7 +69,8 @@
                 jitdriver.can_enter_jit(n=n)
                 jitdriver.jit_merge_point(n=n)
                 l = [0] * 20
-                x = l[3]
+                l[3] = 5
+                x = l[-17]
                 if n < 3:
                     return x
                 n -= 1
@@ -131,37 +132,62 @@
         self.check_all_virtualized()
         self.check_loops(listnonzero=0, guard_true=1, guard_false=0)
 
+    def test_append_pop_rebuild(self):
+        jitdriver = JitDriver(greens = [], reds = ['n'])
+        def f(n):
+            while n > 0:
+                jitdriver.can_enter_jit(n=n)
+                jitdriver.jit_merge_point(n=n)
+                lst = []
+                lst.append(5)
+                lst.append(n)
+                lst[0] -= len(lst)
+                three = lst[0]
+                n = lst.pop() - three
+                if n == 2:
+                    return n + lst.pop()
+            return n
+        res = self.meta_interp(f, [31])
+        assert res == -2
+        self.check_all_virtualized()
+
     def test_list_escapes(self):
-        py.test.skip("XXX")
+        jitdriver = JitDriver(greens = [], reds = ['n'])
         def f(n):
             while True:
+                jitdriver.can_enter_jit(n=n)
+                jitdriver.jit_merge_point(n=n)
                 lst = []
                 lst.append(n)
                 n = lst.pop() - 3
                 if n < 0:
                     return len(lst)
-        res = self.meta_interp(f, [31], exceptions=False)
+        res = self.meta_interp(f, [31])
         assert res == 0
         self.check_all_virtualized()
 
     def test_list_reenters(self):
-        py.test.skip("XXX")
+        jitdriver = JitDriver(greens = [], reds = ['n'])
         def f(n):
             while n > 0:
+                jitdriver.can_enter_jit(n=n)
+                jitdriver.jit_merge_point(n=n)
                 lst = []
                 lst.append(n)
                 if n < 10:
                     lst[-1] = n-1
                 n = lst.pop() - 3
             return n
-        res = self.meta_interp(f, [31], exceptions=False)
+        res = self.meta_interp(f, [31])
         assert res == -1
         self.check_all_virtualized()
 
     def test_cannot_merge(self):
-        py.test.skip("XXX")
+        jitdriver = JitDriver(greens = [], reds = ['n'])
         def f(n):
             while n > 0:
+                jitdriver.can_enter_jit(n=n)
+                jitdriver.jit_merge_point(n=n)
                 lst = []
                 if n < 20:
                     lst.append(n-3)
@@ -169,10 +195,92 @@
                     lst.append(n-4)
                 n = lst.pop()
             return n
-        res = self.meta_interp(f, [30], exceptions=False)
+        res = self.meta_interp(f, [30])
         assert res == -1
         self.check_all_virtualized()
 
+    def test_list_escapes(self):
+        jitdriver = JitDriver(greens = [], reds = ['n'])
+        def g(l):
+            pass
+        
+        def f(n):
+            while n > 0:
+                jitdriver.can_enter_jit(n=n)
+                jitdriver.jit_merge_point(n=n)
+                l = []
+                l.append(3)
+                g(l)
+                n -= 1
+            return n
+        res = self.meta_interp(f, [30], policy=StopAtXPolicy(g))
+        assert res == 0
+        self.check_loops(append=1)
+
+    def test_list_escapes_various_ops(self):
+        jitdriver = JitDriver(greens = [], reds = ['n'])
+        def g(l):
+            pass
+        
+        def f(n):
+            while n > 0:
+                jitdriver.can_enter_jit(n=n)
+                jitdriver.jit_merge_point(n=n)
+                l = []
+                l.append(3)
+                l.append(1)
+                n -= l.pop()
+                n -= l[0]
+                if l:
+                    g(l)
+                n -= 1
+            return n
+        res = self.meta_interp(f, [30], policy=StopAtXPolicy(g))
+        assert res == 0
+        self.check_loops(append=2)
+
+    def test_list_escapes_find_nodes(self):
+        jitdriver = JitDriver(greens = [], reds = ['n'])
+        def g(l):
+            pass
+        
+        def f(n):
+            while n > 0:
+                jitdriver.can_enter_jit(n=n)
+                jitdriver.jit_merge_point(n=n)
+                l = [0] * n
+                l.append(3)
+                l.append(1)
+                n -= l.pop()
+                n -= l[-1]
+                if l:
+                    g(l)
+                n -= 1
+            return n
+        res = self.meta_interp(f, [30], policy=StopAtXPolicy(g))
+        assert res == 0
+        self.check_loops(append=2)
+
+    def test_stuff_escapes_via_setitem(self):
+        jitdriver = JitDriver(greens = [], reds = ['n', 'l'])
+        class Stuff(object):
+            def __init__(self, x):
+                self.x = x
+        
+        def f(n):
+            l = [None]
+            while n > 0:
+                jitdriver.can_enter_jit(n=n, l=l)
+                jitdriver.jit_merge_point(n=n, l=l)
+                s = Stuff(3)
+                l.append(s)
+                n -= l[0].x
+            return n
+        res = self.meta_interp(f, [30])
+        assert res == 0
+        self.check_loops(append=1)
+        
+
     def test_extend(self):
         py.test.skip("XXX")
         def f(n):



More information about the Pypy-commit mailing list