[pypy-commit] pypy fix-bytearray-complexity: All bytearray ops work except for .replace. Some operands are still copied

waedt noreply at buildbot.pypy.org
Mon Jun 2 19:47:04 CEST 2014


Author: Tyler Wade <wayedt at gmail.com>
Branch: fix-bytearray-complexity
Changeset: r71873:fd4cad71fdab
Date: 2014-05-26 09:04 -0500
http://bitbucket.org/pypy/pypy/changeset/fd4cad71fdab/

Log:	All bytearray ops work except for .replace. Some operands are still
	copied

diff --git a/pypy/objspace/std/bytearrayobject.py b/pypy/objspace/std/bytearrayobject.py
--- a/pypy/objspace/std/bytearrayobject.py
+++ b/pypy/objspace/std/bytearrayobject.py
@@ -62,6 +62,10 @@
             raise oefmt(space.w_IndexError, "bytearray index out of range")
         return space.wrap(ord(character))
 
+    def _fillchar(self, space, w_fillchar):
+        c = self._op_val(space, w_fillchar)
+        return [c], len(c)
+
     def _val(self, space):
         return self.data
 
@@ -77,6 +81,9 @@
         assert len(char) == 1
         return str(char)[0]
 
+    def _multi_chr(self, char):
+        return [self._chr(char)]
+
     @staticmethod
     def _builder(size=100):
         return BytearrayBuilder(size)
@@ -495,7 +502,6 @@
             data[len(self.data) + i] = buffer.getitem(i)
         return self._new(data)
 
-
     def descr_reverse(self, space):
         self.data.reverse()
 
@@ -507,6 +513,13 @@
         for i in range(len(s)):
             self.data.append(s[i])
 
+    def append_multiple_char(self, c, count):
+        self.data.extend([c] * count)
+
+    def append_slice(self, value, start, end):
+        for i in range(start, end):
+            self.data.append(value[i])
+
     def build(self):
         return self.data
 
diff --git a/pypy/objspace/std/stringmethods.py b/pypy/objspace/std/stringmethods.py
--- a/pypy/objspace/std/stringmethods.py
+++ b/pypy/objspace/std/stringmethods.py
@@ -3,7 +3,9 @@
 from rpython.rlib import jit
 from rpython.rlib.objectmodel import specialize, newlist_hint
 from rpython.rlib.rarithmetic import ovfcheck
-from rpython.rlib.rstring import endswith, replace, rsplit, split, startswith
+from rpython.rlib.rstring import (
+    search, SEARCH_FIND, SEARCH_RFIND, SEARCH_COUNT, endswith, replace, rsplit,
+    split, startswith)
 
 from pypy.interpreter.error import OperationError, oefmt
 from pypy.interpreter.gateway import WrappedDefault, unwrap_spec
@@ -28,6 +30,9 @@
             space, lenself, w_start, w_end, upper_bound=upper_bound)
         return (value, start, end)
 
+    def _multi_chr(self, c):
+        return self._chr(c)
+
     def descr_len(self, space):
         return space.wrap(self._len())
 
@@ -41,7 +46,7 @@
             return space.newbool(value.find(other) >= 0)
 
         buffer = _get_buffer(space, w_sub)
-        res = _search_slowpath(value, buffer, 0, len(value), FAST_FIND)
+        res = search(value, buffer, 0, len(value), SEARCH_FIND)
         return space.newbool(res >= 0)
 
     def descr_add(self, space, w_other):
@@ -129,7 +134,7 @@
         d = width - len(value)
         if d > 0:
             offset = d//2 + (d & width & 1)
-            fillchar = fillchar[0]    # annotator hint: it's a single character
+            fillchar = self._multi_chr(fillchar)
             centered = offset * fillchar + value + (d - offset) * fillchar
         else:
             centered = value
@@ -144,7 +149,7 @@
                                             end))
 
         buffer = _get_buffer(space, w_sub)
-        res = _search_slowpath(value, buffer, start, end, FAST_COUNT)
+        res = search(value, buffer, start, end, SEARCH_COUNT)
         return space.wrap(max(res, 0))
 
     def descr_decode(self, space, w_encoding=None, w_errors=None):
@@ -152,7 +157,6 @@
             _get_encoding_and_errors, decode_object, unicode_from_string)
         encoding, errors = _get_encoding_and_errors(space, w_encoding,
                                                     w_errors)
-        # TODO: On CPython calling bytearray.decode with no arguments works.
         if encoding is None and errors is None:
             return unicode_from_string(space, self)
         return decode_object(space, self, encoding, errors)
@@ -170,7 +174,11 @@
         if not value:
             return self._empty()
 
-        splitted = value.split(self._chr('\t'))
+        if self._use_rstr_ops(value, self):
+            splitted = value.split(self._chr('\t'))
+        else:
+            splitted = split(value, self._chr('\t'))
+
         try:
             ovfcheck(len(splitted) * tabsize)
         except OverflowError:
@@ -178,7 +186,7 @@
         expanded = oldtoken = splitted.pop(0)
 
         for token in splitted:
-            expanded += self._chr(' ') * self._tabindent(oldtoken,
+            expanded += self._multi_chr(' ') * self._tabindent(oldtoken,
                                                          tabsize) + token
             oldtoken = token
 
@@ -215,7 +223,7 @@
             return space.wrap(res)
 
         buffer = _get_buffer(space, w_sub)
-        res = _search_slowpath(value, buffer, start, end, FAST_FIND)
+        res = search(value, buffer, start, end, SEARCH_FIND)
         return space.wrap(res)
 
     def descr_rfind(self, space, w_sub, w_start=None, w_end=None):
@@ -226,7 +234,7 @@
             return space.wrap(res)
 
         buffer = _get_buffer(space, w_sub)
-        res = _search_slowpath(value, buffer, start, end, FAST_RFIND)
+        res = search(value, buffer, start, end, SEARCH_RFIND)
         return space.wrap(res)
 
     def descr_index(self, space, w_sub, w_start=None, w_end=None):
@@ -236,7 +244,7 @@
             res = value.find(self._op_val(space, w_sub), start, end)
         else:
             buffer = _get_buffer(space, w_sub)
-            res = _search_slowpath(value, buffer, start, end, FAST_FIND)
+            res = search(value, buffer, start, end, SEARCH_FIND)
 
         if res < 0:
             raise oefmt(space.w_ValueError,
@@ -250,7 +258,7 @@
             res = value.rfind(self._op_val(space, w_sub), start, end)
         else:
             buffer = _get_buffer(space, w_sub)
-            res = _search_slowpath(value, buffer, start, end, FAST_RFIND)
+            res = search(value, buffer, start, end, SEARCH_RFIND)
 
         if res < 0:
             raise oefmt(space.w_ValueError,
@@ -401,7 +409,7 @@
                         "ljust() argument 2 must be a single character")
         d = width - len(value)
         if d > 0:
-            fillchar = fillchar[0]    # annotator hint: it's a single character
+            fillchar = self._multi_chr(fillchar)
             value += d * fillchar
 
         return self._new(value)
@@ -415,7 +423,7 @@
                         "rjust() argument 2 must be a single character")
         d = width - len(value)
         if d > 0:
-            fillchar = fillchar[0]    # annotator hint: it's a single character
+            fillchar = self._multi_chr(fillchar)
             value = d * fillchar + value
 
         return self._new(value)
@@ -443,7 +451,7 @@
         if self._use_rstr_ops(space, w_sub):
             pos = value.find(sub)
         else:
-            pos = _search_slowpath(value, sub, 0, len(value), FAST_FIND)
+            pos = search(value, sub, 0, len(value), SEARCH_FIND)
 
         if pos == -1:
             from pypy.objspace.std.bytearrayobject import W_BytearrayObject
@@ -456,7 +464,7 @@
                 w_sub = self._new_from_buffer(sub)
             return space.newtuple(
                 [self._sliced(space, value, 0, pos, self), w_sub,
-                 self._sliced(space, value, pos+len(sub), len(value), self)])
+                 self._sliced(space, value, pos + sublen, len(value), self)])
 
     def descr_rpartition(self, space, w_sub):
         value = self._val(space)
@@ -474,7 +482,7 @@
         if self._use_rstr_ops(space, w_sub):
             pos = value.rfind(sub)
         else:
-            pos = _search_slowpath(value, sub, 0, len(value), FAST_RFIND)
+            pos = search(value, sub, 0, len(value), SEARCH_RFIND)
 
         if pos == -1:
             from pypy.objspace.std.bytearrayobject import W_BytearrayObject
@@ -487,7 +495,7 @@
                 w_sub = self._new_from_buffer(sub)
             return space.newtuple(
                 [self._sliced(space, value, 0, pos, self), w_sub,
-                 self._sliced(space, value, pos+len(sub), len(value), self)])
+                 self._sliced(space, value, pos + sublen, len(value), self)])
 
     @unwrap_spec(count=int)
     def descr_replace(self, space, w_old, w_new, count=-1):
@@ -735,113 +743,3 @@
     return space.buffer_w(w_obj, space.BUF_SIMPLE)
 
 
-
-# Stolen form rpython.rtyper.lltypesytem.rstr
-# TODO: Ask about what to do with this...
-
-FAST_COUNT = 0
-FAST_FIND = 1
-FAST_RFIND = 2
-
-from rpython.rlib.rarithmetic import LONG_BIT as BLOOM_WIDTH
-
-def bloom_add(mask, c):
-    return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
-
-
-def bloom(mask, c):
-    return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
-
- at specialize.argtype(0, 1)
-def _search_slowpath(value, buffer, start, end, mode):
-    if start < 0:
-        start = 0
-    if end > len(value):
-        end = len(value)
-    if start > end:
-        return -1
-
-    count = 0
-    n = end - start
-    m = buffer.getlength()
-
-    if m == 0:
-        if mode == FAST_COUNT:
-            return end - start + 1
-        elif mode == FAST_RFIND:
-            return end
-        else:
-            return start
-
-    w = n - m
-
-    if w < 0:
-        return -1
-
-    mlast = m - 1
-    skip = mlast - 1
-    mask = 0
-
-    if mode != FAST_RFIND:
-        for i in range(mlast):
-            mask = bloom_add(mask, buffer.getitem(i))
-            if buffer.getitem(i) == buffer.getitem(mlast):
-                skip = mlast - i - 1
-        mask = bloom_add(mask, buffer.getitem(mlast))
-
-        i = start - 1
-        while i + 1 <= start + w:
-            i += 1
-            if value[i + m - 1] == buffer.getitem(m - 1):
-                for j in range(mlast):
-                    if value[i + j] != buffer.getitem(j):
-                        break
-                else:
-                    if mode != FAST_COUNT:
-                        return i
-                    count += 1
-                    i += mlast
-                    continue
-
-                if i + m < len(value):
-                    c = value[i + m]
-                else:
-                    c = '\0'
-                if not bloom(mask, c):
-                    i += m
-                else:
-                    i += skip
-            else:
-                if i + m < len(value):
-                    c = value[i + m]
-                else:
-                    c = '\0'
-                if not bloom(mask, c):
-                    i += m
-    else:
-        mask = bloom_add(mask, buffer.getitem(0))
-        for i in range(mlast, 0, -1):
-            mask = bloom_add(mask, buffer.getitem(i))
-            if buffer.getitem(i) == buffer.getitem(0):
-                skip = i - 1
-
-        i = start + w + 1
-        while i - 1 >= start:
-            i -= 1
-            if value[i] == buffer.getitem(0):
-                for j in xrange(mlast, 0, -1):
-                    if value[i + j] != buffer.getitem(j):
-                        break
-                else:
-                    return i
-                if i - 1 >= 0 and not bloom(mask, value[i - 1]):
-                    i -= m
-                else:
-                    i -= skip
-            else:
-                if i - 1 >= 0 and not bloom(mask, value[i - 1]):
-                    i -= m
-
-    if mode != FAST_COUNT:
-        return -1
-    return count
diff --git a/pypy/objspace/std/test/test_bytearrayobject.py b/pypy/objspace/std/test/test_bytearrayobject.py
--- a/pypy/objspace/std/test/test_bytearrayobject.py
+++ b/pypy/objspace/std/test/test_bytearrayobject.py
@@ -442,6 +442,7 @@
         u = b.decode('utf-8')
         assert isinstance(u, unicode)
         assert u == u'abcdefghi'
+        assert b.decode()
 
     def test_int(self):
         assert int(bytearray('-1234')) == -1234
diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py
--- a/rpython/rlib/rstring.py
+++ b/rpython/rlib/rstring.py
@@ -7,7 +7,8 @@
 from rpython.rtyper.llannotation import SomePtr
 from rpython.rlib import jit
 from rpython.rlib.objectmodel import newlist_hint, specialize
-from rpython.rlib.rarithmetic import ovfcheck
+from rpython.rlib.rarithmetic import ovfcheck, LONG_BIT as BLOOM_WIDTH
+from rpython.rlib.buffer import Buffer
 from rpython.rlib.unicodedata import unicodedb_5_2_0 as unicodedb
 from rpython.rtyper.extregistry import ExtRegistryEntry
 from rpython.tool.pairtype import pairtype
@@ -15,6 +16,35 @@
 
 # -------------- public API for string functions -----------------------
 
+ at specialize.argtype(0, 1)
+def _get_access_functions(value, other):
+    if isinstance(other, (str, unicode, list)):
+        def getitem(obj, i):
+            return obj[i]
+        def getlength(obj):
+            return len(obj)
+    else:
+        assert isinstance(other, Buffer)
+        def getitem(obj, i):
+            return obj.getitem(i)
+        def getlength(obj):
+            return obj.getlength()
+
+    if isinstance(value, list) or isinstance(other, Buffer):
+        def find(obj, other, start, end):
+            return search(obj, other, start, end, SEARCH_FIND)
+        def rfind(obj, other, start, end):
+            return search(obj, other, start, end, SEARCH_RFIND)
+    else:
+        assert isinstance(value, (str, unicode))
+        assert isinstance(other, (str, unicode))
+        def find(obj, other, start, end):
+            return obj.find(other, start, end)
+        def rfind(obj, other, start, end):
+            return obj.rfind(other, start, end)
+
+    return getitem, getlength, find, rfind
+
 @specialize.argtype(0)
 def _isspace(char):
     if isinstance(char, str):
@@ -55,10 +85,11 @@
             i = j + 1
         return res
 
-    if isinstance(value, str):
+    if isinstance(value, (list, str)):
         assert isinstance(by, str)
     else:
         assert isinstance(by, unicode)
+    _, _, find, _ = _get_access_functions(value, by)
     bylen = len(by)
     if bylen == 0:
         raise ValueError("empty separator")
@@ -72,7 +103,7 @@
             count = maxsplit
         res = newlist_hint(count + 1)
         while count > 0:
-            next = value.find(by, start)
+            next = find(value, by, start, len(value))
             assert next >= 0 # cannot fail due to the value.count above
             res.append(value[start:next])
             start = next + bylen
@@ -86,7 +117,7 @@
         res = []
 
     while maxsplit != 0:
-        next = value.find(by, start)
+        next = find(value, by, start, len(value))
         if next < 0:
             break
         res.append(value[start:next])
@@ -133,7 +164,7 @@
         res.reverse()
         return res
 
-    if isinstance(value, str):
+    if isinstance(value, (list, str)):
         assert isinstance(by, str)
     else:
         assert isinstance(by, unicode)
@@ -141,13 +172,14 @@
         res = newlist_hint(min(maxsplit + 1, len(value)))
     else:
         res = []
+    _, _, _, rfind = _get_access_functions(value, by)
     end = len(value)
     bylen = len(by)
     if bylen == 0:
         raise ValueError("empty separator")
 
     while maxsplit != 0:
-        next = value.rfind(by, 0, end)
+        next = rfind(value, by, 0, end)
         if next < 0:
             break
         res.append(value[next + bylen:end])
@@ -166,13 +198,20 @@
         assert isinstance(sub, str)
         assert isinstance(by, str)
         Builder = StringBuilder
-    else:
+    elif isinstance(input, unicode):
         assert isinstance(sub, unicode)
         assert isinstance(by, unicode)
         Builder = UnicodeBuilder
+    elif isinstance(input, list):
+        assert isinstance(sub, str)
+        assert isinstance(by, str)
+        # TODO: ????
+        Builder = StringBuilder
     if maxsplit == 0:
         return input
 
+    _, _, find, _ = _get_access_functions(input, sub)
+
     if not sub:
         upper = len(input)
         if maxsplit > 0 and maxsplit < upper + 2:
@@ -210,7 +249,7 @@
         sublen = len(sub)
 
         while maxsplit != 0:
-            next = input.find(sub, start)
+            next = find(input, sub, start, len(input))
             if next < 0:
                 break
             builder.append_slice(input, start, next)
@@ -261,6 +300,114 @@
             return False
     return True
 
+# Stolen form rpython.rtyper.lltypesytem.rstr
+# TODO: Ask about what to do with this...
+
+SEARCH_COUNT = 0
+SEARCH_FIND = 1
+SEARCH_RFIND = 2
+
+def bloom_add(mask, c):
+    return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+
+def bloom(mask, c):
+    return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+
+ at specialize.argtype(0, 1)
+def search(value, other, start, end, mode):
+    getitem, getlength, _, _ = _get_access_functions(value, other)
+    if start < 0:
+        start = 0
+    if end > len(value):
+        end = len(value)
+    if start > end:
+        return -1
+
+    count = 0
+    n = end - start
+    m = getlength(other)
+
+    if m == 0:
+        if mode == SEARCH_COUNT:
+            return end - start + 1
+        elif mode == SEARCH_RFIND:
+            return end
+        else:
+            return start
+
+    w = n - m
+
+    if w < 0:
+        return -1
+
+    mlast = m - 1
+    skip = mlast - 1
+    mask = 0
+
+    if mode != SEARCH_RFIND:
+        for i in range(mlast):
+            mask = bloom_add(mask, getitem(other, i))
+            if getitem(other, i) == getitem(other, mlast):
+                skip = mlast - i - 1
+        mask = bloom_add(mask, getitem(other, mlast))
+
+        i = start - 1
+        while i + 1 <= start + w:
+            i += 1
+            if value[i + m - 1] == getitem(other, m - 1):
+                for j in range(mlast):
+                    if value[i + j] != getitem(other, j):
+                        break
+                else:
+                    if mode != SEARCH_COUNT:
+                        return i
+                    count += 1
+                    i += mlast
+                    continue
+
+                if i + m < len(value):
+                    c = value[i + m]
+                else:
+                    c = '\0'
+                if not bloom(mask, c):
+                    i += m
+                else:
+                    i += skip
+            else:
+                if i + m < len(value):
+                    c = value[i + m]
+                else:
+                    c = '\0'
+                if not bloom(mask, c):
+                    i += m
+    else:
+        mask = bloom_add(mask, getitem(other, 0))
+        for i in range(mlast, 0, -1):
+            mask = bloom_add(mask, getitem(other, i))
+            if getitem(other, i) == getitem(other, 0):
+                skip = i - 1
+
+        i = start + w + 1
+        while i - 1 >= start:
+            i -= 1
+            if value[i] == getitem(other, 0):
+                for j in xrange(mlast, 0, -1):
+                    if value[i + j] != getitem(other, j):
+                        break
+                else:
+                    return i
+                if i - 1 >= 0 and not bloom(mask, value[i - 1]):
+                    i -= m
+                else:
+                    i -= skip
+            else:
+                if i - 1 >= 0 and not bloom(mask, value[i - 1]):
+                    i -= m
+
+    if mode != SEARCH_COUNT:
+        return -1
+    return count
+
 # -------------- numeric parsing support --------------------
 
 def strip_spaces(s):


More information about the pypy-commit mailing list