[pypy-commit] pypy fix-bytearray-complexity: Most bytearray methods fixed

waedt noreply at buildbot.pypy.org
Mon Jun 2 19:47:03 CEST 2014


Author: Tyler Wade <wayedt at gmail.com>
Branch: fix-bytearray-complexity
Changeset: r71872:3521f66aed64
Date: 2014-05-26 03:48 -0500
http://bitbucket.org/pypy/pypy/changeset/3521f66aed64/

Log:	Most bytearray methods fixed

diff --git a/pypy/objspace/std/bytearrayobject.py b/pypy/objspace/std/bytearrayobject.py
--- a/pypy/objspace/std/bytearrayobject.py
+++ b/pypy/objspace/std/bytearrayobject.py
@@ -1,7 +1,7 @@
 """The builtin bytearray implementation"""
 
 from rpython.rlib.objectmodel import (
-    import_from_mixin, newlist_hint, resizelist_hint)
+    import_from_mixin, newlist_hint, resizelist_hint, specialize)
 from rpython.rlib.buffer import Buffer
 from rpython.rlib.rstring import StringBuilder
 
@@ -11,7 +11,7 @@
 from pypy.interpreter.signature import Signature
 from pypy.objspace.std.sliceobject import W_SliceObject
 from pypy.objspace.std.stdtypedef import StdTypeDef
-from pypy.objspace.std.stringmethods import StringMethods
+from pypy.objspace.std.stringmethods import StringMethods, _get_buffer
 from pypy.objspace.std.util import get_positive_index
 
 NON_HEX_MSG = "non-hexadecimal number found in fromhex() arg at position %d"
@@ -40,7 +40,11 @@
         return ''.join(self.data)
 
     def _new(self, value):
-        return W_BytearrayObject(_make_data(value))
+        return W_BytearrayObject(value)
+
+    def _new_from_buffer(self, buffer):
+        length = buffer.getlength()
+        return W_BytearrayObject([buffer.getitem(i) for i in range(length)])
 
     def _new_from_list(self, value):
         return W_BytearrayObject(value)
@@ -58,7 +62,12 @@
             raise oefmt(space.w_IndexError, "bytearray index out of range")
         return space.wrap(ord(character))
 
-    _val = charbuf_w
+    def _val(self, space):
+        return self.data
+
+    @staticmethod
+    def _use_rstr_ops(space, w_other):
+        return False
 
     @staticmethod
     def _op_val(space, w_other):
@@ -68,7 +77,9 @@
         assert len(char) == 1
         return str(char)[0]
 
-    _builder = StringBuilder
+    @staticmethod
+    def _builder(size=100):
+        return BytearrayBuilder(size)
 
     def _newlist_unwrapped(self, space, res):
         return space.newlist([W_BytearrayObject(_make_data(i)) for i in res])
@@ -260,58 +271,116 @@
         return space.wrap(''.join(self.data))
 
     def descr_eq(self, space, w_other):
+        if isinstance(w_other, W_BytearrayObject):
+            return space.newbool(self.data == w_other.data)
+
         try:
-            res = self._val(space) == self._op_val(space, w_other)
+            buffer = _get_buffer(space, w_other)
         except OperationError as e:
             if e.match(space, space.w_TypeError):
                 return space.w_NotImplemented
             raise
-        return space.newbool(res)
+
+        value = self._val(space)
+        buffer_len = buffer.getlength()
+
+        if len(value) != buffer_len:
+            return space.newbool(False)
+
+        min_length = min(len(value), buffer_len)
+        return space.newbool(_memcmp(value, buffer, min_length) == 0)
 
     def descr_ne(self, space, w_other):
+        if isinstance(w_other, W_BytearrayObject):
+            return space.newbool(self.data != w_other.data)
+
         try:
-            res = self._val(space) != self._op_val(space, w_other)
+            buffer = _get_buffer(space, w_other)
         except OperationError as e:
             if e.match(space, space.w_TypeError):
                 return space.w_NotImplemented
             raise
-        return space.newbool(res)
+
+        value = self._val(space)
+        buffer_len = buffer.getlength()
+
+        if len(value) != buffer_len:
+            return space.newbool(True)
+
+        min_length = min(len(value), buffer_len)
+        return space.newbool(_memcmp(value, buffer, min_length) != 0)
 
     def descr_lt(self, space, w_other):
+        if isinstance(w_other, W_BytearrayObject):
+            return space.newbool(self.data < w_other.data)
+
         try:
-            res = self._val(space) < self._op_val(space, w_other)
+            buffer = _get_buffer(space, w_other)
         except OperationError as e:
             if e.match(space, space.w_TypeError):
                 return space.w_NotImplemented
             raise
-        return space.newbool(res)
+
+        value = self._val(space)
+        buffer_len = buffer.getlength()
+
+        cmp = _memcmp(value, buffer, min(len(value), buffer_len))
+        return space.newbool(
+            cmp < 0 or (cmp == 0 and space.newbool(len(value) < buffer_len)))
 
     def descr_le(self, space, w_other):
+        if isinstance(w_other, W_BytearrayObject):
+            return space.newbool(self.data <= w_other.data)
+
         try:
-            res = self._val(space) <= self._op_val(space, w_other)
+            buffer = _get_buffer(space, w_other)
         except OperationError as e:
             if e.match(space, space.w_TypeError):
                 return space.w_NotImplemented
             raise
-        return space.newbool(res)
+
+        value = self._val(space)
+        buffer_len = buffer.getlength()
+
+        cmp = _memcmp(value, buffer, min(len(value), buffer_len))
+        return space.newbool(
+            cmp < 0 or (cmp == 0 and space.newbool(len(value) <= buffer_len)))
 
     def descr_gt(self, space, w_other):
+        if isinstance(w_other, W_BytearrayObject):
+            return space.newbool(self.data > w_other.data)
+
         try:
-            res = self._val(space) > self._op_val(space, w_other)
+            buffer = _get_buffer(space, w_other)
         except OperationError as e:
             if e.match(space, space.w_TypeError):
                 return space.w_NotImplemented
             raise
-        return space.newbool(res)
+
+        value = self._val(space)
+        buffer_len = buffer.getlength()
+
+        cmp = _memcmp(value, buffer, min(len(value), buffer_len))
+        return space.newbool(
+            cmp > 0 or (cmp == 0 and space.newbool(len(value) > buffer_len)))
 
     def descr_ge(self, space, w_other):
+        if isinstance(w_other, W_BytearrayObject):
+            return space.newbool(self.data >= w_other.data)
+
         try:
-            res = self._val(space) >= self._op_val(space, w_other)
+            buffer = _get_buffer(space, w_other)
         except OperationError as e:
             if e.match(space, space.w_TypeError):
                 return space.w_NotImplemented
             raise
-        return space.newbool(res)
+
+        value = self._val(space)
+        buffer_len = buffer.getlength()
+
+        cmp = _memcmp(value, buffer, min(len(value), buffer_len))
+        return space.newbool(
+            cmp > 0 or (cmp == 0 and space.newbool(len(value) >= buffer_len)))
 
     def descr_iter(self, space):
         return space.newseqiter(self)
@@ -319,8 +388,11 @@
     def descr_inplace_add(self, space, w_other):
         if isinstance(w_other, W_BytearrayObject):
             self.data += w_other.data
-        else:
-            self.data += self._op_val(space, w_other)
+            return self
+
+        buffer = _get_buffer(space, w_other)
+        for i in range(buffer.getlength()):
+            self.data.append(buffer.getitem(i))
         return self
 
     def descr_inplace_mul(self, space, w_times):
@@ -403,11 +475,42 @@
         if space.isinstance_w(w_sub, space.w_int):
             char = space.int_w(w_sub)
             return _descr_contains_bytearray(self.data, space, char)
+
         return self._StringMethods_descr_contains(space, w_sub)
 
+    def descr_add(self, space, w_other):
+        if isinstance(w_other, W_BytearrayObject):
+            return self._new(self.data + w_other.data)
+
+        try:
+            buffer = _get_buffer(space, w_other)
+        except OperationError as e:
+            if e.match(space, space.w_TypeError):
+                return space.w_NotImplemented
+            raise
+
+        buffer_len = buffer.getlength()
+        data = list(self.data + ['\0'] * buffer_len)
+        for i in range(buffer_len):
+            data[len(self.data) + i] = buffer.getitem(i)
+        return self._new(data)
+
+
     def descr_reverse(self, space):
         self.data.reverse()
 
+class BytearrayBuilder(object):
+    def __init__(self, size):
+        self.data = newlist_hint(size)
+
+    def append(self, s):
+        for i in range(len(s)):
+            self.data.append(s[i])
+
+    def build(self):
+        return self.data
+
+
 
 # ____________________________________________________________
 # helpers for slow paths, moved out because they contain loops
@@ -1152,3 +1255,13 @@
 
     def setitem(self, index, char):
         self.data[index] = char
+
+
+ at specialize.argtype(0)
+def _memcmp(selfvalue, buffer, length):
+    for i in range(length):
+        if selfvalue[i] < buffer.getitem(i):
+            return -1
+        if selfvalue[i] > buffer.getitem(i):
+            return 1
+    return 0
diff --git a/pypy/objspace/std/bytesobject.py b/pypy/objspace/std/bytesobject.py
--- a/pypy/objspace/std/bytesobject.py
+++ b/pypy/objspace/std/bytesobject.py
@@ -480,6 +480,11 @@
     _val = str_w
 
     @staticmethod
+    def _use_rstr_ops(space, w_other):
+        from pypy.objspace.std.unicodeobject import W_UnicodeObject
+        return isinstance(w_other, (W_BytesObject, W_UnicodeObject))
+
+    @staticmethod
     def _op_val(space, w_other):
         try:
             return space.str_w(w_other)
diff --git a/pypy/objspace/std/stringmethods.py b/pypy/objspace/std/stringmethods.py
--- a/pypy/objspace/std/stringmethods.py
+++ b/pypy/objspace/std/stringmethods.py
@@ -1,7 +1,7 @@
 """Functionality shared between bytes/bytearray/unicode"""
 
 from rpython.rlib import jit
-from rpython.rlib.objectmodel import specialize
+from rpython.rlib.objectmodel import specialize, newlist_hint
 from rpython.rlib.rarithmetic import ovfcheck
 from rpython.rlib.rstring import endswith, replace, rsplit, split, startswith
 
@@ -36,17 +36,27 @@
 
     def descr_contains(self, space, w_sub):
         value = self._val(space)
-        other = self._op_val(space, w_sub)
-        return space.newbool(value.find(other) >= 0)
+        if self._use_rstr_ops(space, w_sub):
+            other = self._op_val(space, w_sub)
+            return space.newbool(value.find(other) >= 0)
+
+        buffer = _get_buffer(space, w_sub)
+        res = _search_slowpath(value, buffer, 0, len(value), FAST_FIND)
+        return space.newbool(res >= 0)
 
     def descr_add(self, space, w_other):
-        try:
-            other = self._op_val(space, w_other)
-        except OperationError as e:
-            if e.match(space, space.w_TypeError):
-                return space.w_NotImplemented
-            raise
-        return self._new(self._val(space) + other)
+        if self._use_rstr_ops(space, w_other):
+            try:
+                other = self._op_val(space, w_other)
+            except OperationError as e:
+                if e.match(space, space.w_TypeError):
+                    return space.w_NotImplemented
+                raise
+            return self._new(self._val(space) + other)
+
+        # Bytearray overrides this method, CPython doesn't support contacting
+        # buffers and strs, and unicodes are always handled above
+        return space.w_NotImplemented
 
     def descr_mul(self, space, w_times):
         try:
@@ -128,14 +138,21 @@
 
     def descr_count(self, space, w_sub, w_start=None, w_end=None):
         value, start, end = self._convert_idx_params(space, w_start, w_end)
-        return space.newint(value.count(self._op_val(space, w_sub), start,
-                                        end))
+
+        if self._use_rstr_ops(space, w_sub):
+            return space.newint(value.count(self._op_val(space, w_sub), start,
+                                            end))
+
+        buffer = _get_buffer(space, w_sub)
+        res = _search_slowpath(value, buffer, start, end, FAST_COUNT)
+        return space.wrap(max(res, 0))
 
     def descr_decode(self, space, w_encoding=None, w_errors=None):
         from pypy.objspace.std.unicodeobject import (
             _get_encoding_and_errors, decode_object, unicode_from_string)
         encoding, errors = _get_encoding_and_errors(space, w_encoding,
                                                     w_errors)
+        # TODO: On CPython calling bytearray.decode with no arguments works.
         if encoding is None and errors is None:
             return unicode_from_string(space, self)
         return decode_object(space, self, encoding, errors)
@@ -192,30 +209,52 @@
 
     def descr_find(self, space, w_sub, w_start=None, w_end=None):
         (value, start, end) = self._convert_idx_params(space, w_start, w_end)
-        res = value.find(self._op_val(space, w_sub), start, end)
+
+        if self._use_rstr_ops(space, w_sub):
+            res = value.find(self._op_val(space, w_sub), start, end)
+            return space.wrap(res)
+
+        buffer = _get_buffer(space, w_sub)
+        res = _search_slowpath(value, buffer, start, end, FAST_FIND)
         return space.wrap(res)
 
     def descr_rfind(self, space, w_sub, w_start=None, w_end=None):
         (value, start, end) = self._convert_idx_params(space, w_start, w_end)
-        res = value.rfind(self._op_val(space, w_sub), start, end)
+
+        if self._use_rstr_ops(space, w_sub):
+            res = value.rfind(self._op_val(space, w_sub), start, end)
+            return space.wrap(res)
+
+        buffer = _get_buffer(space, w_sub)
+        res = _search_slowpath(value, buffer, start, end, FAST_RFIND)
         return space.wrap(res)
 
     def descr_index(self, space, w_sub, w_start=None, w_end=None):
         (value, start, end) = self._convert_idx_params(space, w_start, w_end)
-        res = value.find(self._op_val(space, w_sub), start, end)
+
+        if self._use_rstr_ops(space, w_sub):
+            res = value.find(self._op_val(space, w_sub), start, end)
+        else:
+            buffer = _get_buffer(space, w_sub)
+            res = _search_slowpath(value, buffer, start, end, FAST_FIND)
+
         if res < 0:
             raise oefmt(space.w_ValueError,
                         "substring not found in string.index")
-
         return space.wrap(res)
 
     def descr_rindex(self, space, w_sub, w_start=None, w_end=None):
         (value, start, end) = self._convert_idx_params(space, w_start, w_end)
-        res = value.rfind(self._op_val(space, w_sub), start, end)
+
+        if self._use_rstr_ops(space, w_sub):
+            res = value.rfind(self._op_val(space, w_sub), start, end)
+        else:
+            buffer = _get_buffer(space, w_sub)
+            res = _search_slowpath(value, buffer, start, end, FAST_RFIND)
+
         if res < 0:
             raise oefmt(space.w_ValueError,
                         "substring not found in string.rindex")
-
         return space.wrap(res)
 
     @specialize.arg(2)
@@ -328,6 +367,7 @@
         value = self._val(space)
 
         prealloc_size = len(value) * (size - 1)
+        unwrapped = newlist_hint(size)
         for i in range(size):
             w_s = list_w[i]
             check_item = self._join_check_item(space, w_s)
@@ -337,13 +377,16 @@
                             i, w_s)
             elif check_item == 2:
                 return self._join_autoconvert(space, list_w)
-            prealloc_size += len(self._op_val(space, w_s))
+            # XXX Maybe the extra copy here is okay? It was basically going to
+            #     happen anyway, what with being placed into the builder
+            unwrapped.append(self._op_val(space, w_s))
+            prealloc_size += len(unwrapped[0])
 
         sb = self._builder(prealloc_size)
         for i in range(size):
             if value and i != 0:
                 sb.append(value)
-            sb.append(self._op_val(space, list_w[i]))
+            sb.append(unwrapped[i])
         return self._new(sb.build())
 
     def _join_autoconvert(self, space, list_w):
@@ -386,10 +429,22 @@
 
     def descr_partition(self, space, w_sub):
         value = self._val(space)
-        sub = self._op_val(space, w_sub)
-        if not sub:
+
+        if self._use_rstr_ops(space, w_sub):
+            sub = self._op_val(space, w_sub)
+            sublen = len(sub)
+        else:
+            sub = _get_buffer(space, w_sub)
+            sublen = sub.getlength()
+
+        if sublen == 0:
             raise oefmt(space.w_ValueError, "empty separator")
-        pos = value.find(sub)
+
+        if self._use_rstr_ops(space, w_sub):
+            pos = value.find(sub)
+        else:
+            pos = _search_slowpath(value, sub, 0, len(value), FAST_FIND)
+
         if pos == -1:
             from pypy.objspace.std.bytearrayobject import W_BytearrayObject
             if isinstance(self, W_BytearrayObject):
@@ -398,17 +453,29 @@
         else:
             from pypy.objspace.std.bytearrayobject import W_BytearrayObject
             if isinstance(self, W_BytearrayObject):
-                w_sub = self._new(sub)
+                w_sub = self._new_from_buffer(sub)
             return space.newtuple(
                 [self._sliced(space, value, 0, pos, self), w_sub,
                  self._sliced(space, value, pos+len(sub), len(value), self)])
 
     def descr_rpartition(self, space, w_sub):
         value = self._val(space)
-        sub = self._op_val(space, w_sub)
-        if not sub:
+
+        if self._use_rstr_ops(space, w_sub):
+            sub = self._op_val(space, w_sub)
+            sublen = len(sub)
+        else:
+            sub = _get_buffer(space, w_sub)
+            sublen = sub.getlength()
+
+        if sublen == 0:
             raise oefmt(space.w_ValueError, "empty separator")
-        pos = value.rfind(sub)
+
+        if self._use_rstr_ops(space, w_sub):
+            pos = value.rfind(sub)
+        else:
+            pos = _search_slowpath(value, sub, 0, len(value), FAST_RFIND)
+
         if pos == -1:
             from pypy.objspace.std.bytearrayobject import W_BytearrayObject
             if isinstance(self, W_BytearrayObject):
@@ -417,7 +484,7 @@
         else:
             from pypy.objspace.std.bytearrayobject import W_BytearrayObject
             if isinstance(self, W_BytearrayObject):
-                w_sub = self._new(sub)
+                w_sub = self._new_from_buffer(sub)
             return space.newtuple(
                 [self._sliced(space, value, 0, pos, self), w_sub,
                  self._sliced(space, value, pos+len(sub), len(value), self)])
@@ -616,10 +683,11 @@
             for char in string:
                 buf.append(table[ord(char)])
         else:
+            # XXX Why not preallocate here too?
             buf = self._builder()
             deletion_table = [False] * 256
-            for c in deletechars:
-                deletion_table[ord(c)] = True
+            for i in range(len(deletechars)):
+                deletion_table[ord(deletechars[i])] = True
             for char in string:
                 if not deletion_table[ord(char)]:
                     buf.append(table[ord(char)])
@@ -662,3 +730,118 @@
 @specialize.argtype(0)
 def _descr_getslice_slowpath(selfvalue, start, step, sl):
     return [selfvalue[start + i*step] for i in range(sl)]
+
+def _get_buffer(space, w_obj):
+    return space.buffer_w(w_obj, space.BUF_SIMPLE)
+
+
+
+# Stolen form rpython.rtyper.lltypesytem.rstr
+# TODO: Ask about what to do with this...
+
+FAST_COUNT = 0
+FAST_FIND = 1
+FAST_RFIND = 2
+
+from rpython.rlib.rarithmetic import LONG_BIT as BLOOM_WIDTH
+
+def bloom_add(mask, c):
+    return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+
+
+def bloom(mask, c):
+    return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+
+ at specialize.argtype(0, 1)
+def _search_slowpath(value, buffer, start, end, mode):
+    if start < 0:
+        start = 0
+    if end > len(value):
+        end = len(value)
+    if start > end:
+        return -1
+
+    count = 0
+    n = end - start
+    m = buffer.getlength()
+
+    if m == 0:
+        if mode == FAST_COUNT:
+            return end - start + 1
+        elif mode == FAST_RFIND:
+            return end
+        else:
+            return start
+
+    w = n - m
+
+    if w < 0:
+        return -1
+
+    mlast = m - 1
+    skip = mlast - 1
+    mask = 0
+
+    if mode != FAST_RFIND:
+        for i in range(mlast):
+            mask = bloom_add(mask, buffer.getitem(i))
+            if buffer.getitem(i) == buffer.getitem(mlast):
+                skip = mlast - i - 1
+        mask = bloom_add(mask, buffer.getitem(mlast))
+
+        i = start - 1
+        while i + 1 <= start + w:
+            i += 1
+            if value[i + m - 1] == buffer.getitem(m - 1):
+                for j in range(mlast):
+                    if value[i + j] != buffer.getitem(j):
+                        break
+                else:
+                    if mode != FAST_COUNT:
+                        return i
+                    count += 1
+                    i += mlast
+                    continue
+
+                if i + m < len(value):
+                    c = value[i + m]
+                else:
+                    c = '\0'
+                if not bloom(mask, c):
+                    i += m
+                else:
+                    i += skip
+            else:
+                if i + m < len(value):
+                    c = value[i + m]
+                else:
+                    c = '\0'
+                if not bloom(mask, c):
+                    i += m
+    else:
+        mask = bloom_add(mask, buffer.getitem(0))
+        for i in range(mlast, 0, -1):
+            mask = bloom_add(mask, buffer.getitem(i))
+            if buffer.getitem(i) == buffer.getitem(0):
+                skip = i - 1
+
+        i = start + w + 1
+        while i - 1 >= start:
+            i -= 1
+            if value[i] == buffer.getitem(0):
+                for j in xrange(mlast, 0, -1):
+                    if value[i + j] != buffer.getitem(j):
+                        break
+                else:
+                    return i
+                if i - 1 >= 0 and not bloom(mask, value[i - 1]):
+                    i -= m
+                else:
+                    i -= skip
+            else:
+                if i - 1 >= 0 and not bloom(mask, value[i - 1]):
+                    i -= m
+
+    if mode != FAST_COUNT:
+        return -1
+    return count
diff --git a/pypy/objspace/std/test/test_bytearrayobject.py b/pypy/objspace/std/test/test_bytearrayobject.py
--- a/pypy/objspace/std/test/test_bytearrayobject.py
+++ b/pypy/objspace/std/test/test_bytearrayobject.py
@@ -178,8 +178,10 @@
         assert bytearray('hello').rindex('l') == 3
         assert bytearray('hello').index(bytearray('e')) == 1
         assert bytearray('hello').find('l') == 2
+        assert bytearray('hello').find('l', -2) == 3
         assert bytearray('hello').rfind('l') == 3
 
+
         # these checks used to not raise in pypy but they should
         raises(TypeError, bytearray('hello').index, ord('e'))
         raises(TypeError, bytearray('hello').rindex, ord('e'))
diff --git a/pypy/objspace/std/unicodeobject.py b/pypy/objspace/std/unicodeobject.py
--- a/pypy/objspace/std/unicodeobject.py
+++ b/pypy/objspace/std/unicodeobject.py
@@ -103,6 +103,12 @@
     _val = unicode_w
 
     @staticmethod
+    def _use_rstr_ops(space, w_other):
+        # Always return true because we always need to copy the other
+        # operand(s) before we can do comparisons
+        return True
+
+    @staticmethod
     def _op_val(space, w_other):
         if isinstance(w_other, W_UnicodeObject):
             return w_other._value


More information about the pypy-commit mailing list