[Python-checkins] bpo-37295: Speed up math.comb(n, k) for 0 <= k <= n <= 67 (GH-30275)

mdickinson webhook-mailer at python.org
Tue Dec 28 07:26:45 EST 2021


https://github.com/python/cpython/commit/02b5417f1107415abaf81acab7522f9aa84269ea
commit: 02b5417f1107415abaf81acab7522f9aa84269ea
branch: main
author: Mark Dickinson <mdickinson at enthought.com>
committer: mdickinson <dickinsm at gmail.com>
date: 2021-12-28T12:26:40Z
summary:

bpo-37295: Speed up math.comb(n, k) for 0 <= k <= n <= 67 (GH-30275)

files:
A Misc/NEWS.d/next/Library/2021-12-27-15-52-28.bpo-37295.s3LPo0.rst
M Modules/mathmodule.c

diff --git a/Misc/NEWS.d/next/Library/2021-12-27-15-52-28.bpo-37295.s3LPo0.rst b/Misc/NEWS.d/next/Library/2021-12-27-15-52-28.bpo-37295.s3LPo0.rst
new file mode 100644
index 0000000000000..a624f10637002
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-12-27-15-52-28.bpo-37295.s3LPo0.rst
@@ -0,0 +1 @@
+Add fast path for ``0 <= k <= n <= 67`` for :func:`math.comb`.
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 011ce0afd3aec..c4a23b6514f6b 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -3450,6 +3450,71 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
     return NULL;
 }
 
+/* least significant 64 bits of the odd part of factorial(n), for n in range(68).
+
+Python code to generate the values:
+
+    import math
+
+    for n in range(68):
+        fac = math.factorial(n)
+        fac_odd_part = fac // (fac & -fac)
+        reduced_fac_odd_part = fac_odd_part % (2**64)
+        print(f"{reduced_fac_odd_part:#018x}u")
+*/
+static uint64_t reduced_factorial_odd_part[] = {
+    0x0000000000000001u, 0x0000000000000001u, 0x0000000000000001u, 0x0000000000000003u,
+    0x0000000000000003u, 0x000000000000000fu, 0x000000000000002du, 0x000000000000013bu,
+    0x000000000000013bu, 0x0000000000000b13u, 0x000000000000375fu, 0x0000000000026115u,
+    0x000000000007233fu, 0x00000000005cca33u, 0x0000000002898765u, 0x00000000260eeeebu,
+    0x00000000260eeeebu, 0x0000000286fddd9bu, 0x00000016beecca73u, 0x000001b02b930689u,
+    0x00000870d9df20adu, 0x0000b141df4dae31u, 0x00079dd498567c1bu, 0x00af2e19afc5266du,
+    0x020d8a4d0f4f7347u, 0x335281867ec241efu, 0x9b3093d46fdd5923u, 0x5e1f9767cc5866b1u,
+    0x92dd23d6966aced7u, 0xa30d0f4f0a196e5bu, 0x8dc3e5a1977d7755u, 0x2ab8ce915831734bu,
+    0x2ab8ce915831734bu, 0x81d2a0bc5e5fdcabu, 0x9efcac82445da75bu, 0xbc8b95cf58cde171u,
+    0xa0e8444a1f3cecf9u, 0x4191deb683ce3ffdu, 0xddd3878bc84ebfc7u, 0xcb39a64b83ff3751u,
+    0xf8203f7993fc1495u, 0xbd2a2a78b35f4bddu, 0x84757be6b6d13921u, 0x3fbbcfc0b524988bu,
+    0xbd11ed47c8928df9u, 0x3c26b59e41c2f4c5u, 0x677a5137e883fdb3u, 0xff74e943b03b93ddu,
+    0xfe5ebbcb10b2bb97u, 0xb021f1de3235e7e7u, 0x33509eb2e743a58fu, 0x390f9da41279fb7du,
+    0xe5cb0154f031c559u, 0x93074695ba4ddb6du, 0x81c471caa636247fu, 0xe1347289b5a1d749u,
+    0x286f21c3f76ce2ffu, 0x00be84a2173e8ac7u, 0x1595065ca215b88bu, 0xf95877595b018809u,
+    0x9c2efe3c5516f887u, 0x373294604679382bu, 0xaf1ff7a888adcd35u, 0x18ddf279a2c5800bu,
+    0x18ddf279a2c5800bu, 0x505a90e2542582cbu, 0x5bacad2cd8d5dc2bu, 0xfe3152bcbff89f41u,
+};
+
+/* inverses of reduced_factorial_odd_part values modulo 2**64.
+
+Python code to generate the values:
+
+    import math
+
+    for n in range(68):
+        fac = math.factorial(n)
+        fac_odd_part = fac // (fac & -fac)
+        inverted_fac_odd_part = pow(fac_odd_part, -1, 2**64)
+        print(f"{inverted_fac_odd_part:#018x}u")
+*/
+static uint64_t inverted_factorial_odd_part[] = {
+    0x0000000000000001u, 0x0000000000000001u, 0x0000000000000001u, 0xaaaaaaaaaaaaaaabu,
+    0xaaaaaaaaaaaaaaabu, 0xeeeeeeeeeeeeeeefu, 0x4fa4fa4fa4fa4fa5u, 0x2ff2ff2ff2ff2ff3u,
+    0x2ff2ff2ff2ff2ff3u, 0x938cc70553e3771bu, 0xb71c27cddd93e49fu, 0xb38e3229fcdee63du,
+    0xe684bb63544a4cbfu, 0xc2f684917ca340fbu, 0xf747c9cba417526du, 0xbb26eb51d7bd49c3u,
+    0xbb26eb51d7bd49c3u, 0xb0a7efb985294093u, 0xbe4b8c69f259eabbu, 0x6854d17ed6dc4fb9u,
+    0xe1aa904c915f4325u, 0x3b8206df131cead1u, 0x79c6009fea76fe13u, 0xd8c5d381633cd365u,
+    0x4841f12b21144677u, 0x4a91ff68200b0d0fu, 0x8f9513a58c4f9e8bu, 0x2b3e690621a42251u,
+    0x4f520f00e03c04e7u, 0x2edf84ee600211d3u, 0xadcaa2764aaacdfdu, 0x161f4f9033f4fe63u,
+    0x161f4f9033f4fe63u, 0xbada2932ea4d3e03u, 0xcec189f3efaa30d3u, 0xf7475bb68330bf91u,
+    0x37eb7bf7d5b01549u, 0x46b35660a4e91555u, 0xa567c12d81f151f7u, 0x4c724007bb2071b1u,
+    0x0f4a0cce58a016bdu, 0xfa21068e66106475u, 0x244ab72b5a318ae1u, 0x366ce67e080d0f23u,
+    0xd666fdae5dd2a449u, 0xd740ddd0acc06a0du, 0xb050bbbb28e6f97bu, 0x70b003fe890a5c75u,
+    0xd03aabff83037427u, 0x13ec4ca72c783bd7u, 0x90282c06afdbd96fu, 0x4414ddb9db4a95d5u,
+    0xa2c68735ae6832e9u, 0xbf72d71455676665u, 0xa8469fab6b759b7fu, 0xc1e55b56e606caf9u,
+    0x40455630fc4a1cffu, 0x0120a7b0046d16f7u, 0xa7c3553b08faef23u, 0x9f0bfd1b08d48639u,
+    0xa433ffce9a304d37u, 0xa22ad1d53915c683u, 0xcb6cbc723ba5dd1du, 0x547fb1b8ab9d0ba3u,
+    0x547fb1b8ab9d0ba3u, 0x8f15a826498852e3u, 0x32e1a03f38880283u, 0x3de4cce63283f0c1u,
+};
+
+
 /*[clinic input]
 math.comb
 
@@ -3512,6 +3577,30 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
             goto done;
         }
         assert(ki >= 0);
+
+        if (ni <= 67) {
+            /*
+                For 0 <= k <= n <= 67, comb(n, k) always fits into a uint64_t.
+                We compute it as
+
+                    comb_odd_part << shift
+
+                where 2**shift is the largest power of two dividing comb(n, k)
+                and comb_odd_part is comb(n, k) >> shift. comb_odd_part can be
+                calculated efficiently via arithmetic modulo 2**64, using three
+                lookups and two uint64_t multiplications, while the necessary
+                shift can be computed via Kummer's theorem: it's the number of
+                carries when adding k to n - k in binary, which in turn is the
+                number of set bits of n ^ k ^ (n - k).
+            */
+            uint64_t comb_odd_part = reduced_factorial_odd_part[ni]
+                                   * inverted_factorial_odd_part[ki]
+                                   * inverted_factorial_odd_part[ni - ki];
+            int shift = _Py_popcount32((uint32_t)(ni ^ ki ^ (ni - ki)));
+            result = PyLong_FromUnsignedLongLong(comb_odd_part << shift);
+            goto done;
+        }
+
         ki = Py_MIN(ki, ni - ki);
         if (ki > 1) {
             result = perm_comb_small((unsigned long long)ni,



More information about the Python-checkins mailing list