[pypy-commit] pypy py3.5-fix-decimal-module-name: (stevie, robert-zaremba) FIX: test_pickle (test.test_decimal.PyPythonAPItests)

robert-zaremba pypy.commits at gmail.com
Tue Feb 28 09:49:43 EST 2017


Author: Robert Zaremba <robert.zaremba at scale-it.pl>
Branch: py3.5-fix-decimal-module-name
Changeset: r90424:793b49cacddd
Date: 2017-02-28 15:32 +0100
http://bitbucket.org/pypy/pypy/changeset/793b49cacddd/

Log:	(stevie, robert-zaremba) FIX: test_pickle
	(test.test_decimal.PyPythonAPItests)

	Fixes: http://buildbot.pypy.org/summary/longrepr?testname=unmodified
	&builder=pypy-c-jit-linux-x86-64&build=4406&mod=lib-
	python.3.test.test_decimal

	We removed the __module__ hack from the classes in
	lib_pypy/_decimal.py and added module __name__ pointing to the right
	module name.

diff --git a/extra_tests/support.py b/extra_tests/support.py
new file mode 100644
--- /dev/null
+++ b/extra_tests/support.py
@@ -0,0 +1,98 @@
+import contextlib
+import importlib
+import sys
+import warnings
+
+
+ at contextlib.contextmanager
+def _ignore_deprecated_imports(ignore=True):
+    """Context manager to suppress package and module deprecation
+    warnings when importing them.
+
+    If ignore is False, this context manager has no effect.
+    """
+    if ignore:
+        with warnings.catch_warnings():
+            warnings.filterwarnings("ignore", ".+ (module|package)",
+                                    DeprecationWarning)
+            yield
+    else:
+        yield
+
+
+def _save_and_remove_module(name, orig_modules):
+    """Helper function to save and remove a module from sys.modules
+
+    Raise ImportError if the module can't be imported.
+    """
+    # try to import the module and raise an error if it can't be imported
+    if name not in sys.modules:
+        __import__(name)
+        del sys.modules[name]
+    for modname in list(sys.modules):
+        if modname == name or modname.startswith(name + '.'):
+            orig_modules[modname] = sys.modules[modname]
+            del sys.modules[modname]
+
+def _save_and_block_module(name, orig_modules):
+    """Helper function to save and block a module in sys.modules
+
+    Return True if the module was in sys.modules, False otherwise.
+    """
+    saved = True
+    try:
+        orig_modules[name] = sys.modules[name]
+    except KeyError:
+        saved = False
+    sys.modules[name] = None
+    return saved
+
+
+def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
+    """Import and return a module, deliberately bypassing sys.modules.
+
+    This function imports and returns a fresh copy of the named Python module
+    by removing the named module from sys.modules before doing the import.
+    Note that unlike reload, the original module is not affected by
+    this operation.
+
+    *fresh* is an iterable of additional module names that are also removed
+    from the sys.modules cache before doing the import.
+
+    *blocked* is an iterable of module names that are replaced with None
+    in the module cache during the import to ensure that attempts to import
+    them raise ImportError.
+
+    The named module and any modules named in the *fresh* and *blocked*
+    parameters are saved before starting the import and then reinserted into
+    sys.modules when the fresh import is complete.
+
+    Module and package deprecation messages are suppressed during this import
+    if *deprecated* is True.
+
+    This function will raise ImportError if the named module cannot be
+    imported.
+    """
+    # NOTE: test_heapq, test_json and test_warnings include extra sanity checks
+    # to make sure that this utility function is working as expected
+    with _ignore_deprecated_imports(deprecated):
+        # Keep track of modules saved for later restoration as well
+        # as those which just need a blocking entry removed
+        orig_modules = {}
+        names_to_remove = []
+        _save_and_remove_module(name, orig_modules)
+        try:
+            for fresh_name in fresh:
+                _save_and_remove_module(fresh_name, orig_modules)
+            for blocked_name in blocked:
+                if not _save_and_block_module(blocked_name, orig_modules):
+                    names_to_remove.append(blocked_name)
+            fresh_module = importlib.import_module(name)
+        except ImportError:
+            fresh_module = None
+        finally:
+            for orig_name, module in orig_modules.items():
+                sys.modules[orig_name] = module
+            for name_to_remove in names_to_remove:
+                del sys.modules[name_to_remove]
+        return fresh_module
diff --git a/extra_tests/test_decimal.py b/extra_tests/test_decimal.py
new file mode 100644
--- /dev/null
+++ b/extra_tests/test_decimal.py
@@ -0,0 +1,59 @@
+import pickle
+import sys
+
+from support import import_fresh_module
+
+C = import_fresh_module('decimal', fresh=['_decimal'])
+P = import_fresh_module('decimal', blocked=['_decimal'])
+# import _decimal as C
+# import _pydecimal as P
+
+
+class TestPythonAPI:
+
+    def check_equal(self, val, proto):
+        d = C.Decimal(val)
+        p = pickle.dumps(d, proto)
+        assert d == pickle.loads(p)
+
+    def test_C(self):
+        sys.modules["decimal"] = C
+        import decimal
+        d = decimal.Decimal('1')
+        assert isinstance(d, C.Decimal)
+        assert isinstance(d, decimal.Decimal)
+        assert isinstance(d.as_tuple(), C.DecimalTuple)
+
+        assert d == C.Decimal('1')
+
+    def test_pickle(self):
+        v = '-3.123e81723'
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            sys.modules["decimal"] = C
+            self.check_equal('-3.141590000', proto)
+            self.check_equal(v, proto)
+
+            cd = C.Decimal(v)
+            pd = P.Decimal(v)
+            cdt = cd.as_tuple()
+            pdt = pd.as_tuple()
+            assert cdt.__module__ == pdt.__module__
+
+            p = pickle.dumps(cdt, proto)
+            r = pickle.loads(p)
+            assert isinstance(r, C.DecimalTuple)
+            assert cdt == r
+
+            sys.modules["decimal"] = C
+            p = pickle.dumps(cd, proto)
+            sys.modules["decimal"] = P
+            r = pickle.loads(p)
+            assert isinstance(r, P.Decimal)
+            assert r == pd
+
+            sys.modules["decimal"] = C
+            p = pickle.dumps(cdt, proto)
+            sys.modules["decimal"] = P
+            r = pickle.loads(p)
+            assert isinstance(r, P.DecimalTuple)
+            assert r == pdt
diff --git a/lib_pypy/_decimal.py b/lib_pypy/_decimal.py
--- a/lib_pypy/_decimal.py
+++ b/lib_pypy/_decimal.py
@@ -1,5 +1,9 @@
 # Implementation of the "decimal" module, based on libmpdec library.
 
+__xname__ = __name__    # sys.modules lookup (--without-threads)
+__name__ = 'decimal'    # For pickling
+
+
 import collections as _collections
 import math as _math
 import numbers as _numbers
@@ -23,15 +27,13 @@
 # Errors
 
 class DecimalException(ArithmeticError):
-    __module__ = 'decimal'
     def handle(self, context, *args):
         pass
 
 class Clamped(DecimalException):
-    __module__ = 'decimal'
+    pass
 
 class InvalidOperation(DecimalException):
-    __module__ = 'decimal'
     def handle(self, context, *args):
         if args:
             ans = _dec_from_triple(args[0]._sign, args[0]._int, 'n', True)
@@ -39,41 +41,35 @@
         return _NaN
 
 class ConversionSyntax(InvalidOperation):
-    __module__ = 'decimal'
     def handle(self, context, *args):
         return _NaN
 
 class DivisionByZero(DecimalException, ZeroDivisionError):
-    __module__ = 'decimal'
     def handle(self, context, sign, *args):
         return _SignedInfinity[sign]
 
 class DivisionImpossible(InvalidOperation):
-    __module__ = 'decimal'
     def handle(self, context, *args):
         return _NaN
 
 class DivisionUndefined(InvalidOperation, ZeroDivisionError):
-    __module__ = 'decimal'
     def handle(self, context, *args):
         return _NaN
 
 class Inexact(DecimalException):
-    __module__ = 'decimal'
+    pass
 
 class InvalidContext(InvalidOperation):
-    __module__ = 'decimal'
     def handle(self, context, *args):
         return _NaN
 
 class Rounded(DecimalException):
-    __module__ = 'decimal'
+    pass
 
 class Subnormal(DecimalException):
-    __module__ = 'decimal'
+    pass
 
 class Overflow(Inexact, Rounded):
-    __module__ = 'decimal'
     def handle(self, context, sign, *args):
         if context.rounding in (ROUND_HALF_UP, ROUND_HALF_EVEN,
                                 ROUND_HALF_DOWN, ROUND_UP):
@@ -90,10 +86,10 @@
                              context.Emax-context.prec+1)
 
 class Underflow(Inexact, Rounded, Subnormal):
-    __module__ = 'decimal'
+    pass
 
 class FloatOperation(DecimalException, TypeError):
-    __module__ = 'decimal'
+    pass
 
 
 __version__ = "1.70"
@@ -107,7 +103,7 @@
 
 def getcontext():
     """Returns this thread's context.
-    
+
     If this thread does not yet have a context, returns
     a new context and sets this thread's context.
     New contexts are copies of DefaultContext.
@@ -173,8 +169,6 @@
 _DEC_MINALLOC = 4
 
 class Decimal(object):
-    __module__ = 'decimal'
-
     __slots__ = ('_mpd', '_data')
 
     def __new__(cls, value="0", context=None):
@@ -326,7 +320,7 @@
             builder.append(b'E')
             builder.append(str(exponent).encode())
 
-        return cls._from_bytes(b''.join(builder), context, exact=exact) 
+        return cls._from_bytes(b''.join(builder), context, exact=exact)
 
     @classmethod
     def from_float(cls, value):
@@ -481,7 +475,7 @@
             numerator = Decimal._from_int(other.numerator, context)
             if not _mpdec.mpd_isspecial(self._mpd):
                 # multiplied = self * other.denominator
-                # 
+                #
                 # Prevent Overflow in the following multiplication.
                 # The result of the multiplication is
                 # only used in mpd_qcmp, which can handle values that
@@ -542,7 +536,7 @@
         _mpdec.mpd_qset_ssize(p._mpd, self._PyHASH_MODULUS,
                               maxctx, status_ptr)
         ten = self._new_empty()
-        _mpdec.mpd_qset_ssize(ten._mpd, 10, 
+        _mpdec.mpd_qset_ssize(ten._mpd, 10,
                               maxctx, status_ptr)
         inv10_p = self._new_empty()
         _mpdec.mpd_qset_ssize(inv10_p._mpd, self._PyHASH_10INV,
@@ -755,7 +749,7 @@
     number_class = _make_unary_operation('number_class')
 
     to_eng_string = _make_unary_operation('to_eng_string')
-    
+
     def fma(self, other, third, context=None):
         context = _getcontext(context)
         return context.fma(self, other, third)
@@ -790,7 +784,7 @@
             result = int.from_bytes(s, 'little', signed=False)
         if _mpdec.mpd_isnegative(x._mpd) and not _mpdec.mpd_iszero(x._mpd):
             result = -result
-        return result        
+        return result
 
     def __int__(self):
         return self._to_int(_mpdec.MPD_ROUND_DOWN)
@@ -798,10 +792,10 @@
     __trunc__ = __int__
 
     def __floor__(self):
-        return self._to_int(_mpdec.MPD_ROUND_FLOOR)        
+        return self._to_int(_mpdec.MPD_ROUND_FLOOR)
 
     def __ceil__(self):
-        return self._to_int(_mpdec.MPD_ROUND_CEILING)        
+        return self._to_int(_mpdec.MPD_ROUND_CEILING)
 
     def to_integral(self, rounding=None, context=None):
         context = _getcontext(context)
@@ -817,7 +811,7 @@
         return result
 
     to_integral_value = to_integral
-        
+
     def to_integral_exact(self, rounding=None, context=None):
         context = _getcontext(context)
         workctx = context.copy()
@@ -886,7 +880,7 @@
         if _mpdec.mpd_isspecial(self._mpd):
             return 0
         return _mpdec.mpd_adjexp(self._mpd)
-    
+
     @property
     def real(self):
         return self
@@ -916,7 +910,7 @@
         fmt = specifier.encode('utf-8')
         context = getcontext()
 
-        replace_fillchar = False 
+        replace_fillchar = False
         if fmt and fmt[0] == 0:
             # NUL fill character: must be replaced with a valid UTF-8 char
             # before calling mpd_parse_fmt_str().
@@ -975,7 +969,7 @@
             result = result.replace(b'\xff', b'\0')
         return result.decode('utf-8')
 
-        
+
 # Register Decimal as a kind of Number (an abstract base class).
 # However, do not register it as Real (because Decimals are not
 # interoperable with floats).
@@ -988,7 +982,7 @@
 
 # Rounding
 _ROUNDINGS = {
-    'ROUND_DOWN': _mpdec.MPD_ROUND_DOWN, 
+    'ROUND_DOWN': _mpdec.MPD_ROUND_DOWN,
     'ROUND_HALF_UP': _mpdec.MPD_ROUND_HALF_UP,
     'ROUND_HALF_EVEN': _mpdec.MPD_ROUND_HALF_EVEN,
     'ROUND_CEILING': _mpdec.MPD_ROUND_CEILING,
@@ -1047,8 +1041,6 @@
     clamp -  If 1, change exponents if too high (Default 0)
     """
 
-    __module__ = 'decimal'
-
     __slots__ = ('_ctx', '_capitals')
 
     def __new__(cls, prec=None, rounding=None, Emin=None, Emax=None,
@@ -1068,7 +1060,7 @@
         ctx.round = _mpdec.MPD_ROUND_HALF_EVEN
         ctx.clamp = 0
         ctx.allcr = 1
-        
+
         self._capitals = 1
         return self
 
@@ -1291,7 +1283,7 @@
         if b is NotImplemented:
             return b, b
         return a, b
-        
+
     def _make_unary_method(name, mpd_func_name):
         mpd_func = getattr(_mpdec, mpd_func_name)
 
@@ -1570,7 +1562,7 @@
 
     def copy(self):
         return self._as_dict()
-        
+
     def __len__(self):
         return len(_SIGNALS)
 
@@ -1629,7 +1621,7 @@
     def __enter__(self):
         self.status_ptr = _ffi.new("uint32_t*")
         return self.context._ctx, self.status_ptr
-        
+
     def __exit__(self, *args):
         status = self.status_ptr[0]
         # May raise a DecimalException


More information about the pypy-commit mailing list