[Python-checkins] bpo-46218: Change long_pow() to sliding window algorithm (GH-30319)

tim-one webhook-mailer at python.org
Sun Jan 2 14:18:36 EST 2022


https://github.com/python/cpython/commit/863729e9c6f599286f98ec37c8716e982c4ca9dd
commit: 863729e9c6f599286f98ec37c8716e982c4ca9dd
branch: main
author: Tim Peters <tim.peters at gmail.com>
committer: tim-one <tim.peters at gmail.com>
date: 2022-01-02T13:18:20-06:00
summary:

bpo-46218: Change long_pow() to sliding window algorithm (GH-30319)

* bpo-46218: Change long_pow() to sliding window algorithm

The primary motivation is to eliminate long_pow's reliance on that the number of bits in a long "digit" is a multiple of 5. Now it no longer cares how many bits are in a digit.

But the sliding window approach also allows cutting the precomputed table of small powers in half, which reduces initialization overhead enough that the approach pays off for smaller exponents too. Depending on exponent bit patterns, a sliding window may also be able to save some bigint multiplies (sometimes when at least 5 consecutive exponent bits are 0, regardless of their starting bit position modulo 5).

Note: boosting the window width to 6 didn't work well overall. It give marginal speed improvements for huge exponents, but the increased overhead (the small-power table needs twice as many entries) made it a loss for smaller exponents.

Co-authored-by: Oleg Iarygin <dralife at yandex.ru>

files:
M Include/cpython/longintrepr.h
M Lib/test/test_pow.py
M Objects/longobject.c

diff --git a/Include/cpython/longintrepr.h b/Include/cpython/longintrepr.h
index ff4155f9656de..68dbf9c4382dc 100644
--- a/Include/cpython/longintrepr.h
+++ b/Include/cpython/longintrepr.h
@@ -21,8 +21,6 @@ extern "C" {
    PyLong_SHIFT.  The majority of the code doesn't care about the precise
    value of PyLong_SHIFT, but there are some notable exceptions:
 
-   - long_pow() requires that PyLong_SHIFT be divisible by 5
-
    - PyLong_{As,From}ByteArray require that PyLong_SHIFT be at least 8
 
    - long_hash() requires that PyLong_SHIFT is *strictly* less than the number
@@ -63,10 +61,6 @@ typedef long stwodigits; /* signed variant of twodigits */
 #define PyLong_BASE     ((digit)1 << PyLong_SHIFT)
 #define PyLong_MASK     ((digit)(PyLong_BASE - 1))
 
-#if PyLong_SHIFT % 5 != 0
-#error "longobject.c requires that PyLong_SHIFT be divisible by 5"
-#endif
-
 /* Long integer representation.
    The absolute value of a number is equal to
         SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i)
diff --git a/Lib/test/test_pow.py b/Lib/test/test_pow.py
index 660ff80bbf522..5cea9ceb20f5c 100644
--- a/Lib/test/test_pow.py
+++ b/Lib/test/test_pow.py
@@ -93,6 +93,28 @@ def test_other(self):
                             pow(int(i),j,k)
                         )
 
+    def test_big_exp(self):
+        import random
+        self.assertEqual(pow(2, 50000), 1 << 50000)
+        # Randomized modular tests, checking the identities
+        #  a**(b1 + b2) == a**b1 * a**b2
+        #  a**(b1 * b2) == (a**b1)**b2
+        prime = 1000000000039 # for speed, relatively small prime modulus
+        for i in range(10):
+            a = random.randrange(1000, 1000000)
+            bpower = random.randrange(1000, 50000)
+            b = random.randrange(1 << (bpower - 1), 1 << bpower)
+            b1 = random.randrange(1, b)
+            b2 = b - b1
+            got1 = pow(a, b, prime)
+            got2 = pow(a, b1, prime) * pow(a, b2, prime) % prime
+            if got1 != got2:
+                self.fail(f"{a=:x} {b1=:x} {b2=:x} {got1=:x} {got2=:x}")
+            got3 = pow(a, b1 * b2, prime)
+            got4 = pow(pow(a, b1, prime), b2, prime)
+            if got3 != got4:
+                self.fail(f"{a=:x} {b1=:x} {b2=:x} {got3=:x} {got4=:x}")
+
     def test_bug643260(self):
         class TestRpow:
             def __rpow__(self, other):
diff --git a/Objects/longobject.c b/Objects/longobject.c
index 09ae9455c5b26..b5648fca7dc5c 100644
--- a/Objects/longobject.c
+++ b/Objects/longobject.c
@@ -74,12 +74,34 @@ maybe_small_long(PyLongObject *v)
 #define KARATSUBA_CUTOFF 70
 #define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
 
-/* For exponentiation, use the binary left-to-right algorithm
- * unless the exponent contains more than FIVEARY_CUTOFF digits.
- * In that case, do 5 bits at a time.  The potential drawback is that
- * a table of 2**5 intermediate results is computed.
+/* For exponentiation, use the binary left-to-right algorithm unless the
+ ^ exponent contains more than HUGE_EXP_CUTOFF bits.  In that case, do
+ * (no more than) EXP_WINDOW_SIZE bits at a time.  The potential drawback is
+ * that a table of 2**(EXP_WINDOW_SIZE - 1) intermediate results is
+ * precomputed.
  */
-#define FIVEARY_CUTOFF 8
+#define EXP_WINDOW_SIZE 5
+#define EXP_TABLE_LEN (1 << (EXP_WINDOW_SIZE - 1))
+/* Suppose the exponent has bit length e. All ways of doing this
+ * need e squarings. The binary method also needs a multiply for
+ * each bit set. In a k-ary method with window width w, a multiply
+ * for each non-zero window, so at worst (and likely!)
+ * ceiling(e/w). The k-ary sliding window method has the same
+ * worst case, but the window slides so it can sometimes skip
+ * over an all-zero window that the fixed-window method can't
+ * exploit. In addition, the windowing methods need multiplies
+ * to precompute a table of small powers.
+ *
+ * For the sliding window method with width 5, 16 precomputation
+ * multiplies are needed. Assuming about half the exponent bits
+ * are set, then, the binary method needs about e/2 extra mults
+ * and the window method about 16 + e/5.
+ *
+ * The latter is smaller for e > 53 1/3. We don't have direct
+ * access to the bit length, though, so call it 60, which is a
+ * multiple of a long digit's max bit length (15 or 30 so far).
+ */
+#define HUGE_EXP_CUTOFF 60
 
 #define SIGCHECK(PyTryBlock)                    \
     do {                                        \
@@ -4172,14 +4194,15 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
     int negativeOutput = 0;  /* if x<0 return negative output */
 
     PyLongObject *z = NULL;  /* accumulated result */
-    Py_ssize_t i, j, k;             /* counters */
+    Py_ssize_t i, j;             /* counters */
     PyLongObject *temp = NULL;
+    PyLongObject *a2 = NULL; /* may temporarily hold a**2 % c */
 
-    /* 5-ary values.  If the exponent is large enough, table is
-     * precomputed so that table[i] == a**i % c for i in range(32).
+    /* k-ary values.  If the exponent is large enough, table is
+     * precomputed so that table[i] == a**(2*i+1) % c for i in
+     * range(EXP_TABLE_LEN).
      */
-    PyLongObject *table[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
-                               0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
+    PyLongObject *table[EXP_TABLE_LEN] = {0};
 
     /* a, b, c = v, w, x */
     CHECK_BINOP(v, w);
@@ -4332,7 +4355,7 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
         }
         /* else bi is 0, and z==1 is correct */
     }
-    else if (i <= FIVEARY_CUTOFF) {
+    else if (i <= HUGE_EXP_CUTOFF / PyLong_SHIFT ) {
         /* Left-to-right binary exponentiation (HAC Algorithm 14.79) */
         /* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf    */
 
@@ -4366,23 +4389,59 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
         }
     }
     else {
-        /* Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) */
-        Py_INCREF(z);           /* still holds 1L */
-        table[0] = z;
-        for (i = 1; i < 32; ++i)
-            MULT(table[i-1], a, table[i]);
+        /* Left-to-right k-ary sliding window exponentiation
+         * (Handbook of Applied Cryptography (HAC) Algorithm 14.85)
+         */
+        Py_INCREF(a);
+        table[0] = a;
+        MULT(a, a, a2);
+        /* table[i] == a**(2*i + 1) % c */
+        for (i = 1; i < EXP_TABLE_LEN; ++i)
+            MULT(table[i-1], a2, table[i]);
+        Py_CLEAR(a2);
+
+        /* Repeatedly extract the next (no more than) EXP_WINDOW_SIZE bits
+         * into `pending`, starting with the next 1 bit.  The current bit
+         * length of `pending` is `blen`.
+         */
+        int pending = 0, blen = 0;
+#define ABSORB_PENDING  do { \
+            int ntz = 0; /* number of trailing zeroes in `pending` */ \
+            assert(pending && blen); \
+            assert(pending >> (blen - 1)); \
+            assert(pending >> blen == 0); \
+            while ((pending & 1) == 0) { \
+                ++ntz; \
+                pending >>= 1; \
+            } \
+            assert(ntz < blen); \
+            blen -= ntz; \
+            do { \
+                MULT(z, z, z); \
+            } while (--blen); \
+            MULT(z, table[pending >> 1], z); \
+            while (ntz-- > 0) \
+                MULT(z, z, z); \
+            assert(blen == 0); \
+            pending = 0; \
+        } while(0)
 
         for (i = Py_SIZE(b) - 1; i >= 0; --i) {
             const digit bi = b->ob_digit[i];
-
-            for (j = PyLong_SHIFT - 5; j >= 0; j -= 5) {
-                const int index = (bi >> j) & 0x1f;
-                for (k = 0; k < 5; ++k)
+            for (j = PyLong_SHIFT - 1; j >= 0; --j) {
+                const int bit = (bi >> j) & 1;
+                pending = (pending << 1) | bit;
+                if (pending) {
+                    ++blen;
+                    if (blen == EXP_WINDOW_SIZE)
+                        ABSORB_PENDING;
+                }
+                else /* absorb strings of 0 bits */
                     MULT(z, z, z);
-                if (index)
-                    MULT(z, table[index], z);
             }
         }
+        if (pending)
+            ABSORB_PENDING;
     }
 
     if (negativeOutput && (Py_SIZE(z) != 0)) {
@@ -4399,13 +4458,14 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
     Py_CLEAR(z);
     /* fall through */
   Done:
-    if (Py_SIZE(b) > FIVEARY_CUTOFF) {
-        for (i = 0; i < 32; ++i)
+    if (Py_SIZE(b) > HUGE_EXP_CUTOFF / PyLong_SHIFT) {
+        for (i = 0; i < EXP_TABLE_LEN; ++i)
             Py_XDECREF(table[i]);
     }
     Py_DECREF(a);
     Py_DECREF(b);
     Py_XDECREF(c);
+    Py_XDECREF(a2);
     Py_XDECREF(temp);
     return (PyObject *)z;
 }



More information about the Python-checkins mailing list