How fast can we multiply?
Christian Tismer
tismer at appliedbiometrics.com
Sun Jul 18 14:48:21 EDT 1999
Sorry,
I detected an oversight, which made me loose a lot of speed.
Version 0.2 is now 4 times faster at 100000 bits.
"""
fastlong.py
Fast long integers.
Christian Tismer
990718
Version 0.2:
Adding some forgotten fastlong calls made it much faster.
>From time to time I have argued on the main list that Python's
long integers should have a faster multiplication. It is a well
known fact that by some factorization, integer multiplication's
complexity can be brought from O(n**2) down to O(n**1.5) .
The method is sometimes called Karatsuba algorithm. But studying
chapter 12.7 of (1), it turns out to be nothing else but a special
case of preconditioning polynomials.
I think to remember that the GMP library uses this algorithm,
but I didn't look it up and tried my own straightforward way.
The reason why I don't aim to modify the C implementation any
longer is: It will not be reasonably faster in C. The real
working horse is still the basic internal long multiplication
(sometimes called "highschool algorithm") which is fast for small
long integers. The only effect of a C port would be to move the
break_even point to something smaller.
My proposal to get something like this into the C code:
For very long numbers to multiply, the C version could
be enabled to call back into the Python implementation.
A function to obtain the bit length of a long would
be most helpful. Some kind of split function as well.
Currently, with a "clean" nbits implementation, the
break even point appears to be 50000 bits.
With a marshal-based inaccurate hack, break even is
around 10000 bits.
For a bit length of 100000, my algorithm outperforms
the built in long numbers by a factor of 4.
(1) Aho/Hopcroft/Ullmann:
"The Design and Analysis of Computer Algorithms"
"""
class fastlong:
"""a class around long numbers which can multiply faster than
long"""
break_even = 50000 # clean nbits
break_even = 10000 # hacked nbits
def __init__(self, val):
if isinstance(val, fastlong):
self.val = val.val
else:
self.val = long(val)
def __repr__(self):
return repr(self.val)
def __add__(self, other):
return fastlong(self.val + fastlong(other).val)
def __sub__(self, other):
return fastlong(self.val - fastlong(other).val)
def __div__(self, other):
# not optimized yet
return fastlong(self.val / fastlong(other).val)
def __pos__(self):
return self
def __neg__(self):
return fastlong(-self.val)
def __lshift__(self, bits):
return fastlong(self.val << bits)
def __rshift__(self, bits):
return fastlong(self.val >> bits)
def __mul__(self, other):
"""Karatsuba alogrithm
(a+bX)(c+dX)=(X^2+X)(bd)+(X)(a-b)(d-c)+(X+1)(ac)
"""
val1 = self.val
val2 = fastlong(other).val
if not val1 and val2:
return 0
bitcount = max(nbits(val1), nbits(val2))
if bitcount < self.break_even:
return val1 * val2
shift = bitcount / 2
(hi1, lo1) = split_num(val1, shift)
(hi2, lo2) = split_num(val2, shift)
sum1 = lo1-hi1
sum2 = hi2-lo2
mul1 = hi1 * hi2
mul2 = lo1 * lo2
mul3 = sum1 * sum2
res = mul2 + ((mul1 + mul2 + mul3) << shift) + (mul1 <<
shift+shift)
return res
def split_num(longval, shift):
scale = 1L << (shift)
mask = scale - 1
lo = longval & mask
hi = longval >> shift
return fastlong(hi), fastlong(lo)
def nbits(x):
"""clean, accurate, not so fast"""
if x <= 2:
if x < 0:
return nbits(-x)+1
return x
sum = 1
while x:
shift = 1
x = x2 = x >> 1
while x2:
x = x2
sum = sum + shift
shift = shift + shift
x2 = x >> shift
return sum
import marshal, struct
def nbits(x):
"""not clean, rather fast, inaccurate"""
s=marshal.dumps(x)
return abs(struct.unpack("i", s[1:5])[0])*15 -7
#---------------------------------------------
def test():
import sys
be = fastlong.break_even
for i in range(be/2, be*16, be/4):
arg = (1l << i+1) -1
fastarg = fastlong(arg)
tim1, res1 = timing(lambda x:x*x, arg)
tim2, res2 = timing(lambda x:x*x, fastarg)
if res1 <> res2:
raise ValueError, "different results fo i=" % i
print i, tim1, tim2
sys.stdout.flush()
def timing(func, args, n=1, **keywords) :
import time
time=time.time
appl=apply
if type(args) != type(()) : args=(args,)
rep=range(n)
before=time()
for i in rep: res=appl(func, args, keywords)
return round(time()-before,4), res
# the end ----------------------------------------------------
--
Christian Tismer :^) <mailto:tismer at appliedbiometrics.com>
Applied Biometrics GmbH : Have a break! Take a ride on Python's
Kaiserin-Augusta-Allee 101 : *Starship* http://starship.python.net
10553 Berlin : PGP key -> http://wwwkeys.pgp.net
PGP Fingerprint E182 71C7 1A9D 66E9 9D15 D3CC D4D7 93E2 1FAE F6DF
we're tired of banana software - shipped green, ripens at home
More information about the Python-list
mailing list