[Python-checkins] cpython: Rewrite PyBytes_FromFormatV() using _PyBytesWriter API

victor.stinner python-checkins at python.org
Tue Oct 13 19:55:13 EDT 2015


https://hg.python.org/cpython/rev/388483b53cde
changeset:   98729:388483b53cde
user:        Victor Stinner <victor.stinner at gmail.com>
date:        Wed Oct 14 00:21:35 2015 +0200
summary:
  Rewrite PyBytes_FromFormatV() using _PyBytesWriter API

* Add much more unit tests on PyBytes_FromFormatV()
* Remove the first loop to compute the length of the output string
* Use _PyBytesWriter to handle the bytes buffer, use overallocation
* Cleanup the code to make simpler and easier to review

files:
  Lib/test/test_bytes.py |   90 ++++++-
  Objects/bytesobject.c  |  352 ++++++++++++++--------------
  2 files changed, 252 insertions(+), 190 deletions(-)


diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py
--- a/Lib/test/test_bytes.py
+++ b/Lib/test/test_bytes.py
@@ -783,25 +783,93 @@
     # Test PyBytes_FromFormat()
     def test_from_format(self):
         test.support.import_module('ctypes')
-        from ctypes import pythonapi, py_object, c_int, c_char_p
+        _testcapi = test.support.import_module('_testcapi')
+        from ctypes import pythonapi, py_object
+        from ctypes import (
+            c_int, c_uint,
+            c_long, c_ulong,
+            c_size_t, c_ssize_t,
+            c_char_p)
+
         PyBytes_FromFormat = pythonapi.PyBytes_FromFormat
         PyBytes_FromFormat.restype = py_object
 
+        # basic tests
         self.assertEqual(PyBytes_FromFormat(b'format'),
                          b'format')
+        self.assertEqual(PyBytes_FromFormat(b'Hello %s !', b'world'),
+                         b'Hello world !')
 
+        # test formatters
+        self.assertEqual(PyBytes_FromFormat(b'c=%c', c_int(0)),
+                         b'c=\0')
+        self.assertEqual(PyBytes_FromFormat(b'c=%c', c_int(ord('@'))),
+                         b'c=@')
+        self.assertEqual(PyBytes_FromFormat(b'c=%c', c_int(255)),
+                         b'c=\xff')
+        self.assertEqual(PyBytes_FromFormat(b'd=%d ld=%ld zd=%zd',
+                                            c_int(1), c_long(2),
+                                            c_size_t(3)),
+                         b'd=1 ld=2 zd=3')
+        self.assertEqual(PyBytes_FromFormat(b'd=%d ld=%ld zd=%zd',
+                                            c_int(-1), c_long(-2),
+                                            c_size_t(-3)),
+                         b'd=-1 ld=-2 zd=-3')
+        self.assertEqual(PyBytes_FromFormat(b'u=%u lu=%lu zu=%zu',
+                                            c_uint(123), c_ulong(456),
+                                            c_size_t(789)),
+                         b'u=123 lu=456 zu=789')
+        self.assertEqual(PyBytes_FromFormat(b'i=%i', c_int(123)),
+                         b'i=123')
+        self.assertEqual(PyBytes_FromFormat(b'i=%i', c_int(-123)),
+                         b'i=-123')
+        self.assertEqual(PyBytes_FromFormat(b'x=%x', c_int(0xabc)),
+                         b'x=abc')
+        self.assertEqual(PyBytes_FromFormat(b'ptr=%p',
+                                            c_char_p(0xabcdef)),
+                         b'ptr=0xabcdef')
+        self.assertEqual(PyBytes_FromFormat(b's=%s', c_char_p(b'cstr')),
+                         b's=cstr')
+
+        # test minimum and maximum integer values
+        size_max = c_size_t(-1).value
+        for formatstr, ctypes_type, value, py_formatter in (
+            (b'%d', c_int, _testcapi.INT_MIN, str),
+            (b'%d', c_int, _testcapi.INT_MAX, str),
+            (b'%ld', c_long, _testcapi.LONG_MIN, str),
+            (b'%ld', c_long, _testcapi.LONG_MAX, str),
+            (b'%lu', c_ulong, _testcapi.ULONG_MAX, str),
+            (b'%zd', c_ssize_t, _testcapi.PY_SSIZE_T_MIN, str),
+            (b'%zd', c_ssize_t, _testcapi.PY_SSIZE_T_MAX, str),
+            (b'%zu', c_size_t, size_max, str),
+            (b'%p', c_char_p, size_max, lambda value: '%#x' % value),
+        ):
+            self.assertEqual(PyBytes_FromFormat(formatstr, ctypes_type(value)),
+                             py_formatter(value).encode('ascii')),
+
+        # width and precision (width is currently ignored)
+        self.assertEqual(PyBytes_FromFormat(b'%5s', b'a'),
+                         b'a')
+        self.assertEqual(PyBytes_FromFormat(b'%.3s', b'abcdef'),
+                         b'abc')
+
+        # '%%' formatter
+        self.assertEqual(PyBytes_FromFormat(b'%%'),
+                         b'%')
+        self.assertEqual(PyBytes_FromFormat(b'[%%]'),
+                         b'[%]')
+        self.assertEqual(PyBytes_FromFormat(b'%%%c', c_int(ord('_'))),
+                         b'%_')
+        self.assertEqual(PyBytes_FromFormat(b'%%s'),
+                         b'%s')
+
+        # Invalid formats and partial formatting
         self.assertEqual(PyBytes_FromFormat(b'%'), b'%')
-        self.assertEqual(PyBytes_FromFormat(b'%%'), b'%')
-        self.assertEqual(PyBytes_FromFormat(b'%%s'), b'%s')
-        self.assertEqual(PyBytes_FromFormat(b'[%%]'), b'[%]')
-        self.assertEqual(PyBytes_FromFormat(b'%%%c', c_int(ord('_'))), b'%_')
+        self.assertEqual(PyBytes_FromFormat(b'x=%i y=%', c_int(2), c_int(3)),
+                         b'x=2 y=%')
 
-        self.assertEqual(PyBytes_FromFormat(b'c:%c', c_int(255)),
-                         b'c:\xff')
-        self.assertEqual(PyBytes_FromFormat(b's:%s', c_char_p(b'cstr')),
-                         b's:cstr')
-
-        # Issue #19969
+        # Issue #19969: %c must raise OverflowError for values
+        # not in the range [0; 255]
         self.assertRaises(OverflowError,
                           PyBytes_FromFormat, b'%c', c_int(-1))
         self.assertRaises(OverflowError,
diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c
--- a/Objects/bytesobject.c
+++ b/Objects/bytesobject.c
@@ -174,190 +174,184 @@
 PyObject *
 PyBytes_FromFormatV(const char *format, va_list vargs)
 {
-    va_list count;
-    Py_ssize_t n = 0;
-    const char* f;
     char *s;
-    PyObject* string;
-
-    Py_VA_COPY(count, vargs);
-    /* step 1: figure out how large a buffer we need */
+    const char *f;
+    const char *p;
+    Py_ssize_t prec;
+    int longflag;
+    int size_tflag;
+    /* Longest 64-bit formatted numbers:
+       - "18446744073709551615\0" (21 bytes)
+       - "-9223372036854775808\0" (21 bytes)
+       Decimal takes the most space (it isn't enough for octal.)
+
+       Longest 64-bit pointer representation:
+       "0xffffffffffffffff\0" (19 bytes). */
+    char buffer[21];
+    _PyBytesWriter writer;
+
+    _PyBytesWriter_Init(&writer);
+
+    s = _PyBytesWriter_Alloc(&writer, strlen(format));
+    if (s == NULL)
+        return NULL;
+    writer.overallocate = 1;
+
+#define WRITE_BYTES(str) \
+    do { \
+        s = _PyBytesWriter_WriteBytes(&writer, s, (str), strlen(str)); \
+        if (s == NULL) \
+            goto error; \
+    } while (0)
+
     for (f = format; *f; f++) {
-        if (*f == '%') {
-            const char* p = f;
-            while (*++f && *f != '%' && !Py_ISALPHA(*f))
-                ;
-
-            /* skip the 'l' or 'z' in {%ld, %zd, %lu, %zu} since
-             * they don't affect the amount of space we reserve.
-             */
-            if ((*f == 'l' || *f == 'z') &&
-                            (f[1] == 'd' || f[1] == 'u'))
-                ++f;
-
-            switch (*f) {
-            case 'c':
-            {
-                int c = va_arg(count, int);
-                if (c < 0 || c > 255) {
-                    PyErr_SetString(PyExc_OverflowError,
-                                    "PyBytes_FromFormatV(): %c format "
-                                    "expects an integer in range [0; 255]");
-                    return NULL;
-                }
-                n++;
-                break;
+        if (*f != '%') {
+            *s++ = *f;
+            continue;
+        }
+
+        p = f++;
+
+        /* ignore the width (ex: 10 in "%10s") */
+        while (Py_ISDIGIT(*f))
+            f++;
+
+        /* parse the precision (ex: 10 in "%.10s") */
+        prec = 0;
+        if (*f == '.') {
+            f++;
+            for (; Py_ISDIGIT(*f); f++) {
+                prec = (prec * 10) + (*f - '0');
             }
-            case '%':
-                n++;
-                break;
-            case 'd': case 'u': case 'i': case 'x':
-                (void) va_arg(count, int);
-                /* 20 bytes is enough to hold a 64-bit
-                   integer.  Decimal takes the most space.
-                   This isn't enough for octal. */
-                n += 20;
-                break;
-            case 's':
-                s = va_arg(count, char*);
-                n += strlen(s);
-                break;
-            case 'p':
-                (void) va_arg(count, int);
-                /* maximum 64-bit pointer representation:
-                 * 0xffffffffffffffff
-                 * so 19 characters is enough.
-                 * XXX I count 18 -- what's the extra for?
-                 */
-                n += 19;
-                break;
-            default:
-                /* if we stumble upon an unknown
-                   formatting code, copy the rest of
-                   the format string to the output
-                   string. (we cannot just skip the
-                   code, since there's no way to know
-                   what's in the argument list) */
-                n += strlen(p);
-                goto expand;
+        }
+
+        while (*f && *f != '%' && !Py_ISALPHA(*f))
+            f++;
+
+        /* handle the long flag ('l'), but only for %ld and %lu.
+           others can be added when necessary. */
+        longflag = 0;
+        if (*f == 'l' && (f[1] == 'd' || f[1] == 'u')) {
+            longflag = 1;
+            ++f;
+        }
+
+        /* handle the size_t flag ('z'). */
+        size_tflag = 0;
+        if (*f == 'z' && (f[1] == 'd' || f[1] == 'u')) {
+            size_tflag = 1;
+            ++f;
+        }
+
+        /* substract bytes preallocated for the format string
+           (ex: 2 for "%s") */
+        writer.min_size -= (f - p + 1);
+
+        switch (*f) {
+        case 'c':
+        {
+            int c = va_arg(vargs, int);
+            if (c < 0 || c > 255) {
+                PyErr_SetString(PyExc_OverflowError,
+                                "PyBytes_FromFormatV(): %c format "
+                                "expects an integer in range [0; 255]");
+                goto error;
             }
-        } else
-            n++;
+            writer.min_size++;
+            *s++ = (unsigned char)c;
+            break;
+        }
+
+        case 'd':
+            if (longflag)
+                sprintf(buffer, "%ld", va_arg(vargs, long));
+            else if (size_tflag)
+                sprintf(buffer, "%" PY_FORMAT_SIZE_T "d",
+                    va_arg(vargs, Py_ssize_t));
+            else
+                sprintf(buffer, "%d", va_arg(vargs, int));
+            assert(strlen(buffer) < sizeof(buffer));
+            WRITE_BYTES(buffer);
+            break;
+
+        case 'u':
+            if (longflag)
+                sprintf(buffer, "%lu",
+                    va_arg(vargs, unsigned long));
+            else if (size_tflag)
+                sprintf(buffer, "%" PY_FORMAT_SIZE_T "u",
+                    va_arg(vargs, size_t));
+            else
+                sprintf(buffer, "%u",
+                    va_arg(vargs, unsigned int));
+            assert(strlen(buffer) < sizeof(buffer));
+            WRITE_BYTES(buffer);
+            break;
+
+        case 'i':
+            sprintf(buffer, "%i", va_arg(vargs, int));
+            assert(strlen(buffer) < sizeof(buffer));
+            WRITE_BYTES(buffer);
+            break;
+
+        case 'x':
+            sprintf(buffer, "%x", va_arg(vargs, int));
+            assert(strlen(buffer) < sizeof(buffer));
+            WRITE_BYTES(buffer);
+            break;
+
+        case 's':
+        {
+            Py_ssize_t i;
+
+            p = va_arg(vargs, char*);
+            i = strlen(p);
+            if (prec > 0 && i > prec)
+                i = prec;
+            s = _PyBytesWriter_WriteBytes(&writer, s, p, i);
+            if (s == NULL)
+                goto error;
+            break;
+        }
+
+        case 'p':
+            sprintf(buffer, "%p", va_arg(vargs, void*));
+            assert(strlen(buffer) < sizeof(buffer));
+            /* %p is ill-defined:  ensure leading 0x. */
+            if (buffer[1] == 'X')
+                buffer[1] = 'x';
+            else if (buffer[1] != 'x') {
+                memmove(buffer+2, buffer, strlen(buffer)+1);
+                buffer[0] = '0';
+                buffer[1] = 'x';
+            }
+            WRITE_BYTES(buffer);
+            break;
+
+        case '%':
+            writer.min_size++;
+            *s++ = '%';
+            break;
+
+        default:
+            if (*f == 0) {
+                /* fix min_size if we reached the end of the format string */
+                writer.min_size++;
+            }
+
+            /* invalid format string: copy unformatted string and exit */
+            WRITE_BYTES(p);
+            return _PyBytesWriter_Finish(&writer, s);
+        }
     }
- expand:
-    /* step 2: fill the buffer */
-    /* Since we've analyzed how much space we need for the worst case,
-       use sprintf directly instead of the slower PyOS_snprintf. */
-    string = PyBytes_FromStringAndSize(NULL, n);
-    if (!string)
-        return NULL;
-
-    s = PyBytes_AsString(string);
-
-    for (f = format; *f; f++) {
-        if (*f == '%') {
-            const char* p = f++;
-            Py_ssize_t i;
-            int longflag = 0;
-            int size_tflag = 0;
-            /* parse the width.precision part (we're only
-               interested in the precision value, if any) */
-            n = 0;
-            while (Py_ISDIGIT(*f))
-                n = (n*10) + *f++ - '0';
-            if (*f == '.') {
-                f++;
-                n = 0;
-                while (Py_ISDIGIT(*f))
-                    n = (n*10) + *f++ - '0';
-            }
-            while (*f && *f != '%' && !Py_ISALPHA(*f))
-                f++;
-            /* handle the long flag, but only for %ld and %lu.
-               others can be added when necessary. */
-            if (*f == 'l' && (f[1] == 'd' || f[1] == 'u')) {
-                longflag = 1;
-                ++f;
-            }
-            /* handle the size_t flag. */
-            if (*f == 'z' && (f[1] == 'd' || f[1] == 'u')) {
-                size_tflag = 1;
-                ++f;
-            }
-
-            switch (*f) {
-            case 'c':
-            {
-                int c = va_arg(vargs, int);
-                /* c has been checked for overflow in the first step */
-                *s++ = (unsigned char)c;
-                break;
-            }
-            case 'd':
-                if (longflag)
-                    sprintf(s, "%ld", va_arg(vargs, long));
-                else if (size_tflag)
-                    sprintf(s, "%" PY_FORMAT_SIZE_T "d",
-                        va_arg(vargs, Py_ssize_t));
-                else
-                    sprintf(s, "%d", va_arg(vargs, int));
-                s += strlen(s);
-                break;
-            case 'u':
-                if (longflag)
-                    sprintf(s, "%lu",
-                        va_arg(vargs, unsigned long));
-                else if (size_tflag)
-                    sprintf(s, "%" PY_FORMAT_SIZE_T "u",
-                        va_arg(vargs, size_t));
-                else
-                    sprintf(s, "%u",
-                        va_arg(vargs, unsigned int));
-                s += strlen(s);
-                break;
-            case 'i':
-                sprintf(s, "%i", va_arg(vargs, int));
-                s += strlen(s);
-                break;
-            case 'x':
-                sprintf(s, "%x", va_arg(vargs, int));
-                s += strlen(s);
-                break;
-            case 's':
-                p = va_arg(vargs, char*);
-                i = strlen(p);
-                if (n > 0 && i > n)
-                    i = n;
-                Py_MEMCPY(s, p, i);
-                s += i;
-                break;
-            case 'p':
-                sprintf(s, "%p", va_arg(vargs, void*));
-                /* %p is ill-defined:  ensure leading 0x. */
-                if (s[1] == 'X')
-                    s[1] = 'x';
-                else if (s[1] != 'x') {
-                    memmove(s+2, s, strlen(s)+1);
-                    s[0] = '0';
-                    s[1] = 'x';
-                }
-                s += strlen(s);
-                break;
-            case '%':
-                *s++ = '%';
-                break;
-            default:
-                strcpy(s, p);
-                s += strlen(s);
-                goto end;
-            }
-        } else
-            *s++ = *f;
-    }
-
- end:
-    _PyBytes_Resize(&string, s - PyBytes_AS_STRING(string));
-    return string;
+
+#undef WRITE_BYTES
+
+    return _PyBytesWriter_Finish(&writer, s);
+
+ error:
+    _PyBytesWriter_Dealloc(&writer);
+    return NULL;
 }
 
 PyObject *

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list