[pypy-svn] pypy default: shift optimizations reenabled, now with proper overflow checking

hakanardo commits-noreply at bitbucket.org
Sat Apr 23 11:55:10 CEST 2011


Author: Hakan Ardo <hakan at debian.org>
Branch: 
Changeset: r43539:065cd49333e7
Date: 2011-04-23 11:54 +0200
http://bitbucket.org/pypy/pypy/changeset/065cd49333e7/

Log:	shift optimizations reenabled, now with proper overflow checking

diff --git a/pypy/jit/metainterp/test/test_optimizeopt.py b/pypy/jit/metainterp/test/test_optimizeopt.py
--- a/pypy/jit/metainterp/test/test_optimizeopt.py
+++ b/pypy/jit/metainterp/test/test_optimizeopt.py
@@ -4398,6 +4398,8 @@
         i4 = int_rshift(i3, i2)
         i5 = int_lshift(i1, 2)
         i6 = int_rshift(i5, 2)
+        i6t= int_eq(i6, i1)
+        guard_true(i6t) []
         i7 = int_lshift(i1, 100)
         i8 = int_rshift(i7, 100)
         i9 = int_lt(i1b, 100)
@@ -4422,6 +4424,8 @@
         i4 = int_rshift(i3, i2)
         i5 = int_lshift(i1, 2)
         i6 = int_rshift(i5, 2)
+        i6t= int_eq(i6, i1)
+        guard_true(i6t) []
         i7 = int_lshift(i1, 100)
         i8 = int_rshift(i7, 100)
         i9 = int_lt(i1b, 100)
@@ -4431,11 +4435,8 @@
         i13 = int_lshift(i1b, i2)
         i14 = int_rshift(i13, i2)
         i15 = int_lshift(i1b, 2)
-        i16 = int_rshift(i15, 2)
         i17 = int_lshift(i1b, 100)
         i18 = int_rshift(i17, 100)
-        i19 = int_eq(i1b, i16)
-        guard_true(i19) []
         jump(i2, i3, i1b, i2b)
         """
         self.optimize_loop(ops, expected)

diff --git a/pypy/jit/metainterp/optimizeopt/intbounds.py b/pypy/jit/metainterp/optimizeopt/intbounds.py
--- a/pypy/jit/metainterp/optimizeopt/intbounds.py
+++ b/pypy/jit/metainterp/optimizeopt/intbounds.py
@@ -130,12 +130,12 @@
         r = self.getvalue(op.result)
         b = v1.intbound.lshift_bound(v2.intbound)
         r.intbound.intersect(b)
-        # --- The following is actually wrong if the INT_LSHIFT overflowed.
-        # --- It is precisely the pattern we use to detect overflows of the
-        # --- app-level '<<' operator: INT_LSHIFT/INT_RSHIFT/INT_EQ
-        #if b.has_lower and b.has_upper:
-        #    # Synthesize the reverse op for optimize_default to reuse
-        #    self.pure(rop.INT_RSHIFT, [op.result, op.getarg(1)], op.getarg(0))
+        # intbound.lshift_bound checks for an overflow and if the
+        # lshift can be proven not to overflow sets b.has_upper and
+        # b.has_lower
+        if b.has_lower and b.has_upper:
+            # Synthesize the reverse op for optimize_default to reuse
+            self.pure(rop.INT_RSHIFT, [op.result, op.getarg(1)], op.getarg(0))
 
     def optimize_INT_RSHIFT(self, op):
         v1 = self.getvalue(op.getarg(0))

diff --git a/pypy/jit/metainterp/optimizeopt/intutils.py b/pypy/jit/metainterp/optimizeopt/intutils.py
--- a/pypy/jit/metainterp/optimizeopt/intutils.py
+++ b/pypy/jit/metainterp/optimizeopt/intutils.py
@@ -1,4 +1,4 @@
-from pypy.rlib.rarithmetic import ovfcheck
+from pypy.rlib.rarithmetic import ovfcheck, ovfcheck_lshift
 
 class IntBound(object):
     _attrs_ = ('has_upper', 'has_lower', 'upper', 'lower')
@@ -163,12 +163,12 @@
            other.has_upper and other.has_lower and \
            other.known_ge(IntBound(0, 0)):
             try:
-                vals = (ovfcheck(self.upper * pow2(other.upper)),
-                        ovfcheck(self.upper * pow2(other.lower)),
-                        ovfcheck(self.lower * pow2(other.upper)),
-                        ovfcheck(self.lower * pow2(other.lower)))
+                vals = (ovfcheck_lshift(self.upper, other.upper),
+                        ovfcheck_lshift(self.upper, other.lower),
+                        ovfcheck_lshift(self.lower, other.upper),
+                        ovfcheck_lshift(self.lower, other.lower))
                 return IntBound(min4(vals), max4(vals))
-            except OverflowError:
+            except (OverflowError, ValueError):
                 return IntUnbounded()
         else:
             return IntUnbounded()
@@ -177,14 +177,11 @@
         if self.has_upper and self.has_lower and \
            other.has_upper and other.has_lower and \
            other.known_ge(IntBound(0, 0)):
-            try:
-                vals = (ovfcheck(self.upper / pow2(other.upper)),
-                        ovfcheck(self.upper / pow2(other.lower)),
-                        ovfcheck(self.lower / pow2(other.upper)),
-                        ovfcheck(self.lower / pow2(other.lower)))
-                return IntBound(min4(vals), max4(vals))
-            except OverflowError:
-                return IntUnbounded()
+            vals = (self.upper >> other.upper,
+                    self.upper >> other.lower,
+                    self.lower >> other.upper,
+                    self.lower >> other.lower)
+            return IntBound(min4(vals), max4(vals))
         else:
             return IntUnbounded()
 
@@ -252,11 +249,3 @@
 
 def max4(t):
     return max(max(t[0], t[1]), max(t[2], t[3]))
-
-def pow2(x):
-    y = 1 << x
-    if y < 1:
-        raise OverflowError, "pow2 did overflow"
-    return y
-
-        

diff --git a/pypy/jit/metainterp/test/test_intbound.py b/pypy/jit/metainterp/test/test_intbound.py
--- a/pypy/jit/metainterp/test/test_intbound.py
+++ b/pypy/jit/metainterp/test/test_intbound.py
@@ -1,6 +1,7 @@
 from pypy.jit.metainterp.optimizeopt.intutils import IntBound, IntUpperBound, \
      IntLowerBound, IntUnbounded
 from copy import copy
+import sys
 
 def bound(a,b):
     if a is None and b is None:
@@ -221,6 +222,14 @@
                         assert bleft.contains(n1 << n2)
                         assert bright.contains(n1 >> n2)
 
+def test_shift_overflow():
+    b10 = IntBound(0, 10)
+    b100 = IntBound(0, 100)
+    bmax = IntBound(0, sys.maxint/2)
+    assert not b10.lshift_bound(b100).has_upper
+    assert not bmax.lshift_bound(b10).has_upper
+    assert b10.lshift_bound(b10).has_upper
+    
 def test_div_bound():
     for _, _, b1 in some_bounds():
         for _, _, b2 in some_bounds():

diff --git a/pypy/jit/metainterp/test/test_ajit.py b/pypy/jit/metainterp/test/test_ajit.py
--- a/pypy/jit/metainterp/test/test_ajit.py
+++ b/pypy/jit/metainterp/test/test_ajit.py
@@ -1984,6 +1984,122 @@
         assert res == 12
         self.check_tree_loop_count(2)
 
+    def test_overflowing_shift_pos(self):
+        myjitdriver = JitDriver(greens = [], reds = ['a', 'b', 'n', 'sa'])
+        def f1(a, b):
+            n = sa = 0
+            while n < 10:
+                myjitdriver.jit_merge_point(a=a, b=b, n=n, sa=sa)
+                if 0 < a < 10: pass
+                if 0 < b < 10: pass
+                sa += (a << b) >> b
+                n += 1
+            return sa
+
+        def f2(a, b):
+            n = sa = 0
+            while n < 10:
+                myjitdriver.jit_merge_point(a=a, b=b, n=n, sa=sa)
+                if 0 < a < hint(sys.maxint/2, promote=True): pass
+                if 0 < b < 100: pass
+                sa += (a << b) >> b
+                n += 1
+            return sa
+        
+        assert self.meta_interp(f1, [5, 5]) == 50
+        self.check_loops(int_rshift=0, everywhere=True)
+
+        for f in (f1, f2):
+            assert self.meta_interp(f, [5, 10]) == 50
+            self.check_loops(int_rshift=1, everywhere=True)
+
+            assert self.meta_interp(f, [10, 5]) == 100
+            self.check_loops(int_rshift=1, everywhere=True)
+
+            assert self.meta_interp(f, [10, 10]) == 100
+            self.check_loops(int_rshift=1, everywhere=True)
+
+            assert self.meta_interp(f, [5, 100]) == 0
+            self.check_loops(int_rshift=1, everywhere=True)
+
+            bigval = 1
+            while (bigval << 3).__class__ is int:
+                bigval = bigval << 1
+
+            assert self.meta_interp(f, [bigval, 5]) == 0
+            self.check_loops(int_rshift=1, everywhere=True)
+
+    def test_overflowing_shift_neg(self):
+        myjitdriver = JitDriver(greens = [], reds = ['a', 'b', 'n', 'sa'])
+        def f1(a, b):
+            n = sa = 0
+            while n < 10:
+                myjitdriver.jit_merge_point(a=a, b=b, n=n, sa=sa)
+                if -10 < a < 0: pass
+                if 0 < b < 10: pass
+                sa += (a << b) >> b
+                n += 1
+            return sa
+
+        def f2(a, b):
+            n = sa = 0
+            while n < 10:
+                myjitdriver.jit_merge_point(a=a, b=b, n=n, sa=sa)
+                if -hint(sys.maxint/2, promote=True) < a < 0: pass
+                if 0 < b < 100: pass
+                sa += (a << b) >> b
+                n += 1
+            return sa
+        
+        assert self.meta_interp(f1, [-5, 5]) == -50
+        self.check_loops(int_rshift=0, everywhere=True)
+
+        for f in (f1, f2):
+            assert self.meta_interp(f, [-5, 10]) == -50
+            self.check_loops(int_rshift=1, everywhere=True)
+
+            assert self.meta_interp(f, [-10, 5]) == -100
+            self.check_loops(int_rshift=1, everywhere=True)
+
+            assert self.meta_interp(f, [-10, 10]) == -100
+            self.check_loops(int_rshift=1, everywhere=True)
+
+            assert self.meta_interp(f, [-5, 100]) == 0
+            self.check_loops(int_rshift=1, everywhere=True)
+
+            bigval = 1
+            while (bigval << 3).__class__ is int:
+                bigval = bigval << 1
+
+            assert self.meta_interp(f, [-bigval, 5]) == 0
+            self.check_loops(int_rshift=1, everywhere=True)
+
+    def notest_overflowing_shift2(self):
+        myjitdriver = JitDriver(greens = [], reds = ['a', 'b', 'n', 'sa'])
+        def f(a, b):
+            n = sa = 0
+            while n < 10:
+                myjitdriver.jit_merge_point(a=a, b=b, n=n, sa=sa)
+                if 0 < a < hint(sys.maxint/2, promote=True): pass
+                if 0 < b < 100: pass
+                sa += (a << b) >> b
+                n += 1
+            return sa
+
+        assert self.meta_interp(f, [5, 5]) == 50
+        self.check_loops(int_rshift=0, everywhere=True)
+
+        assert self.meta_interp(f, [5, 10]) == 50
+        self.check_loops(int_rshift=1, everywhere=True)
+
+        assert self.meta_interp(f, [10, 5]) == 100
+        self.check_loops(int_rshift=1, everywhere=True)
+
+        assert self.meta_interp(f, [10, 10]) == 100
+        self.check_loops(int_rshift=1, everywhere=True)
+
+        assert self.meta_interp(f, [5, 100]) == 0
+        self.check_loops(int_rshift=1, everywhere=True)
 
 
 class TestOOtype(BasicTests, OOJitMixin):


More information about the Pypy-commit mailing list