[pypy-commit] pypy default: merge math-improvements
cfbolz
pypy.commits at gmail.com
Fri Feb 8 05:57:35 EST 2019
Author: Carl Friedrich Bolz-Tereick <cfbolz at gmx.de>
Branch:
Changeset: r95895:7a4d0769c63d
Date: 2019-02-08 11:01 +0100
http://bitbucket.org/pypy/pypy/changeset/7a4d0769c63d/
Log: merge math-improvements
diff too long, truncating to 2000 out of 2276 lines
diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst
--- a/pypy/doc/whatsnew-head.rst
+++ b/pypy/doc/whatsnew-head.rst
@@ -13,3 +13,9 @@
The zlib module's compressobj and decompressobj now expose copy methods
as they do on CPython.
+
+
+.. math-improvements
+
+Improve performance of long operations where one of the operands fits into
+an int.
\ No newline at end of file
diff --git a/pypy/objspace/std/intobject.py b/pypy/objspace/std/intobject.py
--- a/pypy/objspace/std/intobject.py
+++ b/pypy/objspace/std/intobject.py
@@ -299,7 +299,7 @@
return ix
-def _pow_ovf2long(space, iv, iw, w_modulus):
+def _pow_ovf2long(space, iv, w_iv, iw, w_iw, w_modulus):
if space.is_none(w_modulus) and _recover_with_smalllong(space):
from pypy.objspace.std.smalllongobject import _pow as _pow_small
try:
@@ -308,9 +308,12 @@
return _pow_small(space, r_longlong(iv), iw, r_longlong(0))
except (OverflowError, ValueError):
pass
- from pypy.objspace.std.longobject import W_LongObject
- w_iv = W_LongObject.fromint(space, iv)
- w_iw = W_LongObject.fromint(space, iw)
+ from pypy.objspace.std.longobject import W_LongObject, W_AbstractLongObject
+ if w_iv is None or not isinstance(w_iv, W_AbstractLongObject):
+ w_iv = W_LongObject.fromint(space, iv)
+ if w_iw is None or not isinstance(w_iw, W_AbstractLongObject):
+ w_iw = W_LongObject.fromint(space, iw)
+
return w_iv.descr_pow(space, w_iw, w_modulus)
@@ -318,7 +321,7 @@
op = getattr(operator, opname, None)
assert op or ovf2small
- def ovf2long(space, x, y):
+ def ovf2long(space, x, w_x, y, w_y):
"""Handle overflowing to smalllong or long"""
if _recover_with_smalllong(space):
if ovf2small:
@@ -330,9 +333,12 @@
b = r_longlong(y)
return W_SmallLongObject(op(a, b))
- from pypy.objspace.std.longobject import W_LongObject
- w_x = W_LongObject.fromint(space, x)
- w_y = W_LongObject.fromint(space, y)
+ from pypy.objspace.std.longobject import W_LongObject, W_AbstractLongObject
+ if w_x is None or not isinstance(w_x, W_AbstractLongObject):
+ w_x = W_LongObject.fromint(space, x)
+ if w_y is None or not isinstance(w_y, W_AbstractLongObject):
+ w_y = W_LongObject.fromint(space, y)
+
return getattr(w_x, 'descr_' + opname)(space, w_y)
return ovf2long
@@ -496,12 +502,18 @@
# can't return NotImplemented (space.pow doesn't do full
# ternary, i.e. w_modulus.__zpow__(self, w_exponent)), so
# handle it ourselves
- return _pow_ovf2long(space, x, y, w_modulus)
+ return _pow_ovf2long(space, x, self, y, w_exponent, w_modulus)
try:
result = _pow(space, x, y, z)
- except (OverflowError, ValueError):
- return _pow_ovf2long(space, x, y, w_modulus)
+ except OverflowError:
+ return _pow_ovf2long(space, x, self, y, w_exponent, w_modulus)
+ except ValueError:
+ # float result, so let avoid a roundtrip in rbigint.
+ self = self.descr_float(space)
+ w_exponent = w_exponent.descr_float(space)
+ return space.pow(self, w_exponent, space.w_None)
+
return space.newint(result)
@unwrap_spec(w_modulus=WrappedDefault(None))
@@ -546,7 +558,7 @@
try:
z = ovfcheck(op(x, y))
except OverflowError:
- return ovf2long(space, x, y)
+ return ovf2long(space, x, self, y, w_other)
else:
z = op(x, y)
return wrapint(space, z)
@@ -568,7 +580,7 @@
try:
z = ovfcheck(op(y, x))
except OverflowError:
- return ovf2long(space, y, x)
+ return ovf2long(space, y, w_other, x, self) # XXX write a test
else:
z = op(y, x)
return wrapint(space, z)
@@ -599,7 +611,7 @@
try:
return func(space, x, y)
except OverflowError:
- return ovf2long(space, x, y)
+ return ovf2long(space, x, self, y, w_other)
else:
return func(space, x, y)
@@ -614,7 +626,7 @@
try:
return func(space, y, x)
except OverflowError:
- return ovf2long(space, y, x)
+ return ovf2long(space, y, w_other, x, self)
else:
return func(space, y, x)
diff --git a/pypy/objspace/std/longobject.py b/pypy/objspace/std/longobject.py
--- a/pypy/objspace/std/longobject.py
+++ b/pypy/objspace/std/longobject.py
@@ -308,28 +308,47 @@
@unwrap_spec(w_modulus=WrappedDefault(None))
def descr_pow(self, space, w_exponent, w_modulus=None):
+ exp_int = 0
+ exp_bigint = None
+ sign = 0
+
if isinstance(w_exponent, W_AbstractIntObject):
- w_exponent = w_exponent.descr_long(space)
+ exp_int = w_exponent.int_w(space)
+ if exp_int > 0:
+ sign = 1
+ elif exp_int < 0:
+ sign = -1
elif not isinstance(w_exponent, W_AbstractLongObject):
return space.w_NotImplemented
+ else:
+ exp_bigint = w_exponent.asbigint()
+ sign = exp_bigint.sign
if space.is_none(w_modulus):
- if w_exponent.asbigint().sign < 0:
+ if sign < 0:
self = self.descr_float(space)
w_exponent = w_exponent.descr_float(space)
return space.pow(self, w_exponent, space.w_None)
- return W_LongObject(self.num.pow(w_exponent.asbigint()))
+ if not exp_bigint:
+ return W_LongObject(self.num.int_pow(exp_int))
+ else:
+ return W_LongObject(self.num.pow(exp_bigint))
+
elif isinstance(w_modulus, W_AbstractIntObject):
w_modulus = w_modulus.descr_long(space)
+
elif not isinstance(w_modulus, W_AbstractLongObject):
return space.w_NotImplemented
- if w_exponent.asbigint().sign < 0:
+ if sign < 0:
raise oefmt(space.w_TypeError,
"pow() 2nd argument cannot be negative when 3rd "
"argument specified")
try:
- result = self.num.pow(w_exponent.asbigint(), w_modulus.asbigint())
+ if not exp_bigint:
+ result = self.num.int_pow(exp_int, w_modulus.asbigint())
+ else:
+ result = self.num.pow(exp_bigint, w_modulus.asbigint())
except ValueError:
raise oefmt(space.w_ValueError, "pow 3rd argument cannot be 0")
return W_LongObject(result)
@@ -372,22 +391,16 @@
descr_gt = _make_descr_cmp('gt')
descr_ge = _make_descr_cmp('ge')
- def _make_generic_descr_binop_noncommutative(opname):
- methname = opname + '_' if opname in ('and', 'or') else opname
- descr_rname = 'descr_r' + opname
- op = getattr(rbigint, methname)
+ def descr_sub(self, space, w_other):
+ if isinstance(w_other, W_AbstractIntObject):
+ return W_LongObject(self.num.int_sub(w_other.int_w(space)))
+ elif not isinstance(w_other, W_AbstractLongObject):
+ return space.w_NotImplemented
+ return W_LongObject(self.num.sub(w_other.asbigint()))
- @func_renamer('descr_' + opname)
- @delegate_other
- def descr_binop(self, space, w_other):
- return W_LongObject(op(self.num, w_other.asbigint()))
-
- @func_renamer(descr_rname)
- @delegate_other
- def descr_rbinop(self, space, w_other):
- return W_LongObject(op(w_other.asbigint(), self.num))
-
- return descr_binop, descr_rbinop
+ @delegate_other
+ def descr_rsub(self, space, w_other):
+ return W_LongObject(w_other.asbigint().sub(self.num))
def _make_generic_descr_binop(opname):
if opname not in COMMUTATIVE_OPS:
@@ -419,28 +432,23 @@
return descr_binop, descr_rbinop
descr_add, descr_radd = _make_generic_descr_binop('add')
- descr_sub, descr_rsub = _make_generic_descr_binop_noncommutative('sub')
+
descr_mul, descr_rmul = _make_generic_descr_binop('mul')
descr_and, descr_rand = _make_generic_descr_binop('and')
descr_or, descr_ror = _make_generic_descr_binop('or')
descr_xor, descr_rxor = _make_generic_descr_binop('xor')
- def _make_descr_binop(func, int_func=None):
+ def _make_descr_binop(func, int_func):
opname = func.__name__[1:]
- if int_func:
- @func_renamer('descr_' + opname)
- def descr_binop(self, space, w_other):
- if isinstance(w_other, W_AbstractIntObject):
- return int_func(self, space, w_other.int_w(space))
- elif not isinstance(w_other, W_AbstractLongObject):
- return space.w_NotImplemented
- return func(self, space, w_other)
- else:
- @delegate_other
- @func_renamer('descr_' + opname)
- def descr_binop(self, space, w_other):
- return func(self, space, w_other)
+ @func_renamer('descr_' + opname)
+ def descr_binop(self, space, w_other):
+ if isinstance(w_other, W_AbstractIntObject):
+ return int_func(self, space, w_other.int_w(space))
+ elif not isinstance(w_other, W_AbstractLongObject):
+ return space.w_NotImplemented
+ return func(self, space, w_other)
+
@delegate_other
@func_renamer('descr_r' + opname)
def descr_rbinop(self, space, w_other):
@@ -460,10 +468,10 @@
raise oefmt(space.w_OverflowError, "shift count too large")
return W_LongObject(self.num.lshift(shift))
- def _int_lshift(self, space, w_other):
- if w_other < 0:
+ def _int_lshift(self, space, other):
+ if other < 0:
raise oefmt(space.w_ValueError, "negative shift count")
- return W_LongObject(self.num.lshift(w_other))
+ return W_LongObject(self.num.lshift(other))
descr_lshift, descr_rlshift = _make_descr_binop(_lshift, _int_lshift)
@@ -476,11 +484,11 @@
raise oefmt(space.w_OverflowError, "shift count too large")
return newlong(space, self.num.rshift(shift))
- def _int_rshift(self, space, w_other):
- if w_other < 0:
+ def _int_rshift(self, space, other):
+ if other < 0:
raise oefmt(space.w_ValueError, "negative shift count")
- return newlong(space, self.num.rshift(w_other))
+ return newlong(space, self.num.rshift(other))
descr_rshift, descr_rrshift = _make_descr_binop(_rshift, _int_rshift)
def _floordiv(self, space, w_other):
@@ -491,17 +499,18 @@
"long division or modulo by zero")
return newlong(space, z)
- def _floordiv(self, space, w_other):
+ def _int_floordiv(self, space, other):
try:
- z = self.num.floordiv(w_other.asbigint())
+ z = self.num.int_floordiv(other)
except ZeroDivisionError:
raise oefmt(space.w_ZeroDivisionError,
"long division or modulo by zero")
return newlong(space, z)
- descr_floordiv, descr_rfloordiv = _make_descr_binop(_floordiv)
+ descr_floordiv, descr_rfloordiv = _make_descr_binop(_floordiv, _int_floordiv)
_div = func_with_new_name(_floordiv, '_div')
- descr_div, descr_rdiv = _make_descr_binop(_div)
+ _int_div = func_with_new_name(_int_floordiv, '_int_div')
+ descr_div, descr_rdiv = _make_descr_binop(_div, _int_div)
def _mod(self, space, w_other):
try:
@@ -511,9 +520,9 @@
"long division or modulo by zero")
return newlong(space, z)
- def _int_mod(self, space, w_other):
+ def _int_mod(self, space, other):
try:
- z = self.num.int_mod(w_other)
+ z = self.num.int_mod(other)
except ZeroDivisionError:
raise oefmt(space.w_ZeroDivisionError,
"long division or modulo by zero")
@@ -527,7 +536,16 @@
raise oefmt(space.w_ZeroDivisionError,
"long division or modulo by zero")
return space.newtuple([newlong(space, div), newlong(space, mod)])
- descr_divmod, descr_rdivmod = _make_descr_binop(_divmod)
+
+ def _int_divmod(self, space, other):
+ try:
+ div, mod = self.num.int_divmod(other)
+ except ZeroDivisionError:
+ raise oefmt(space.w_ZeroDivisionError,
+ "long division or modulo by zero")
+ return space.newtuple([newlong(space, div), newlong(space, mod)])
+
+ descr_divmod, descr_rdivmod = _make_descr_binop(_divmod, _int_divmod)
def newlong(space, bigint):
diff --git a/pypy/objspace/std/test/test_intobject.py b/pypy/objspace/std/test/test_intobject.py
--- a/pypy/objspace/std/test/test_intobject.py
+++ b/pypy/objspace/std/test/test_intobject.py
@@ -679,6 +679,11 @@
x = int(321)
assert x.__rlshift__(333) == 1422567365923326114875084456308921708325401211889530744784729710809598337369906606315292749899759616L
+ def test_some_rops(self):
+ import sys
+ x = int(-sys.maxint)
+ assert x.__rsub__(2) == (2 + sys.maxint)
+
class AppTestIntShortcut(AppTestInt):
spaceconfig = {"objspace.std.intshortcut": True}
diff --git a/rpython/rlib/rarithmetic.py b/rpython/rlib/rarithmetic.py
--- a/rpython/rlib/rarithmetic.py
+++ b/rpython/rlib/rarithmetic.py
@@ -612,6 +612,7 @@
r_ulonglong = build_int('r_ulonglong', False, 64)
r_longlonglong = build_int('r_longlonglong', True, 128)
+r_ulonglonglong = build_int('r_ulonglonglong', False, 128)
longlongmax = r_longlong(LONGLONG_TEST - 1)
if r_longlong is not r_int:
diff --git a/rpython/rlib/rbigint.py b/rpython/rlib/rbigint.py
--- a/rpython/rlib/rbigint.py
+++ b/rpython/rlib/rbigint.py
@@ -27,6 +27,7 @@
else:
UDIGIT_MASK = longlongmask
LONG_TYPE = rffi.__INT128_T
+ ULONG_TYPE = rffi.__UINT128_T
if LONG_BIT > SHIFT:
STORE_TYPE = lltype.Signed
UNSIGNED_TYPE = lltype.Unsigned
@@ -40,6 +41,7 @@
STORE_TYPE = lltype.Signed
UNSIGNED_TYPE = lltype.Unsigned
LONG_TYPE = rffi.LONGLONG
+ ULONG_TYPE = rffi.ULONGLONG
MASK = int((1 << SHIFT) - 1)
FLOAT_MULTIPLIER = float(1 << SHIFT)
@@ -97,6 +99,9 @@
def _widen_digit(x):
return rffi.cast(LONG_TYPE, x)
+def _unsigned_widen_digit(x):
+ return rffi.cast(ULONG_TYPE, x)
+
@specialize.argtype(0)
def _store_digit(x):
return rffi.cast(STORE_TYPE, x)
@@ -108,6 +113,7 @@
NULLDIGIT = _store_digit(0)
ONEDIGIT = _store_digit(1)
+NULLDIGITS = [NULLDIGIT]
def _check_digits(l):
for x in l:
@@ -133,22 +139,26 @@
def specialize_call(self, hop):
hop.exception_cannot_occur()
+def intsign(i):
+ return -1 if i < 0 else 1
class rbigint(object):
"""This is a reimplementation of longs using a list of digits."""
_immutable_ = True
- _immutable_fields_ = ["_digits"]
-
- def __init__(self, digits=[NULLDIGIT], sign=0, size=0):
+ _immutable_fields_ = ["_digits[*]", "size", "sign"]
+
+ def __init__(self, digits=NULLDIGITS, sign=0, size=0):
if not we_are_translated():
_check_digits(digits)
make_sure_not_resized(digits)
self._digits = digits
+
assert size >= 0
self.size = size or len(digits)
+
self.sign = sign
- # __eq__ and __ne__ method exist for testingl only, they are not RPython!
+ # __eq__ and __ne__ method exist for testing only, they are not RPython!
@not_rpython
def __eq__(self, other):
if not isinstance(other, rbigint):
@@ -159,6 +169,7 @@
def __ne__(self, other):
return not (self == other)
+ @specialize.argtype(1)
def digit(self, x):
"""Return the x'th digit, as an int."""
return self._digits[x]
@@ -170,6 +181,12 @@
return _widen_digit(self._digits[x])
widedigit._always_inline_ = True
+ def uwidedigit(self, x):
+ """Return the x'th digit, as a long long int if needed
+ to have enough room to contain two digits."""
+ return _unsigned_widen_digit(self._digits[x])
+ uwidedigit._always_inline_ = True
+
def udigit(self, x):
"""Return the x'th digit, as an unsigned int."""
return _load_unsigned_digit(self._digits[x])
@@ -183,7 +200,9 @@
setdigit._always_inline_ = True
def numdigits(self):
- return self.size
+ w = self.size
+ assert w > 0
+ return w
numdigits._always_inline_ = True
@staticmethod
@@ -196,13 +215,15 @@
if intval < 0:
sign = -1
ival = -r_uint(intval)
+ carry = ival >> SHIFT
elif intval > 0:
sign = 1
ival = r_uint(intval)
+ carry = 0
else:
return NULLRBIGINT
- carry = ival >> SHIFT
+
if carry:
return rbigint([_store_digit(ival & MASK),
_store_digit(carry)], sign, 2)
@@ -509,23 +530,22 @@
return True
@jit.elidable
- def int_eq(self, other):
+ def int_eq(self, iother):
""" eq with int """
-
- if not int_in_valid_range(other):
- # Fallback to Long.
- return self.eq(rbigint.fromint(other))
+ if not int_in_valid_range(iother):
+ # Fallback to Long.
+ return self.eq(rbigint.fromint(iother))
if self.numdigits() > 1:
return False
- return (self.sign * self.digit(0)) == other
+ return (self.sign * self.digit(0)) == iother
def ne(self, other):
return not self.eq(other)
- def int_ne(self, other):
- return not self.int_eq(other)
+ def int_ne(self, iother):
+ return not self.int_eq(iother)
@jit.elidable
def lt(self, other):
@@ -563,59 +583,38 @@
return False
@jit.elidable
- def int_lt(self, other):
+ def int_lt(self, iother):
""" lt where other is an int """
- if not int_in_valid_range(other):
+ if not int_in_valid_range(iother):
# Fallback to Long.
- return self.lt(rbigint.fromint(other))
-
- osign = 1
- if other == 0:
- osign = 0
- elif other < 0:
- osign = -1
-
- if self.sign > osign:
- return False
- elif self.sign < osign:
- return True
-
- digits = self.numdigits()
-
- if digits > 1:
- if osign == 1:
- return False
- else:
- return True
-
- d1 = self.sign * self.digit(0)
- if d1 < other:
- return True
- return False
+ return self.lt(rbigint.fromint(iother))
+
+ return _x_int_lt(self, iother, False)
def le(self, other):
return not other.lt(self)
- def int_le(self, other):
- # Alternative that might be faster, reimplant this. as a check with other + 1. But we got to check for overflow
- # or reduce valid range.
-
- if self.int_eq(other):
- return True
- return self.int_lt(other)
+ def int_le(self, iother):
+ """ le where iother is an int """
+
+ if not int_in_valid_range(iother):
+ # Fallback to Long.
+ return self.le(rbigint.fromint(iother))
+
+ return _x_int_lt(self, iother, True)
def gt(self, other):
return other.lt(self)
- def int_gt(self, other):
- return not self.int_le(other)
+ def int_gt(self, iother):
+ return not self.int_le(iother)
def ge(self, other):
return not self.lt(other)
- def int_ge(self, other):
- return not self.int_lt(other)
+ def int_ge(self, iother):
+ return not self.int_lt(iother)
@jit.elidable
def hash(self):
@@ -635,20 +634,20 @@
return result
@jit.elidable
- def int_add(self, other):
- if not int_in_valid_range(other):
+ def int_add(self, iother):
+ if not int_in_valid_range(iother):
# Fallback to long.
- return self.add(rbigint.fromint(other))
+ return self.add(rbigint.fromint(iother))
elif self.sign == 0:
- return rbigint.fromint(other)
- elif other == 0:
+ return rbigint.fromint(iother)
+ elif iother == 0:
return self
- sign = -1 if other < 0 else 1
+ sign = intsign(iother)
if self.sign == sign:
- result = _x_int_add(self, other)
+ result = _x_int_add(self, iother)
else:
- result = _x_int_sub(self, other)
+ result = _x_int_sub(self, iother)
result.sign *= -1
result.sign *= sign
return result
@@ -658,7 +657,7 @@
if other.sign == 0:
return self
elif self.sign == 0:
- return rbigint(other._digits[:other.size], -other.sign, other.size)
+ return rbigint(other._digits[:other.numdigits()], -other.sign, other.numdigits())
elif self.sign == other.sign:
result = _x_sub(self, other)
else:
@@ -667,93 +666,94 @@
return result
@jit.elidable
- def int_sub(self, other):
- if not int_in_valid_range(other):
+ def int_sub(self, iother):
+ if not int_in_valid_range(iother):
# Fallback to long.
- return self.sub(rbigint.fromint(other))
- elif other == 0:
+ return self.sub(rbigint.fromint(iother))
+ elif iother == 0:
return self
elif self.sign == 0:
- return rbigint.fromint(-other)
- elif self.sign == (-1 if other < 0 else 1):
- result = _x_int_sub(self, other)
+ return rbigint.fromint(-iother)
+ elif self.sign == intsign(iother):
+ result = _x_int_sub(self, iother)
else:
- result = _x_int_add(self, other)
+ result = _x_int_add(self, iother)
result.sign *= self.sign
return result
@jit.elidable
- def mul(self, b):
- asize = self.numdigits()
- bsize = b.numdigits()
-
- a = self
-
- if asize > bsize:
- a, b, asize, bsize = b, a, bsize, asize
-
- if a.sign == 0 or b.sign == 0:
+ def mul(self, other):
+ selfsize = self.numdigits()
+ othersize = other.numdigits()
+
+ if selfsize > othersize:
+ self, other, selfsize, othersize = other, self, othersize, selfsize
+
+ if self.sign == 0 or other.sign == 0:
return NULLRBIGINT
- if asize == 1:
- if a._digits[0] == ONEDIGIT:
- return rbigint(b._digits[:b.size], a.sign * b.sign, b.size)
- elif bsize == 1:
- res = b.widedigit(0) * a.widedigit(0)
+ if selfsize == 1:
+ if self._digits[0] == ONEDIGIT:
+ return rbigint(other._digits[:othersize], self.sign * other.sign, othersize)
+ elif othersize == 1:
+ res = other.uwidedigit(0) * self.udigit(0)
carry = res >> SHIFT
if carry:
- return rbigint([_store_digit(res & MASK), _store_digit(carry)], a.sign * b.sign, 2)
+ return rbigint([_store_digit(res & MASK), _store_digit(carry)], self.sign * other.sign, 2)
else:
- return rbigint([_store_digit(res & MASK)], a.sign * b.sign, 1)
-
- result = _x_mul(a, b, a.digit(0))
+ return rbigint([_store_digit(res & MASK)], self.sign * other.sign, 1)
+
+ result = _x_mul(self, other, self.digit(0))
elif USE_KARATSUBA:
- if a is b:
+ if self is other:
i = KARATSUBA_SQUARE_CUTOFF
else:
i = KARATSUBA_CUTOFF
- if asize <= i:
- result = _x_mul(a, b)
- """elif 2 * asize <= bsize:
- result = _k_lopsided_mul(a, b)"""
+ if selfsize <= i:
+ result = _x_mul(self, other)
+ """elif 2 * selfsize <= othersize:
+ result = _k_lopsided_mul(self, other)"""
else:
- result = _k_mul(a, b)
+ result = _k_mul(self, other)
else:
- result = _x_mul(a, b)
-
- result.sign = a.sign * b.sign
+ result = _x_mul(self, other)
+
+ result.sign = self.sign * other.sign
return result
@jit.elidable
- def int_mul(self, b):
- if not int_in_valid_range(b):
+ def int_mul(self, iother):
+ if not int_in_valid_range(iother):
# Fallback to long.
- return self.mul(rbigint.fromint(b))
-
- if self.sign == 0 or b == 0:
+ return self.mul(rbigint.fromint(iother))
+
+ if self.sign == 0 or iother == 0:
return NULLRBIGINT
asize = self.numdigits()
- digit = abs(b)
- bsign = -1 if b < 0 else 1
+ digit = abs(iother)
+
+ othersign = intsign(iother)
if digit == 1:
- return rbigint(self._digits[:self.size], self.sign * bsign, asize)
+ if othersign == 1:
+ return self
+ return rbigint(self._digits[:asize], self.sign * othersign, asize)
elif asize == 1:
- res = self.widedigit(0) * digit
+ udigit = r_uint(digit)
+ res = self.uwidedigit(0) * udigit
carry = res >> SHIFT
if carry:
- return rbigint([_store_digit(res & MASK), _store_digit(carry)], self.sign * bsign, 2)
+ return rbigint([_store_digit(res & MASK), _store_digit(carry)], self.sign * othersign, 2)
else:
- return rbigint([_store_digit(res & MASK)], self.sign * bsign, 1)
-
+ return rbigint([_store_digit(res & MASK)], self.sign * othersign, 1)
elif digit & (digit - 1) == 0:
result = self.lqshift(ptwotable[digit])
else:
result = _muladd1(self, digit)
- result.sign = self.sign * bsign
+ result.sign = self.sign * othersign
return result
@jit.elidable
@@ -763,12 +763,10 @@
@jit.elidable
def floordiv(self, other):
- if self.sign == 1 and other.numdigits() == 1 and other.sign == 1:
- digit = other.digit(0)
- if digit == 1:
- return rbigint(self._digits[:self.size], 1, self.size)
- elif digit and digit & (digit - 1) == 0:
- return self.rshift(ptwotable[digit])
+ if other.numdigits() == 1:
+ otherint = other.digit(0) * other.sign
+ assert int_in_valid_range(otherint)
+ return self.int_floordiv(otherint)
div, mod = _divrem(self, other)
if mod.sign * other.sign == -1:
@@ -782,6 +780,37 @@
return self.floordiv(other)
@jit.elidable
+ def int_floordiv(self, iother):
+ if not int_in_valid_range(iother):
+ # Fallback to long.
+ return self.floordiv(rbigint.fromint(iother))
+
+ if iother == 0:
+ raise ZeroDivisionError("long division by zero")
+
+ digit = abs(iother)
+ assert digit > 0
+
+ if self.sign == 1 and iother > 0:
+ if digit == 1:
+ return self
+ elif digit & (digit - 1) == 0:
+ return self.rqshift(ptwotable[digit])
+
+ div, mod = _divrem1(self, digit)
+
+ if mod != 0 and self.sign * intsign(iother) == -1:
+ if div.sign == 0:
+ return ONENEGATIVERBIGINT
+ div = div.int_add(1)
+ div.sign = self.sign * intsign(iother)
+ div._normalize()
+ return div
+
+ def int_div(self, iother):
+ return self.int_floordiv(iother)
+
+ @jit.elidable
def mod(self, other):
if other.sign == 0:
raise ZeroDivisionError("long division or modulo by zero")
@@ -799,50 +828,50 @@
return mod
@jit.elidable
- def int_mod(self, other):
- if other == 0:
+ def int_mod(self, iother):
+ if iother == 0:
raise ZeroDivisionError("long division or modulo by zero")
if self.sign == 0:
return NULLRBIGINT
- elif not int_in_valid_range(other):
+ elif not int_in_valid_range(iother):
# Fallback to long.
- return self.mod(rbigint.fromint(other))
+ return self.mod(rbigint.fromint(iother))
if 1: # preserve indentation to preserve history
- digit = abs(other)
+ digit = abs(iother)
if digit == 1:
return NULLRBIGINT
elif digit == 2:
modm = self.digit(0) & 1
if modm:
- return ONENEGATIVERBIGINT if other < 0 else ONERBIGINT
+ return ONENEGATIVERBIGINT if iother < 0 else ONERBIGINT
return NULLRBIGINT
elif digit & (digit - 1) == 0:
mod = self.int_and_(digit - 1)
else:
# Perform
- size = self.numdigits() - 1
+ size = UDIGIT_TYPE(self.numdigits() - 1)
if size > 0:
- rem = self.widedigit(size)
- size -= 1
- while size >= 0:
- rem = ((rem << SHIFT) + self.widedigit(size)) % digit
+ wrem = self.widedigit(size)
+ while size > 0:
size -= 1
+ wrem = ((wrem << SHIFT) | self.digit(size)) % digit
+ rem = _store_digit(wrem)
else:
- rem = self.digit(0) % digit
+ rem = _store_digit(self.digit(0) % digit)
if rem == 0:
return NULLRBIGINT
- mod = rbigint([_store_digit(rem)], -1 if self.sign < 0 else 1, 1)
-
- if mod.sign * (-1 if other < 0 else 1) == -1:
- mod = mod.int_add(other)
+ mod = rbigint([rem], -1 if self.sign < 0 else 1, 1)
+
+ if mod.sign * intsign(iother) == -1:
+ mod = mod.int_add(iother)
return mod
@jit.elidable
- def divmod(v, w):
+ def divmod(self, other):
"""
The / and % operators are now defined in terms of divmod().
The expression a mod b has the value a - b*floor(a/b).
@@ -859,46 +888,78 @@
have different signs. We then subtract one from the 'div'
part of the outcome to keep the invariant intact.
"""
- div, mod = _divrem(v, w)
- if mod.sign * w.sign == -1:
- mod = mod.add(w)
+ div, mod = _divrem(self, other)
+ if mod.sign * other.sign == -1:
+ mod = mod.add(other)
if div.sign == 0:
return ONENEGATIVERBIGINT, mod
div = div.int_sub(1)
return div, mod
@jit.elidable
- def pow(a, b, c=None):
+ def int_divmod(self, iother):
+ """ Divmod with int """
+
+ if iother == 0:
+ raise ZeroDivisionError("long division or modulo by zero")
+
+ wsign = intsign(iother)
+ if not int_in_valid_range(iother) or (wsign == -1 and self.sign != wsign):
+ # Just fallback.
+ return self.divmod(rbigint.fromint(iother))
+
+ digit = abs(iother)
+ assert digit > 0
+
+ div, mod = _divrem1(self, digit)
+ # _divrem1 doesn't fix the sign
+ if div.size == 1 and div._digits[0] == NULLDIGIT:
+ div.sign = 0
+ else:
+ div.sign = self.sign * wsign
+ if self.sign < 0:
+ mod = -mod
+ if mod and self.sign * wsign == -1:
+ mod += iother
+ if div.sign == 0:
+ div = ONENEGATIVERBIGINT
+ else:
+ div = div.int_sub(1)
+ mod = rbigint.fromint(mod)
+ return div, mod
+
+ @jit.elidable
+ def pow(self, other, modulus=None):
negativeOutput = False # if x<0 return negative output
# 5-ary values. If the exponent is large enough, table is
- # precomputed so that table[i] == a**i % c for i in range(32).
+ # precomputed so that table[i] == self**i % modulus for i in range(32).
# python translation: the table is computed when needed.
- if b.sign < 0: # if exponent is negative
- if c is not None:
+ if other.sign < 0: # if exponent is negative
+ if modulus is not None:
raise TypeError(
"pow() 2nd argument "
"cannot be negative when 3rd argument specified")
# XXX failed to implement
raise ValueError("bigint pow() too negative")
- size_b = b.numdigits()
-
- if c is not None:
- if c.sign == 0:
+ size_b = UDIGIT_TYPE(other.numdigits())
+
+ if modulus is not None:
+ if modulus.sign == 0:
raise ValueError("pow() 3rd argument cannot be 0")
# if modulus < 0:
# negativeOutput = True
# modulus = -modulus
- if c.sign < 0:
+ if modulus.sign < 0:
negativeOutput = True
- c = c.neg()
+ modulus = modulus.neg()
# if modulus == 1:
# return 0
- if c.numdigits() == 1 and c._digits[0] == ONEDIGIT:
+ if modulus.numdigits() == 1 and modulus._digits[0] == ONEDIGIT:
return NULLRBIGINT
# Reduce base by modulus in some cases:
@@ -910,63 +971,61 @@
# base % modulus instead.
# We could _always_ do this reduction, but mod() isn't cheap,
# so we only do it when it buys something.
- if a.sign < 0 or a.numdigits() > c.numdigits():
- a = a.mod(c)
-
- elif b.sign == 0:
+ if self.sign < 0 or self.numdigits() > modulus.numdigits():
+ self = self.mod(modulus)
+ elif other.sign == 0:
return ONERBIGINT
- elif a.sign == 0:
+ elif self.sign == 0:
return NULLRBIGINT
elif size_b == 1:
- if b._digits[0] == NULLDIGIT:
- return ONERBIGINT if a.sign == 1 else ONENEGATIVERBIGINT
- elif b._digits[0] == ONEDIGIT:
- return a
- elif a.numdigits() == 1:
- adigit = a.digit(0)
- digit = b.digit(0)
+ if other._digits[0] == ONEDIGIT:
+ return self
+ elif self.numdigits() == 1 and modulus is None:
+ adigit = self.digit(0)
+ digit = other.digit(0)
if adigit == 1:
- if a.sign == -1 and digit % 2:
+ if self.sign == -1 and digit % 2:
return ONENEGATIVERBIGINT
return ONERBIGINT
elif adigit & (adigit - 1) == 0:
- ret = a.lshift(((digit-1)*(ptwotable[adigit]-1)) + digit-1)
- if a.sign == -1 and not digit % 2:
+ ret = self.lshift(((digit-1)*(ptwotable[adigit]-1)) + digit-1)
+ if self.sign == -1 and not digit % 2:
ret.sign = 1
return ret
- # At this point a, b, and c are guaranteed non-negative UNLESS
- # c is NULL, in which case a may be negative. */
-
- z = rbigint([ONEDIGIT], 1, 1)
+ # At this point self, other, and modulus are guaranteed non-negative UNLESS
+ # modulus is NULL, in which case self may be negative. */
+
+ z = ONERBIGINT
# python adaptation: moved macros REDUCE(X) and MULT(X, Y, result)
# into helper function result = _help_mult(x, y, c)
if size_b <= FIVEARY_CUTOFF:
# Left-to-right binary exponentiation (HAC Algorithm 14.79)
# http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
- size_b -= 1
- while size_b >= 0:
- bi = b.digit(size_b)
+
+ while size_b > 0:
+ size_b -= 1
+ bi = other.digit(size_b)
j = 1 << (SHIFT-1)
while j != 0:
- z = _help_mult(z, z, c)
+ z = _help_mult(z, z, modulus)
if bi & j:
- z = _help_mult(z, a, c)
+ z = _help_mult(z, self, modulus)
j >>= 1
- size_b -= 1
+
else:
# Left-to-right 5-ary exponentiation (HAC Algorithm 14.82)
- # This is only useful in the case where c != None.
+ # This is only useful in the case where modulus != None.
# z still holds 1L
table = [z] * 32
table[0] = z
for i in range(1, 32):
- table[i] = _help_mult(table[i-1], a, c)
+ table[i] = _help_mult(table[i-1], self, modulus)
# Note that here SHIFT is not a multiple of 5. The difficulty
- # is to extract 5 bits at a time from 'b', starting from the
+ # is to extract 5 bits at a time from 'other', starting from the
# most significant digits, so that at the end of the algorithm
# it falls exactly to zero.
# m = max number of bits = i * SHIFT
@@ -985,37 +1044,120 @@
index = (accum >> j) & 0x1f
else:
# 'accum' does not have enough digit.
- # must get the next digit from 'b' in order to complete
+ # must get the next digit from 'other' in order to complete
if size_b == 0:
break # Done
size_b -= 1
assert size_b >= 0
- bi = b.udigit(size_b)
+ bi = other.udigit(size_b)
index = ((accum << (-j)) | (bi >> (j+SHIFT))) & 0x1f
accum = bi
j += SHIFT
#
for k in range(5):
- z = _help_mult(z, z, c)
+ z = _help_mult(z, z, modulus)
if index:
- z = _help_mult(z, table[index], c)
+ z = _help_mult(z, table[index], modulus)
#
assert j == -5
if negativeOutput and z.sign != 0:
- z = z.sub(c)
+ z = z.sub(modulus)
+ return z
+
+ @jit.elidable
+ def int_pow(self, iother, modulus=None):
+ negativeOutput = False # if x<0 return negative output
+
+ # 5-ary values. If the exponent is large enough, table is
+ # precomputed so that table[i] == self**i % modulus for i in range(32).
+ # python translation: the table is computed when needed.
+
+ if iother < 0: # if exponent is negative
+ if modulus is not None:
+ raise TypeError(
+ "pow() 2nd argument "
+ "cannot be negative when 3rd argument specified")
+ # XXX failed to implement
+ raise ValueError("bigint pow() too negative")
+
+ assert iother >= 0
+ if modulus is not None:
+ if modulus.sign == 0:
+ raise ValueError("pow() 3rd argument cannot be 0")
+
+ # if modulus < 0:
+ # negativeOutput = True
+ # modulus = -modulus
+ if modulus.sign < 0:
+ negativeOutput = True
+ modulus = modulus.neg()
+
+ # if modulus == 1:
+ # return 0
+ if modulus.numdigits() == 1 and modulus._digits[0] == ONEDIGIT:
+ return NULLRBIGINT
+
+ # Reduce base by modulus in some cases:
+ # 1. If base < 0. Forcing the base non-neg makes things easier.
+ # 2. If base is obviously larger than the modulus. The "small
+ # exponent" case later can multiply directly by base repeatedly,
+ # while the "large exponent" case multiplies directly by base 31
+ # times. It can be unboundedly faster to multiply by
+ # base % modulus instead.
+ # We could _always_ do this reduction, but mod() isn't cheap,
+ # so we only do it when it buys something.
+ if self.sign < 0 or self.numdigits() > modulus.numdigits():
+ self = self.mod(modulus)
+ elif iother == 0:
+ return ONERBIGINT
+ elif self.sign == 0:
+ return NULLRBIGINT
+ elif iother == 1:
+ return self
+ elif self.numdigits() == 1:
+ adigit = self.digit(0)
+ if adigit == 1:
+ if self.sign == -1 and iother % 2:
+ return ONENEGATIVERBIGINT
+ return ONERBIGINT
+ elif adigit & (adigit - 1) == 0:
+ ret = self.lshift(((iother-1)*(ptwotable[adigit]-1)) + iother-1)
+ if self.sign == -1 and not iother % 2:
+ ret.sign = 1
+ return ret
+
+ # At this point self, iother, and modulus are guaranteed non-negative UNLESS
+ # modulus is NULL, in which case self may be negative. */
+
+ z = ONERBIGINT
+
+ # python adaptation: moved macros REDUCE(X) and MULT(X, Y, result)
+ # into helper function result = _help_mult(x, y, modulus)
+ # Left-to-right binary exponentiation (HAC Algorithm 14.79)
+ # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+ j = 1 << (SHIFT-1)
+
+ while j != 0:
+ z = _help_mult(z, z, modulus)
+ if iother & j:
+ z = _help_mult(z, self, modulus)
+ j >>= 1
+
+ if negativeOutput and z.sign != 0:
+ z = z.sub(modulus)
return z
@jit.elidable
def neg(self):
- return rbigint(self._digits, -self.sign, self.size)
+ return rbigint(self._digits, -self.sign, self.numdigits())
@jit.elidable
def abs(self):
if self.sign != -1:
return self
- return rbigint(self._digits, 1, self.size)
+ return rbigint(self._digits, 1, self.numdigits())
@jit.elidable
def invert(self): #Implement ~x as -(x + 1)
@@ -1041,15 +1183,15 @@
# So we can avoid problems with eq, AND avoid the need for normalize.
if self.sign == 0:
return self
- return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign, self.size + wordshift)
+ return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign, self.numdigits() + wordshift)
oldsize = self.numdigits()
newsize = oldsize + wordshift + 1
z = rbigint([NULLDIGIT] * newsize, self.sign, newsize)
- accum = _widen_digit(0)
+ accum = _unsigned_widen_digit(0)
j = 0
while j < oldsize:
- accum += self.widedigit(j) << remshift
+ accum += self.uwidedigit(j) << remshift
z.setdigit(wordshift, accum)
accum >>= SHIFT
wordshift += 1
@@ -1061,7 +1203,7 @@
z._normalize()
return z
- lshift._always_inline_ = True # It's so fast that it's always benefitial.
+ lshift._always_inline_ = True # It's so fast that it's always beneficial.
@jit.elidable
def lqshift(self, int_other):
@@ -1071,17 +1213,17 @@
oldsize = self.numdigits()
z = rbigint([NULLDIGIT] * (oldsize + 1), self.sign, (oldsize + 1))
- accum = _widen_digit(0)
+ accum = _unsigned_widen_digit(0)
i = 0
while i < oldsize:
- accum += self.widedigit(i) << int_other
+ accum += self.uwidedigit(i) << int_other
z.setdigit(i, accum)
accum >>= SHIFT
i += 1
z.setdigit(oldsize, accum)
z._normalize()
return z
- lqshift._always_inline_ = True # It's so fast that it's always benefitial.
+ lqshift._always_inline_ = True # It's so fast that it's always beneficial.
@jit.elidable
def rshift(self, int_other, dont_invert=False):
@@ -1112,6 +1254,31 @@
z._normalize()
return z
rshift._always_inline_ = 'try' # It's so fast that it's always benefitial.
+
+ @jit.elidable
+ def rqshift(self, int_other):
+ wordshift = int_other / SHIFT
+ loshift = int_other % SHIFT
+ newsize = self.numdigits() - wordshift
+
+ if newsize <= 0:
+ return NULLRBIGINT
+
+ hishift = SHIFT - loshift
+ z = rbigint([NULLDIGIT] * newsize, self.sign, newsize)
+ i = 0
+
+ while i < newsize:
+ digit = self.udigit(wordshift)
+ newdigit = (digit >> loshift)
+ if i+1 < newsize:
+ newdigit |= (self.udigit(wordshift+1) << hishift)
+ z.setdigit(i, newdigit)
+ i += 1
+ wordshift += 1
+ z._normalize()
+ return z
+ rshift._always_inline_ = 'try' # It's so fast that it's always beneficial.
@jit.elidable
def abs_rshift_and_mask(self, bigshiftcount, mask):
@@ -1167,24 +1334,24 @@
return _bitwise(self, '&', other)
@jit.elidable
- def int_and_(self, other):
- return _int_bitwise(self, '&', other)
+ def int_and_(self, iother):
+ return _int_bitwise(self, '&', iother)
@jit.elidable
def xor(self, other):
return _bitwise(self, '^', other)
@jit.elidable
- def int_xor(self, other):
- return _int_bitwise(self, '^', other)
+ def int_xor(self, iother):
+ return _int_bitwise(self, '^', iother)
@jit.elidable
def or_(self, other):
return _bitwise(self, '|', other)
@jit.elidable
- def int_or_(self, other):
- return _int_bitwise(self, '|', other)
+ def int_or_(self, iother):
+ return _int_bitwise(self, '|', iother)
@jit.elidable
def oct(self):
@@ -1218,7 +1385,10 @@
for d in digits:
l = l << SHIFT
l += intmask(d)
- return l * self.sign
+ result = l * self.sign
+ if result == 0:
+ assert self.sign == 0
+ return result
def _normalize(self):
i = self.numdigits()
@@ -1227,11 +1397,10 @@
i -= 1
assert i > 0
- if i != self.numdigits():
- self.size = i
- if self.numdigits() == 1 and self._digits[0] == NULLDIGIT:
+ self.size = i
+ if i == 1 and self._digits[0] == NULLDIGIT:
self.sign = 0
- self._digits = [NULLDIGIT]
+ self._digits = NULLDIGITS
_normalize._always_inline_ = True
@@ -1256,8 +1425,8 @@
def __repr__(self):
return "<rbigint digits=%s, sign=%s, size=%d, len=%d, %s>" % (self._digits,
- self.sign, self.size, len(self._digits),
- self.str())
+ self.sign, self.numdigits(), len(self._digits),
+ self.tolong())
ONERBIGINT = rbigint([ONEDIGIT], 1, 1)
ONENEGATIVERBIGINT = rbigint([ONEDIGIT], -1, 1)
@@ -1322,7 +1491,7 @@
if x > 0:
return digits_from_nonneg_long(x), 1
elif x == 0:
- return [NULLDIGIT], 0
+ return NULLDIGITS, 0
elif x != most_neg_value_of_same_type(x):
# normal case
return digits_from_nonneg_long(-x), -1
@@ -1340,7 +1509,7 @@
def args_from_long(x):
if x >= 0:
if x == 0:
- return [NULLDIGIT], 0
+ return NULLDIGITS, 0
else:
return digits_from_nonneg_long(x), 1
else:
@@ -1450,7 +1619,7 @@
if adigit == bdigit:
return NULLRBIGINT
-
+
return rbigint.fromint(adigit - bdigit)
z = rbigint([NULLDIGIT] * size_a, 1, size_a)
@@ -1497,11 +1666,11 @@
z = rbigint([NULLDIGIT] * (size_a + size_b), 1)
i = UDIGIT_TYPE(0)
while i < size_a:
- f = a.widedigit(i)
+ f = a.uwidedigit(i)
pz = i << 1
pa = i + 1
- carry = z.widedigit(pz) + f * f
+ carry = z.uwidedigit(pz) + f * f
z.setdigit(pz, carry)
pz += 1
carry >>= SHIFT
@@ -1511,18 +1680,18 @@
# pyramid it appears. Same as adding f<<1 once.
f <<= 1
while pa < size_a:
- carry += z.widedigit(pz) + a.widedigit(pa) * f
+ carry += z.uwidedigit(pz) + a.uwidedigit(pa) * f
pa += 1
z.setdigit(pz, carry)
pz += 1
carry >>= SHIFT
if carry:
- carry += z.widedigit(pz)
+ carry += z.udigit(pz)
z.setdigit(pz, carry)
pz += 1
carry >>= SHIFT
if carry:
- z.setdigit(pz, z.widedigit(pz) + carry)
+ z.setdigit(pz, z.udigit(pz) + carry)
assert (carry >> SHIFT) == 0
i += 1
z._normalize()
@@ -1543,29 +1712,29 @@
size_a1 = UDIGIT_TYPE(size_a - 1)
size_b1 = UDIGIT_TYPE(size_b - 1)
while i < size_a1:
- f0 = a.widedigit(i)
- f1 = a.widedigit(i + 1)
+ f0 = a.uwidedigit(i)
+ f1 = a.uwidedigit(i + 1)
pz = i
- carry = z.widedigit(pz) + b.widedigit(0) * f0
+ carry = z.uwidedigit(pz) + b.uwidedigit(0) * f0
z.setdigit(pz, carry)
pz += 1
carry >>= SHIFT
j = UDIGIT_TYPE(0)
while j < size_b1:
- # this operation does not overflow using
+ # this operation does not overflow using
# SHIFT = (LONG_BIT // 2) - 1 = B - 1; in fact before it
# carry and z.widedigit(pz) are less than 2**(B - 1);
# b.widedigit(j + 1) * f0 < (2**(B-1) - 1)**2; so
# carry + z.widedigit(pz) + b.widedigit(j + 1) * f0 +
# b.widedigit(j) * f1 < 2**(2*B - 1) - 2**B < 2**LONG)BIT - 1
- carry += z.widedigit(pz) + b.widedigit(j + 1) * f0 + \
- b.widedigit(j) * f1
+ carry += z.uwidedigit(pz) + b.uwidedigit(j + 1) * f0 + \
+ b.uwidedigit(j) * f1
z.setdigit(pz, carry)
pz += 1
carry >>= SHIFT
j += 1
# carry < 2**(B + 1) - 2
- carry += z.widedigit(pz) + b.widedigit(size_b1) * f1
+ carry += z.uwidedigit(pz) + b.uwidedigit(size_b1) * f1
z.setdigit(pz, carry)
pz += 1
carry >>= SHIFT
@@ -1576,17 +1745,17 @@
i += 2
if size_a & 1:
pz = size_a1
- f = a.widedigit(pz)
+ f = a.uwidedigit(pz)
pb = 0
- carry = _widen_digit(0)
+ carry = _unsigned_widen_digit(0)
while pb < size_b:
- carry += z.widedigit(pz) + b.widedigit(pb) * f
+ carry += z.uwidedigit(pz) + b.uwidedigit(pb) * f
pb += 1
z.setdigit(pz, carry)
pz += 1
carry >>= SHIFT
if carry:
- z.setdigit(pz, z.widedigit(pz) + carry)
+ z.setdigit(pz, z.udigit(pz) + carry)
z._normalize()
return z
@@ -1602,8 +1771,8 @@
size_lo = min(size_n, size)
# We use "or" her to avoid having a check where list can be empty in _normalize.
- lo = rbigint(n._digits[:size_lo] or [NULLDIGIT], 1)
- hi = rbigint(n._digits[size_lo:n.size] or [NULLDIGIT], 1)
+ lo = rbigint(n._digits[:size_lo] or NULLDIGITS, 1)
+ hi = rbigint(n._digits[size_lo:size_n] or NULLDIGITS, 1)
lo._normalize()
hi._normalize()
return hi, lo
@@ -1708,113 +1877,16 @@
ret._normalize()
return ret
-""" (*) Why adding t3 can't "run out of room" above.
-
-Let f(x) mean the floor of x and c(x) mean the ceiling of x. Some facts
-to start with:
-
-1. For any integer i, i = c(i/2) + f(i/2). In particular,
- bsize = c(bsize/2) + f(bsize/2).
-2. shift = f(bsize/2)
-3. asize <= bsize
-4. Since we call k_lopsided_mul if asize*2 <= bsize, asize*2 > bsize in this
- routine, so asize > bsize/2 >= f(bsize/2) in this routine.
-
-We allocated asize + bsize result digits, and add t3 into them at an offset
-of shift. This leaves asize+bsize-shift allocated digit positions for t3
-to fit into, = (by #1 and #2) asize + f(bsize/2) + c(bsize/2) - f(bsize/2) =
-asize + c(bsize/2) available digit positions.
-
-bh has c(bsize/2) digits, and bl at most f(size/2) digits. So bh+hl has
-at most c(bsize/2) digits + 1 bit.
-
-If asize == bsize, ah has c(bsize/2) digits, else ah has at most f(bsize/2)
-digits, and al has at most f(bsize/2) digits in any case. So ah+al has at
-most (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 1 bit.
-
-The product (ah+al)*(bh+bl) therefore has at most
-
- c(bsize/2) + (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits
-
-and we have asize + c(bsize/2) available digit positions. We need to show
-this is always enough. An instance of c(bsize/2) cancels out in both, so
-the question reduces to whether asize digits is enough to hold
-(asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits. If asize < bsize,
-then we're asking whether asize digits >= f(bsize/2) digits + 2 bits. By #4,
-asize is at least f(bsize/2)+1 digits, so this in turn reduces to whether 1
-digit is enough to hold 2 bits. This is so since SHIFT=15 >= 2. If
-asize == bsize, then we're asking whether bsize digits is enough to hold
-c(bsize/2) digits + 2 bits, or equivalently (by #1) whether f(bsize/2) digits
-is enough to hold 2 bits. This is so if bsize >= 2, which holds because
-bsize >= KARATSUBA_CUTOFF >= 2.
-
-Note that since there's always enough room for (ah+al)*(bh+bl), and that's
-clearly >= each of ah*bh and al*bl, there's always enough room to subtract
-ah*bh and al*bl too.
-"""
-
-def _k_lopsided_mul(a, b):
- # Not in use anymore, only account for like 1% performance. Perhaps if we
- # Got rid of the extra list allocation this would be more effective.
- """
- b has at least twice the digits of a, and a is big enough that Karatsuba
- would pay off *if* the inputs had balanced sizes. View b as a sequence
- of slices, each with a->ob_size digits, and multiply the slices by a,
- one at a time. This gives k_mul balanced inputs to work with, and is
- also cache-friendly (we compute one double-width slice of the result
- at a time, then move on, never bactracking except for the helpful
- single-width slice overlap between successive partial sums).
- """
- asize = a.numdigits()
- bsize = b.numdigits()
- # nbdone is # of b digits already multiplied
-
- assert asize > KARATSUBA_CUTOFF
- assert 2 * asize <= bsize
-
- # Allocate result space, and zero it out.
- ret = rbigint([NULLDIGIT] * (asize + bsize), 1)
-
- # Successive slices of b are copied into bslice.
- #bslice = rbigint([0] * asize, 1)
- # XXX we cannot pre-allocate, see comments below!
- # XXX prevent one list from being created.
- bslice = rbigint(sign=1)
-
- nbdone = 0
- while bsize > 0:
- nbtouse = min(bsize, asize)
-
- # Multiply the next slice of b by a.
-
- #bslice.digits[:nbtouse] = b.digits[nbdone : nbdone + nbtouse]
- # XXX: this would be more efficient if we adopted CPython's
- # way to store the size, instead of resizing the list!
- # XXX change the implementation, encoding length via the sign.
- bslice._digits = b._digits[nbdone : nbdone + nbtouse]
- bslice.size = nbtouse
- product = _k_mul(a, bslice)
-
- # Add into result.
- _v_iadd(ret, nbdone, ret.numdigits() - nbdone,
- product, product.numdigits())
-
- bsize -= nbtouse
- nbdone += nbtouse
-
- ret._normalize()
- return ret
-
def _inplace_divrem1(pout, pin, n):
"""
Divide bigint pin by non-zero digit n, storing quotient
in pout, and returning the remainder. It's OK for pin == pout on entry.
"""
- rem = _widen_digit(0)
+ rem = _unsigned_widen_digit(0)
assert n > 0 and n <= MASK
size = pin.numdigits() - 1
while size >= 0:
- rem = (rem << SHIFT) | pin.widedigit(size)
+ rem = (rem << SHIFT) | pin.udigit(size)
hi = rem // n
pout.setdigit(size, hi)
rem -= hi * n
@@ -1891,14 +1963,15 @@
def _muladd1(a, n, extra=0):
"""Multiply by a single digit and add a single digit, ignoring the sign.
"""
+ assert n > 0
size_a = a.numdigits()
z = rbigint([NULLDIGIT] * (size_a+1), 1)
assert extra & MASK == extra
- carry = _widen_digit(extra)
+ carry = _unsigned_widen_digit(extra)
i = 0
while i < size_a:
- carry += a.widedigit(i) * n
+ carry += a.uwidedigit(i) * n
z.setdigit(i, carry)
carry >>= SHIFT
i += 1
@@ -1912,10 +1985,10 @@
"""
carry = 0
- assert 0 <= d and d < SHIFT
+ #assert 0 <= d and d < SHIFT
i = 0
while i < m:
- acc = a.widedigit(i) << d | carry
+ acc = a.uwidedigit(i) << d | carry
z.setdigit(i, acc)
carry = acc >> SHIFT
i += 1
@@ -1927,14 +2000,14 @@
* result in z[0:m], and return the d bits shifted out of the bottom.
"""
- carry = _widen_digit(0)
- acc = _widen_digit(0)
+ carry = _unsigned_widen_digit(0)
+ acc = _unsigned_widen_digit(0)
mask = (1 << d) - 1
- assert 0 <= d and d < SHIFT
+ #assert 0 <= d and d < SHIFT
i = m-1
while i >= 0:
- acc = (carry << SHIFT) | a.widedigit(i)
+ acc = (carry << SHIFT) | a.udigit(i)
carry = acc & mask
z.setdigit(i, acc >> d)
i -= 1
@@ -1989,10 +2062,17 @@
else:
vtop = v.widedigit(j)
assert vtop <= wm1
+
vv = (vtop << SHIFT) | v.widedigit(abs(j-1))
+
+ # Hints to make division just as fast as doing it unsigned. But avoids casting to get correct results.
+ assert vv >= 0
+ assert wm1 >= 1
+
q = vv / wm1
- r = vv - wm1 * q
- while wm2 * q > ((r << SHIFT) | v.widedigit(abs(j-2))):
+ r = vv % wm1 # This seems to be slightly faster on widen digits than vv - wm1 * q.
+ vj2 = v.digit(abs(j-2))
+ while wm2 * q > ((r << SHIFT) | vj2):
q -= 1
r += wm1
@@ -2059,6 +2139,36 @@
rem.sign = - rem.sign
return z, rem
+def _x_int_lt(a, b, eq=False):
+ """ Compare bigint a with int b for less than or less than or equal """
+ osign = 1
+ if b == 0:
+ osign = 0
+ elif b < 0:
+ osign = -1
+
+ if a.sign > osign:
+ return False
+ elif a.sign < osign:
+ return True
+
+ digits = a.numdigits()
+
+ if digits > 1:
+ if osign == 1:
+ return False
+ else:
+ return True
+
+ d1 = a.sign * a.digit(0)
+ if eq:
+ if d1 <= b:
+ return True
+ else:
+ if d1 < b:
+ return True
+ return False
+
# ______________ conversions to double _______________
def _AsScaledDouble(v):
@@ -2764,7 +2874,7 @@
elif s[p] == '+':
p += 1
- a = rbigint()
+ a = NULLRBIGINT
tens = 1
dig = 0
ord0 = ord('0')
@@ -2785,7 +2895,7 @@
base = parser.base
if (base & (base - 1)) == 0 and base >= 2:
return parse_string_from_binary_base(parser)
- a = rbigint()
+ a = NULLRBIGINT
digitmax = BASE_MAX[base]
tens, dig = 1, 0
while True:
diff --git a/rpython/rlib/test/test_rbigint.py b/rpython/rlib/test/test_rbigint.py
--- a/rpython/rlib/test/test_rbigint.py
+++ b/rpython/rlib/test/test_rbigint.py
@@ -95,6 +95,46 @@
r2 = op1 // op2
assert r1.tolong() == r2
+ def test_int_floordiv(self):
+ x = 1000L
+ r = rbigint.fromlong(x)
+ r2 = r.int_floordiv(10)
+ assert r2.tolong() == 100L
+
+ for op1 in gen_signs(long_vals):
+ for op2 in signed_int_vals:
+ if not op2:
+ continue
+ rl_op1 = rbigint.fromlong(op1)
+ r1 = rl_op1.int_floordiv(op2)
+ r2 = op1 // op2
+ assert r1.tolong() == r2
+
+ assert pytest.raises(ZeroDivisionError, r.int_floordiv, 0)
+
+ # Error pointed out by Armin Rigo
+ n = sys.maxint+1
+ r = rbigint.fromlong(n)
+ assert r.int_floordiv(int(-n)).tolong() == -1L
+
+ for x in int_vals:
+ if not x:
+ continue
+ r = rbigint.fromlong(x)
+ rn = rbigint.fromlong(-x)
+ res = r.int_floordiv(x)
+ res2 = r.int_floordiv(-x)
+ res3 = rn.int_floordiv(x)
+ assert res.tolong() == 1L
+ assert res2.tolong() == -1L
+ assert res3.tolong() == -1L
+
+ def test_floordiv2(self):
+ n1 = rbigint.fromlong(sys.maxint + 1)
+ n2 = rbigint.fromlong(-(sys.maxint + 1))
+ assert n1.floordiv(n2).tolong() == -1L
+ assert n2.floordiv(n1).tolong() == -1L
+
def test_truediv(self):
for op1 in gen_signs(long_vals_not_too_big):
rl_op1 = rbigint.fromlong(op1)
@@ -185,9 +225,26 @@
r4 = pow(op1, op2, op3)
assert r3.tolong() == r4
+ def test_int_pow(self):
+ for op1 in gen_signs(long_vals_not_too_big):
+ rl_op1 = rbigint.fromlong(op1)
+ for op2 in [0, 1, 2, 8, 9, 10, 11, 127, 128, 129]:
+ r1 = rl_op1.int_pow(op2)
+ r2 = op1 ** op2
+ assert r1.tolong() == r2
+
+ for op3 in gen_signs(long_vals_not_too_big):
+ if not op3:
+ continue
+ r3 = rl_op1.int_pow(op2, rbigint.fromlong(op3))
+ r4 = pow(op1, op2, op3)
+ print op1, op2, op3
+ assert r3.tolong() == r4
+
def test_pow_raises(self):
r1 = rbigint.fromint(2)
r0 = rbigint.fromint(0)
+ py.test.raises(ValueError, r1.int_pow, 2, r0)
py.test.raises(ValueError, r1.pow, r1, r0)
def test_touint(self):
@@ -601,6 +658,9 @@
# test special optimization case in rshift:
assert rbigint.fromlong(-(1 << 100)).rshift(5).tolong() == -(1 << 100) >> 5
+ # Chek value accuracy.
+ assert rbigint.fromlong(18446744073709551615L).rshift(1).tolong() == 18446744073709551615L >> 1
+
def test_qshift(self):
for x in range(10):
for y in range(1, 161, 16):
@@ -610,11 +670,18 @@
for z in range(1, 31):
res1 = f1.lqshift(z).tolong()
+ res2 = f1.rqshift(z).tolong()
res3 = nf1.lqshift(z).tolong()
assert res1 == num << z
+ assert res2 == num >> z
assert res3 == -num << z
+ # Large digit
+ for x in range((1 << SHIFT) - 10, (1 << SHIFT) + 10):
+ f1 = rbigint.fromlong(x)
+ assert f1.rqshift(SHIFT).tolong() == x >> SHIFT
+ assert f1.rqshift(SHIFT+1).tolong() == x >> (SHIFT+1)
def test_from_list_n_bits(self):
for x in ([3L ** 30L, 5L ** 20L, 7 ** 300] +
@@ -864,6 +931,27 @@
assert rem.tolong() == _rem
+ def test_int_divmod(self):
+ for x in long_vals:
+ for y in int_vals + [-sys.maxint-1]:
+ if not y:
+ continue
+ for sx, sy in (1, 1), (1, -1), (-1, -1), (-1, 1):
+ sx *= x
+ sy *= y
+ if sy == sys.maxint + 1:
+ continue
+ f1 = rbigint.fromlong(sx)
+ div, rem = f1.int_divmod(sy)
+ div1, rem1 = f1.divmod(rbigint.fromlong(sy))
+ _div, _rem = divmod(sx, sy)
+ print sx, sy, " | ", div.tolong(), rem.tolong()
+ assert div1.tolong() == _div
+ assert rem1.tolong() == _rem
+ assert div.tolong() == _div
+ assert rem.tolong() == _rem
+ py.test.raises(ZeroDivisionError, rbigint.fromlong(x).int_divmod, 0)
+
# testing Karatsuba stuff
def test__v_iadd(self):
f1 = bigint([lobj.MASK] * 10, 1)
@@ -1067,8 +1155,14 @@
except Exception as e:
pytest.raises(type(e), f1.pow, f2, f3)
else:
- v = f1.pow(f2, f3)
- assert v.tolong() == res
+ v1 = f1.pow(f2, f3)
+ try:
+ v2 = f1.int_pow(f2.toint(), f3)
+ except OverflowError:
+ pass
+ else:
+ assert v2.tolong() == res
+ assert v1.tolong() == res
@given(biglongs, biglongs)
@example(510439143470502793407446782273075179618477362188870662225920,
@@ -1088,6 +1182,18 @@
a, b = f1.divmod(f2)
assert (a.tolong(), b.tolong()) == res
+ @given(biglongs, ints)
+ def test_int_divmod(self, x, iy):
+ f1 = rbigint.fromlong(x)
+ try:
+ res = divmod(x, iy)
+ except Exception as e:
+ pytest.raises(type(e), f1.int_divmod, iy)
+ else:
+ print x, iy
+ a, b = f1.int_divmod(iy)
+ assert (a.tolong(), b.tolong()) == res
+
@given(longs)
def test_hash(self, x):
# hash of large integers: should be equal to the hash of the
@@ -1118,10 +1224,34 @@
assert ra.truediv(rb) == a / b
@given(longs, longs)
- def test_bitwise(self, x, y):
+ def test_bitwise_and_mul(self, x, y):
lx = rbigint.fromlong(x)
ly = rbigint.fromlong(y)
- for mod in "xor and_ or_".split():
- res1 = getattr(lx, mod)(ly).tolong()
+ for mod in "xor and_ or_ mul".split():
+ res1a = getattr(lx, mod)(ly).tolong()
+ res1b = getattr(ly, mod)(lx).tolong()
+ res2 = getattr(operator, mod)(x, y)
+ assert res1a == res2
+
+ @given(longs, ints)
+ def test_int_bitwise_and_mul(self, x, y):
+ lx = rbigint.fromlong(x)
+ for mod in "xor and_ or_ mul".split():
+ res1 = getattr(lx, 'int_' + mod)(y).tolong()
res2 = getattr(operator, mod)(x, y)
assert res1 == res2
+
+ @given(longs, ints)
+ def test_int_comparison(self, x, y):
+ lx = rbigint.fromlong(x)
+ assert lx.int_lt(y) == (x < y)
+ assert lx.int_eq(y) == (x == y)
+ assert lx.int_le(y) == (x <= y)
+
+ @given(longs, longs)
+ def test_int_comparison(self, x, y):
+ lx = rbigint.fromlong(x)
+ ly = rbigint.fromlong(y)
+ assert lx.lt(ly) == (x < y)
+ assert lx.eq(ly) == (x == y)
+ assert lx.le(ly) == (x <= y)
diff --git a/rpython/rtyper/lltypesystem/ll2ctypes.py b/rpython/rtyper/lltypesystem/ll2ctypes.py
--- a/rpython/rtyper/lltypesystem/ll2ctypes.py
+++ b/rpython/rtyper/lltypesystem/ll2ctypes.py
@@ -175,7 +175,16 @@
if res >= (1 << 127):
res -= 1 << 128
return res
+ class c_uint128(ctypes.Array): # based on 2 ulongs
+ _type_ = ctypes.c_uint64
+ _length_ = 2
+ @property
+ def value(self):
+ res = self[0] | (self[1] << 64)
+ return res
+
_ctypes_cache[rffi.__INT128_T] = c_int128
+ _ctypes_cache[rffi.__UINT128_T] = c_uint128
# for unicode strings, do not use ctypes.c_wchar because ctypes
# automatically converts arrays into unicode strings.
diff --git a/rpython/rtyper/lltypesystem/lloperation.py b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -324,6 +324,26 @@
'lllong_rshift': LLOp(canfold=True), # args (r_longlonglong, int)
'lllong_xor': LLOp(canfold=True),
+ 'ulllong_is_true': LLOp(canfold=True),
+ 'ulllong_invert': LLOp(canfold=True),
+
+ 'ulllong_add': LLOp(canfold=True),
+ 'ulllong_sub': LLOp(canfold=True),
+ 'ulllong_mul': LLOp(canfold=True),
+ 'ulllong_floordiv': LLOp(canfold=True),
+ 'ulllong_mod': LLOp(canfold=True),
+ 'ulllong_lt': LLOp(canfold=True),
+ 'ulllong_le': LLOp(canfold=True),
+ 'ulllong_eq': LLOp(canfold=True),
+ 'ulllong_ne': LLOp(canfold=True),
+ 'ulllong_gt': LLOp(canfold=True),
+ 'ulllong_ge': LLOp(canfold=True),
+ 'ulllong_and': LLOp(canfold=True),
+ 'ulllong_or': LLOp(canfold=True),
+ 'ulllong_lshift': LLOp(canfold=True), # args (r_ulonglonglong, int)
+ 'ulllong_rshift': LLOp(canfold=True), # args (r_ulonglonglong, int)
+ 'ulllong_xor': LLOp(canfold=True),
+
'cast_primitive': LLOp(canfold=True),
'cast_bool_to_int': LLOp(canfold=True),
'cast_bool_to_uint': LLOp(canfold=True),
diff --git a/rpython/rtyper/lltypesystem/lltype.py b/rpython/rtyper/lltypesystem/lltype.py
--- a/rpython/rtyper/lltypesystem/lltype.py
+++ b/rpython/rtyper/lltypesystem/lltype.py
@@ -8,7 +8,7 @@
from rpython.rlib.rarithmetic import (
base_int, intmask, is_emulated_long, is_valid_int, longlonglongmask,
longlongmask, maxint, normalizedinttype, r_int, r_longfloat, r_longlong,
- r_longlonglong, r_singlefloat, r_uint, r_ulonglong)
+ r_longlonglong, r_singlefloat, r_uint, r_ulonglong, r_ulonglonglong)
from rpython.rtyper.extregistry import ExtRegistryEntry
from rpython.tool import leakfinder
from rpython.tool.identity_dict import identity_dict
@@ -676,6 +676,7 @@
_numbertypes[r_int] = _numbertypes[int]
_numbertypes[r_longlonglong] = Number("SignedLongLongLong", r_longlonglong,
longlonglongmask)
+
More information about the pypy-commit
mailing list