[Python-checkins] bpo-36887: add math.isqrt (GH-13244)

Mark Dickinson webhook-mailer at python.org
Sat May 18 07:29:56 EDT 2019


https://github.com/python/cpython/commit/73934b9da07daefb203e7d26089e7486a1ce4fdf
commit: 73934b9da07daefb203e7d26089e7486a1ce4fdf
branch: master
author: Mark Dickinson <dickinsm at gmail.com>
committer: GitHub <noreply at github.com>
date: 2019-05-18T12:29:50+01:00
summary:

bpo-36887: add math.isqrt (GH-13244)

* Add math.isqrt function computing the integer square root.

* Code cleanup: remove redundant comments, rename some variables.

* Tighten up code a bit more; use Py_XDECREF to simplify error handling.

* Update Modules/mathmodule.c

Co-Authored-By: Serhiy Storchaka <storchaka at gmail.com>

* Update Modules/mathmodule.c

Use real argument clinic type instead of an alias

Co-Authored-By: Serhiy Storchaka <storchaka at gmail.com>

* Add proof sketch

* Updates from review.

* Correct and expand documentation.

* Fix bad reference handling on error; make some variables block-local; other tidying.

* Style and consistency fixes.

* Add missing error check; don't try to DECREF a NULL a

* Simplify some error returns.

* Another two test cases:

- clarify that floats are rejected even if they happen to be
  squares of small integers
- TypeError beats ValueError for a negative float

* Documentation and markup improvements; thanks Serhiy for the suggestions!

* Cleaner Misc/NEWS entry wording.

* Clean up (with one fix) to the algorithm explanation and proof.

files:
A Misc/NEWS.d/next/Library/2019-05-11-14-50-59.bpo-36887.XD3f22.rst
M Doc/library/math.rst
M Doc/whatsnew/3.8.rst
M Lib/test/test_math.py
M Modules/clinic/mathmodule.c.h
M Modules/mathmodule.c

diff --git a/Doc/library/math.rst b/Doc/library/math.rst
index 49f932d03845..bf660ae9defa 100644
--- a/Doc/library/math.rst
+++ b/Doc/library/math.rst
@@ -166,6 +166,20 @@ Number-theoretic and representation functions
    Return ``True`` if *x* is a NaN (not a number), and ``False`` otherwise.
 
 
+.. function:: isqrt(n)
+
+   Return the integer square root of the nonnegative integer *n*. This is the
+   floor of the exact square root of *n*, or equivalently the greatest integer
+   *a* such that *a*\ ² |nbsp| ≤ |nbsp| *n*.
+
+   For some applications, it may be more convenient to have the least integer
+   *a* such that *n* |nbsp| ≤ |nbsp| *a*\ ², or in other words the ceiling of
+   the exact square root of *n*. For positive *n*, this can be computed using
+   ``a = 1 + isqrt(n - 1)``.
+
+   .. versionadded:: 3.8
+
+
 .. function:: ldexp(x, i)
 
    Return ``x * (2**i)``.  This is essentially the inverse of function
@@ -538,3 +552,6 @@ Constants
 
    Module :mod:`cmath`
       Complex number versions of many of these functions.
+
+.. |nbsp| unicode:: 0xA0
+   :trim:
diff --git a/Doc/whatsnew/3.8.rst b/Doc/whatsnew/3.8.rst
index d47993bf1129..07da4047a383 100644
--- a/Doc/whatsnew/3.8.rst
+++ b/Doc/whatsnew/3.8.rst
@@ -344,6 +344,9 @@ Added new function, :func:`math.prod`, as analogous function to :func:`sum`
 that returns the product of a 'start' value (default: 1) times an iterable of
 numbers. (Contributed by Pablo Galindo in :issue:`35606`)
 
+Added new function :func:`math.isqrt` for computing integer square roots.
+(Contributed by Mark Dickinson in :issue:`36887`.)
+
 os
 --
 
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index cb05dee0e0fd..a11a34478564 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -912,6 +912,57 @@ class T(tuple):
             self.assertEqual(math.dist(p, q), 5*scale)
             self.assertEqual(math.dist(q, p), 5*scale)
 
+    def testIsqrt(self):
+        # Test a variety of inputs, large and small.
+        test_values = (
+            list(range(1000))
+            + list(range(10**6 - 1000, 10**6 + 1000))
+            + [3**9999, 10**5001]
+        )
+
+        for value in test_values:
+            with self.subTest(value=value):
+                s = math.isqrt(value)
+                self.assertIs(type(s), int)
+                self.assertLessEqual(s*s, value)
+                self.assertLess(value, (s+1)*(s+1))
+
+        # Negative values
+        with self.assertRaises(ValueError):
+            math.isqrt(-1)
+
+        # Integer-like things
+        s = math.isqrt(True)
+        self.assertIs(type(s), int)
+        self.assertEqual(s, 1)
+
+        s = math.isqrt(False)
+        self.assertIs(type(s), int)
+        self.assertEqual(s, 0)
+
+        class IntegerLike(object):
+            def __init__(self, value):
+                self.value = value
+
+            def __index__(self):
+                return self.value
+
+        s = math.isqrt(IntegerLike(1729))
+        self.assertIs(type(s), int)
+        self.assertEqual(s, 41)
+
+        with self.assertRaises(ValueError):
+            math.isqrt(IntegerLike(-3))
+
+        # Non-integer-like things
+        bad_values = [
+            3.5, "a string", decimal.Decimal("3.5"), 3.5j,
+            100.0, -4.0,
+        ]
+        for value in bad_values:
+            with self.subTest(value=value):
+                with self.assertRaises(TypeError):
+                    math.isqrt(value)
 
     def testLdexp(self):
         self.assertRaises(TypeError, math.ldexp)
diff --git a/Misc/NEWS.d/next/Library/2019-05-11-14-50-59.bpo-36887.XD3f22.rst b/Misc/NEWS.d/next/Library/2019-05-11-14-50-59.bpo-36887.XD3f22.rst
new file mode 100644
index 000000000000..fe2929cea85a
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-05-11-14-50-59.bpo-36887.XD3f22.rst
@@ -0,0 +1 @@
+Add new function :func:`math.isqrt` to compute integer square roots.
diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h
index 1806a01588c5..e677bd896cd8 100644
--- a/Modules/clinic/mathmodule.c.h
+++ b/Modules/clinic/mathmodule.c.h
@@ -65,6 +65,15 @@ PyDoc_STRVAR(math_fsum__doc__,
 #define MATH_FSUM_METHODDEF    \
     {"fsum", (PyCFunction)math_fsum, METH_O, math_fsum__doc__},
 
+PyDoc_STRVAR(math_isqrt__doc__,
+"isqrt($module, n, /)\n"
+"--\n"
+"\n"
+"Return the integer part of the square root of the input.");
+
+#define MATH_ISQRT_METHODDEF    \
+    {"isqrt", (PyCFunction)math_isqrt, METH_O, math_isqrt__doc__},
+
 PyDoc_STRVAR(math_factorial__doc__,
 "factorial($module, x, /)\n"
 "--\n"
@@ -628,4 +637,4 @@ math_prod(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *k
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=96e71135dce41c48 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=aeed62f403b90199 input=a9049054013a1b77]*/
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 8f6a303cc4de..821309221f82 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -1476,6 +1476,266 @@ count_set_bits(unsigned long n)
     return count;
 }
 
+/* Integer square root
+
+Given a nonnegative integer `n`, we want to compute the largest integer
+`a` for which `a * a <= n`, or equivalently the integer part of the exact
+square root of `n`.
+
+We use an adaptive-precision pure-integer version of Newton's iteration. Given
+a positive integer `n`, the algorithm produces at each iteration an integer
+approximation `a` to the square root of `n >> s` for some even integer `s`,
+with `s` decreasing as the iterations progress. On the final iteration, `s` is
+zero and we have an approximation to the square root of `n` itself.
+
+At every step, the approximation `a` is strictly within 1.0 of the true square
+root, so we have
+
+    (a - 1)**2 < (n >> s) < (a + 1)**2
+
+After the final iteration, a check-and-correct step is needed to determine
+whether `a` or `a - 1` gives the desired integer square root of `n`.
+
+The algorithm is remarkable in its simplicity. There's no need for a
+per-iteration check-and-correct step, and termination is straightforward: the
+number of iterations is known in advance (it's exactly `floor(log2(log2(n)))`
+for `n > 1`). The only tricky part of the correctness proof is in establishing
+that the bound `(a - 1)**2 < (n >> s) < (a + 1)**2` is maintained from one
+iteration to the next. A sketch of the proof of this is given below.
+
+In addition to the proof sketch, a formal, computer-verified proof
+of correctness (using Lean) of an equivalent recursive algorithm can be found
+here:
+
+    https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean
+
+
+Here's Python code equivalent to the C implementation below:
+
+    def isqrt(n):
+        """
+        Return the integer part of the square root of the input.
+        """
+        n = operator.index(n)
+
+        if n < 0:
+            raise ValueError("isqrt() argument must be nonnegative")
+        if n == 0:
+            return 0
+
+        c = (n.bit_length() - 1) // 2
+        a = 1
+        d = 0
+        for s in reversed(range(c.bit_length())):
+            e = d
+            d = c >> s
+            a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
+            assert (a-1)**2 < n >> 2*(c - d) < (a+1)**2
+
+        return a - (a*a > n)
+
+
+Sketch of proof of correctness
+------------------------------
+
+The delicate part of the correctness proof is showing that the loop invariant
+is preserved from one iteration to the next. That is, just before the line
+
+    a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
+
+is executed in the above code, we know that
+
+    (1)  (a - 1)**2 < (n >> 2*(c - e)) < (a + 1)**2.
+
+(since `e` is always the value of `d` from the previous iteration). We must
+prove that after that line is executed, we have
+
+    (a - 1)**2 < (n >> 2*(c - d)) < (a + 1)**2
+
+To faciliate the proof, we make some changes of notation. Write `m` for
+`n >> 2*(c-d)`, and write `b` for the new value of `a`, so
+
+    b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
+
+or equivalently:
+
+    (2)  b = (a << d - e - 1) + (m >> d - e + 1) // a
+
+Then we can rewrite (1) as:
+
+    (3)  (a - 1)**2 < (m >> 2*(d - e)) < (a + 1)**2
+
+and we must show that (b - 1)**2 < m < (b + 1)**2.
+
+From this point on, we switch to mathematical notation, so `/` means exact
+division rather than integer division and `^` is used for exponentiation. We
+use the `√` symbol for the exact square root. In (3), we can remove the
+implicit floor operation to give:
+
+    (4)  (a - 1)^2 < m / 4^(d - e) < (a + 1)^2
+
+Taking square roots throughout (4), scaling by `2^(d-e)`, and rearranging gives
+
+    (5)  0 <= | 2^(d-e)a - √m | < 2^(d-e)
+
+Squaring and dividing through by `2^(d-e+1) a` gives
+
+    (6)  0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a
+
+We'll show below that `2^(d-e-1) <= a`. Given that, we can replace the
+right-hand side of (6) with `1`, and now replacing the central
+term `m / (2^(d-e+1) a)` with its floor in (6) gives
+
+    (7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1
+
+Or equivalently, from (2):
+
+    (7) -1 < b - √m < 1
+
+and rearranging gives that `(b-1)^2 < m < (b+1)^2`, which is what we needed
+to prove.
+
+We're not quite done: we still have to prove the inequality `2^(d - e - 1) <=
+a` that was used to get line (7) above. From the definition of `c`, we have
+`4^c <= n`, which implies
+
+    (8)  4^d <= m
+
+also, since `e == d >> 1`, `d` is at most `2e + 1`, from which it follows
+that `2d - 2e - 1 <= d` and hence that
+
+    (9)  4^(2d - 2e - 1) <= m
+
+Dividing both sides by `4^(d - e)` gives
+
+    (10)  4^(d - e - 1) <= m / 4^(d - e)
+
+But we know from (4) that `m / 4^(d-e) < (a + 1)^2`, hence
+
+    (11)  4^(d - e - 1) < (a + 1)^2
+
+Now taking square roots of both sides and observing that both `2^(d-e-1)` and
+`a` are integers gives `2^(d - e - 1) <= a`, which is what we needed. This
+completes the proof sketch.
+
+*/
+
+/*[clinic input]
+math.isqrt
+
+    n: object
+    /
+
+Return the integer part of the square root of the input.
+[clinic start generated code]*/
+
+static PyObject *
+math_isqrt(PyObject *module, PyObject *n)
+/*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/
+{
+    int a_too_large, s;
+    size_t c, d;
+    PyObject *a = NULL, *b;
+
+    n = PyNumber_Index(n);
+    if (n == NULL) {
+        return NULL;
+    }
+
+    if (_PyLong_Sign(n) < 0) {
+        PyErr_SetString(
+            PyExc_ValueError,
+            "isqrt() argument must be nonnegative");
+        goto error;
+    }
+    if (_PyLong_Sign(n) == 0) {
+        Py_DECREF(n);
+        return PyLong_FromLong(0);
+    }
+
+    c = _PyLong_NumBits(n);
+    if (c == (size_t)(-1)) {
+        goto error;
+    }
+    c = (c - 1U) / 2U;
+
+    /* s = c.bit_length() */
+    s = 0;
+    while ((c >> s) > 0) {
+        ++s;
+    }
+
+    a = PyLong_FromLong(1);
+    if (a == NULL) {
+        goto error;
+    }
+    d = 0;
+    while (--s >= 0) {
+        PyObject *q, *shift;
+        size_t e = d;
+
+        d = c >> s;
+
+        /* q = (n >> 2*c - e - d + 1) // a */
+        shift = PyLong_FromSize_t(2U*c - d - e + 1U);
+        if (shift == NULL) {
+            goto error;
+        }
+        q = PyNumber_Rshift(n, shift);
+        Py_DECREF(shift);
+        if (q == NULL) {
+            goto error;
+        }
+        Py_SETREF(q, PyNumber_FloorDivide(q, a));
+        if (q == NULL) {
+            goto error;
+        }
+
+        /* a = (a << d - 1 - e) + q */
+        shift = PyLong_FromSize_t(d - 1U - e);
+        if (shift == NULL) {
+            Py_DECREF(q);
+            goto error;
+        }
+        Py_SETREF(a, PyNumber_Lshift(a, shift));
+        Py_DECREF(shift);
+        if (a == NULL) {
+            Py_DECREF(q);
+            goto error;
+        }
+        Py_SETREF(a, PyNumber_Add(a, q));
+        Py_DECREF(q);
+        if (a == NULL) {
+            goto error;
+        }
+    }
+
+    /* The correct result is either a or a - 1. Figure out which, and
+       decrement a if necessary. */
+
+    /* a_too_large = n < a * a */
+    b = PyNumber_Multiply(a, a);
+    if (b == NULL) {
+        goto error;
+    }
+    a_too_large = PyObject_RichCompareBool(n, b, Py_LT);
+    Py_DECREF(b);
+    if (a_too_large == -1) {
+        goto error;
+    }
+
+    if (a_too_large) {
+        Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One));
+    }
+    Py_DECREF(n);
+    return a;
+
+  error:
+    Py_XDECREF(a);
+    Py_DECREF(n);
+    return NULL;
+}
+
 /* Divide-and-conquer factorial algorithm
  *
  * Based on the formula and pseudo-code provided at:
@@ -2737,6 +2997,7 @@ static PyMethodDef math_methods[] = {
     MATH_ISFINITE_METHODDEF
     MATH_ISINF_METHODDEF
     MATH_ISNAN_METHODDEF
+    MATH_ISQRT_METHODDEF
     MATH_LDEXP_METHODDEF
     {"lgamma",          math_lgamma,    METH_O,         math_lgamma_doc},
     MATH_LOG_METHODDEF



More information about the Python-checkins mailing list