[pypy-svn] pypy default: Turn INT_MUL with a constant argument that is a power of 2 into an lshift.

alex_gaynor commits-noreply at bitbucket.org
Tue Jan 11 08:44:21 CET 2011


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: 
Changeset: r40572:76034ab5f03a
Date: 2011-01-11 01:44 -0600
http://bitbucket.org/pypy/pypy/changeset/76034ab5f03a/

Log:	Turn INT_MUL with a constant argument that is a power of 2 into an
	lshift.

diff --git a/pypy/jit/metainterp/test/test_optimizeopt.py b/pypy/jit/metainterp/test/test_optimizeopt.py
--- a/pypy/jit/metainterp/test/test_optimizeopt.py
+++ b/pypy/jit/metainterp/test/test_optimizeopt.py
@@ -4074,6 +4074,25 @@
         """
         self.optimize_loop(ops, expected)
 
+    def test_mul_to_lshift(self):
+        ops = """
+        [i1, i2]
+        i3 = int_mul(i1, 2)
+        i4 = int_mul(2, i2)
+        i5 = int_mul(i1, 32)
+        i6 = int_mul(i1, i2)
+        jump(i5, i6)
+        """
+        expected = """
+        [i1, i2]
+        i3 = int_lshift(i1, 1)
+        i4 = int_lshift(i2, 1)
+        i5 = int_lshift(i1, 5)
+        i6 = int_mul(i1, i2)
+        jump(i5, i6)
+        """
+        self.optimize_loop(ops, expected)
+
     def test_lshift_rshift(self):
         ops = """
         [i1, i2, i2b, i1b]

diff --git a/pypy/rlib/rarithmetic.py b/pypy/rlib/rarithmetic.py
--- a/pypy/rlib/rarithmetic.py
+++ b/pypy/rlib/rarithmetic.py
@@ -165,6 +165,18 @@
         assert t.BITS <= r_longlong.BITS
         return build_int(None, t.SIGNED, r_longlong.BITS)
 
+def highest_bit(n):
+    """
+    Calculates the highest set bit in n.  This function assumes that n is a
+    power of 2 (and thus only has a single set bit).
+    """
+    assert n and (n & (n - 1)) == 0
+    i = -1
+    while n:
+        i += 1
+        n >>= 1
+    return i
+
 class base_int(long):
     """ fake unsigned integer implementation """
 

diff --git a/pypy/jit/metainterp/optimizeopt/rewrite.py b/pypy/jit/metainterp/optimizeopt/rewrite.py
--- a/pypy/jit/metainterp/optimizeopt/rewrite.py
+++ b/pypy/jit/metainterp/optimizeopt/rewrite.py
@@ -5,6 +5,8 @@
 from pypy.jit.metainterp.resoperation import rop, ResOperation
 from pypy.jit.codewriter.effectinfo import EffectInfo
 from pypy.jit.metainterp.optimizeopt.intutils import IntBound
+from pypy.rlib.rarithmetic import highest_bit
+
 
 class OptRewrite(Optimization):
     """Rewrite operations into equivalent, cheaper operations.
@@ -142,6 +144,14 @@
              (v2.is_constant() and v2.box.getint() == 0):
             self.make_constant_int(op.result, 0)
         else:
+            for lhs, rhs in [(v1, v2), (v2, v1)]:
+                # x & (x -1) == 0 is a quick test for power of 2
+                if (lhs.is_constant() and
+                    (lhs.box.getint() & (lhs.box.getint() - 1)) == 0):
+                    new_rhs = ConstInt(highest_bit(lhs.box.getint()))
+                    op = op.copy_and_change(rop.INT_LSHIFT, args=[rhs.box, new_rhs])
+                    break
+
             self.emit_operation(op)
 
     def optimize_CALL_PURE(self, op):
@@ -387,11 +397,8 @@
         if v1.intbound.known_ge(IntBound(0, 0)) and v2.is_constant():
             val = v2.box.getint()
             if val & (val - 1) == 0 and val > 0: # val == 2**shift
-                shift = 0
-                while (1 << shift) < val:
-                    shift += 1
                 op = op.copy_and_change(rop.INT_RSHIFT,
-                                        args = [op.getarg(0), ConstInt(shift)])
+                                        args = [op.getarg(0), ConstInt(highest_bit(val))])
         self.emit_operation(op)
 
 

diff --git a/pypy/rlib/test/test_rarithmetic.py b/pypy/rlib/test/test_rarithmetic.py
--- a/pypy/rlib/test/test_rarithmetic.py
+++ b/pypy/rlib/test/test_rarithmetic.py
@@ -394,3 +394,10 @@
 def test_int_real_union():
     from pypy.rpython.lltypesystem.rffi import r_int_real
     assert compute_restype(r_int_real, r_int_real) is r_int_real
+
+def test_highest_bit():
+    py.test.raises(AssertionError, highest_bit, 0)
+    py.test.raises(AssertionError, highest_bit, 14)
+
+    for i in xrange(31):
+        assert highest_bit(2**i) == i


More information about the Pypy-commit mailing list