[Python-checkins] gh-90716: Refactor PyLong_FromString to separate concerns (GH-96808)

mdickinson webhook-mailer at python.org
Sun Sep 25 05:09:58 EDT 2022


https://github.com/python/cpython/commit/817fa28f81eed43539fad2c8e696df954afa6ad7
commit: 817fa28f81eed43539fad2c8e696df954afa6ad7
branch: main
author: Oscar Benjamin <oscar.j.benjamin at gmail.com>
committer: mdickinson <dickinsm at gmail.com>
date: 2022-09-25T10:09:50+01:00
summary:

gh-90716: Refactor PyLong_FromString to separate concerns (GH-96808)

This is a preliminary PR to refactor `PyLong_FromString` which is currently quite messy and has spaghetti like code that mixes up different concerns as well as duplicating logic.

In particular:

- `PyLong_FromString` now only handles sign, base and prefix detection and calls a new function `long_from_string_base` to parse the main body of the string.
- The `long_from_string_base` function handles all string validation and then calls `long_from_binary_base` or a new function `long_from_non_binary_base` to construct the actual `PyLong`.
- The existing `long_from_binary_base` function is simplified by factoring duplicated logic to `long_from_string_base`.
- The new function `long_from_non_binary_base` factors out much of the code from `PyLong_FromString` including in particular the quadratic algorithm reffered to in gh-95778 so that this can be seen separately from unrelated concerns such as string validation.

files:
A Misc/NEWS.d/next/Core and Builtins/2022-09-13-21-45-07.gh-issue-95778.Oll4_5.rst
M Objects/longobject.c

diff --git a/Misc/NEWS.d/next/Core and Builtins/2022-09-13-21-45-07.gh-issue-95778.Oll4_5.rst b/Misc/NEWS.d/next/Core and Builtins/2022-09-13-21-45-07.gh-issue-95778.Oll4_5.rst
new file mode 100644
index 000000000000..f202afc1f259
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2022-09-13-21-45-07.gh-issue-95778.Oll4_5.rst	
@@ -0,0 +1,2 @@
+The ``PyLong_FromString`` function was refactored to make it more maintainable
+and extensible.
diff --git a/Objects/longobject.c b/Objects/longobject.c
index c0bade182218..77a8782d8a67 100644
--- a/Objects/longobject.c
+++ b/Objects/longobject.c
@@ -2193,23 +2193,23 @@ unsigned char _PyLong_DigitValue[256] = {
     37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
 };
 
-/* *str points to the first digit in a string of base `base` digits.  base
- * is a power of 2 (2, 4, 8, 16, or 32).  *str is set to point to the first
- * non-digit (which may be *str!).  A normalized int is returned.
- * The point to this routine is that it takes time linear in the number of
- * string characters.
+/* `start` and `end` point to the start and end of a string of base `base`
+ * digits.  base is a power of 2 (2, 4, 8, 16, or 32). An unnormalized int is
+ * returned in *res. The string should be already validated by the caller and
+ * consists only of valid digit characters and underscores. `digits` gives the
+ * number of digit characters.
+ *
+ * The point to this routine is that it takes time linear in the
+ * number of string characters.
  *
  * Return values:
  *   -1 on syntax error (exception needs to be set, *res is untouched)
  *   0 else (exception may be set, in that case *res is set to NULL)
  */
 static int
-long_from_binary_base(const char **str, int base, PyLongObject **res)
+long_from_binary_base(const char *start, const char *end, Py_ssize_t digits, int base, PyLongObject **res)
 {
-    const char *p = *str;
-    const char *start = p;
-    char prev = 0;
-    Py_ssize_t digits = 0;
+    const char *p;
     int bits_per_char;
     Py_ssize_t n;
     PyLongObject *z;
@@ -2222,26 +2222,7 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
     for (bits_per_char = -1; n; ++bits_per_char) {
         n >>= 1;
     }
-    /* count digits and set p to end-of-string */
-    while (_PyLong_DigitValue[Py_CHARMASK(*p)] < base || *p == '_') {
-        if (*p == '_') {
-            if (prev == '_') {
-                *str = p - 1;
-                return -1;
-            }
-        } else {
-            ++digits;
-        }
-        prev = *p;
-        ++p;
-    }
-    if (prev == '_') {
-        /* Trailing underscore not allowed. */
-        *str = p - 1;
-        return -1;
-    }
 
-    *str = p;
     /* n <- the number of Python digits needed,
             = ceiling((digits * bits_per_char) / PyLong_SHIFT). */
     if (digits > (PY_SSIZE_T_MAX - (PyLong_SHIFT - 1)) / bits_per_char) {
@@ -2262,6 +2243,7 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
     accum = 0;
     bits_in_accum = 0;
     pdigit = z->ob_digit;
+    p = end;
     while (--p >= start) {
         int k;
         if (*p == '_') {
@@ -2286,88 +2268,14 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
     }
     while (pdigit - z->ob_digit < n)
         *pdigit++ = 0;
-    *res = long_normalize(z);
+    *res = z;
     return 0;
 }
 
-/* Parses an int from a bytestring. Leading and trailing whitespace will be
- * ignored.
- *
- * If successful, a PyLong object will be returned and 'pend' will be pointing
- * to the first unused byte unless it's NULL.
- *
- * If unsuccessful, NULL will be returned.
- */
-PyObject *
-PyLong_FromString(const char *str, char **pend, int base)
-{
-    int sign = 1, error_if_nonzero = 0;
-    const char *start, *orig_str = str;
-    PyLongObject *z = NULL;
-    PyObject *strobj;
-    Py_ssize_t slen;
-
-    if ((base != 0 && base < 2) || base > 36) {
-        PyErr_SetString(PyExc_ValueError,
-                        "int() arg 2 must be >= 2 and <= 36");
-        return NULL;
-    }
-    while (*str != '\0' && Py_ISSPACE(*str)) {
-        str++;
-    }
-    if (*str == '+') {
-        ++str;
-    }
-    else if (*str == '-') {
-        ++str;
-        sign = -1;
-    }
-    if (base == 0) {
-        if (str[0] != '0') {
-            base = 10;
-        }
-        else if (str[1] == 'x' || str[1] == 'X') {
-            base = 16;
-        }
-        else if (str[1] == 'o' || str[1] == 'O') {
-            base = 8;
-        }
-        else if (str[1] == 'b' || str[1] == 'B') {
-            base = 2;
-        }
-        else {
-            /* "old" (C-style) octal literal, now invalid.
-               it might still be zero though */
-            error_if_nonzero = 1;
-            base = 10;
-        }
-    }
-    if (str[0] == '0' &&
-        ((base == 16 && (str[1] == 'x' || str[1] == 'X')) ||
-         (base == 8  && (str[1] == 'o' || str[1] == 'O')) ||
-         (base == 2  && (str[1] == 'b' || str[1] == 'B')))) {
-        str += 2;
-        /* One underscore allowed here. */
-        if (*str == '_') {
-            ++str;
-        }
-    }
-    if (str[0] == '_') {
-        /* May not start with underscores. */
-        goto onError;
-    }
-
-    start = str;
-    if ((base & (base - 1)) == 0) {
-        /* binary bases are not limited by int_max_str_digits */
-        int res = long_from_binary_base(&str, base, &z);
-        if (res < 0) {
-            /* Syntax error. */
-            goto onError;
-        }
-    }
-    else {
 /***
+long_from_non_binary_base: parameters and return values are the same as
+long_from_binary_base.
+
 Binary bases can be converted in time linear in the number of digits, because
 Python's representation base is binary.  Other bases (including decimal!) use
 the simple quadratic-time algorithm below, complicated by some speed tricks.
@@ -2452,171 +2360,311 @@ that triggers it(!).  Instead the code was tested by artificially allocating
 just 1 digit at the start, so that the copying code was exercised for every
 digit beyond the first.
 ***/
-        twodigits c;           /* current input character */
-        Py_ssize_t size_z;
-        Py_ssize_t digits = 0;
-        int i;
-        int convwidth;
-        twodigits convmultmax, convmult;
-        digit *pz, *pzstop;
-        const char *scan, *lastdigit;
-        char prev = 0;
-
-        static double log_base_BASE[37] = {0.0e0,};
-        static int convwidth_base[37] = {0,};
-        static twodigits convmultmax_base[37] = {0,};
-
-        if (log_base_BASE[base] == 0.0) {
-            twodigits convmax = base;
-            int i = 1;
-
-            log_base_BASE[base] = (log((double)base) /
-                                   log((double)PyLong_BASE));
-            for (;;) {
-                twodigits next = convmax * base;
-                if (next > PyLong_BASE) {
-                    break;
-                }
-                convmax = next;
-                ++i;
+static int
+long_from_non_binary_base(const char *start, const char *end, Py_ssize_t digits, int base, PyLongObject **res)
+{
+    twodigits c;           /* current input character */
+    Py_ssize_t size_z;
+    int i;
+    int convwidth;
+    twodigits convmultmax, convmult;
+    digit *pz, *pzstop;
+    PyLongObject *z;
+    const char *p;
+
+    static double log_base_BASE[37] = {0.0e0,};
+    static int convwidth_base[37] = {0,};
+    static twodigits convmultmax_base[37] = {0,};
+
+    if (log_base_BASE[base] == 0.0) {
+        twodigits convmax = base;
+        int i = 1;
+
+        log_base_BASE[base] = (log((double)base) /
+                               log((double)PyLong_BASE));
+        for (;;) {
+            twodigits next = convmax * base;
+            if (next > PyLong_BASE) {
+                break;
+            }
+            convmax = next;
+            ++i;
+        }
+        convmultmax_base[base] = convmax;
+        assert(i > 0);
+        convwidth_base[base] = i;
+    }
+
+    /* Create an int object that can contain the largest possible
+     * integer with this base and length.  Note that there's no
+     * need to initialize z->ob_digit -- no slot is read up before
+     * being stored into.
+     */
+    double fsize_z = (double)digits * log_base_BASE[base] + 1.0;
+    if (fsize_z > (double)MAX_LONG_DIGITS) {
+        /* The same exception as in _PyLong_New(). */
+        PyErr_SetString(PyExc_OverflowError,
+                        "too many digits in integer");
+        *res = NULL;
+        return 0;
+    }
+    size_z = (Py_ssize_t)fsize_z;
+    /* Uncomment next line to test exceedingly rare copy code */
+    /* size_z = 1; */
+    assert(size_z > 0);
+    z = _PyLong_New(size_z);
+    if (z == NULL) {
+        *res = NULL;
+        return 0;
+    }
+    Py_SET_SIZE(z, 0);
+
+    /* `convwidth` consecutive input digits are treated as a single
+     * digit in base `convmultmax`.
+     */
+    convwidth = convwidth_base[base];
+    convmultmax = convmultmax_base[base];
+
+    /* Work ;-) */
+    p = start;
+    while (p < end) {
+        if (*p == '_') {
+            p++;
+            continue;
+        }
+        /* grab up to convwidth digits from the input string */
+        c = (digit)_PyLong_DigitValue[Py_CHARMASK(*p++)];
+        for (i = 1; i < convwidth && p != end; ++p) {
+            if (*p == '_') {
+                continue;
             }
-            convmultmax_base[base] = convmax;
-            assert(i > 0);
-            convwidth_base[base] = i;
+            i++;
+            c = (twodigits)(c *  base +
+                            (int)_PyLong_DigitValue[Py_CHARMASK(*p)]);
+            assert(c < PyLong_BASE);
         }
 
-        /* Find length of the string of numeric characters. */
-        scan = str;
-        lastdigit = str;
+        convmult = convmultmax;
+        /* Calculate the shift only if we couldn't get
+         * convwidth digits.
+         */
+        if (i != convwidth) {
+            convmult = base;
+            for ( ; i > 1; --i) {
+                convmult *= base;
+            }
+        }
 
-        while (_PyLong_DigitValue[Py_CHARMASK(*scan)] < base || *scan == '_') {
-            if (*scan == '_') {
-                if (prev == '_') {
-                    /* Only one underscore allowed. */
-                    str = lastdigit + 1;
-                    goto onError;
-                }
+        /* Multiply z by convmult, and add c. */
+        pz = z->ob_digit;
+        pzstop = pz + Py_SIZE(z);
+        for (; pz < pzstop; ++pz) {
+            c += (twodigits)*pz * convmult;
+            *pz = (digit)(c & PyLong_MASK);
+            c >>= PyLong_SHIFT;
+        }
+        /* carry off the current end? */
+        if (c) {
+            assert(c < PyLong_BASE);
+            if (Py_SIZE(z) < size_z) {
+                *pz = (digit)c;
+                Py_SET_SIZE(z, Py_SIZE(z) + 1);
             }
             else {
-                ++digits;
-                lastdigit = scan;
+                PyLongObject *tmp;
+                /* Extremely rare.  Get more space. */
+                assert(Py_SIZE(z) == size_z);
+                tmp = _PyLong_New(size_z + 1);
+                if (tmp == NULL) {
+                    Py_DECREF(z);
+                    *res = NULL;
+                    return 0;
+                }
+                memcpy(tmp->ob_digit,
+                       z->ob_digit,
+                       sizeof(digit) * size_z);
+                Py_DECREF(z);
+                z = tmp;
+                z->ob_digit[size_z] = (digit)c;
+                ++size_z;
             }
-            prev = *scan;
-            ++scan;
         }
-        if (prev == '_') {
-            /* Trailing underscore not allowed. */
-            /* Set error pointer to first underscore. */
-            str = lastdigit + 1;
-            goto onError;
+    }
+    *res = z;
+    return 0;
+}
+
+/* *str points to the first digit in a string of base `base` digits. base is an
+ * integer from 2 to 36 inclusive. Here we don't need to worry about prefixes
+ * like 0x or leading +- signs. The string should be null terminated consisting
+ * of ASCII digits and separating underscores possibly with trailing whitespace
+ * but we have to validate all of those points here.
+ *
+ * If base is a power of 2 then the complexity is linear in the number of
+ * characters in the string. Otherwise a quadratic algorithm is used for
+ * non-binary bases.
+ *
+ * Return values:
+ *
+ *   - Returns -1 on syntax error (exception needs to be set, *res is untouched)
+ *   - Returns 0 and sets *res to NULL for MemoryError/OverflowError.
+ *   - Returns 0 and sets *res to an unsigned, unnormalized PyLong (success!).
+ *
+ * Afterwards *str is set to point to the first non-digit (which may be *str!).
+ */
+static int
+long_from_string_base(const char **str, int base, PyLongObject **res)
+{
+    const char *start, *end, *p;
+    char prev = 0;
+    Py_ssize_t digits = 0;
+    int is_binary_base = (base & (base - 1)) == 0;
+
+    /* Here we do four things:
+     *
+     * - Find the `end` of the string.
+     * - Validate the string.
+     * - Count the number of `digits` (rather than underscores)
+     * - Point *str to the end-of-string or first invalid character.
+     */
+    start = p = *str;
+    /* Leading underscore not allowed. */
+    if (*start == '_') {
+        return -1;
+    }
+    /* Verify all characters are digits and underscores. */
+    while (_PyLong_DigitValue[Py_CHARMASK(*p)] < base || *p == '_') {
+        if (*p == '_') {
+            /* Double underscore not allowed. */
+            if (prev == '_') {
+                *str = p - 1;
+                return -1;
+            }
+        } else {
+            ++digits;
         }
+        prev = *p;
+        ++p;
+    }
+    /* Trailing underscore not allowed. */
+    if (prev == '_') {
+        *str = p - 1;
+        return -1;
+    }
+    *str = end = p;
+    /* Reject empty strings */
+    if (start == end) {
+        return -1;
+    }
+    /* Allow only trailing whitespace after `end` */
+    while (*p && Py_ISSPACE(*p)) {
+        p++;
+    }
+    *str = p;
+    if (*p != '\0') {
+        return -1;
+    }
 
-        /* Limit the size to avoid excessive computation attacks. */
+    /*
+     * Pass a validated string consisting of only valid digits and underscores
+     * to long_from_xxx_base.
+     */
+    if (is_binary_base) {
+        /* Use the linear algorithm for binary bases. */
+        return long_from_binary_base(start, end, digits, base, res);
+    }
+    else {
+        /* Limit the size to avoid excessive computation attacks exploiting the
+         * quadratic algorithm. */
         if (digits > _PY_LONG_MAX_STR_DIGITS_THRESHOLD) {
             PyInterpreterState *interp = _PyInterpreterState_GET();
             int max_str_digits = interp->int_max_str_digits;
             if ((max_str_digits > 0) && (digits > max_str_digits)) {
                 PyErr_Format(PyExc_ValueError, _MAX_STR_DIGITS_ERROR_FMT_TO_INT,
                              max_str_digits, digits);
-                return NULL;
+                *res = NULL;
+                return 0;
             }
         }
+        /* Use the quadratic algorithm for non binary bases. */
+        return long_from_non_binary_base(start, end, digits, base, res);
+    }
+}
 
-        /* Create an int object that can contain the largest possible
-         * integer with this base and length.  Note that there's no
-         * need to initialize z->ob_digit -- no slot is read up before
-         * being stored into.
-         */
-        double fsize_z = (double)digits * log_base_BASE[base] + 1.0;
-        if (fsize_z > (double)MAX_LONG_DIGITS) {
-            /* The same exception as in _PyLong_New(). */
-            PyErr_SetString(PyExc_OverflowError,
-                            "too many digits in integer");
-            return NULL;
+/* Parses an int from a bytestring. Leading and trailing whitespace will be
+ * ignored.
+ *
+ * If successful, a PyLong object will be returned and 'pend' will be pointing
+ * to the first unused byte unless it's NULL.
+ *
+ * If unsuccessful, NULL will be returned.
+ */
+PyObject *
+PyLong_FromString(const char *str, char **pend, int base)
+{
+    int sign = 1, error_if_nonzero = 0;
+    const char *orig_str = str;
+    PyLongObject *z = NULL;
+    PyObject *strobj;
+    Py_ssize_t slen;
+
+    if ((base != 0 && base < 2) || base > 36) {
+        PyErr_SetString(PyExc_ValueError,
+                        "int() arg 2 must be >= 2 and <= 36");
+        return NULL;
+    }
+    while (*str != '\0' && Py_ISSPACE(*str)) {
+        ++str;
+    }
+    if (*str == '+') {
+        ++str;
+    }
+    else if (*str == '-') {
+        ++str;
+        sign = -1;
+    }
+    if (base == 0) {
+        if (str[0] != '0') {
+            base = 10;
         }
-        size_z = (Py_ssize_t)fsize_z;
-        /* Uncomment next line to test exceedingly rare copy code */
-        /* size_z = 1; */
-        assert(size_z > 0);
-        z = _PyLong_New(size_z);
-        if (z == NULL) {
-            return NULL;
+        else if (str[1] == 'x' || str[1] == 'X') {
+            base = 16;
         }
-        Py_SET_SIZE(z, 0);
-
-        /* `convwidth` consecutive input digits are treated as a single
-         * digit in base `convmultmax`.
-         */
-        convwidth = convwidth_base[base];
-        convmultmax = convmultmax_base[base];
-
-        /* Work ;-) */
-        while (str < scan) {
-            if (*str == '_') {
-                str++;
-                continue;
-            }
-            /* grab up to convwidth digits from the input string */
-            c = (digit)_PyLong_DigitValue[Py_CHARMASK(*str++)];
-            for (i = 1; i < convwidth && str != scan; ++str) {
-                if (*str == '_') {
-                    continue;
-                }
-                i++;
-                c = (twodigits)(c *  base +
-                                (int)_PyLong_DigitValue[Py_CHARMASK(*str)]);
-                assert(c < PyLong_BASE);
-            }
-
-            convmult = convmultmax;
-            /* Calculate the shift only if we couldn't get
-             * convwidth digits.
-             */
-            if (i != convwidth) {
-                convmult = base;
-                for ( ; i > 1; --i) {
-                    convmult *= base;
-                }
-            }
-
-            /* Multiply z by convmult, and add c. */
-            pz = z->ob_digit;
-            pzstop = pz + Py_SIZE(z);
-            for (; pz < pzstop; ++pz) {
-                c += (twodigits)*pz * convmult;
-                *pz = (digit)(c & PyLong_MASK);
-                c >>= PyLong_SHIFT;
-            }
-            /* carry off the current end? */
-            if (c) {
-                assert(c < PyLong_BASE);
-                if (Py_SIZE(z) < size_z) {
-                    *pz = (digit)c;
-                    Py_SET_SIZE(z, Py_SIZE(z) + 1);
-                }
-                else {
-                    PyLongObject *tmp;
-                    /* Extremely rare.  Get more space. */
-                    assert(Py_SIZE(z) == size_z);
-                    tmp = _PyLong_New(size_z + 1);
-                    if (tmp == NULL) {
-                        Py_DECREF(z);
-                        return NULL;
-                    }
-                    memcpy(tmp->ob_digit,
-                           z->ob_digit,
-                           sizeof(digit) * size_z);
-                    Py_DECREF(z);
-                    z = tmp;
-                    z->ob_digit[size_z] = (digit)c;
-                    ++size_z;
-                }
-            }
+        else if (str[1] == 'o' || str[1] == 'O') {
+            base = 8;
+        }
+        else if (str[1] == 'b' || str[1] == 'B') {
+            base = 2;
+        }
+        else {
+            /* "old" (C-style) octal literal, now invalid.
+               it might still be zero though */
+            error_if_nonzero = 1;
+            base = 10;
         }
     }
+    if (str[0] == '0' &&
+        ((base == 16 && (str[1] == 'x' || str[1] == 'X')) ||
+         (base == 8  && (str[1] == 'o' || str[1] == 'O')) ||
+         (base == 2  && (str[1] == 'b' || str[1] == 'B')))) {
+        str += 2;
+        /* One underscore allowed here. */
+        if (*str == '_') {
+            ++str;
+        }
+    }
+
+    /* long_from_string_base is the main workhorse here. */
+    int ret = long_from_string_base(&str, base, &z);
+    if (ret == -1) {
+        /* Syntax error. */
+        goto onError;
+    }
     if (z == NULL) {
+        /* Error. exception already set. */
         return NULL;
     }
+
     if (error_if_nonzero) {
         /* reset the base to 0, else the exception message
            doesn't make too much sense */
@@ -2627,23 +2675,14 @@ digit beyond the first.
         /* there might still be other problems, therefore base
            remains zero here for the same reason */
     }
-    if (str == start) {
-        goto onError;
-    }
+
+    /* Set sign and normalize */
     if (sign < 0) {
         Py_SET_SIZE(z, -(Py_SIZE(z)));
     }
-    while (*str && Py_ISSPACE(*str)) {
-        str++;
-    }
-    if (*str != '\0') {
-        goto onError;
-    }
     long_normalize(z);
     z = maybe_small_long(z);
-    if (z == NULL) {
-        return NULL;
-    }
+
     if (pend != NULL) {
         *pend = (char *)str;
     }



More information about the Python-checkins mailing list