[Python-checkins] [3.8] Revert bpo-39576: Prevent memory error for overly optimistic precisions (GH-20747)

Stefan Krah webhook-mailer at python.org
Mon Jun 8 19:57:15 EDT 2020


https://github.com/python/cpython/commit/0f5a28f834bdac2da8a04597dc0fc5b71e50da9d
commit: 0f5a28f834bdac2da8a04597dc0fc5b71e50da9d
branch: 3.8
author: Stefan Krah <skrah at bytereef.org>
committer: GitHub <noreply at github.com>
date: 2020-06-09T01:57:11+02:00
summary:

[3.8] Revert bpo-39576: Prevent memory error for overly optimistic precisions (GH-20747)

This reverts commit b6271025c640c228505dc9f194362a0c2ab81c61.

files:
M Lib/test/test_decimal.py
M Modules/_decimal/libmpdec/mpdecimal.c
M Modules/_decimal/tests/deccheck.py

diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 0e9cd3095c85e..1f37b5372a3e7 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -5476,41 +5476,6 @@ def __abs__(self):
             self.assertEqual(Decimal.from_float(cls(101.1)),
                              Decimal.from_float(101.1))
 
-    def test_maxcontext_exact_arith(self):
-
-        # Make sure that exact operations do not raise MemoryError due
-        # to huge intermediate values when the context precision is very
-        # large.
-
-        # The following functions fill the available precision and are
-        # therefore not suitable for large precisions (by design of the
-        # specification).
-        MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus',
-                          'logical_and', 'logical_or', 'logical_xor',
-                          'next_toward', 'rotate', 'shift']
-
-        Decimal = C.Decimal
-        Context = C.Context
-        localcontext = C.localcontext
-
-        # Here only some functions that are likely candidates for triggering a
-        # MemoryError are tested.  deccheck.py has an exhaustive test.
-        maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX)
-        with localcontext(maxcontext):
-            self.assertEqual(Decimal(0).exp(), 1)
-            self.assertEqual(Decimal(1).ln(), 0)
-            self.assertEqual(Decimal(1).log10(), 0)
-            self.assertEqual(Decimal(10**2).log10(), 2)
-            self.assertEqual(Decimal(10**223).log10(), 223)
-            self.assertEqual(Decimal(10**19).logb(), 19)
-            self.assertEqual(Decimal(4).sqrt(), 2)
-            self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5'))
-            self.assertEqual(divmod(Decimal(10), 3), (3, 1))
-            self.assertEqual(Decimal(10) // 3, 3)
-            self.assertEqual(Decimal(4) / 2, 2)
-            self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))
-
-
 @requires_docstrings
 @unittest.skipUnless(C, "test requires C version")
 class SignatureTest(unittest.TestCase):
diff --git a/Modules/_decimal/libmpdec/mpdecimal.c b/Modules/_decimal/libmpdec/mpdecimal.c
index 0986edb576a10..bfa8bb343e60c 100644
--- a/Modules/_decimal/libmpdec/mpdecimal.c
+++ b/Modules/_decimal/libmpdec/mpdecimal.c
@@ -3781,43 +3781,6 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b,
          const mpd_context_t *ctx, uint32_t *status)
 {
     _mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status);
-
-    if (*status & MPD_Malloc_error) {
-        /* Inexact quotients (the usual case) fill the entire context precision,
-         * which can lead to malloc() failures for very high precisions. Retry
-         * the operation with a lower precision in case the result is exact.
-         *
-         * We need an upper bound for the number of digits of a_coeff / b_coeff
-         * when the result is exact.  If a_coeff' * 1 / b_coeff' is in lowest
-         * terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable
-         * bound.
-         *
-         * 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5.
-         * The largest amount of digits is generated if b_coeff' is a power of 2 or
-         * a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff').
-         *
-         * We arrive at a total upper bound:
-         *
-         *   maxdigits(a_coeff') + maxdigits(1 / b_coeff') <=
-         *   a->digits + log2(b_coeff) =
-         *   a->digits + log10(b_coeff) / log10(2) <=
-         *   a->digits + b->digits * 4;
-         */
-        uint32_t workstatus = 0;
-        mpd_context_t workctx = *ctx;
-        workctx.prec = a->digits + b->digits * 4;
-        if (workctx.prec >= ctx->prec) {
-            return;  /* No point in retrying, keep the original error. */
-        }
-
-        _mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus);
-        if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */
-            *status = 0;
-            return;
-        }
-
-        mpd_seterror(q, *status, status);
-    }
 }
 
 /* Internal function. */
@@ -7739,9 +7702,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
 /* END LIBMPDEC_ONLY */
 
 /* Algorithm from decimal.py */
-static void
-_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
-           uint32_t *status)
+void
+mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
+          uint32_t *status)
 {
     mpd_context_t maxcontext;
     MPD_NEW_STATIC(c,0,0,0,0);
@@ -7873,40 +7836,6 @@ _mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
     goto out;
 }
 
-void
-mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
-          uint32_t *status)
-{
-    _mpd_qsqrt(result, a, ctx, status);
-
-    if (*status & (MPD_Malloc_error|MPD_Division_impossible)) {
-        /* The above conditions can occur at very high context precisions
-         * if intermediate values get too large. Retry the operation with
-         * a lower context precision in case the result is exact.
-         *
-         * If the result is exact, an upper bound for the number of digits
-         * is the number of digits in the input.
-         *
-         * NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2
-         */
-        uint32_t workstatus = 0;
-        mpd_context_t workctx = *ctx;
-        workctx.prec = a->digits;
-
-        if (workctx.prec >= ctx->prec) {
-            return; /* No point in repeating this, keep the original error. */
-        }
-
-        _mpd_qsqrt(result, a, &workctx, &workstatus);
-        if (workstatus == 0) {
-            *status = 0;
-            return;
-        }
-
-        mpd_seterror(result, *status, status);
-    }
-}
-
 
 /******************************************************************************/
 /*                              Base conversions                              */
diff --git a/Modules/_decimal/tests/deccheck.py b/Modules/_decimal/tests/deccheck.py
index 5cd5db5711426..f907531e1ffa5 100644
--- a/Modules/_decimal/tests/deccheck.py
+++ b/Modules/_decimal/tests/deccheck.py
@@ -125,12 +125,6 @@
     'special': ('context.__reduce_ex__', 'context.create_decimal_from_float')
 }
 
-# Functions that set no context flags but whose result can differ depending
-# on prec, Emin and Emax.
-MaxContextSkip = ['is_normal', 'is_subnormal', 'logical_invert', 'next_minus',
-                  'next_plus', 'number_class', 'logical_and', 'logical_or',
-                  'logical_xor', 'next_toward', 'rotate', 'shift']
-
 # Functions that require a restricted exponent range for reasonable runtimes.
 UnaryRestricted = [
   '__ceil__', '__floor__', '__int__', '__trunc__',
@@ -350,20 +344,6 @@ def __init__(self, funcname, operands):
         self.pex = RestrictedList()      # Python exceptions for P.Decimal
         self.presults = RestrictedList() # P.Decimal results
 
-        # If the above results are exact, unrounded and not clamped, repeat
-        # the operation with a maxcontext to ensure that huge intermediate
-        # values do not cause a MemoryError.
-        self.with_maxcontext = False
-        self.maxcontext = context.c.copy()
-        self.maxcontext.prec = C.MAX_PREC
-        self.maxcontext.Emax = C.MAX_EMAX
-        self.maxcontext.Emin = C.MIN_EMIN
-        self.maxcontext.clear_flags()
-
-        self.maxop = RestrictedList()       # converted C.Decimal operands
-        self.maxex = RestrictedList()       # Python exceptions for C.Decimal
-        self.maxresults = RestrictedList()  # C.Decimal results
-
 
 # ======================================================================
 #                SkipHandler: skip known discrepancies
@@ -565,17 +545,13 @@ def function_as_string(t):
     if t.contextfunc:
         cargs = t.cop
         pargs = t.pop
-        maxargs = t.maxop
         cfunc = "c_func: %s(" % t.funcname
         pfunc = "p_func: %s(" % t.funcname
-        maxfunc = "max_func: %s(" % t.funcname
     else:
         cself, cargs = t.cop[0], t.cop[1:]
         pself, pargs = t.pop[0], t.pop[1:]
-        maxself, maxargs = t.maxop[0], t.maxop[1:]
         cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname)
         pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname)
-        maxfunc = "max_func: %s.%s(" % (repr(maxself), t.funcname)
 
     err = cfunc
     for arg in cargs:
@@ -589,14 +565,6 @@ def function_as_string(t):
     err = err.rstrip(", ")
     err += ")"
 
-    if t.with_maxcontext:
-        err += "\n"
-        err += maxfunc
-        for arg in maxargs:
-            err += "%s, " % repr(arg)
-        err = err.rstrip(", ")
-        err += ")"
-
     return err
 
 def raise_error(t):
@@ -609,24 +577,9 @@ def raise_error(t):
     err = "Error in %s:\n\n" % t.funcname
     err += "input operands: %s\n\n" % (t.op,)
     err += function_as_string(t)
-
-    err += "\n\nc_result: %s\np_result: %s\n" % (t.cresults, t.presults)
-    if t.with_maxcontext:
-        err += "max_result: %s\n\n" % (t.maxresults)
-    else:
-        err += "\n"
-
-    err += "c_exceptions: %s\np_exceptions: %s\n" % (t.cex, t.pex)
-    if t.with_maxcontext:
-        err += "max_exceptions: %s\n\n" % t.maxex
-    else:
-        err += "\n"
-
-    err += "%s\n" % str(t.context)
-    if t.with_maxcontext:
-        err += "%s\n" % str(t.maxcontext)
-    else:
-        err += "\n"
+    err += "\n\nc_result: %s\np_result: %s\n\n" % (t.cresults, t.presults)
+    err += "c_exceptions: %s\np_exceptions: %s\n\n" % (t.cex, t.pex)
+    err += "%s\n\n" % str(t.context)
 
     raise VerifyError(err)
 
@@ -650,13 +603,6 @@ def raise_error(t):
 #                are printed to stdout.
 # ======================================================================
 
-def all_nan(a):
-    if isinstance(a, C.Decimal):
-        return a.is_nan()
-    elif isinstance(a, tuple):
-        return all(all_nan(v) for v in a)
-    return False
-
 def convert(t, convstr=True):
     """ t is the testset. At this stage the testset contains a tuple of
         operands t.op of various types. For decimal methods the first
@@ -671,12 +617,10 @@ def convert(t, convstr=True):
     for i, op in enumerate(t.op):
 
         context.clear_status()
-        t.maxcontext.clear_flags()
 
         if op in RoundModes:
             t.cop.append(op)
             t.pop.append(op)
-            t.maxop.append(op)
 
         elif not t.contextfunc and i == 0 or \
              convstr and isinstance(op, str):
@@ -694,25 +638,11 @@ def convert(t, convstr=True):
                 p = None
                 pex = e.__class__
 
-            try:
-                C.setcontext(t.maxcontext)
-                maxop = C.Decimal(op)
-                maxex = None
-            except (TypeError, ValueError, OverflowError) as e:
-                maxop = None
-                maxex = e.__class__
-            finally:
-                C.setcontext(context.c)
-
             t.cop.append(c)
             t.cex.append(cex)
-
             t.pop.append(p)
             t.pex.append(pex)
 
-            t.maxop.append(maxop)
-            t.maxex.append(maxex)
-
             if cex is pex:
                 if str(c) != str(p) or not context.assert_eq_status():
                     raise_error(t)
@@ -722,21 +652,14 @@ def convert(t, convstr=True):
             else:
                 raise_error(t)
 
-            # The exceptions in the maxcontext operation can legitimately
-            # differ, only test that maxex implies cex:
-            if maxex is not None and cex is not maxex:
-                raise_error(t)
-
         elif isinstance(op, Context):
             t.context = op
             t.cop.append(op.c)
             t.pop.append(op.p)
-            t.maxop.append(t.maxcontext)
 
         else:
             t.cop.append(op)
             t.pop.append(op)
-            t.maxop.append(op)
 
     return 1
 
@@ -750,7 +673,6 @@ def callfuncs(t):
         t.rc and t.rp are the results of the operation.
     """
     context.clear_status()
-    t.maxcontext.clear_flags()
 
     try:
         if t.contextfunc:
@@ -778,35 +700,6 @@ def callfuncs(t):
         t.rp = None
         t.pex.append(e.__class__)
 
-    # If the above results are exact, unrounded, normal etc., repeat the
-    # operation with a maxcontext to ensure that huge intermediate values
-    # do not cause a MemoryError.
-    if (t.funcname not in MaxContextSkip and
-        not context.c.flags[C.InvalidOperation] and
-        not context.c.flags[C.Inexact] and
-        not context.c.flags[C.Rounded] and
-        not context.c.flags[C.Subnormal] and
-        not context.c.flags[C.Clamped] and
-        not context.clamp and # results are padded to context.prec if context.clamp==1.
-        not any(isinstance(v, C.Context) for v in t.cop)): # another context is used.
-        t.with_maxcontext = True
-        try:
-            if t.contextfunc:
-                maxargs = t.maxop
-                t.rmax = getattr(t.maxcontext, t.funcname)(*maxargs)
-            else:
-                maxself = t.maxop[0]
-                maxargs = t.maxop[1:]
-                try:
-                    C.setcontext(t.maxcontext)
-                    t.rmax = getattr(maxself, t.funcname)(*maxargs)
-                finally:
-                    C.setcontext(context.c)
-            t.maxex.append(None)
-        except (TypeError, ValueError, OverflowError, MemoryError) as e:
-            t.rmax = None
-            t.maxex.append(e.__class__)
-
 def verify(t, stat):
     """ t is the testset. At this stage the testset contains the following
         tuples:
@@ -821,9 +714,6 @@ def verify(t, stat):
     """
     t.cresults.append(str(t.rc))
     t.presults.append(str(t.rp))
-    if t.with_maxcontext:
-        t.maxresults.append(str(t.rmax))
-
     if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal):
         # General case: both results are Decimals.
         t.cresults.append(t.rc.to_eng_string())
@@ -835,12 +725,6 @@ def verify(t, stat):
         t.presults.append(str(t.rp.imag))
         t.presults.append(str(t.rp.real))
 
-        if t.with_maxcontext and isinstance(t.rmax, C.Decimal):
-            t.maxresults.append(t.rmax.to_eng_string())
-            t.maxresults.append(t.rmax.as_tuple())
-            t.maxresults.append(str(t.rmax.imag))
-            t.maxresults.append(str(t.rmax.real))
-
         nc = t.rc.number_class().lstrip('+-s')
         stat[nc] += 1
     else:
@@ -848,9 +732,6 @@ def verify(t, stat):
         if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple):
             if t.rc != t.rp:
                 raise_error(t)
-            if t.with_maxcontext and not isinstance(t.rmax, tuple):
-                if t.rmax != t.rc:
-                    raise_error(t)
         stat[type(t.rc).__name__] += 1
 
     # The return value lists must be equal.
@@ -863,20 +744,6 @@ def verify(t, stat):
     if not t.context.assert_eq_status():
         raise_error(t)
 
-    if t.with_maxcontext:
-        # NaN payloads etc. depend on precision and clamp.
-        if all_nan(t.rc) and all_nan(t.rmax):
-            return
-        # The return value lists must be equal.
-        if t.maxresults != t.cresults:
-            raise_error(t)
-        # The Python exception lists (TypeError, etc.) must be equal.
-        if t.maxex != t.cex:
-            raise_error(t)
-        # The context flags must be equal.
-        if t.maxcontext.flags != t.context.c.flags:
-            raise_error(t)
-
 
 # ======================================================================
 #                           Main test loops



More information about the Python-checkins mailing list