[pypy-svn] r52802 - pypy/branch/jit-hotpath/pypy/jit/rainbow/test

arigo at codespeak.net arigo at codespeak.net
Fri Mar 21 15:06:09 CET 2008


Author: arigo
Date: Fri Mar 21 15:06:08 2008
New Revision: 52802

Modified:
   pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hp_virtualizable.py
Log:
Phew.  Took me a while to make
test_simple_interpreter_with_frame_with_stack
fail in a way that looks similar to pypy-c-jit's.


Modified: pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hp_virtualizable.py
==============================================================================
--- pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hp_virtualizable.py	(original)
+++ pypy/branch/jit-hotpath/pypy/jit/rainbow/test/test_hp_virtualizable.py	Fri Mar 21 15:06:08 2008
@@ -118,9 +118,6 @@
 
 
 class TestVirtualizableExplicit(test_hotpath.HotPathTest):
-    def timeshift_from_portal(self, *args, **kwargs):
-        py.test.skip("port me")
-
     type_system = "lltype"
 
     def test_simple(self):
@@ -1073,6 +1070,9 @@
         py.test.skip("port me")
 
     def test_simple(self):
+        class MyJitDriver(JitDriver):
+            greens = []
+            reds = ['xy', 'i', 'res']
 
         class XY(object):
             _virtualizable_ = True
@@ -1082,17 +1082,28 @@
                 self.y = y
    
         def f(xy):
-            return xy.x+xy.y
+            i = 1024
+            while i > 0:
+                i >>= 1
+                res = xy.x+xy.y
+                MyJitDriver.jit_merge_point(xy=xy, res=res, i=i)
+                MyJitDriver.can_enter_jit(xy=xy, res=res, i=i)
+            return res
 
         def main(x, y):
             xy = XY(x, y)
             return f(xy)
 
-        res = self.timeshift_from_portal(main, f, [20, 22], policy=P_OOPSPEC)
+        res = self.run(main, [20, 22], threshold=2)
         assert res == 42
-        self.check_insns(getfield=0)
+        self.check_insns_in_loops(getfield=0)
 
     def test_simple__class__(self):
+        py.test.skip("in-progress")
+        class MyJitDriver(JitDriver):
+            greens = []
+            reds = ['v', 'i', 'res']
+
         class V(object):
             _virtualizable_ = True
             def __init__(self, a):
@@ -1108,9 +1119,13 @@
                 V.__init__(self, 2)
 
         def f(v):
-            hint(None, global_merge_point=True)
-            #V1(0).b
-            return v.__class__
+            i = 1024
+            while i > 0:
+                i >>= 1
+                res = v.__class__
+                MyJitDriver.jit_merge_point(v=v, res=res, i=i)
+                MyJitDriver.can_enter_jit(v=v, res=res, i=i)
+            return res
 
         def main(x, y):
             if x:
@@ -1124,12 +1139,17 @@
             V2()
             return c is not None
 
-        res = self.timeshift_from_portal(main, f, [0, 1], policy=P_OOPSPEC)
+        res = self.run(main, [0, 1], threshold=2)
         assert not res
-        res = self.timeshift_from_portal(main, f, [1, 0], policy=P_OOPSPEC)
+        res = self.run(main, [1, 0], threshold=2)
+        assert res
+        res = self.run(main, [1, 0], threshold=1)
         assert res
 
     def test_simple_inheritance(self):
+        class MyJitDriver(JitDriver):
+            greens = []
+            reds = ['xy', 'i', 'res']
 
         class X(object):
             _virtualizable_ = True
@@ -1144,18 +1164,32 @@
                 self.y = y
    
         def f(xy):
-            return xy.x+xy.y
+            i = 1024
+            while i > 0:
+                i >>= 1
+                res = xy.x+xy.y
+                MyJitDriver.jit_merge_point(xy=xy, res=res, i=i)
+                MyJitDriver.can_enter_jit(xy=xy, res=res, i=i)
+            return res
 
         def main(x, y):
             X(0)
             xy = XY(x, y)
             return f(xy)
 
-        res = self.timeshift_from_portal(main, f, [20, 22], policy=P_OOPSPEC)
+        res = self.run(main, [20, 22], threshold=2)
         assert res == 42
-        self.check_insns(getfield=0)
+        self.check_insns_in_loops(getfield=0)
+
+        res = self.run(main, [20, 22], threshold=1)
+        assert res == 42
+        self.check_insns_in_loops(getfield=0)
 
     def test_simple_interpreter_with_frame(self):
+        class MyJitDriver(JitDriver):
+            greens = ['pc', 'n', 's']
+            reds = ['frame']
+
         class Log:
             acc = 0
         log = Log()
@@ -1169,42 +1203,53 @@
 
             def run(self):
                 self.plus_minus(self.code)
+                assert self.pc == len(self.code)
                 return self.acc
 
             def plus_minus(self, s):
                 n = len(s)
                 pc = 0
-                while pc < n:
-                    hint(None, global_merge_point=True)
+                while True:
+                    MyJitDriver.jit_merge_point(frame=self, pc=pc, n=n, s=s)
                     self.pc = pc
+                    if hint(pc >= n, concrete=True):
+                        break
                     op = s[pc]
                     op = hint(op, concrete=True)
+                    pc += 1
                     if op == '+':
                         self.acc += self.y
                     elif op == '-':
                         self.acc -= self.y
+                    elif op == 'r':
+                        if self.acc > 0:
+                            pc -= 3
+                            assert pc >= 0
+                            MyJitDriver.can_enter_jit(frame=self, pc=pc,
+                                                      n=n, s=s)
                     elif op == 'd':
                         self.debug()
-                    pc += 1
                 return 0
 
             def debug(self):
                 log.acc = self.acc
             
-        def main(x, y):
-            code = '+d+-+'
+        def main(x, y, case):
+            code = ['+d+-+++++ -r++', '+++++++++d-r+++'][case]
             f = Frame(code, x, y)
             return f.run() * 10 + log.acc
 
-        res = self.timeshift_from_portal(main, Frame.plus_minus.im_func,
-                            [0, 2],
-                            policy=StopAtXPolicy(Frame.debug.im_func))
+        assert main(0, 2, 0) == 42
+        assert main(0, 2, 1) == 62
+
+        res = self.run(main, [0, 2, 0], threshold=2,
+                       policy=StopAtXPolicy(Frame.debug.im_func))
         assert res == 42
-        if self.on_llgraph:
-            calls = self.count_direct_calls()
-            call_count = sum(calls.values())
-            # one call to "continue_compilation" and one call to debug
-            assert call_count == 2
+        self.check_insns_in_loops({'int_sub': 1, 'int_gt': 1})
+
+        res = self.run(main, [0, 2, 1], threshold=2,
+                       policy=StopAtXPolicy(Frame.debug.im_func))
+        assert res == 62
 
 
     def test_setting_pointer_in_residual_call(self):
@@ -1465,6 +1510,10 @@
 
         
     def test_virtual_list(self):
+        class MyJitDriver(JitDriver):
+            greens = []
+            reds = ['v', 'i', 'res']
+
         class V(object):
             _virtualizable_ = True
             def __init__(self, l):
@@ -1479,21 +1528,32 @@
             v.l = [x*100, y*100]
             
         def f(v):
-            hint(None, global_merge_point=True)
-            l = [1, 10]
-            v.l = l
-            g(v)
-            l2 = v.l
-            return l[0]*2 + l[1] + l2[0] * 2 + l2[1]
+            i = 1024
+            while i > 0:
+                i >>= 1
+                l = [1, 10]
+                v.l = l
+                g(v)
+                l2 = v.l
+                res = l[0]*2 + l[1] + l2[0] * 2 + l2[1]
+                MyJitDriver.jit_merge_point(v=v, res=res, i=i)
+                MyJitDriver.can_enter_jit(v=v, res=res, i=i)
+            return res
 
         def main():
             v = V(None)
             return f(v)
 
-        res = self.timeshift_from_portal(main, f, [], policy=StopAtXPolicy(g))
-        assert res == 20 + 1 + 200 + 1000
+        res = self.run(main, [], threshold=2, policy=StopAtXPolicy(g))
+        assert res == main()
+        res = self.run(main, [], threshold=1, policy=StopAtXPolicy(g))
+        assert res == main()
 
     def test_virtual_list_and_struct(self):
+        class MyJitDriver(JitDriver):
+            greens = []
+            reds = ['v', 'i', 'res']
+
         class S(object):
             def __init__(self, x, y):
                 self.x = x
@@ -1513,24 +1573,47 @@
             v.l = [x*100, y*100]
             
         def f(v):
-            hint(None, global_merge_point=True)
-            l = [1, 10]
-            s = S(3, 7)
-            v.l = l
-            v.s = s
-            g(v)
-            l2 = v.l
-            s2 = v.s
-            return l[0]*2 + l[1] + l2[0] * 2 + l2[1] + s.x * 7 + s.y + s2.x * 7 + s2.y 
+            i = 1024
+            while i > 0:
+                i >>= 1
+                l = [1, 10]
+                s = S(3, 7)
+                v.l = l
+                v.s = s
+                g(v)
+                l2 = v.l
+                s2 = v.s
+                res = l[0]*2 + l[1] + l2[0] * 2 + l2[1] + s.x * 7 + s.y + s2.x * 7 + s2.y
+                MyJitDriver.jit_merge_point(v=v, res=res, i=i)
+                MyJitDriver.can_enter_jit(v=v, res=res, i=i)
+            return res
 
         def main():
             v = V(None, None)
             return f(v)
 
-        res = self.timeshift_from_portal(main, f, [], policy=StopAtXPolicy(g))
+        res = self.run(main, [], threshold=2, policy=StopAtXPolicy(g))
+        assert res == main()
+        res = self.run(main, [], threshold=1, policy=StopAtXPolicy(g))
         assert res == main()
 
     def test_simple_interpreter_with_frame_with_stack(self):
+        class MyJitDriver(JitDriver):
+            greens = ['pc', 's']
+            reds = ['frame']
+
+            def on_enter_jit(self):
+                frame = self.frame
+                origstack = frame.stack
+                stacklen = hint(len(origstack), promote=True)
+                curstack = []
+                i = 0
+                while i < stacklen:
+                    hint(i, concrete=True)
+                    curstack.append(origstack[i])
+                    i += 1
+                frame.stack = curstack
+
         class Log:
             stack = None
         log = Log()
@@ -1542,25 +1625,19 @@
                 self.stack = list(args)
                 
             def run(self):
-                return self.interpret(self.code)
+                self.trace = 0
+                self.interpret(self.code)
+                assert self.pc == len(self.code)
+                assert len(self.stack) == 1
+                return self.stack.pop()
 
             def interpret(self, s):
-                hint(None, global_merge_point=True)
-                n = len(s)
                 pc = 0
-                origstack = self.stack
-                stacklen = len(origstack)
-                stacklen = hint(stacklen, promote=True)
-                curstack = [0] * stacklen
-                i = 0
-                while i < stacklen:
-                    hint(i, concrete=True)
-                    curstack[i] = origstack[i]
-                    i += 1
-                self.stack = curstack
-                while pc < n:
-                    hint(None, global_merge_point=True)
+                while True:
+                    MyJitDriver.jit_merge_point(frame=self, s=s, pc=pc)
                     self.pc = pc
+                    if hint(pc >= len(s), concrete=True):
+                        break
                     op = s[pc]
                     pc += 1
                     op = hint(op, concrete=True)
@@ -1569,6 +1646,8 @@
                         pc += 1
                         hint(arg, concrete=True)
                         self.stack.append(ord(arg) - ord('0')) 
+                    elif op == 'D':
+                        self.stack.append(self.stack[-1])    # dup
                     elif op == 'p':
                         self.stack.pop()
                     elif op == '+':
@@ -1577,27 +1656,52 @@
                     elif op == '-':
                         arg = self.stack.pop()
                         self.stack[-1] -= arg
+                    elif op == 'J':
+                        target = self.stack.pop()
+                        cond = self.stack.pop()
+                        if cond > 0:
+                            pc = hint(target, promote=True)
+                            MyJitDriver.can_enter_jit(frame=self, s=s, pc=pc)
+                    elif op == 't':
+                        self.trace = self.trace * 3 + self.stack[-1]
                     elif op == 'd':
                         self.debug()
                     else:
                         raise NotImplementedError
-                result = self.stack.pop()
-                self.stack = None
-                return result
 
             def debug(self):
-                log.stack = self.stack[:]
+                for item in self.stack:
+                    self.trace = self.trace * 7 - item
             
-        def main(x):
-            code = 'P2+P5+P3-'
+        def main(x, case, expected):
+            code = ['P2+tP5+tP3-', 'P1+tP3-DP3J', 'P4d-DtP0J'][case]
+            log.stack = []
             f = Frame(code, x)
-            return f.run()
+            res = f.run()
+            assert res == expected
+            return f.trace
+
+        assert main(38, 0, 42) == 40*3+45
+        assert main(15, 1, -2) == ((((16*3+13)*3+10)*3+7)*3+4)*3+1
+        main(21, 2, -3)   # to check that this works too
+
+        res = self.run(main, [38, 0, 42], threshold=2,
+                       policy=StopAtXPolicy(Frame.debug.im_func))
+        assert res == 40*3+45
+        self.check_nothing_compiled_at_all()
+
+        res = self.run(main, [15, 1, -2], threshold=2,
+                       policy=StopAtXPolicy(Frame.debug.im_func))
+        assert res == ((((16*3+13)*3+10)*3+7)*3+4)*3+1
+        self.check_insns_in_loops({'int_sub': 1, 'int_gt': 1,
+                                   'int_mul': 1, 'int_add': 1})
 
-        res = self.timeshift_from_portal(main, Frame.interpret.im_func,
-                            [38],
-                            policy=StopAtXPolicy(Frame.debug.im_func))
-        assert res == 42
-        self.check_oops(newlist=0)
+        py.test.skip("in-progress")
+        res = self.run(main, [21, 2, -3], threshold=2,
+                       policy=StopAtXPolicy(Frame.debug.im_func))
+        assert res == main(21, 2, -3)
+        self.check_insns_in_loops({'int_sub': 1, 'int_gt': 1,
+                                   'int_mul': 1, 'int_add': 1})
 
 
     def test_recursive(self):



More information about the Pypy-commit mailing list