[pypy-commit] pypy py3.5-fix-decimal-module-name: (stevie, robert-zaremba) FIX: test_pickle (test.test_decimal.PyPythonAPItests)
robert-zaremba
pypy.commits at gmail.com
Tue Feb 28 09:49:43 EST 2017
Author: Robert Zaremba <robert.zaremba at scale-it.pl>
Branch: py3.5-fix-decimal-module-name
Changeset: r90424:793b49cacddd
Date: 2017-02-28 15:32 +0100
http://bitbucket.org/pypy/pypy/changeset/793b49cacddd/
Log: (stevie, robert-zaremba) FIX: test_pickle
(test.test_decimal.PyPythonAPItests)
Fixes: http://buildbot.pypy.org/summary/longrepr?testname=unmodified
&builder=pypy-c-jit-linux-x86-64&build=4406&mod=lib-
python.3.test.test_decimal
We removed the __module__ hack from the classes in
lib_pypy/_decimal.py and added module __name__ pointing to the right
module name.
diff --git a/extra_tests/support.py b/extra_tests/support.py
new file mode 100644
--- /dev/null
+++ b/extra_tests/support.py
@@ -0,0 +1,98 @@
+import contextlib
+import importlib
+import sys
+import warnings
+
+
+ at contextlib.contextmanager
+def _ignore_deprecated_imports(ignore=True):
+ """Context manager to suppress package and module deprecation
+ warnings when importing them.
+
+ If ignore is False, this context manager has no effect.
+ """
+ if ignore:
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", ".+ (module|package)",
+ DeprecationWarning)
+ yield
+ else:
+ yield
+
+
+def _save_and_remove_module(name, orig_modules):
+ """Helper function to save and remove a module from sys.modules
+
+ Raise ImportError if the module can't be imported.
+ """
+ # try to import the module and raise an error if it can't be imported
+ if name not in sys.modules:
+ __import__(name)
+ del sys.modules[name]
+ for modname in list(sys.modules):
+ if modname == name or modname.startswith(name + '.'):
+ orig_modules[modname] = sys.modules[modname]
+ del sys.modules[modname]
+
+def _save_and_block_module(name, orig_modules):
+ """Helper function to save and block a module in sys.modules
+
+ Return True if the module was in sys.modules, False otherwise.
+ """
+ saved = True
+ try:
+ orig_modules[name] = sys.modules[name]
+ except KeyError:
+ saved = False
+ sys.modules[name] = None
+ return saved
+
+
+def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
+ """Import and return a module, deliberately bypassing sys.modules.
+
+ This function imports and returns a fresh copy of the named Python module
+ by removing the named module from sys.modules before doing the import.
+ Note that unlike reload, the original module is not affected by
+ this operation.
+
+ *fresh* is an iterable of additional module names that are also removed
+ from the sys.modules cache before doing the import.
+
+ *blocked* is an iterable of module names that are replaced with None
+ in the module cache during the import to ensure that attempts to import
+ them raise ImportError.
+
+ The named module and any modules named in the *fresh* and *blocked*
+ parameters are saved before starting the import and then reinserted into
+ sys.modules when the fresh import is complete.
+
+ Module and package deprecation messages are suppressed during this import
+ if *deprecated* is True.
+
+ This function will raise ImportError if the named module cannot be
+ imported.
+ """
+ # NOTE: test_heapq, test_json and test_warnings include extra sanity checks
+ # to make sure that this utility function is working as expected
+ with _ignore_deprecated_imports(deprecated):
+ # Keep track of modules saved for later restoration as well
+ # as those which just need a blocking entry removed
+ orig_modules = {}
+ names_to_remove = []
+ _save_and_remove_module(name, orig_modules)
+ try:
+ for fresh_name in fresh:
+ _save_and_remove_module(fresh_name, orig_modules)
+ for blocked_name in blocked:
+ if not _save_and_block_module(blocked_name, orig_modules):
+ names_to_remove.append(blocked_name)
+ fresh_module = importlib.import_module(name)
+ except ImportError:
+ fresh_module = None
+ finally:
+ for orig_name, module in orig_modules.items():
+ sys.modules[orig_name] = module
+ for name_to_remove in names_to_remove:
+ del sys.modules[name_to_remove]
+ return fresh_module
diff --git a/extra_tests/test_decimal.py b/extra_tests/test_decimal.py
new file mode 100644
--- /dev/null
+++ b/extra_tests/test_decimal.py
@@ -0,0 +1,59 @@
+import pickle
+import sys
+
+from support import import_fresh_module
+
+C = import_fresh_module('decimal', fresh=['_decimal'])
+P = import_fresh_module('decimal', blocked=['_decimal'])
+# import _decimal as C
+# import _pydecimal as P
+
+
+class TestPythonAPI:
+
+ def check_equal(self, val, proto):
+ d = C.Decimal(val)
+ p = pickle.dumps(d, proto)
+ assert d == pickle.loads(p)
+
+ def test_C(self):
+ sys.modules["decimal"] = C
+ import decimal
+ d = decimal.Decimal('1')
+ assert isinstance(d, C.Decimal)
+ assert isinstance(d, decimal.Decimal)
+ assert isinstance(d.as_tuple(), C.DecimalTuple)
+
+ assert d == C.Decimal('1')
+
+ def test_pickle(self):
+ v = '-3.123e81723'
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ sys.modules["decimal"] = C
+ self.check_equal('-3.141590000', proto)
+ self.check_equal(v, proto)
+
+ cd = C.Decimal(v)
+ pd = P.Decimal(v)
+ cdt = cd.as_tuple()
+ pdt = pd.as_tuple()
+ assert cdt.__module__ == pdt.__module__
+
+ p = pickle.dumps(cdt, proto)
+ r = pickle.loads(p)
+ assert isinstance(r, C.DecimalTuple)
+ assert cdt == r
+
+ sys.modules["decimal"] = C
+ p = pickle.dumps(cd, proto)
+ sys.modules["decimal"] = P
+ r = pickle.loads(p)
+ assert isinstance(r, P.Decimal)
+ assert r == pd
+
+ sys.modules["decimal"] = C
+ p = pickle.dumps(cdt, proto)
+ sys.modules["decimal"] = P
+ r = pickle.loads(p)
+ assert isinstance(r, P.DecimalTuple)
+ assert r == pdt
diff --git a/lib_pypy/_decimal.py b/lib_pypy/_decimal.py
--- a/lib_pypy/_decimal.py
+++ b/lib_pypy/_decimal.py
@@ -1,5 +1,9 @@
# Implementation of the "decimal" module, based on libmpdec library.
+__xname__ = __name__ # sys.modules lookup (--without-threads)
+__name__ = 'decimal' # For pickling
+
+
import collections as _collections
import math as _math
import numbers as _numbers
@@ -23,15 +27,13 @@
# Errors
class DecimalException(ArithmeticError):
- __module__ = 'decimal'
def handle(self, context, *args):
pass
class Clamped(DecimalException):
- __module__ = 'decimal'
+ pass
class InvalidOperation(DecimalException):
- __module__ = 'decimal'
def handle(self, context, *args):
if args:
ans = _dec_from_triple(args[0]._sign, args[0]._int, 'n', True)
@@ -39,41 +41,35 @@
return _NaN
class ConversionSyntax(InvalidOperation):
- __module__ = 'decimal'
def handle(self, context, *args):
return _NaN
class DivisionByZero(DecimalException, ZeroDivisionError):
- __module__ = 'decimal'
def handle(self, context, sign, *args):
return _SignedInfinity[sign]
class DivisionImpossible(InvalidOperation):
- __module__ = 'decimal'
def handle(self, context, *args):
return _NaN
class DivisionUndefined(InvalidOperation, ZeroDivisionError):
- __module__ = 'decimal'
def handle(self, context, *args):
return _NaN
class Inexact(DecimalException):
- __module__ = 'decimal'
+ pass
class InvalidContext(InvalidOperation):
- __module__ = 'decimal'
def handle(self, context, *args):
return _NaN
class Rounded(DecimalException):
- __module__ = 'decimal'
+ pass
class Subnormal(DecimalException):
- __module__ = 'decimal'
+ pass
class Overflow(Inexact, Rounded):
- __module__ = 'decimal'
def handle(self, context, sign, *args):
if context.rounding in (ROUND_HALF_UP, ROUND_HALF_EVEN,
ROUND_HALF_DOWN, ROUND_UP):
@@ -90,10 +86,10 @@
context.Emax-context.prec+1)
class Underflow(Inexact, Rounded, Subnormal):
- __module__ = 'decimal'
+ pass
class FloatOperation(DecimalException, TypeError):
- __module__ = 'decimal'
+ pass
__version__ = "1.70"
@@ -107,7 +103,7 @@
def getcontext():
"""Returns this thread's context.
-
+
If this thread does not yet have a context, returns
a new context and sets this thread's context.
New contexts are copies of DefaultContext.
@@ -173,8 +169,6 @@
_DEC_MINALLOC = 4
class Decimal(object):
- __module__ = 'decimal'
-
__slots__ = ('_mpd', '_data')
def __new__(cls, value="0", context=None):
@@ -326,7 +320,7 @@
builder.append(b'E')
builder.append(str(exponent).encode())
- return cls._from_bytes(b''.join(builder), context, exact=exact)
+ return cls._from_bytes(b''.join(builder), context, exact=exact)
@classmethod
def from_float(cls, value):
@@ -481,7 +475,7 @@
numerator = Decimal._from_int(other.numerator, context)
if not _mpdec.mpd_isspecial(self._mpd):
# multiplied = self * other.denominator
- #
+ #
# Prevent Overflow in the following multiplication.
# The result of the multiplication is
# only used in mpd_qcmp, which can handle values that
@@ -542,7 +536,7 @@
_mpdec.mpd_qset_ssize(p._mpd, self._PyHASH_MODULUS,
maxctx, status_ptr)
ten = self._new_empty()
- _mpdec.mpd_qset_ssize(ten._mpd, 10,
+ _mpdec.mpd_qset_ssize(ten._mpd, 10,
maxctx, status_ptr)
inv10_p = self._new_empty()
_mpdec.mpd_qset_ssize(inv10_p._mpd, self._PyHASH_10INV,
@@ -755,7 +749,7 @@
number_class = _make_unary_operation('number_class')
to_eng_string = _make_unary_operation('to_eng_string')
-
+
def fma(self, other, third, context=None):
context = _getcontext(context)
return context.fma(self, other, third)
@@ -790,7 +784,7 @@
result = int.from_bytes(s, 'little', signed=False)
if _mpdec.mpd_isnegative(x._mpd) and not _mpdec.mpd_iszero(x._mpd):
result = -result
- return result
+ return result
def __int__(self):
return self._to_int(_mpdec.MPD_ROUND_DOWN)
@@ -798,10 +792,10 @@
__trunc__ = __int__
def __floor__(self):
- return self._to_int(_mpdec.MPD_ROUND_FLOOR)
+ return self._to_int(_mpdec.MPD_ROUND_FLOOR)
def __ceil__(self):
- return self._to_int(_mpdec.MPD_ROUND_CEILING)
+ return self._to_int(_mpdec.MPD_ROUND_CEILING)
def to_integral(self, rounding=None, context=None):
context = _getcontext(context)
@@ -817,7 +811,7 @@
return result
to_integral_value = to_integral
-
+
def to_integral_exact(self, rounding=None, context=None):
context = _getcontext(context)
workctx = context.copy()
@@ -886,7 +880,7 @@
if _mpdec.mpd_isspecial(self._mpd):
return 0
return _mpdec.mpd_adjexp(self._mpd)
-
+
@property
def real(self):
return self
@@ -916,7 +910,7 @@
fmt = specifier.encode('utf-8')
context = getcontext()
- replace_fillchar = False
+ replace_fillchar = False
if fmt and fmt[0] == 0:
# NUL fill character: must be replaced with a valid UTF-8 char
# before calling mpd_parse_fmt_str().
@@ -975,7 +969,7 @@
result = result.replace(b'\xff', b'\0')
return result.decode('utf-8')
-
+
# Register Decimal as a kind of Number (an abstract base class).
# However, do not register it as Real (because Decimals are not
# interoperable with floats).
@@ -988,7 +982,7 @@
# Rounding
_ROUNDINGS = {
- 'ROUND_DOWN': _mpdec.MPD_ROUND_DOWN,
+ 'ROUND_DOWN': _mpdec.MPD_ROUND_DOWN,
'ROUND_HALF_UP': _mpdec.MPD_ROUND_HALF_UP,
'ROUND_HALF_EVEN': _mpdec.MPD_ROUND_HALF_EVEN,
'ROUND_CEILING': _mpdec.MPD_ROUND_CEILING,
@@ -1047,8 +1041,6 @@
clamp - If 1, change exponents if too high (Default 0)
"""
- __module__ = 'decimal'
-
__slots__ = ('_ctx', '_capitals')
def __new__(cls, prec=None, rounding=None, Emin=None, Emax=None,
@@ -1068,7 +1060,7 @@
ctx.round = _mpdec.MPD_ROUND_HALF_EVEN
ctx.clamp = 0
ctx.allcr = 1
-
+
self._capitals = 1
return self
@@ -1291,7 +1283,7 @@
if b is NotImplemented:
return b, b
return a, b
-
+
def _make_unary_method(name, mpd_func_name):
mpd_func = getattr(_mpdec, mpd_func_name)
@@ -1570,7 +1562,7 @@
def copy(self):
return self._as_dict()
-
+
def __len__(self):
return len(_SIGNALS)
@@ -1629,7 +1621,7 @@
def __enter__(self):
self.status_ptr = _ffi.new("uint32_t*")
return self.context._ctx, self.status_ptr
-
+
def __exit__(self, *args):
status = self.status_ptr[0]
# May raise a DecimalException
More information about the pypy-commit
mailing list