[Python-checkins] bpo-41324 Add a minimal decimal capsule API (#21519)

Stefan Krah webhook-mailer at python.org
Mon Aug 10 10:32:30 EDT 2020


https://github.com/python/cpython/commit/39042e00ab01d6521548c1b7cc6554c09f4389ff
commit: 39042e00ab01d6521548c1b7cc6554c09f4389ff
branch: master
author: Stefan Krah <skrah at bytereef.org>
committer: GitHub <noreply at github.com>
date: 2020-08-10T16:32:21+02:00
summary:

bpo-41324 Add a minimal decimal capsule API (#21519)

files:
A Doc/c-api/decimal.rst
A Include/pydecimal.h
A Misc/NEWS.d/next/C API/2020-08-10-16-05-08.bpo-41324.waZD35.rst
M Doc/c-api/concrete.rst
M Lib/test/test_decimal.py
M Modules/_decimal/_decimal.c
M Modules/_decimal/tests/deccheck.py
M Modules/_testcapimodule.c

diff --git a/Doc/c-api/concrete.rst b/Doc/c-api/concrete.rst
index c1d9fa1b41a3f..bf263d6e4c264 100644
--- a/Doc/c-api/concrete.rst
+++ b/Doc/c-api/concrete.rst
@@ -115,3 +115,4 @@ Other Objects
    coro.rst
    contextvars.rst
    datetime.rst
+   decimal.rst
diff --git a/Doc/c-api/decimal.rst b/Doc/c-api/decimal.rst
new file mode 100644
index 0000000000000..f530571ebae57
--- /dev/null
+++ b/Doc/c-api/decimal.rst
@@ -0,0 +1,231 @@
+.. sectionauthor:: Stefan Krah
+
+.. highlight:: c
+
+
+Decimal capsule API
+===================
+
+Capsule API functions can be used in the same manner as regular library
+functions, provided that the API has been initialized.
+
+
+Initialize
+----------
+
+Typically, a C extension module that uses the decimal API will do these
+steps in its init function:
+
+.. code-block::
+
+    #include "pydecimal.h"
+
+    static int decimal_initialized = 0;
+    if (!decimal_initialized) {
+        if (import_decimal() < 0) {
+            return NULL;
+        }
+
+        decimal_initialized = 1;
+    }
+
+
+Type checking, predicates, accessors
+------------------------------------
+
+.. c:function:: int PyDec_TypeCheck(const PyObject *dec)
+
+   Return 1 if ``dec`` is a Decimal, 0 otherwise.  This function does not set
+   any exceptions.
+
+
+.. c:function:: int PyDec_IsSpecial(const PyObject *dec)
+
+   Return 1 if ``dec`` is ``NaN``, ``sNaN`` or ``Infinity``, 0 otherwise.
+
+   Set TypeError and return -1 if ``dec`` is not a Decimal.  It is guaranteed that
+   this is the only failure mode, so if ``dec`` has already been type-checked, no
+   errors can occur and the function can be treated as a simple predicate.
+
+
+.. c:function:: int PyDec_IsNaN(const PyObject *dec)
+
+   Return 1 if ``dec`` is ``NaN`` or ``sNaN``, 0 otherwise.
+
+   Set TypeError and return -1 if ``dec`` is not a Decimal.  It is guaranteed that
+   this is the only failure mode, so if ``dec`` has already been type-checked, no
+   errors can occur and the function can be treated as a simple predicate.
+
+
+.. c:function:: int PyDec_IsInfinite(const PyObject *dec)
+
+   Return 1 if ``dec`` is ``Infinity``, 0 otherwise.
+
+   Set TypeError and return -1 if ``dec`` is not a Decimal.  It is guaranteed that
+   this is the only failure mode, so if ``dec`` has already been type-checked, no
+   errors can occur and the function can be treated as a simple predicate.
+
+
+.. c:function:: int64_t PyDec_GetDigits(const PyObject *dec)
+
+   Return the number of digits in the coefficient.  For ``Infinity``, the
+   number of digits is always zero.  Typically, the same applies to ``NaN``
+   and ``sNaN``, but both of these can have a payload that is equivalent to
+   a coefficient.  Therefore, ``NaNs`` can have a nonzero return value.
+
+   Set TypeError and return -1 if ``dec`` is not a Decimal.  It is guaranteed that
+   this is the only failure mode, so if ``dec`` has already been type-checked, no
+   errors can occur and the function can be treated as a simple accessor.
+
+
+Exact conversions between decimals and primitive C types
+--------------------------------------------------------
+
+This API supports conversions for decimals with a coefficient up to 38 digits.
+
+Data structures
+~~~~~~~~~~~~~~~
+
+The conversion functions use the following status codes and data structures:
+
+.. code-block::
+
+   /* status cases for getting a triple */
+   enum mpd_triple_class {
+     MPD_TRIPLE_NORMAL,
+     MPD_TRIPLE_INF,
+     MPD_TRIPLE_QNAN,
+     MPD_TRIPLE_SNAN,
+     MPD_TRIPLE_ERROR,
+   };
+
+   typedef struct {
+     enum mpd_triple_class tag;
+     uint8_t sign;
+     uint64_t hi;
+     uint64_t lo;
+     int64_t exp;
+   } mpd_uint128_triple_t;
+
+The status cases are explained below.  ``sign`` is 0 for positive and 1 for negative.
+``((uint128_t)hi << 64) + lo`` is the coefficient, ``exp`` is the exponent.
+
+The data structure is called "triple" because the decimal triple (sign, coeff, exp)
+is an established term and (``hi``, ``lo``) represents a single ``uint128_t`` coefficient.
+
+
+Functions
+~~~~~~~~~
+
+.. c:function:: mpd_uint128_triple_t PyDec_AsUint128Triple(const PyObject *dec)
+
+   Convert a decimal to a triple.  As above, it is guaranteed that the only
+   Python failure mode is a TypeError, checks can be omitted if the type is
+   known.
+
+   For simplicity, the usage of the function and all special cases are
+   explained in code form and comments:
+
+.. code-block::
+
+    triple = PyDec_AsUint128Triple(dec);
+    switch (triple.tag) {
+    case MPD_TRIPLE_QNAN:
+        /*
+         * Success: handle a quiet NaN.
+         *   1) triple.sign is 0 or 1.
+         *   2) triple.exp is always 0.
+         *   3) If triple.hi or triple.lo are nonzero, the NaN has a payload.
+         */
+        break;
+
+    case MPD_TRIPLE_SNAN:
+        /*
+         * Success: handle a signaling NaN.
+         *   1) triple.sign is 0 or 1.
+         *   2) triple.exp is always 0.
+         *   3) If triple.hi or triple.lo are nonzero, the sNaN has a payload.
+         */
+        break;
+
+    case MPD_TRIPLE_INF:
+        /*
+         * Success: handle Infinity.
+         *   1) triple.sign is 0 or 1.
+         *   2) triple.exp is always 0.
+         *   3) triple.hi and triple.lo are always zero.
+         */
+        break;
+
+    case MPD_TRIPLE_NORMAL:
+        /* Success: handle a finite value. */
+        break;
+
+    case MPD_TRIPLE_ERROR:
+        /* TypeError check: can be omitted if the type of dec is known. */
+        if (PyErr_Occurred()) {
+            return NULL;
+        }
+
+        /* Too large for conversion.  PyDec_AsUint128Triple() does not set an
+           exception so applications can choose themselves.  Typically this
+           would be a ValueError. */
+        PyErr_SetString(PyExc_ValueError,
+            "value out of bounds for a uint128 triple");
+        return NULL;
+    }
+
+.. c:function:: PyObject *PyDec_FromUint128Triple(const mpd_uint128_triple_t *triple)
+
+   Create a decimal from a triple.  The following rules must be observed for
+   initializing the triple:
+
+   1) ``triple.sign`` must always be 0 (for positive) or 1 (for negative).
+
+   2) ``MPD_TRIPLE_QNAN``: ``triple.exp`` must be 0.  If ``triple.hi`` or ``triple.lo``
+      are nonzero,  create a ``NaN`` with a payload.
+
+   3) ``MPD_TRIPLE_SNAN``: ``triple.exp`` must be 0. If ``triple.hi`` or ``triple.lo``
+      are nonzero,  create an ``sNaN`` with a payload.
+
+   4) ``MPD_TRIPLE_INF``: ``triple.exp``, ``triple.hi`` and ``triple.lo`` must be zero.
+
+   5) ``MPD_TRIPLE_NORMAL``: ``MPD_MIN_ETINY + 38 < triple.exp < MPD_MAX_EMAX - 38``.
+      ``triple.hi`` and ``triple.lo`` can be chosen freely.
+
+   6) ``MPD_TRIPLE_ERROR``: It is always an error to set this tag.
+
+
+   If one of the above conditions is not met, the function returns ``NaN`` if
+   the ``InvalidOperation`` trap is not set in the thread local context.  Otherwise,
+   it sets the ``InvalidOperation`` exception and returns NULL.
+
+   Additionally, though extremely unlikely give the small allocation sizes,
+   the function can set ``MemoryError`` and return ``NULL``.
+
+
+Advanced API
+------------
+
+This API enables the use of ``libmpdec`` functions.  Since Python is compiled with
+hidden symbols, the API requires an external libmpdec and the ``mpdecimal.h``
+header.
+
+
+Functions
+~~~~~~~~~
+
+.. c:function:: PyObject *PyDec_Alloc(void)
+
+   Return a new decimal that can be used in the ``result`` position of ``libmpdec``
+   functions.
+
+.. c:function:: mpd_t *PyDec_Get(PyObject *v)
+
+   Get a pointer to the internal ``mpd_t`` of the decimal.  Decimals are immutable,
+   so this function must only be used on a new Decimal that has been created by
+   PyDec_Alloc().
+
+.. c:function:: const mpd_t *PyDec_GetConst(const PyObject *v)
+
+   Get a pointer to the constant internal ``mpd_t`` of the decimal.
diff --git a/Include/pydecimal.h b/Include/pydecimal.h
new file mode 100644
index 0000000000000..9b6440e1c2ab1
--- /dev/null
+++ b/Include/pydecimal.h
@@ -0,0 +1,180 @@
+/*
+ * Copyright (c) 2020 Stefan Krah. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright
+ *    notice, this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright
+ *    notice, this list of conditions and the following disclaimer in the
+ *    documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
+ * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+ * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
+ * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
+ * SUCH DAMAGE.
+ */
+
+
+#ifndef CPYTHON_DECIMAL_H_
+#define CPYTHON_DECIMAL_H_
+
+
+#include <Python.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/****************************************************************************/
+/*                               Libmpdec API                               */
+/****************************************************************************/
+
+#ifndef LIBMPDEC_MPDECIMAL_H_
+struct mpd_t;  /* ABI-stable in the libmpdec-2.x series */
+
+/* status cases for getting a triple */
+enum mpd_triple_class {
+  MPD_TRIPLE_NORMAL,
+  MPD_TRIPLE_INF,
+  MPD_TRIPLE_QNAN,
+  MPD_TRIPLE_SNAN,
+  MPD_TRIPLE_ERROR,
+};
+
+typedef struct {
+  enum mpd_triple_class tag;
+  uint8_t sign;
+  uint64_t hi;
+  uint64_t lo;
+  int64_t exp;
+} mpd_uint128_triple_t;
+#endif
+
+
+/****************************************************************************/
+/*                                Capsule API                               */
+/****************************************************************************/
+
+/* Simple API */
+#define PyDec_TypeCheck_INDEX 0
+#define PyDec_TypeCheck_RETURN int
+#define PyDec_TypeCheck_ARGS (const PyObject *)
+
+#define PyDec_IsSpecial_INDEX 1
+#define PyDec_IsSpecial_RETURN int
+#define PyDec_IsSpecial_ARGS (const PyObject *)
+
+#define PyDec_IsNaN_INDEX 2
+#define PyDec_IsNaN_RETURN int
+#define PyDec_IsNaN_ARGS (const PyObject *)
+
+#define PyDec_IsInfinite_INDEX 3
+#define PyDec_IsInfinite_RETURN int
+#define PyDec_IsInfinite_ARGS (const PyObject *)
+
+#define PyDec_GetDigits_INDEX 4
+#define PyDec_GetDigits_RETURN int64_t
+#define PyDec_GetDigits_ARGS (const PyObject *)
+
+#define PyDec_AsUint128Triple_INDEX 5
+#define PyDec_AsUint128Triple_RETURN mpd_uint128_triple_t
+#define PyDec_AsUint128Triple_ARGS (const PyObject *)
+
+#define PyDec_FromUint128Triple_INDEX 6
+#define PyDec_FromUint128Triple_RETURN PyObject *
+#define PyDec_FromUint128Triple_ARGS (const mpd_uint128_triple_t *triple)
+
+/* Advanced API */
+#define PyDec_Alloc_INDEX 7
+#define PyDec_Alloc_RETURN PyObject *
+#define PyDec_Alloc_ARGS (void)
+
+#define PyDec_Get_INDEX 8
+#define PyDec_Get_RETURN mpd_t *
+#define PyDec_Get_ARGS (PyObject *)
+
+#define PyDec_GetConst_INDEX 9
+#define PyDec_GetConst_RETURN const mpd_t *
+#define PyDec_GetConst_ARGS (const PyObject *)
+
+#define CPYTHON_DECIMAL_MAX_API 10
+
+
+#ifdef CPYTHON_DECIMAL_MODULE
+/* Simple API */
+static PyDec_TypeCheck_RETURN PyDec_TypeCheck PyDec_TypeCheck_ARGS;
+static PyDec_IsSpecial_RETURN PyDec_IsSpecial PyDec_IsSpecial_ARGS;
+static PyDec_IsNaN_RETURN PyDec_IsNaN PyDec_IsNaN_ARGS;
+static PyDec_IsInfinite_RETURN PyDec_IsInfinite PyDec_IsInfinite_ARGS;
+static PyDec_GetDigits_RETURN PyDec_GetDigits PyDec_GetDigits_ARGS;
+static PyDec_AsUint128Triple_RETURN PyDec_AsUint128Triple PyDec_AsUint128Triple_ARGS;
+static PyDec_FromUint128Triple_RETURN PyDec_FromUint128Triple PyDec_FromUint128Triple_ARGS;
+
+/* Advanced API */
+static PyDec_Alloc_RETURN PyDec_Alloc PyDec_Alloc_ARGS;
+static PyDec_Get_RETURN PyDec_Get PyDec_Get_ARGS;
+static PyDec_GetConst_RETURN PyDec_GetConst PyDec_GetConst_ARGS;
+#else
+static void **_decimal_api;
+
+/* Simple API */
+#define PyDec_TypeCheck \
+    (*(PyDec_TypeCheck_RETURN (*)PyDec_TypeCheck_ARGS) _decimal_api[PyDec_TypeCheck_INDEX])
+
+#define PyDec_IsSpecial \
+    (*(PyDec_IsSpecial_RETURN (*)PyDec_IsSpecial_ARGS) _decimal_api[PyDec_IsSpecial_INDEX])
+
+#define PyDec_IsNaN \
+    (*(PyDec_IsNaN_RETURN (*)PyDec_IsNaN_ARGS) _decimal_api[PyDec_IsNaN_INDEX])
+
+#define PyDec_IsInfinite \
+    (*(PyDec_IsInfinite_RETURN (*)PyDec_IsInfinite_ARGS) _decimal_api[PyDec_IsInfinite_INDEX])
+
+#define PyDec_GetDigits \
+    (*(PyDec_GetDigits_RETURN (*)PyDec_GetDigits_ARGS) _decimal_api[PyDec_GetDigits_INDEX])
+
+#define PyDec_AsUint128Triple \
+    (*(PyDec_AsUint128Triple_RETURN (*)PyDec_AsUint128Triple_ARGS) _decimal_api[PyDec_AsUint128Triple_INDEX])
+
+#define PyDec_FromUint128Triple \
+    (*(PyDec_FromUint128Triple_RETURN (*)PyDec_FromUint128Triple_ARGS) _decimal_api[PyDec_FromUint128Triple_INDEX])
+
+/* Advanced API */
+#define PyDec_Alloc \
+    (*(PyDec_Alloc_RETURN (*)PyDec_Alloc_ARGS) _decimal_api[PyDec_Alloc_INDEX])
+
+#define PyDec_Get \
+    (*(PyDec_Get_RETURN (*)PyDec_Get_ARGS) _decimal_api[PyDec_Get_INDEX])
+
+#define PyDec_GetConst \
+    (*(PyDec_GetConst_RETURN (*)PyDec_GetConst_ARGS) _decimal_api[PyDec_GetConst_INDEX])
+
+
+static int
+import_decimal(void)
+{
+    _decimal_api = (void **)PyCapsule_Import("_decimal._API", 0);
+    if (_decimal_api == NULL) {
+        return -1;
+    }
+
+    return 0;
+}
+#endif
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif  /* CPYTHON_DECIMAL_H_ */
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 5d0992a66e6ce..113b37ddaa9cd 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -43,6 +43,13 @@
 import inspect
 import threading
 
+from _testcapi import decimal_is_special
+from _testcapi import decimal_is_nan
+from _testcapi import decimal_is_infinite
+from _testcapi import decimal_get_digits
+from _testcapi import decimal_as_triple
+from _testcapi import decimal_from_triple
+
 
 C = import_fresh_module('decimal', fresh=['_decimal'])
 P = import_fresh_module('decimal', blocked=['_decimal'])
@@ -4751,6 +4758,175 @@ def test_constants(self):
         self.assertEqual(C.DecTraps,
                          C.DecErrors|C.DecOverflow|C.DecUnderflow)
 
+    def test_decimal_api_predicates(self):
+        # Capsule API
+
+        d = C.Decimal("0")
+        self.assertFalse(decimal_is_special(d))
+        self.assertFalse(decimal_is_nan(d))
+        self.assertFalse(decimal_is_infinite(d))
+
+        d = C.Decimal("NaN")
+        self.assertTrue(decimal_is_special(d))
+        self.assertTrue(decimal_is_nan(d))
+        self.assertFalse(decimal_is_infinite(d))
+
+        d = C.Decimal("sNaN")
+        self.assertTrue(decimal_is_special(d))
+        self.assertTrue(decimal_is_nan(d))
+        self.assertFalse(decimal_is_infinite(d))
+
+        d = C.Decimal("inf")
+        self.assertTrue(decimal_is_special(d))
+        self.assertFalse(decimal_is_nan(d))
+        self.assertTrue(decimal_is_infinite(d))
+
+    def test_decimal_api_get_digits(self):
+        # Capsule API
+
+        d = C.Decimal("0")
+        self.assertEqual(decimal_get_digits(d), 1)
+
+        d = C.Decimal("1234567890")
+        self.assertEqual(decimal_get_digits(d), 10)
+
+        d = C.Decimal("inf")
+        self.assertEqual(decimal_get_digits(d), 0)
+
+        d = C.Decimal("NaN")
+        self.assertEqual(decimal_get_digits(d), 0)
+
+        d = C.Decimal("sNaN")
+        self.assertEqual(decimal_get_digits(d), 0)
+
+        d = C.Decimal("NaN1234567890")
+        self.assertEqual(decimal_get_digits(d), 10)
+
+        d = C.Decimal("sNaN1234567890")
+        self.assertEqual(decimal_get_digits(d), 10)
+
+    def test_decimal_api_triple(self):
+        # Capsule API
+
+        def as_triple(d):
+            """Convert a decimal to a decimal triple with a split uint128_t
+               coefficient:
+
+                   (sign, hi, lo, exp)
+
+               It is called 'triple' because (hi, lo) are regarded as a single
+               uint128_t that is split because not all compilers support uint128_t.
+            """
+            sign, digits, exp = d.as_tuple()
+
+            s = "".join(str(d) for d in digits)
+            coeff = int(s) if s else 0
+
+            if coeff < 0 or coeff >= 2**128:
+                raise ValueError("value out of bounds for a uint128 triple");
+
+            hi, lo = divmod(coeff, 2**64)
+            return (sign, hi, lo, exp)
+
+        def from_triple(triple):
+            """Convert a decimal triple with a split uint128_t coefficient to a string.
+            """
+            sign, hi, lo, exp = triple
+            coeff = hi * 2**64 + lo
+
+            if coeff < 0 or coeff >= 2**128:
+                raise ValueError("value out of bounds for a uint128 triple");
+
+            digits = tuple(int(c) for c in str(coeff))
+
+            return P.Decimal((sign, digits, exp))
+
+        signs = ["", "-"]
+
+        coefficients = [
+            "000000000000000000000000000000000000000",
+
+            "299999999999999999999999999999999999999",
+            "299999999999999999990000000000000000000",
+            "200000000000000000009999999999999999999",
+            "000000000000000000009999999999999999999",
+
+            "299999999999999999999999999999000000000",
+            "299999999999999999999000000000999999999",
+            "299999999999000000000999999999999999999",
+            "299000000000999999999999999999999999999",
+            "000999999999999999999999999999999999999",
+
+            "300000000000000000000000000000000000000",
+            "310000000000000000001000000000000000000",
+            "310000000000000000000000000000000000000",
+            "300000000000000000001000000000000000000",
+
+            "340100000000100000000100000000100000000",
+            "340100000000100000000100000000000000000",
+            "340100000000100000000000000000100000000",
+            "340100000000000000000100000000100000000",
+            "340000000000100000000100000000100000000",
+
+            "340282366920938463463374607431768211455",
+        ]
+
+        exponents = [
+            "E+0", "E+1", "E-1",
+            "E+%s" % str(C.MAX_EMAX-38),
+            "E-%s" % str(C.MIN_ETINY+38),
+        ]
+
+        for sign in signs:
+            for coeff in coefficients:
+                for exp in exponents:
+                    s = sign + coeff + exp
+
+                    ctriple = decimal_as_triple(C.Decimal(s))
+                    ptriple = as_triple(P.Decimal(s))
+                    self.assertEqual(ctriple, ptriple)
+
+                    c = decimal_from_triple(ctriple)
+                    p = decimal_from_triple(ptriple)
+                    self.assertEqual(str(c), str(p))
+
+        for s in ["NaN", "-NaN", "sNaN", "-sNaN", "NaN123", "sNaN123", "inf", "-inf"]:
+            ctriple = decimal_as_triple(C.Decimal(s))
+            ptriple = as_triple(P.Decimal(s))
+            self.assertEqual(ctriple, ptriple)
+
+            c = decimal_from_triple(ctriple)
+            p = decimal_from_triple(ptriple)
+            self.assertEqual(str(c), str(p))
+
+    def test_decimal_api_errors(self):
+        # Capsule API
+
+        self.assertRaises(TypeError, decimal_as_triple, "X")
+        self.assertRaises(ValueError, decimal_as_triple, C.Decimal(2**128))
+        self.assertRaises(ValueError, decimal_as_triple, C.Decimal(-2**128))
+
+        self.assertRaises(TypeError, decimal_from_triple, "X")
+        self.assertRaises(ValueError, decimal_from_triple, ())
+        self.assertRaises(ValueError, decimal_from_triple, (1, 2, 3, 4, 5))
+        self.assertRaises(ValueError, decimal_from_triple, (2**8, 0, 0, 0))
+        self.assertRaises(OverflowError, decimal_from_triple, (0, 2**64, 0, 0))
+        self.assertRaises(OverflowError, decimal_from_triple, (0, 0, 2**64, 0))
+        self.assertRaises(OverflowError, decimal_from_triple, (0, 0, 0, 2**63))
+        self.assertRaises(OverflowError, decimal_from_triple, (0, 0, 0, -2**63-1))
+        self.assertRaises(ValueError, decimal_from_triple, (0, 0, 0, "X"))
+        self.assertRaises(TypeError, decimal_from_triple, (0, 0, 0, ()))
+
+        with C.localcontext(C.Context()):
+            self.assertRaises(C.InvalidOperation, decimal_from_triple, (2, 0, 0, 0))
+            self.assertRaises(C.InvalidOperation, decimal_from_triple, (0, 0, 0, 2**63-1))
+            self.assertRaises(C.InvalidOperation, decimal_from_triple, (0, 0, 0, -2**63))
+
+        self.assertRaises(TypeError, decimal_is_special, "X")
+        self.assertRaises(TypeError, decimal_is_nan, "X")
+        self.assertRaises(TypeError, decimal_is_infinite, "X")
+        self.assertRaises(TypeError, decimal_get_digits, "X")
+
 class CWhitebox(unittest.TestCase):
     """Whitebox testing for _decimal"""
 
diff --git a/Misc/NEWS.d/next/C API/2020-08-10-16-05-08.bpo-41324.waZD35.rst b/Misc/NEWS.d/next/C API/2020-08-10-16-05-08.bpo-41324.waZD35.rst
new file mode 100644
index 0000000000000..e09332ab11e1d
--- /dev/null
+++ b/Misc/NEWS.d/next/C API/2020-08-10-16-05-08.bpo-41324.waZD35.rst	
@@ -0,0 +1,3 @@
+Add a minimal decimal capsule API.  The API supports fast conversions
+between Decimals up to 38 digits and their triple representation as a C
+struct.
diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c
index fb4e020f1260e..e7c44acba02fc 100644
--- a/Modules/_decimal/_decimal.c
+++ b/Modules/_decimal/_decimal.c
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2008-2012 Stefan Krah. All rights reserved.
+ * Copyright (c) 2008-2020 Stefan Krah. All rights reserved.
  *
  * Redistribution and use in source and binary forms, with or without
  * modification, are permitted provided that the following conditions
@@ -33,6 +33,8 @@
 
 #include <stdlib.h>
 
+#define CPYTHON_DECIMAL_MODULE
+#include "pydecimal.h"
 #include "docstrings.h"
 
 
@@ -5555,6 +5557,160 @@ static PyTypeObject PyDecContext_Type =
 };
 
 
+/****************************************************************************/
+/*                                   C-API                                  */
+/****************************************************************************/
+
+static void *_decimal_api[CPYTHON_DECIMAL_MAX_API];
+
+/* Simple API */
+static int
+PyDec_TypeCheck(const PyObject *v)
+{
+    return PyDec_Check(v);
+}
+
+static int
+PyDec_IsSpecial(const PyObject *v)
+{
+    if (!PyDec_Check(v)) {
+        PyErr_SetString(PyExc_TypeError,
+            "PyDec_IsSpecial: argument must be a Decimal");
+        return -1;
+    }
+
+    return mpd_isspecial(MPD(v));
+}
+
+static int
+PyDec_IsNaN(const PyObject *v)
+{
+    if (!PyDec_Check(v)) {
+        PyErr_SetString(PyExc_TypeError,
+            "PyDec_IsNaN: argument must be a Decimal");
+        return -1;
+    }
+
+    return mpd_isnan(MPD(v));
+}
+
+static int
+PyDec_IsInfinite(const PyObject *v)
+{
+    if (!PyDec_Check(v)) {
+        PyErr_SetString(PyExc_TypeError,
+            "PyDec_IsInfinite: argument must be a Decimal");
+        return -1;
+    }
+
+    return mpd_isinfinite(MPD(v));
+}
+
+static int64_t
+PyDec_GetDigits(const PyObject *v)
+{
+    if (!PyDec_Check(v)) {
+        PyErr_SetString(PyExc_TypeError,
+            "PyDec_GetDigits: argument must be a Decimal");
+        return -1;
+    }
+
+    return MPD(v)->digits;
+}
+
+static mpd_uint128_triple_t
+PyDec_AsUint128Triple(const PyObject *v)
+{
+    if (!PyDec_Check(v)) {
+        mpd_uint128_triple_t triple = { MPD_TRIPLE_ERROR, 0, 0, 0, 0 };
+        PyErr_SetString(PyExc_TypeError,
+            "PyDec_AsUint128Triple: argument must be a Decimal");
+        return triple;
+    }
+
+    return mpd_as_uint128_triple(MPD(v));
+}
+
+static PyObject *
+PyDec_FromUint128Triple(const mpd_uint128_triple_t *triple)
+{
+    PyObject *context;
+    PyObject *result;
+    uint32_t status = 0;
+
+    CURRENT_CONTEXT(context);
+
+    result = dec_alloc();
+    if (result == NULL) {
+        return NULL;
+    }
+
+    if (mpd_from_uint128_triple(MPD(result), triple, &status) < 0) {
+        if (dec_addstatus(context, status)) {
+            Py_DECREF(result);
+            return NULL;
+        }
+    }
+
+    return result;
+}
+
+/* Advanced API */
+static PyObject *
+PyDec_Alloc(void)
+{
+    return dec_alloc();
+}
+
+static mpd_t *
+PyDec_Get(PyObject *v)
+{
+    if (!PyDec_Check(v)) {
+        PyErr_SetString(PyExc_TypeError,
+            "PyDec_Get: argument must be a Decimal");
+        return NULL;
+    }
+
+    return MPD(v);
+}
+
+static const mpd_t *
+PyDec_GetConst(const PyObject *v)
+{
+    if (!PyDec_Check(v)) {
+        PyErr_SetString(PyExc_TypeError,
+            "PyDec_GetConst: argument must be a Decimal");
+        return NULL;
+    }
+
+    return MPD(v);
+}
+
+static PyObject *
+init_api(void)
+{
+    /* Simple API */
+    _decimal_api[PyDec_TypeCheck_INDEX] = (void *)PyDec_TypeCheck;
+    _decimal_api[PyDec_IsSpecial_INDEX] = (void *)PyDec_IsSpecial;
+    _decimal_api[PyDec_IsNaN_INDEX] = (void *)PyDec_IsNaN;
+    _decimal_api[PyDec_IsInfinite_INDEX] = (void *)PyDec_IsInfinite;
+    _decimal_api[PyDec_GetDigits_INDEX] = (void *)PyDec_GetDigits;
+    _decimal_api[PyDec_AsUint128Triple_INDEX] = (void *)PyDec_AsUint128Triple;
+    _decimal_api[PyDec_FromUint128Triple_INDEX] = (void *)PyDec_FromUint128Triple;
+
+    /* Advanced API */
+    _decimal_api[PyDec_Alloc_INDEX] = (void *)PyDec_Alloc;
+    _decimal_api[PyDec_Get_INDEX] = (void *)PyDec_Get;
+    _decimal_api[PyDec_GetConst_INDEX] = (void *)PyDec_GetConst;
+
+    return PyCapsule_New(_decimal_api, "_decimal._API", NULL);
+}
+
+
+/****************************************************************************/
+/*                                  Module                                  */
+/****************************************************************************/
+
 static PyMethodDef _decimal_methods [] =
 {
   { "getcontext", (PyCFunction)PyDec_GetCurrentContext, METH_NOARGS, doc_getcontext},
@@ -5665,17 +5821,27 @@ PyInit__decimal(void)
     DecCondMap *cm;
     struct ssize_constmap *ssize_cm;
     struct int_constmap *int_cm;
+    static PyObject *capsule = NULL;
+    static int initialized = 0;
     int i;
 
 
     /* Init libmpdec */
-    mpd_traphandler = dec_traphandler;
-    mpd_mallocfunc = PyMem_Malloc;
-    mpd_reallocfunc = PyMem_Realloc;
-    mpd_callocfunc = mpd_callocfunc_em;
-    mpd_free = PyMem_Free;
-    mpd_setminalloc(_Py_DEC_MINALLOC);
+    if (!initialized) {
+        mpd_traphandler = dec_traphandler;
+        mpd_mallocfunc = PyMem_Malloc;
+        mpd_reallocfunc = PyMem_Realloc;
+        mpd_callocfunc = mpd_callocfunc_em;
+        mpd_free = PyMem_Free;
+        mpd_setminalloc(_Py_DEC_MINALLOC);
+
+        capsule = init_api();
+        if (capsule == NULL) {
+            return NULL;
+        }
 
+        initialized = 1;
+    }
 
     /* Init external C-API functions */
     _py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
@@ -5900,6 +6066,11 @@ PyInit__decimal(void)
     CHECK_INT(PyModule_AddStringConstant(m, "__version__", "1.70"));
     CHECK_INT(PyModule_AddStringConstant(m, "__libmpdec_version__", mpd_version()));
 
+    /* Add capsule API */
+    Py_INCREF(capsule);
+    if (PyModule_AddObject(m, "_API", capsule) < 0) {
+        goto error;
+    }
 
     return m;
 
diff --git a/Modules/_decimal/tests/deccheck.py b/Modules/_decimal/tests/deccheck.py
index 5d9179e61689d..15f104dc463cb 100644
--- a/Modules/_decimal/tests/deccheck.py
+++ b/Modules/_decimal/tests/deccheck.py
@@ -49,6 +49,9 @@
 from formathelper import rand_format, rand_locale
 from _pydecimal import _dec_from_triple
 
+from _testcapi import decimal_as_triple
+from _testcapi import decimal_from_triple
+
 C = import_fresh_module('decimal', fresh=['_decimal'])
 P = import_fresh_module('decimal', blocked=['_decimal'])
 EXIT_STATUS = 0
@@ -153,6 +156,45 @@
 TernaryRestricted = ['__pow__', 'context.power']
 
 
+# ======================================================================
+#                            Triple tests
+# ======================================================================
+
+def c_as_triple(dec):
+    sign, hi, lo, exp = decimal_as_triple(dec)
+
+    coeff = hi * 2**64 + lo
+    return (sign, coeff, exp)
+
+def c_from_triple(triple):
+    sign, coeff, exp = triple
+
+    hi = coeff // 2**64
+    lo = coeff % 2**64
+    return decimal_from_triple((sign, hi, lo, exp))
+
+def p_as_triple(dec):
+    sign, digits, exp = dec.as_tuple()
+
+    s = "".join(str(d) for d in digits)
+    coeff = int(s) if s else 0
+
+    if coeff < 0 or coeff >= 2**128:
+        raise ValueError("value out of bounds for a uint128 triple");
+
+    return (sign, coeff, exp)
+
+def p_from_triple(triple):
+    sign, coeff, exp = triple
+
+    if coeff < 0 or coeff >= 2**128:
+        raise ValueError("value out of bounds for a uint128 triple");
+
+    digits = tuple(int(c) for c in str(coeff))
+
+    return P.Decimal((sign, digits, exp))
+
+
 # ======================================================================
 #                            Unified Context
 # ======================================================================
@@ -846,12 +888,44 @@ def verify(t, stat):
         t.presults.append(str(t.rp.imag))
         t.presults.append(str(t.rp.real))
 
+        ctriple = None
+        if t.funcname not in ['__radd__', '__rmul__']: # see skip handler
+            try:
+                ctriple = c_as_triple(t.rc)
+            except ValueError:
+                try:
+                    ptriple = p_as_triple(t.rp)
+                except ValueError:
+                    pass
+                else:
+                    raise RuntimeError("ValueError not raised")
+            else:
+                cres = c_from_triple(ctriple)
+                t.cresults.append(ctriple)
+                t.cresults.append(str(cres))
+
+                ptriple = p_as_triple(t.rp)
+                pres = p_from_triple(ptriple)
+                t.presults.append(ptriple)
+                t.presults.append(str(pres))
+
         if t.with_maxcontext and isinstance(t.rmax, C.Decimal):
             t.maxresults.append(t.rmax.to_eng_string())
             t.maxresults.append(t.rmax.as_tuple())
             t.maxresults.append(str(t.rmax.imag))
             t.maxresults.append(str(t.rmax.real))
 
+            if ctriple is not None:
+                # NaN payloads etc. depend on precision and clamp.
+                if all_nan(t.rc) and all_nan(t.rmax):
+                    t.maxresults.append(ctriple)
+                    t.maxresults.append(str(cres))
+                else:
+                    maxtriple = c_as_triple(t.rmax)
+                    maxres = c_from_triple(maxtriple)
+                    t.maxresults.append(maxtriple)
+                    t.maxresults.append(str(maxres))
+
         nc = t.rc.number_class().lstrip('+-s')
         stat[nc] += 1
     else:
diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c
index fca94a83a5d04..593034ef65e2c 100644
--- a/Modules/_testcapimodule.c
+++ b/Modules/_testcapimodule.c
@@ -19,6 +19,7 @@
 
 #include "Python.h"
 #include "datetime.h"
+#include "pydecimal.h"
 #include "marshal.h"
 #include "structmember.h"         // PyMemberDef
 #include <float.h>
@@ -2705,6 +2706,252 @@ test_PyDateTime_DELTA_GET(PyObject *self, PyObject *obj)
     return Py_BuildValue("(lll)", days, seconds, microseconds);
 }
 
+/* Test decimal API */
+static int decimal_initialized = 0;
+static PyObject *
+decimal_is_special(PyObject *module, PyObject *dec)
+{
+    int is_special;
+
+    (void)module;
+    if (!decimal_initialized) {
+       if (import_decimal() < 0) {
+            return NULL;
+       }
+
+       decimal_initialized = 1;
+    }
+
+    is_special = PyDec_IsSpecial(dec);
+    if (is_special < 0) {
+        return NULL;
+    }
+
+    return PyBool_FromLong(is_special);
+}
+
+static PyObject *
+decimal_is_nan(PyObject *module, PyObject *dec)
+{
+    int is_nan;
+
+    (void)module;
+    if (!decimal_initialized) {
+       if (import_decimal() < 0) {
+            return NULL;
+       }
+
+       decimal_initialized = 1;
+    }
+
+    is_nan = PyDec_IsNaN(dec);
+    if (is_nan < 0) {
+        return NULL;
+    }
+
+    return PyBool_FromLong(is_nan);
+}
+
+static PyObject *
+decimal_is_infinite(PyObject *module, PyObject *dec)
+{
+    int is_infinite;
+
+    (void)module;
+    if (!decimal_initialized) {
+       if (import_decimal() < 0) {
+            return NULL;
+       }
+
+       decimal_initialized = 1;
+    }
+
+    is_infinite = PyDec_IsInfinite(dec);
+    if (is_infinite < 0) {
+        return NULL;
+    }
+
+    return PyBool_FromLong(is_infinite);
+}
+
+static PyObject *
+decimal_get_digits(PyObject *module, PyObject *dec)
+{
+    int64_t digits;
+
+    (void)module;
+    if (!decimal_initialized) {
+       if (import_decimal() < 0) {
+            return NULL;
+       }
+
+       decimal_initialized = 1;
+    }
+
+    digits = PyDec_GetDigits(dec);
+    if (digits < 0) {
+        return NULL;
+    }
+
+    return PyLong_FromLongLong(digits);
+}
+
+static PyObject *
+decimal_as_triple(PyObject *module, PyObject *dec)
+{
+    PyObject *tuple = NULL;
+    PyObject *sign, *hi, *lo;
+    mpd_uint128_triple_t triple;
+
+    (void)module;
+    if (!decimal_initialized) {
+       if (import_decimal() < 0) {
+            return NULL;
+       }
+
+       decimal_initialized = 1;
+    }
+
+    triple = PyDec_AsUint128Triple(dec);
+    if (triple.tag == MPD_TRIPLE_ERROR && PyErr_Occurred()) {
+        return NULL;
+    }
+
+    sign = PyLong_FromUnsignedLong(triple.sign);
+    if (sign == NULL) {
+        return NULL;
+    }
+
+    hi = PyLong_FromUnsignedLongLong(triple.hi);
+    if (hi == NULL) {
+        Py_DECREF(sign);
+        return NULL;
+    }
+
+    lo = PyLong_FromUnsignedLongLong(triple.lo);
+    if (lo == NULL) {
+        Py_DECREF(hi);
+        Py_DECREF(sign);
+        return NULL;
+    }
+
+    switch (triple.tag) {
+    case MPD_TRIPLE_QNAN:
+        assert(triple.exp == 0);
+        tuple = Py_BuildValue("(OOOs)", sign, hi, lo, "n");
+        break;
+
+    case MPD_TRIPLE_SNAN:
+        assert(triple.exp == 0);
+        tuple = Py_BuildValue("(OOOs)", sign, hi, lo, "N");
+        break;
+
+    case MPD_TRIPLE_INF:
+        assert(triple.hi == 0);
+        assert(triple.lo == 0);
+        assert(triple.exp == 0);
+        tuple = Py_BuildValue("(OOOs)", sign, hi, lo, "F");
+        break;
+
+    case MPD_TRIPLE_NORMAL:
+        tuple = Py_BuildValue("(OOOL)", sign, hi, lo, triple.exp);
+        break;
+
+    case MPD_TRIPLE_ERROR:
+        PyErr_SetString(PyExc_ValueError,
+            "value out of bounds for a uint128 triple");
+        break;
+
+    default:
+        PyErr_SetString(PyExc_RuntimeError,
+            "decimal_as_triple: internal error: unexpected tag");
+        break;
+    }
+
+    Py_DECREF(lo);
+    Py_DECREF(hi);
+    Py_DECREF(sign);
+
+    return tuple;
+}
+
+static PyObject *
+decimal_from_triple(PyObject *module, PyObject *tuple)
+{
+    mpd_uint128_triple_t triple = { MPD_TRIPLE_ERROR, 0, 0, 0, 0 };
+    PyObject *exp;
+    unsigned long sign;
+
+    (void)module;
+    if (!decimal_initialized) {
+       if (import_decimal() < 0) {
+            return NULL;
+       }
+
+       decimal_initialized = 1;
+    }
+
+    if (!PyTuple_Check(tuple)) {
+        PyErr_SetString(PyExc_TypeError, "argument must be a tuple");
+        return NULL;
+    }
+
+    if (PyTuple_GET_SIZE(tuple) != 4) {
+        PyErr_SetString(PyExc_ValueError, "tuple size must be 4");
+        return NULL;
+    }
+
+    sign = PyLong_AsUnsignedLong(PyTuple_GET_ITEM(tuple, 0));
+    if (sign == (unsigned long)-1 && PyErr_Occurred()) {
+        return NULL;
+    }
+    if (sign > UINT8_MAX) {
+        PyErr_SetString(PyExc_ValueError, "sign must be 0 or 1");
+        return NULL;
+    }
+    triple.sign = (uint8_t)sign;
+
+    triple.hi = PyLong_AsUnsignedLongLong(PyTuple_GET_ITEM(tuple, 1));
+    if (triple.hi == (unsigned long long)-1 && PyErr_Occurred()) {
+        return NULL;
+    }
+
+    triple.lo = PyLong_AsUnsignedLongLong(PyTuple_GET_ITEM(tuple, 2));
+    if (triple.lo == (unsigned long long)-1 && PyErr_Occurred()) {
+        return NULL;
+    }
+
+    exp = PyTuple_GET_ITEM(tuple, 3);
+    if (PyLong_Check(exp)) {
+        triple.tag = MPD_TRIPLE_NORMAL;
+        triple.exp = PyLong_AsLongLong(exp);
+        if (triple.exp == -1 && PyErr_Occurred()) {
+            return NULL;
+        }
+    }
+    else if (PyUnicode_Check(exp)) {
+        if (PyUnicode_CompareWithASCIIString(exp, "F") == 0) {
+            triple.tag = MPD_TRIPLE_INF;
+        }
+        else if (PyUnicode_CompareWithASCIIString(exp, "n") == 0) {
+            triple.tag = MPD_TRIPLE_QNAN;
+        }
+        else if (PyUnicode_CompareWithASCIIString(exp, "N") == 0) {
+            triple.tag = MPD_TRIPLE_SNAN;
+        }
+        else {
+            PyErr_SetString(PyExc_ValueError, "not a valid exponent");
+            return NULL;
+        }
+    }
+    else {
+        PyErr_SetString(PyExc_TypeError, "exponent must be int or string");
+        return NULL;
+    }
+
+    return PyDec_FromUint128Triple(&triple);
+}
+
 /* test_thread_state spawns a thread of its own, and that thread releases
  * `thread_done` when it's finished.  The driver code has to know when the
  * thread finishes, because the thread uses a PyObject (the callable) that
@@ -5314,6 +5561,12 @@ static PyMethodDef TestMethods[] = {
     {"PyDateTime_DATE_GET",        test_PyDateTime_DATE_GET,      METH_O},
     {"PyDateTime_TIME_GET",        test_PyDateTime_TIME_GET,      METH_O},
     {"PyDateTime_DELTA_GET",       test_PyDateTime_DELTA_GET,     METH_O},
+    {"decimal_is_special",      decimal_is_special,              METH_O},
+    {"decimal_is_nan",          decimal_is_nan,                  METH_O},
+    {"decimal_is_infinite",     decimal_is_infinite,             METH_O},
+    {"decimal_get_digits",      decimal_get_digits,              METH_O},
+    {"decimal_as_triple",       decimal_as_triple,               METH_O},
+    {"decimal_from_triple",     decimal_from_triple,             METH_O},
     {"test_list_api",           test_list_api,                   METH_NOARGS},
     {"test_dict_iteration",     test_dict_iteration,             METH_NOARGS},
     {"dict_getitem_knownhash",  dict_getitem_knownhash,          METH_VARARGS},



More information about the Python-checkins mailing list