[pypy-commit] pypy improve-rbigint: Probably my final toom cook test. Didn't go so well. Also disable jit.eldible because it seems to slow down good algoritms

stian noreply at buildbot.pypy.org
Sat Jul 21 18:41:53 CEST 2012


Author: stian
Branch: improve-rbigint
Changeset: r56362:fd2621060fe3
Date: 2012-07-12 19:38 +0200
http://bitbucket.org/pypy/pypy/changeset/fd2621060fe3/

Log:	Probably my final toom cook test. Didn't go so well. Also disable
	jit.eldible because it seems to slow down good algoritms

diff --git a/pypy/rlib/rbigint.py b/pypy/rlib/rbigint.py
--- a/pypy/rlib/rbigint.py
+++ b/pypy/rlib/rbigint.py
@@ -70,7 +70,7 @@
 KARATSUBA_SQUARE_CUTOFF = 2 * KARATSUBA_CUTOFF
 
 USE_TOOMCOCK = False
-TOOMCOOK_CUTOFF = 2000 # Smallest possible cutoff is 3. Ideal is probably around 150+
+TOOMCOOK_CUTOFF = 10000 # Smallest possible cutoff is 3. Ideal is probably around 150+
 
 # For exponentiation, use the binary left-to-right algorithm
 # unless the exponent contains more than FIVEARY_CUTOFF digits.
@@ -220,7 +220,7 @@
         return v
 
     @staticmethod
-    @jit.elidable
+    #@jit.elidable
     def frombool(b):
         # This function is marked as pure, so you must not call it and
         # then modify the result.
@@ -335,21 +335,21 @@
     def tofloat(self):
         return _AsDouble(self)
 
-    @jit.elidable
+    #@jit.elidable
     def format(self, digits, prefix='', suffix=''):
         # 'digits' is a string whose length is the base to use,
         # and where each character is the corresponding digit.
         return _format(self, digits, prefix, suffix)
 
-    @jit.elidable
+    #@jit.elidable
     def repr(self):
         return _format(self, BASE10, '', 'L')
 
-    @jit.elidable
+    #@jit.elidable
     def str(self):
         return _format(self, BASE10)
 
-    @jit.elidable
+    #@jit.elidable
     def eq(self, other):
         if (self.sign != other.sign or
             self.numdigits() != other.numdigits()):
@@ -365,7 +365,7 @@
     def ne(self, other):
         return not self.eq(other)
 
-    @jit.elidable
+    #@jit.elidable
     def lt(self, other):
         if self.sign > other.sign:
             return False
@@ -413,7 +413,7 @@
     def hash(self):
         return _hash(self)
 
-    @jit.elidable
+    #@jit.elidable
     def add(self, other):
         if self.sign == 0:
             return other
@@ -426,7 +426,7 @@
         result.sign *= other.sign
         return result
 
-    @jit.elidable
+    #@jit.elidable
     def sub(self, other):
         if other.sign == 0:
             return self
@@ -439,7 +439,7 @@
         result.sign *= self.sign
         return result
 
-    @jit.elidable
+    #@jit.elidable
     def mul(self, b):
         asize = self.numdigits()
         bsize = b.numdigits()
@@ -487,12 +487,12 @@
         result.sign = a.sign * b.sign
         return result
 
-    @jit.elidable
+    #@jit.elidable
     def truediv(self, other):
         div = _bigint_true_divide(self, other)
         return div
 
-    @jit.elidable
+    #@jit.elidable
     def floordiv(self, other):
         if other.numdigits() == 1 and other.sign == 1:
             digit = other.digit(0)
@@ -506,11 +506,11 @@
             div = div.sub(ONERBIGINT)
         return div
 
-    @jit.elidable
+    #@jit.elidable
     def div(self, other):
         return self.floordiv(other)
 
-    @jit.elidable
+    #@jit.elidable
     def mod(self, other):
         if self.sign == 0:
             return NULLRBIGINT
@@ -549,7 +549,7 @@
             mod = mod.add(other)
         return mod
 
-    @jit.elidable
+    #@jit.elidable
     def divmod(v, w):
         """
         The / and % operators are now defined in terms of divmod().
@@ -573,7 +573,7 @@
             div = div.sub(ONERBIGINT)
         return div, mod
 
-    @jit.elidable
+    #@jit.elidable
     def pow(a, b, c=None):
         negativeOutput = False  # if x<0 return negative output
 
@@ -726,7 +726,7 @@
         ret.sign = -ret.sign
         return ret
 
-    @jit.elidable
+    #@jit.elidable
     def lshift(self, int_other):
         if int_other < 0:
             raise ValueError("negative shift count")
@@ -760,7 +760,7 @@
         return z
     lshift._always_inline_ = True # It's so fast that it's always benefitial.
     
-    @jit.elidable
+    #@jit.elidable
     def lqshift(self, int_other):
         " A quicker one with much less checks, int_other is valid and for the most part constant."
         assert int_other > 0
@@ -780,7 +780,7 @@
         return z
     lqshift._always_inline_ = True # It's so fast that it's always benefitial.
     
-    @jit.elidable
+    #@jit.elidable
     def rshift(self, int_other, dont_invert=False):
         if int_other < 0:
             raise ValueError("negative shift count")
@@ -815,15 +815,15 @@
         return z
     rshift._always_inline_ = True # It's so fast that it's always benefitial.
     
-    @jit.elidable
+    #@jit.elidable
     def and_(self, other):
         return _bitwise(self, '&', other)
 
-    @jit.elidable
+    #@jit.elidable
     def xor(self, other):
         return _bitwise(self, '^', other)
 
-    @jit.elidable
+    #@jit.elidable
     def or_(self, other):
         return _bitwise(self, '|', other)
 
@@ -836,7 +836,7 @@
     def hex(self):
         return _format(self, BASE16, '0x', 'L')
 
-    @jit.elidable
+    #@jit.elidable
     def log(self, base):
         # base is supposed to be positive or 0.0, which means we use e
         if base == 10.0:
@@ -1134,17 +1134,16 @@
     viewing the shift as being by digits.  The sign bit is ignored, and
     the return values are >= 0.
     """
-    size_n = n.numdigits() / 3
+    size_n = n.numdigits()
     size_lo = min(size_n, size)
     lo = rbigint(n._digits[:size_lo], 1)
-    mid = rbigint(n._digits[size_lo:size * 2], 1)
+    mid = rbigint(n._digits[size_lo:size_lo * 2], 1)
     hi = rbigint(n._digits[size_lo *2:], 1)
     lo._normalize()
     mid._normalize()
     hi._normalize()
     return hi, mid, lo
 
-THREERBIGINT = rbigint.fromint(3) # Used by tc_mul
 def _tc_mul(a, b):
     """
     Toom Cook
@@ -1153,7 +1152,7 @@
     bsize = b.numdigits()
 
     # Split a & b into hi, mid and lo pieces.
-    shift = bsize // 3
+    shift = (2+bsize) // 3
     ah, am, al = _tcmul_split(a, shift)
     assert ah.sign == 1    # the split isn't degenerate
 
@@ -1164,41 +1163,46 @@
     else:
         bh, bm, bl = _tcmul_split(b, shift)
     # 2. ahl, bhl
-    ahl = al.add(ah)
-    bhl = bl.add(bh)
+    ahl = _x_add(al, ah)
+    bhl = _x_add(bl, bh)
     
     # Points
     v0 = al.mul(bl)
-    v1 = ahl.add(bm).mul(bhl.add(bm))
+    vn1 = ahl.sub(am).mul(bhl.sub(bm))
     
-    vn1 = ahl.sub(bm).mul(bhl.sub(bm))
-    v2 = al.add(am.lshift(1)).add(ah.lshift(2)).mul(bl.add(bm.lshift(1)).add(bh.lshift(2)))
+    ahml = _x_add(ahl, am)
+    bhml = _x_add(bhl, bm)
+    
+    v1 = ahml.mul(bhml)
+    v2 = _x_add(ahml, ah).lshift(1).sub(al).mul(_x_add(bhml, bh).lshift(1).sub(bl))
     vinf = ah.mul(bh)
     
-    # Construct
-    t1 = v0.mul(THREERBIGINT).add(vn1.lshift(1)).add(v2)
-    _inplace_divrem1(t1, t1, 6)
-    t1 = t1.sub(vinf.lshift(1))
-    t2 = v1.add(vn1)
+    t2 = _x_sub(v2, vn1)
+    _inplace_divrem1(t2, t2, 3)
+    tn1 = v1.sub(vn1)
+    _v_rshift(tn1, tn1, tn1.numdigits(), 1)
+    t1 = v1
+    _v_isub(t1, 0, t1.numdigits(), v0, v0.numdigits())
+    _v_isub(t2, 0, t2.numdigits(), t1, t1.numdigits())
     _v_rshift(t2, t2, t2.numdigits(), 1)
+    _v_isub(t1, 0, t1.numdigits(), tn1, tn1.numdigits())
+    _v_isub(t1, 0, t1.numdigits(), vinf, vinf.numdigits())
     
-    r1 = v1.sub(t1)
-    r2 = t2.sub(v0).sub(vinf)
-    r3 = t1.sub(t2)
-    # r0 = v0, r4 = vinf
+    t2 = t2.sub(vinf.lshift(1))
+    _v_isub(tn1, 0, tn1.numdigits(), t2, t2.numdigits())
     
-    # Now we fit r+ r2 + r4 into the new string.
+    # Now we fit t+ t2 + t4 into the new string.
     # Now we got to add the r1 and r3 in the mid shift.
     # Allocate result space.
-    ret = rbigint([NULLDIGIT] * (4*shift + vinf.numdigits()), 1)  # This is because of the size of vinf
+    ret = rbigint([NULLDIGIT] * (4 * shift + vinf.numdigits() + 1), 1)  # This is because of the size of vinf
     
     ret._digits[:v0.numdigits()] = v0._digits
     #print ret.numdigits(), r2.numdigits(), vinf.numdigits(), shift, shift * 5, asize, bsize
     #print r2.sign >= 0
-    assert r2.sign >= 0
+    assert t2.sign >= 0
     #print 2*shift + r2.numdigits() < ret.numdigits()
-    assert 2*shift + r2.numdigits() < ret.numdigits()
-    ret._digits[shift * 2:shift * 2+r2.numdigits()] = r2._digits
+    assert 2*shift + t2.numdigits() < ret.numdigits()
+    ret._digits[shift * 2:shift * 2+t2.numdigits()] = t2._digits
     #print vinf.sign >= 0
     assert vinf.sign >= 0
     #print 4*shift + vinf.numdigits() <= ret.numdigits()
@@ -1207,8 +1211,8 @@
 
 
     i = ret.numdigits() - shift
-    _v_iadd(ret, shift, i, r1, r1.numdigits())
-    _v_iadd(ret, shift * 3, i, r3, r3.numdigits())
+    _v_iadd(ret, shift, i, tn1, tn1.numdigits())
+    _v_iadd(ret, shift * 3, i, t1, t1.numdigits())
 
     ret._normalize()
     return ret
@@ -1469,14 +1473,12 @@
         carry += x.udigit(i) + y.udigit(i-xofs)
         x.setdigit(i, carry)
         carry >>= SHIFT
-        assert (carry & 1) == carry
         i += 1
     iend = xofs + m
     while carry and i < iend:
         carry += x.udigit(i)
         x.setdigit(i, carry)
         carry >>= SHIFT
-        assert (carry & 1) == carry
         i += 1
     return carry
 
diff --git a/pypy/translator/goal/targetbigintbenchmark.py b/pypy/translator/goal/targetbigintbenchmark.py
--- a/pypy/translator/goal/targetbigintbenchmark.py
+++ b/pypy/translator/goal/targetbigintbenchmark.py
@@ -2,7 +2,7 @@
 
 import os, sys
 from time import time
-from pypy.rlib.rbigint import rbigint
+from pypy.rlib.rbigint import rbigint, _k_mul, _tc_mul
 
 # __________  Entry point  __________
 
@@ -74,17 +74,30 @@
     sumTime = 0.0
     
     
-    """t = time()
-    by = rbigint.pow(rbigint.fromint(63), rbigint.fromint(100))
-    for n in xrange(9900):
+    """ t = time()
+    by = rbigint.fromint(2**62).lshift(1030000)
+    for n in xrange(5000):
         by2 = by.lshift(63)
-        rbigint.mul(by, by2)
+        _tc_mul(by, by2)
         by = by2
         
 
     _time = time() - t
     sumTime += _time
-    print "Toom-cook effectivity 100-10000 digits:", _time"""
+    print "Toom-cook effectivity _Tcmul 1030000-1035000 digits:", _time
+    
+    t = time()
+    by = rbigint.fromint(2**62).lshift(1030000)
+    for n in xrange(5000):
+        by2 = by.lshift(63)
+        _k_mul(by, by2)
+        by = by2
+        
+
+    _time = time() - t
+    sumTime += _time
+    print "Toom-cook effectivity _kMul 1030000-1035000 digits:", _time"""
+    
     
     V2 = rbigint.fromint(2)
     num = rbigint.pow(rbigint.fromint(100000000), rbigint.fromint(1024))


More information about the pypy-commit mailing list