[pypy-commit] pypy decimal-libmpdec: Implement Decimal.__hash__, and improve comparison with float.

amauryfa noreply at buildbot.pypy.org
Wed May 21 00:23:56 CEST 2014


Author: Amaury Forgeot d'Arc <amauryfa at gmail.com>
Branch: decimal-libmpdec
Changeset: r71624:7d4a82d375d2
Date: 2014-05-21 00:22 +0200
http://bitbucket.org/pypy/pypy/changeset/7d4a82d375d2/

Log:	Implement Decimal.__hash__, and improve comparison with float.

diff --git a/pypy/module/_decimal/interp_decimal.py b/pypy/module/_decimal/interp_decimal.py
--- a/pypy/module/_decimal/interp_decimal.py
+++ b/pypy/module/_decimal/interp_decimal.py
@@ -8,8 +8,17 @@
 from pypy.interpreter.typedef import (TypeDef, GetSetProperty, descr_get_dict,
     descr_set_dict, descr_del_dict)
 from pypy.objspace.std import unicodeobject
+from pypy.objspace.std.floatobject import HASH_MODULUS, HASH_INF, HASH_NAN
 from pypy.module._decimal import interp_context
 
+if HASH_MODULUS == 2**31 - 1:
+    INVERSE_10_MODULUS = 1503238553
+elif HASH_MODULUS == 2**61 - 1:
+    INVERSE_10_MODULUS = 2075258708292324556
+else:
+    raise NotImplementedError('Unsupported HASH_MODULUS')
+assert (INVERSE_10_MODULUS * 10) % HASH_MODULUS == 1
+
 
 IEEE_CONTEXT_MAX_BITS = rmpdec.MPD_IEEE_CONTEXT_MAX_BITS
 MAX_PREC = rmpdec.MPD_MAX_PREC
@@ -82,6 +91,78 @@
             rmpdec.mpd_free(cp)
         return space.wrap("Decimal('%s')" % result)
 
+    def descr_hash(self, space):
+        if rmpdec.mpd_isspecial(self.mpd):
+            if rmpdec.mpd_issnan(self.mpd):
+                raise oefmt(space.w_TypeError,
+                            "cannot hash a signaling NaN value")
+            elif rmpdec.mpd_isnan(self.mpd):
+                return space.wrap(HASH_NAN)
+            elif rmpdec.mpd_isnegative(self.mpd):
+                return space.wrap(-HASH_INF)
+            else:
+                return space.wrap(HASH_INF)
+
+        with lltype.scoped_alloc(rffi.CArrayPtr(rffi.UINT).TO, 1,
+                                 zero=True) as status_ptr:
+            with lltype.scoped_alloc(rmpdec.MPD_CONTEXT_PTR.TO) as ctx:
+                rmpdec.mpd_maxcontext(ctx)
+
+                # XXX cache these
+                w_p = W_Decimal.allocate(space)
+                rmpdec.mpd_qset_ssize(w_p.mpd, HASH_MODULUS,
+                                      ctx, status_ptr)
+                w_ten = W_Decimal.allocate(space)
+                rmpdec.mpd_qset_ssize(w_ten.mpd, 10,
+                                      ctx, status_ptr)
+                w_inv10_p = W_Decimal.allocate(space)
+                rmpdec.mpd_qset_ssize(w_inv10_p.mpd, INVERSE_10_MODULUS,
+                                      ctx, status_ptr)
+
+
+                w_exp_hash = W_Decimal.allocate(space)
+                w_tmp = W_Decimal.allocate(space)
+                exp = self.mpd.c_exp
+                if exp >= 0:
+                    # 10**exp(v) % p
+                    rmpdec.mpd_qsset_ssize(w_tmp.mpd, exp, ctx, status_ptr)
+                    rmpdec.mpd_qpowmod(
+                        w_exp_hash.mpd, w_ten.mpd, w_tmp.mpd, w_p.mpd,
+                        ctx, status_ptr)
+                else:
+                    # inv10_p**(-exp(v)) % p
+                    rmpdec.mpd_qsset_ssize(w_tmp.mpd, -exp, ctx, status_ptr)
+                    rmpdec.mpd_qpowmod(
+                        w_exp_hash.mpd, w_inv10_p.mpd, w_tmp.mpd, w_p.mpd,
+                        ctx, status_ptr)
+                # hash = (int(v) * exp_hash) % p
+                rmpdec.mpd_qcopy(w_tmp.mpd, self.mpd, status_ptr)
+                w_tmp.mpd.c_exp = 0
+                rmpdec.mpd_set_positive(w_tmp.mpd)
+
+                ctx.c_prec = rmpdec.MPD_MAX_PREC + 21
+                ctx.c_emax = rmpdec.MPD_MAX_EMAX + 21
+                ctx.c_emin = rmpdec.MPD_MIN_EMIN - 21
+
+                rmpdec.mpd_qmul(w_tmp.mpd, w_tmp.mpd, w_exp_hash.mpd,
+                                ctx, status_ptr)
+                rmpdec.mpd_qrem(w_tmp.mpd, w_tmp.mpd, w_p.mpd,
+                                ctx, status_ptr)
+
+                result = rmpdec.mpd_qget_ssize(w_tmp.mpd, status_ptr);
+                if rmpdec.mpd_isnegative(self.mpd):
+                    result = -result
+                if result == -1:
+                    result = -2
+            status = rffi.cast(lltype.Signed, status_ptr[0])
+        if status:
+            if status & rmpdec.MPD_Malloc_error:
+                raise OperationError(space.w_MemoryError, space.w_None)
+            else:
+                raise OperationError(space.w_SystemError, space.wrap(
+                        "Decimal.__hash__ internal error; please report"))
+        return space.wrap(result)
+
     def descr_bool(self, space):
         return space.wrap(not rmpdec.mpd_iszero(self.mpd))
 
@@ -166,16 +247,19 @@
 
     def compare(self, space, w_other, op):
         context = interp_context.getcontext(space)
-        w_err, w_other = convert_op(space, context, w_other)
+        w_err, w_self, w_other = convert_binop_cmp(
+            space, context, op, self, w_other)
         if w_err:
             return w_err
-        with lltype.scoped_alloc(rffi.CArrayPtr(rffi.UINT).TO, 1) as status_ptr:
-            r = rmpdec.mpd_qcmp(self.mpd, w_other.mpd, status_ptr)
+        with lltype.scoped_alloc(rffi.CArrayPtr(rffi.UINT).TO, 1,
+                                 zero=True) as status_ptr:
+            r = rmpdec.mpd_qcmp(w_self.mpd, w_other.mpd, status_ptr)
 
             if r > 0xFFFF:
                 # sNaNs or op={le,ge,lt,gt} always signal.
-                if (rmpdec.mpd_issnan(self.mpd) or rmpdec.mpd_issnan(w_other.mpd)
-                    or (op not in ('eq', 'ne'))):
+                if (rmpdec.mpd_issnan(w_self.mpd) or
+                    rmpdec.mpd_issnan(w_other.mpd) or
+                    op not in ('eq', 'ne')):
                     status = rffi.cast(lltype.Signed, status_ptr[0])
                     context.addstatus(space, status)
                 # qNaN comparison with op={eq,ne} or comparison with
@@ -436,6 +520,41 @@
                     space.type(w_y))
     return w_a, w_b
 
+def convert_binop_cmp(space, context, op, w_v, w_w):
+    if isinstance(w_w, W_Decimal):
+        return None, w_v, w_w
+    elif space.isinstance_w(w_w, space.w_int):
+        value = space.bigint_w(w_w)
+        w_w = decimal_from_bigint(space, None, value, context,
+                                  exact=True)
+        return None, w_v, w_w
+    elif space.isinstance_w(w_w, space.w_float):
+        if op not in ('eq', 'ne'):
+            # Add status, and maybe raise
+            context.addstatus(space, rmpdec.MPD_Float_operation)
+        else:
+            # Add status, but don't raise
+            new_status = (rmpdec.MPD_Float_operation |
+                          rffi.cast(lltype.Signed, context.ctx.c_status))
+            context.ctx.c_status = rffi.cast(rffi.UINT, new_status)
+        w_w = decimal_from_float(space, None, w_w, context, exact=True)
+    elif space.isinstance_w(w_w, space.w_complex):
+        if op not in ('eq', 'ne'):
+            return space.w_NotImplemented, None, None
+        real, imag = space.unpackcomplex(w_w)
+        if imag == 0.0:
+            # Add status, but don't raise
+            new_status = (rmpdec.MPD_Float_operation |
+                          rffi.cast(lltype.Signed, context.ctx.c_status))
+            context.ctx.c_status = rffi.cast(rffi.UINT, new_status)
+            w_w = decimal_from_float(space, None, w_w, context, exact=True)
+        else:
+            return space.w_NotImplemented, None, None
+    else:
+        return space.w_NotImplemented, None, None
+    return None, w_v, w_w
+
+
 def binary_number_method(space, mpd_func, w_x, w_y):
     context = interp_context.getcontext(space)
 
@@ -684,6 +803,7 @@
     __new__ = interp2app(descr_new_decimal),
     __str__ = interp2app(W_Decimal.descr_str),
     __repr__ = interp2app(W_Decimal.descr_repr),
+    __hash__ = interp2app(W_Decimal.descr_hash),
     __bool__ = interp2app(W_Decimal.descr_bool),
     __float__ = interp2app(W_Decimal.descr_float),
     __int__ = interp2app(W_Decimal.descr_int),
diff --git a/pypy/module/_decimal/test/test_decimal.py b/pypy/module/_decimal/test/test_decimal.py
--- a/pypy/module/_decimal/test/test_decimal.py
+++ b/pypy/module/_decimal/test/test_decimal.py
@@ -644,6 +644,180 @@
         assert -Decimal(45) == Decimal(-45)
         assert abs(Decimal(45)) == abs(Decimal(-45))
 
+    def test_hash_method(self):
+
+        Decimal = self.decimal.Decimal
+        localcontext = self.decimal.localcontext
+
+        def hashit(d):
+            a = hash(d)
+            b = d.__hash__()
+            assert a == b
+            return a
+
+        #just that it's hashable
+        hashit(Decimal(23))
+        hashit(Decimal('Infinity'))
+        hashit(Decimal('-Infinity'))
+        hashit(Decimal('nan123'))
+        hashit(Decimal('-NaN'))
+
+        test_values = [Decimal(sign*(2**m + n))
+                       for m in [0, 14, 15, 16, 17, 30, 31,
+                                 32, 33, 61, 62, 63, 64, 65, 66]
+                       for n in range(-10, 10)
+                       for sign in [-1, 1]]
+        test_values.extend([
+                Decimal("-1"), # ==> -2
+                Decimal("-0"), # zeros
+                Decimal("0.00"),
+                Decimal("-0.000"),
+                Decimal("0E10"),
+                Decimal("-0E12"),
+                Decimal("10.0"), # negative exponent
+                Decimal("-23.00000"),
+                Decimal("1230E100"), # positive exponent
+                Decimal("-4.5678E50"),
+                # a value for which hash(n) != hash(n % (2**64-1))
+                # in Python pre-2.6
+                Decimal(2**64 + 2**32 - 1),
+                # selection of values which fail with the old (before
+                # version 2.6) long.__hash__
+                Decimal("1.634E100"),
+                Decimal("90.697E100"),
+                Decimal("188.83E100"),
+                Decimal("1652.9E100"),
+                Decimal("56531E100"),
+                ])
+
+        # check that hash(d) == hash(int(d)) for integral values
+        for value in test_values:
+            assert hashit(value) == hashit(int(value))
+
+        #the same hash that to an int
+        assert hashit(Decimal(23)) == hashit(23)
+        raises(TypeError, hash, Decimal('sNaN'))
+        assert hashit(Decimal('Inf'))
+        assert hashit(Decimal('-Inf'))
+
+        # check that the hashes of a Decimal float match when they
+        # represent exactly the same values
+        test_strings = ['inf', '-Inf', '0.0', '-.0e1',
+                        '34.0', '2.5', '112390.625', '-0.515625']
+        for s in test_strings:
+            f = float(s)
+            d = Decimal(s)
+            assert hashit(f) == hashit(d)
+
+        with localcontext() as c:
+            # check that the value of the hash doesn't depend on the
+            # current context (issue #1757)
+            x = Decimal("123456789.1")
+
+            c.prec = 6
+            h1 = hashit(x)
+            c.prec = 10
+            h2 = hashit(x)
+            c.prec = 16
+            h3 = hashit(x)
+
+            assert h1 == h2 == h3
+
+            c.prec = 10000
+            x = 1100 ** 1248
+            assert hashit(Decimal(x)) == hashit(x)
+
+    def test_float_comparison(self):
+        Decimal = self.decimal.Decimal
+        Context = self.decimal.Context
+        FloatOperation = self.decimal.FloatOperation
+        localcontext = self.decimal.localcontext
+
+        def assert_attr(a, b, attr, context, signal=None):
+            context.clear_flags()
+            f = getattr(a, attr)
+            if signal == FloatOperation:
+                raises(signal, f, b)
+            else:
+                assert f(b) is True
+            assert context.flags[FloatOperation]
+
+        small_d = Decimal('0.25')
+        big_d = Decimal('3.0')
+        small_f = 0.25
+        big_f = 3.0
+
+        zero_d = Decimal('0.0')
+        neg_zero_d = Decimal('-0.0')
+        zero_f = 0.0
+        neg_zero_f = -0.0
+
+        inf_d = Decimal('Infinity')
+        neg_inf_d = Decimal('-Infinity')
+        inf_f = float('inf')
+        neg_inf_f = float('-inf')
+
+        def doit(c, signal=None):
+            # Order
+            for attr in '__lt__', '__le__':
+                assert_attr(small_d, big_f, attr, c, signal)
+
+            for attr in '__gt__', '__ge__':
+                assert_attr(big_d, small_f, attr, c, signal)
+
+            # Equality
+            assert_attr(small_d, small_f, '__eq__', c, None)
+
+            assert_attr(neg_zero_d, neg_zero_f, '__eq__', c, None)
+            assert_attr(neg_zero_d, zero_f, '__eq__', c, None)
+
+            assert_attr(zero_d, neg_zero_f, '__eq__', c, None)
+            assert_attr(zero_d, zero_f, '__eq__', c, None)
+
+            assert_attr(neg_inf_d, neg_inf_f, '__eq__', c, None)
+            assert_attr(inf_d, inf_f, '__eq__', c, None)
+
+            # Inequality
+            assert_attr(small_d, big_f, '__ne__', c, None)
+
+            assert_attr(Decimal('0.1'), 0.1, '__ne__', c, None)
+
+            assert_attr(neg_inf_d, inf_f, '__ne__', c, None)
+            assert_attr(inf_d, neg_inf_f, '__ne__', c, None)
+
+            assert_attr(Decimal('NaN'), float('nan'), '__ne__', c, None)
+
+        def test_containers(c, signal=None):
+            c.clear_flags()
+            s = set([100.0, Decimal('100.0')])
+            assert len(s) == 1
+            assert c.flags[FloatOperation]
+
+            c.clear_flags()
+            if signal:
+                raises(signal, sorted, [1.0, Decimal('10.0')])
+            else:
+                s = sorted([10.0, Decimal('10.0')])
+            assert c.flags[FloatOperation]
+
+            c.clear_flags()
+            b = 10.0 in [Decimal('10.0'), 1.0]
+            assert c.flags[FloatOperation]
+
+            c.clear_flags()
+            b = 10.0 in {Decimal('10.0'):'a', 1.0:'b'}
+            assert c.flags[FloatOperation]
+
+        nc = Context()
+        with localcontext(nc) as c:
+            assert not c.traps[FloatOperation]
+            doit(c, signal=None)
+            test_containers(c, signal=None)
+
+            c.traps[FloatOperation] = True
+            doit(c, signal=FloatOperation)
+            test_containers(c, signal=FloatOperation)
+
     def test_nan_comparisons(self):
         import operator
         # comparisons involving signaling nans signal InvalidOperation
diff --git a/rpython/rlib/rmpdec.py b/rpython/rlib/rmpdec.py
--- a/rpython/rlib/rmpdec.py
+++ b/rpython/rlib/rmpdec.py
@@ -37,9 +37,10 @@
                            ],
     export_symbols=[
         "mpd_qset_ssize", "mpd_qset_uint", "mpd_qset_string",
+        "mpd_qsset_ssize", "mpd_qget_ssize",
         "mpd_qcopy", "mpd_qncopy", "mpd_setspecial", "mpd_clear_flags",
         "mpd_qimport_u32", "mpd_qexport_u32", "mpd_qexport_u16",
-        "mpd_set_sign", "mpd_sign", "mpd_qfinalize",
+        "mpd_set_sign", "mpd_set_positive", "mpd_sign", "mpd_qfinalize",
         "mpd_getprec", "mpd_getemin",  "mpd_getemax", "mpd_getround", "mpd_getclamp",
         "mpd_qsetprec", "mpd_qsetemin",  "mpd_qsetemax", "mpd_qsetround", "mpd_qsetclamp",
         "mpd_maxcontext",
@@ -151,6 +152,10 @@
     'mpd_qset_uint', [MPD_PTR, rffi.UINT, MPD_CONTEXT_PTR, rffi.UINTP], lltype.Void)
 mpd_qset_string = external(
     'mpd_qset_string', [MPD_PTR, rffi.CCHARP, MPD_CONTEXT_PTR, rffi.UINTP], lltype.Void)
+mpd_qsset_ssize = external(
+    'mpd_qsset_ssize', [MPD_PTR, rffi.SSIZE_T, MPD_CONTEXT_PTR, rffi.UINTP], lltype.Void)
+mpd_qget_ssize = external(
+    'mpd_qget_ssize', [MPD_PTR, rffi.UINTP], rffi.SSIZE_T)
 mpd_qimport_u32 = external(
     'mpd_qimport_u32', [
         MPD_PTR, rffi.UINTP, rffi.SIZE_T,
@@ -171,6 +176,8 @@
     'mpd_setspecial', [MPD_PTR, rffi.UCHAR, rffi.UCHAR], lltype.Void)
 mpd_set_sign = external(
     'mpd_set_sign', [MPD_PTR, rffi.UCHAR], lltype.Void)
+mpd_set_positive = external(
+    'mpd_set_positive', [MPD_PTR], lltype.Void)
 mpd_clear_flags = external(
     'mpd_clear_flags', [MPD_PTR], lltype.Void)
 mpd_sign = external(


More information about the pypy-commit mailing list