[pypy-commit] pypy math-improvements: Test + fix for bugs found by Armin Rigo

stian pypy.commits at gmail.com
Wed Oct 18 16:58:19 EDT 2017


Author: stian
Branch: math-improvements
Changeset: r92792:06129c0b5e0e
Date: 2017-10-18 22:13 +0200
http://bitbucket.org/pypy/pypy/changeset/06129c0b5e0e/

Log:	Test + fix for bugs found by Armin Rigo

diff --git a/pypy/objspace/std/test/test_longobject.py b/pypy/objspace/std/test/test_longobject.py
--- a/pypy/objspace/std/test/test_longobject.py
+++ b/pypy/objspace/std/test/test_longobject.py
@@ -2,7 +2,6 @@
 from pypy.objspace.std import longobject as lobj
 from rpython.rlib.rbigint import rbigint
 
-
 class TestW_LongObject:
     def test_bigint_w(self):
         space = self.space
@@ -70,6 +69,23 @@
         a = x // 10000000L
         assert a == 3L
 
+    def test_int_floordiv(self):
+        import sys
+
+        x = 3000L
+        a = x // 1000
+        assert a == 3L
+
+        x = 3000L
+        a = x // -1000
+        assert a == -3L
+
+        x = 3000L
+        raises(ZeroDivisionError, "x // 0")
+
+        n = sys.maxint+1
+        assert n / int(-n) == -1L
+
     def test_numerator_denominator(self):
         assert (1L).numerator == 1L
         assert (1L).denominator == 1L
@@ -208,6 +224,11 @@
         check_division(x, y)
         raises(ZeroDivisionError, "x // 0L")
 
+    def test_int_divmod(self):
+        q, r = divmod(100L, 11)
+        assert q == 9L
+        assert r == 1L
+        
     def test_format(self):
         assert repr(12345678901234567890) == '12345678901234567890L'
         assert str(12345678901234567890) == '12345678901234567890'
@@ -386,3 +407,4 @@
         n = "a" * size
         expected = (2 << (size * 4)) // 3
         assert long(n, 16) == expected
+
diff --git a/rpython/rlib/rbigint.py b/rpython/rlib/rbigint.py
--- a/rpython/rlib/rbigint.py
+++ b/rpython/rlib/rbigint.py
@@ -145,6 +145,7 @@
         make_sure_not_resized(digits)
         self._digits = digits
         assert size >= 0
+        
         self.size = size or len(digits)
         self.sign = sign
 
@@ -183,7 +184,9 @@
     setdigit._always_inline_ = True
 
     def numdigits(self):
-        return self.size
+        w = self.size
+        assert w > 0
+        return w
     numdigits._always_inline_ = True
 
     @staticmethod
@@ -510,7 +513,6 @@
     @jit.elidable
     def int_eq(self, other):
         """ eq with int """
-        
         if not int_in_valid_range(other):
             # Fallback to Long. 
             return self.eq(rbigint.fromint(other))
@@ -657,7 +659,7 @@
         if other.sign == 0:
             return self
         elif self.sign == 0:
-            return rbigint(other._digits[:other.size], -other.sign, other.size)
+            return rbigint(other._digits[:other.numdigits()], -other.sign, other.numdigits())
         elif self.sign == other.sign:
             result = _x_sub(self, other)
         else:
@@ -698,7 +700,7 @@
             if a._digits[0] == NULLDIGIT:
                 return NULLRBIGINT
             elif a._digits[0] == ONEDIGIT:
-                return rbigint(b._digits[:b.size], a.sign * b.sign, b.size)
+                return rbigint(b._digits[:b.numdigits()], a.sign * b.sign, b.numdigits())
             elif bsize == 1:
                 res = b.widedigit(0) * a.widedigit(0)
                 carry = res >> SHIFT
@@ -740,7 +742,7 @@
         bsign = -1 if b < 0 else 1
 
         if digit == 1:
-            return rbigint(self._digits[:self.size], self.sign * bsign, asize)
+            return rbigint(self._digits[:self.numdigits()], self.sign * bsign, asize)
         elif asize == 1:
             res = self.widedigit(0) * digit
             carry = res >> SHIFT
@@ -767,7 +769,7 @@
         if self.sign == 1 and other.numdigits() == 1 and other.sign == 1:
             digit = other.digit(0)
             if digit == 1:
-                return rbigint(self._digits[:self.size], 1, self.size)
+                return rbigint(self._digits[:self.numdigits()], 1, self.numdigits())
             elif digit and digit & (digit - 1) == 0:
                 return self.rshift(ptwotable[digit])
 
@@ -781,7 +783,37 @@
 
     def div(self, other):
         return self.floordiv(other)
-
+        
+    @jit.elidable
+    def int_floordiv(self, b):
+        if not int_in_valid_range(b):
+            # Fallback to long.
+            return self.floordiv(rbigint.fromint(b))
+        
+        if b == 0:
+            raise ZeroDivisionError("long division by zero")
+
+        digit = abs(b)
+        assert digit > 0
+
+        if self.sign == 1 and b > 0:
+            if digit == 1:
+                return self
+            elif digit & (digit - 1) == 0:
+                return self.rshift(ptwotable[digit])
+            
+        div, mod = _divrem1(self, digit)
+
+        if mod != 0 and self.sign * (-1 if b < 0 else 1) == -1:
+            if div.sign == 0:
+                return ONENEGATIVERBIGINT
+            div = div.int_add(1)
+        div.sign = self.sign * (-1 if b < 0 else 1)
+        return div
+
+    def int_div(self, other):
+        return self.int_floordiv(other)
+        
     @jit.elidable
     def mod(self, other):
         if self.sign == 0:
@@ -888,6 +920,30 @@
         return div, mod
 
     @jit.elidable
+    def int_divmod(v, w):
+        """ Divmod with int """
+
+        if w == 0:
+            raise ZeroDivisionError("long division or modulo by zero")
+
+        wsign = (-1 if w < 0 else 1)
+        if not int_in_valid_range(w) or v.sign != wsign:
+            # Divrem1 doesn't deal with the sign difference. Instead of having yet another copy,
+            # Just fallback.
+            return v.divmod(rbigint.fromint(w))
+
+        digit = abs(w)
+        assert digit > 0
+
+        div, mod = _divrem1(v, digit)
+        mod = rbigint.fromint(mod * wsign)
+        
+        #mod.sign = wsign
+        div.sign = v.sign * wsign
+
+        return div, mod
+        
+    @jit.elidable
     def pow(a, b, c=None):
         negativeOutput = False  # if x<0 return negative output
 
@@ -1029,13 +1085,13 @@
 
     @jit.elidable
     def neg(self):
-        return rbigint(self._digits, -self.sign, self.size)
+        return rbigint(self._digits, -self.sign, self.numdigits())
 
     @jit.elidable
     def abs(self):
         if self.sign != -1:
             return self
-        return rbigint(self._digits, 1, self.size)
+        return rbigint(self._digits, 1, self.numdigits())
 
     @jit.elidable
     def invert(self): #Implement ~x as -(x + 1)
@@ -1061,7 +1117,7 @@
             # So we can avoid problems with eq, AND avoid the need for normalize.
             if self.sign == 0:
                 return self
-            return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign, self.size + wordshift)
+            return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign, self.numdigits() + wordshift)
 
         oldsize = self.numdigits()
         newsize = oldsize + wordshift + 1
@@ -1276,7 +1332,7 @@
 
     def __repr__(self):
         return "<rbigint digits=%s, sign=%s, size=%d, len=%d, %s>" % (self._digits,
-                                            self.sign, self.size, len(self._digits),
+                                            self.sign, self.numdigits(), len(self._digits),
                                             self.str())
 
 ONERBIGINT = rbigint([ONEDIGIT], 1, 1)


More information about the pypy-commit mailing list