[Python-checkins] bpo-37295: Optimize math.comb() and math.perm() (GH-29090)

serhiy-storchaka webhook-mailer at python.org
Sun Dec 5 15:26:20 EST 2021


https://github.com/python/cpython/commit/60c320c38e4e95877cde0b1d8562ebd6bc02ac61
commit: 60c320c38e4e95877cde0b1d8562ebd6bc02ac61
branch: main
author: Serhiy Storchaka <storchaka at gmail.com>
committer: serhiy-storchaka <storchaka at gmail.com>
date: 2021-12-05T22:26:10+02:00
summary:

bpo-37295: Optimize math.comb() and math.perm() (GH-29090)

For very large numbers use divide-and-conquer algorithm for getting
benefit of Karatsuba multiplication of large numbers.

Do calculations completely in C unsigned long long instead of Python
integers if possible.

files:
A Misc/NEWS.d/next/Library/2021-10-18-16-08-55.bpo-37295.wBEWH2.rst
M Doc/whatsnew/3.11.rst
M Modules/mathmodule.c

diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst
index 10dc30939414c..b06d8d4215033 100644
--- a/Doc/whatsnew/3.11.rst
+++ b/Doc/whatsnew/3.11.rst
@@ -351,6 +351,11 @@ Optimizations
 * Pure ASCII strings are now normalized in constant time by :func:`unicodedata.normalize`.
   (Contributed by Dong-hee Na in :issue:`44987`.)
 
+* :mod:`math` functions :func:`~math.comb` and :func:`~math.perm` are now up
+  to 10 times or more faster for large arguments (the speed up is larger for
+  larger *k*).
+  (Contributed by Serhiy Storchaka in :issue:`37295`.)
+
 
 CPython bytecode changes
 ========================
diff --git a/Misc/NEWS.d/next/Library/2021-10-18-16-08-55.bpo-37295.wBEWH2.rst b/Misc/NEWS.d/next/Library/2021-10-18-16-08-55.bpo-37295.wBEWH2.rst
new file mode 100644
index 0000000000000..634f0c453884e
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-10-18-16-08-55.bpo-37295.wBEWH2.rst
@@ -0,0 +1 @@
+Optimize :func:`math.comb` and :func:`math.perm`.
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 64ce4e6a13fd5..84b5b954b1051 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
 }
 
 
+/* Number of permutations and combinations.
+ * P(n, k) = n! / (n-k)!
+ * C(n, k) = P(n, k) / k!
+ */
+
+/* Calculate C(n, k) for n in the 63-bit range. */
+static PyObject *
+perm_comb_small(unsigned long long n, unsigned long long k, int iscomb)
+{
+    /* long long is at least 64 bit */
+    static const unsigned long long fast_comb_limits[] = {
+        0, ULLONG_MAX, 4294967296ULL, 3329022, 102570, 13467, 3612, 1449,  // 0-7
+        746, 453, 308, 227, 178, 147, 125, 110,  // 8-15
+        99, 90, 84, 79, 75, 72, 69, 68,  // 16-23
+        66, 65, 64, 63, 63, 62, 62, 62,  // 24-31
+    };
+    static const unsigned long long fast_perm_limits[] = {
+        0, ULLONG_MAX, 4294967296ULL, 2642246, 65537, 7133, 1627, 568,  // 0-7
+        259, 142, 88, 61, 45, 36, 30,  // 8-14
+    };
+
+    if (k == 0) {
+        return PyLong_FromLong(1);
+    }
+
+    /* For small enough n and k the result fits in the 64-bit range and can
+     * be calculated without allocating intermediate PyLong objects. */
+    if (iscomb
+        ? (k < Py_ARRAY_LENGTH(fast_comb_limits)
+           && n <= fast_comb_limits[k])
+        : (k < Py_ARRAY_LENGTH(fast_perm_limits)
+            && n <= fast_perm_limits[k]))
+    {
+        unsigned long long result = n;
+        if (iscomb) {
+            for (unsigned long long i = 1; i < k;) {
+                result *= --n;
+                result /= ++i;
+            }
+        }
+        else {
+            for (unsigned long long i = 1; i < k;) {
+                result *= --n;
+                ++i;
+            }
+        }
+        return PyLong_FromUnsignedLongLong(result);
+    }
+
+    /* For larger n use recursive formula. */
+    /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
+    unsigned long long j = k / 2;
+    PyObject *a, *b;
+    a = perm_comb_small(n, j, iscomb);
+    if (a == NULL) {
+        return NULL;
+    }
+    b = perm_comb_small(n - j, k - j, iscomb);
+    if (b == NULL) {
+        goto error;
+    }
+    Py_SETREF(a, PyNumber_Multiply(a, b));
+    Py_DECREF(b);
+    if (iscomb && a != NULL) {
+        b = perm_comb_small(k, j, 1);
+        if (b == NULL) {
+            goto error;
+        }
+        Py_SETREF(a, PyNumber_FloorDivide(a, b));
+        Py_DECREF(b);
+    }
+    return a;
+
+error:
+    Py_DECREF(a);
+    return NULL;
+}
+
+/* Calculate P(n, k) or C(n, k) using recursive formulas.
+ * It is more efficient than sequential multiplication thanks to
+ * Karatsuba multiplication.
+ */
+static PyObject *
+perm_comb(PyObject *n, unsigned long long k, int iscomb)
+{
+    if (k == 0) {
+        return PyLong_FromLong(1);
+    }
+    if (k == 1) {
+        Py_INCREF(n);
+        return n;
+    }
+
+    /* P(n, k) = P(n, j) * P(n-j, k-j) */
+    /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
+    unsigned long long j = k / 2;
+    PyObject *a, *b;
+    a = perm_comb(n, j, iscomb);
+    if (a == NULL) {
+        return NULL;
+    }
+    PyObject *t = PyLong_FromUnsignedLongLong(j);
+    if (t == NULL) {
+        goto error;
+    }
+    n = PyNumber_Subtract(n, t);
+    Py_DECREF(t);
+    if (n == NULL) {
+        goto error;
+    }
+    b = perm_comb(n, k - j, iscomb);
+    Py_DECREF(n);
+    if (b == NULL) {
+        goto error;
+    }
+    Py_SETREF(a, PyNumber_Multiply(a, b));
+    Py_DECREF(b);
+    if (iscomb && a != NULL) {
+        b = perm_comb_small(k, j, 1);
+        if (b == NULL) {
+            goto error;
+        }
+        Py_SETREF(a, PyNumber_FloorDivide(a, b));
+        Py_DECREF(b);
+    }
+    return a;
+
+error:
+    Py_DECREF(a);
+    return NULL;
+}
+
 /*[clinic input]
 math.perm
 
@@ -3244,9 +3376,9 @@ static PyObject *
 math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
 /*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
 {
-    PyObject *result = NULL, *factor = NULL;
+    PyObject *result = NULL;
     int overflow, cmp;
-    long long i, factors;
+    long long ki, ni;
 
     if (k == Py_None) {
         return math_factorial(module, n);
@@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
         Py_DECREF(n);
         return NULL;
     }
+    assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
 
     if (Py_SIZE(n) < 0) {
         PyErr_SetString(PyExc_ValueError,
@@ -3281,42 +3414,26 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
         goto error;
     }
 
-    factors = PyLong_AsLongLongAndOverflow(k, &overflow);
+    ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+    assert(overflow >= 0 && !PyErr_Occurred());
     if (overflow > 0) {
         PyErr_Format(PyExc_OverflowError,
                      "k must not exceed %lld",
                      LLONG_MAX);
         goto error;
     }
-    else if (factors == -1) {
-        /* k is nonnegative, so a return value of -1 can only indicate error */
-        goto error;
-    }
+    assert(ki >= 0);
 
-    if (factors == 0) {
-        result = PyLong_FromLong(1);
-        goto done;
+    ni = PyLong_AsLongLongAndOverflow(n, &overflow);
+    assert(overflow >= 0 && !PyErr_Occurred());
+    if (!overflow && ki > 1) {
+        assert(ni >= 0);
+        result = perm_comb_small((unsigned long long)ni,
+                                 (unsigned long long)ki, 0);
     }
-
-    result = n;
-    Py_INCREF(result);
-    if (factors == 1) {
-        goto done;
-    }
-
-    factor = Py_NewRef(n);
-    PyObject *one = _PyLong_GetOne();  // borrowed ref
-    for (i = 1; i < factors; ++i) {
-        Py_SETREF(factor, PyNumber_Subtract(factor, one));
-        if (factor == NULL) {
-            goto error;
-        }
-        Py_SETREF(result, PyNumber_Multiply(result, factor));
-        if (result == NULL) {
-            goto error;
-        }
+    else {
+        result = perm_comb(n, (unsigned long long)ki, 0);
     }
-    Py_DECREF(factor);
 
 done:
     Py_DECREF(n);
@@ -3324,14 +3441,11 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
     return result;
 
 error:
-    Py_XDECREF(factor);
-    Py_XDECREF(result);
     Py_DECREF(n);
     Py_DECREF(k);
     return NULL;
 }
 
-
 /*[clinic input]
 math.comb
 
@@ -3357,9 +3471,9 @@ static PyObject *
 math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
 /*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
 {
-    PyObject *result = NULL, *factor = NULL, *temp;
+    PyObject *result = NULL, *temp;
     int overflow, cmp;
-    long long i, factors;
+    long long ki, ni;
 
     n = PyNumber_Index(n);
     if (n == NULL) {
@@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
         Py_DECREF(n);
         return NULL;
     }
+    assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
 
     if (Py_SIZE(n) < 0) {
         PyErr_SetString(PyExc_ValueError,
@@ -3382,73 +3497,59 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
         goto error;
     }
 
-    /* k = min(k, n - k) */
-    temp = PyNumber_Subtract(n, k);
-    if (temp == NULL) {
-        goto error;
-    }
-    if (Py_SIZE(temp) < 0) {
-        Py_DECREF(temp);
-        result = PyLong_FromLong(0);
-        goto done;
-    }
-    cmp = PyObject_RichCompareBool(temp, k, Py_LT);
-    if (cmp > 0) {
-        Py_SETREF(k, temp);
+    ni = PyLong_AsLongLongAndOverflow(n, &overflow);
+    assert(overflow >= 0 && !PyErr_Occurred());
+    if (!overflow) {
+        assert(ni >= 0);
+        ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+        assert(overflow >= 0 && !PyErr_Occurred());
+        if (overflow || ki > ni) {
+            result = PyLong_FromLong(0);
+            goto done;
+        }
+        assert(ki >= 0);
+        ki = Py_MIN(ki, ni - ki);
+        if (ki > 1) {
+            result = perm_comb_small((unsigned long long)ni,
+                                     (unsigned long long)ki, 1);
+            goto done;
+        }
+        /* For k == 1 just return the original n in perm_comb(). */
     }
     else {
-        Py_DECREF(temp);
-        if (cmp < 0) {
+        /* k = min(k, n - k) */
+        temp = PyNumber_Subtract(n, k);
+        if (temp == NULL) {
             goto error;
         }
-    }
-
-    factors = PyLong_AsLongLongAndOverflow(k, &overflow);
-    if (overflow > 0) {
-        PyErr_Format(PyExc_OverflowError,
-                     "min(n - k, k) must not exceed %lld",
-                     LLONG_MAX);
-        goto error;
-    }
-    if (factors == -1) {
-        /* k is nonnegative, so a return value of -1 can only indicate error */
-        goto error;
-    }
-
-    if (factors == 0) {
-        result = PyLong_FromLong(1);
-        goto done;
-    }
-
-    result = n;
-    Py_INCREF(result);
-    if (factors == 1) {
-        goto done;
-    }
-
-    factor = Py_NewRef(n);
-    PyObject *one = _PyLong_GetOne();  // borrowed ref
-    for (i = 1; i < factors; ++i) {
-        Py_SETREF(factor, PyNumber_Subtract(factor, one));
-        if (factor == NULL) {
-            goto error;
+        if (Py_SIZE(temp) < 0) {
+            Py_DECREF(temp);
+            result = PyLong_FromLong(0);
+            goto done;
         }
-        Py_SETREF(result, PyNumber_Multiply(result, factor));
-        if (result == NULL) {
-            goto error;
+        cmp = PyObject_RichCompareBool(temp, k, Py_LT);
+        if (cmp > 0) {
+            Py_SETREF(k, temp);
         }
-
-        temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
-        if (temp == NULL) {
-            goto error;
+        else {
+            Py_DECREF(temp);
+            if (cmp < 0) {
+                goto error;
+            }
         }
-        Py_SETREF(result, PyNumber_FloorDivide(result, temp));
-        Py_DECREF(temp);
-        if (result == NULL) {
+
+        ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+        assert(overflow >= 0 && !PyErr_Occurred());
+        if (overflow) {
+            PyErr_Format(PyExc_OverflowError,
+                         "min(n - k, k) must not exceed %lld",
+                         LLONG_MAX);
             goto error;
         }
+        assert(ki >= 0);
     }
-    Py_DECREF(factor);
+
+    result = perm_comb(n, (unsigned long long)ki, 1);
 
 done:
     Py_DECREF(n);
@@ -3456,8 +3557,6 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
     return result;
 
 error:
-    Py_XDECREF(factor);
-    Py_XDECREF(result);
     Py_DECREF(n);
     Py_DECREF(k);
     return NULL;



More information about the Python-checkins mailing list