[pypy-commit] pypy fix-bytearray-complexity: Most bytearray methods fixed
waedt
noreply at buildbot.pypy.org
Mon Jun 2 19:47:03 CEST 2014
Author: Tyler Wade <wayedt at gmail.com>
Branch: fix-bytearray-complexity
Changeset: r71872:3521f66aed64
Date: 2014-05-26 03:48 -0500
http://bitbucket.org/pypy/pypy/changeset/3521f66aed64/
Log: Most bytearray methods fixed
diff --git a/pypy/objspace/std/bytearrayobject.py b/pypy/objspace/std/bytearrayobject.py
--- a/pypy/objspace/std/bytearrayobject.py
+++ b/pypy/objspace/std/bytearrayobject.py
@@ -1,7 +1,7 @@
"""The builtin bytearray implementation"""
from rpython.rlib.objectmodel import (
- import_from_mixin, newlist_hint, resizelist_hint)
+ import_from_mixin, newlist_hint, resizelist_hint, specialize)
from rpython.rlib.buffer import Buffer
from rpython.rlib.rstring import StringBuilder
@@ -11,7 +11,7 @@
from pypy.interpreter.signature import Signature
from pypy.objspace.std.sliceobject import W_SliceObject
from pypy.objspace.std.stdtypedef import StdTypeDef
-from pypy.objspace.std.stringmethods import StringMethods
+from pypy.objspace.std.stringmethods import StringMethods, _get_buffer
from pypy.objspace.std.util import get_positive_index
NON_HEX_MSG = "non-hexadecimal number found in fromhex() arg at position %d"
@@ -40,7 +40,11 @@
return ''.join(self.data)
def _new(self, value):
- return W_BytearrayObject(_make_data(value))
+ return W_BytearrayObject(value)
+
+ def _new_from_buffer(self, buffer):
+ length = buffer.getlength()
+ return W_BytearrayObject([buffer.getitem(i) for i in range(length)])
def _new_from_list(self, value):
return W_BytearrayObject(value)
@@ -58,7 +62,12 @@
raise oefmt(space.w_IndexError, "bytearray index out of range")
return space.wrap(ord(character))
- _val = charbuf_w
+ def _val(self, space):
+ return self.data
+
+ @staticmethod
+ def _use_rstr_ops(space, w_other):
+ return False
@staticmethod
def _op_val(space, w_other):
@@ -68,7 +77,9 @@
assert len(char) == 1
return str(char)[0]
- _builder = StringBuilder
+ @staticmethod
+ def _builder(size=100):
+ return BytearrayBuilder(size)
def _newlist_unwrapped(self, space, res):
return space.newlist([W_BytearrayObject(_make_data(i)) for i in res])
@@ -260,58 +271,116 @@
return space.wrap(''.join(self.data))
def descr_eq(self, space, w_other):
+ if isinstance(w_other, W_BytearrayObject):
+ return space.newbool(self.data == w_other.data)
+
try:
- res = self._val(space) == self._op_val(space, w_other)
+ buffer = _get_buffer(space, w_other)
except OperationError as e:
if e.match(space, space.w_TypeError):
return space.w_NotImplemented
raise
- return space.newbool(res)
+
+ value = self._val(space)
+ buffer_len = buffer.getlength()
+
+ if len(value) != buffer_len:
+ return space.newbool(False)
+
+ min_length = min(len(value), buffer_len)
+ return space.newbool(_memcmp(value, buffer, min_length) == 0)
def descr_ne(self, space, w_other):
+ if isinstance(w_other, W_BytearrayObject):
+ return space.newbool(self.data != w_other.data)
+
try:
- res = self._val(space) != self._op_val(space, w_other)
+ buffer = _get_buffer(space, w_other)
except OperationError as e:
if e.match(space, space.w_TypeError):
return space.w_NotImplemented
raise
- return space.newbool(res)
+
+ value = self._val(space)
+ buffer_len = buffer.getlength()
+
+ if len(value) != buffer_len:
+ return space.newbool(True)
+
+ min_length = min(len(value), buffer_len)
+ return space.newbool(_memcmp(value, buffer, min_length) != 0)
def descr_lt(self, space, w_other):
+ if isinstance(w_other, W_BytearrayObject):
+ return space.newbool(self.data < w_other.data)
+
try:
- res = self._val(space) < self._op_val(space, w_other)
+ buffer = _get_buffer(space, w_other)
except OperationError as e:
if e.match(space, space.w_TypeError):
return space.w_NotImplemented
raise
- return space.newbool(res)
+
+ value = self._val(space)
+ buffer_len = buffer.getlength()
+
+ cmp = _memcmp(value, buffer, min(len(value), buffer_len))
+ return space.newbool(
+ cmp < 0 or (cmp == 0 and space.newbool(len(value) < buffer_len)))
def descr_le(self, space, w_other):
+ if isinstance(w_other, W_BytearrayObject):
+ return space.newbool(self.data <= w_other.data)
+
try:
- res = self._val(space) <= self._op_val(space, w_other)
+ buffer = _get_buffer(space, w_other)
except OperationError as e:
if e.match(space, space.w_TypeError):
return space.w_NotImplemented
raise
- return space.newbool(res)
+
+ value = self._val(space)
+ buffer_len = buffer.getlength()
+
+ cmp = _memcmp(value, buffer, min(len(value), buffer_len))
+ return space.newbool(
+ cmp < 0 or (cmp == 0 and space.newbool(len(value) <= buffer_len)))
def descr_gt(self, space, w_other):
+ if isinstance(w_other, W_BytearrayObject):
+ return space.newbool(self.data > w_other.data)
+
try:
- res = self._val(space) > self._op_val(space, w_other)
+ buffer = _get_buffer(space, w_other)
except OperationError as e:
if e.match(space, space.w_TypeError):
return space.w_NotImplemented
raise
- return space.newbool(res)
+
+ value = self._val(space)
+ buffer_len = buffer.getlength()
+
+ cmp = _memcmp(value, buffer, min(len(value), buffer_len))
+ return space.newbool(
+ cmp > 0 or (cmp == 0 and space.newbool(len(value) > buffer_len)))
def descr_ge(self, space, w_other):
+ if isinstance(w_other, W_BytearrayObject):
+ return space.newbool(self.data >= w_other.data)
+
try:
- res = self._val(space) >= self._op_val(space, w_other)
+ buffer = _get_buffer(space, w_other)
except OperationError as e:
if e.match(space, space.w_TypeError):
return space.w_NotImplemented
raise
- return space.newbool(res)
+
+ value = self._val(space)
+ buffer_len = buffer.getlength()
+
+ cmp = _memcmp(value, buffer, min(len(value), buffer_len))
+ return space.newbool(
+ cmp > 0 or (cmp == 0 and space.newbool(len(value) >= buffer_len)))
def descr_iter(self, space):
return space.newseqiter(self)
@@ -319,8 +388,11 @@
def descr_inplace_add(self, space, w_other):
if isinstance(w_other, W_BytearrayObject):
self.data += w_other.data
- else:
- self.data += self._op_val(space, w_other)
+ return self
+
+ buffer = _get_buffer(space, w_other)
+ for i in range(buffer.getlength()):
+ self.data.append(buffer.getitem(i))
return self
def descr_inplace_mul(self, space, w_times):
@@ -403,11 +475,42 @@
if space.isinstance_w(w_sub, space.w_int):
char = space.int_w(w_sub)
return _descr_contains_bytearray(self.data, space, char)
+
return self._StringMethods_descr_contains(space, w_sub)
+ def descr_add(self, space, w_other):
+ if isinstance(w_other, W_BytearrayObject):
+ return self._new(self.data + w_other.data)
+
+ try:
+ buffer = _get_buffer(space, w_other)
+ except OperationError as e:
+ if e.match(space, space.w_TypeError):
+ return space.w_NotImplemented
+ raise
+
+ buffer_len = buffer.getlength()
+ data = list(self.data + ['\0'] * buffer_len)
+ for i in range(buffer_len):
+ data[len(self.data) + i] = buffer.getitem(i)
+ return self._new(data)
+
+
def descr_reverse(self, space):
self.data.reverse()
+class BytearrayBuilder(object):
+ def __init__(self, size):
+ self.data = newlist_hint(size)
+
+ def append(self, s):
+ for i in range(len(s)):
+ self.data.append(s[i])
+
+ def build(self):
+ return self.data
+
+
# ____________________________________________________________
# helpers for slow paths, moved out because they contain loops
@@ -1152,3 +1255,13 @@
def setitem(self, index, char):
self.data[index] = char
+
+
+ at specialize.argtype(0)
+def _memcmp(selfvalue, buffer, length):
+ for i in range(length):
+ if selfvalue[i] < buffer.getitem(i):
+ return -1
+ if selfvalue[i] > buffer.getitem(i):
+ return 1
+ return 0
diff --git a/pypy/objspace/std/bytesobject.py b/pypy/objspace/std/bytesobject.py
--- a/pypy/objspace/std/bytesobject.py
+++ b/pypy/objspace/std/bytesobject.py
@@ -480,6 +480,11 @@
_val = str_w
@staticmethod
+ def _use_rstr_ops(space, w_other):
+ from pypy.objspace.std.unicodeobject import W_UnicodeObject
+ return isinstance(w_other, (W_BytesObject, W_UnicodeObject))
+
+ @staticmethod
def _op_val(space, w_other):
try:
return space.str_w(w_other)
diff --git a/pypy/objspace/std/stringmethods.py b/pypy/objspace/std/stringmethods.py
--- a/pypy/objspace/std/stringmethods.py
+++ b/pypy/objspace/std/stringmethods.py
@@ -1,7 +1,7 @@
"""Functionality shared between bytes/bytearray/unicode"""
from rpython.rlib import jit
-from rpython.rlib.objectmodel import specialize
+from rpython.rlib.objectmodel import specialize, newlist_hint
from rpython.rlib.rarithmetic import ovfcheck
from rpython.rlib.rstring import endswith, replace, rsplit, split, startswith
@@ -36,17 +36,27 @@
def descr_contains(self, space, w_sub):
value = self._val(space)
- other = self._op_val(space, w_sub)
- return space.newbool(value.find(other) >= 0)
+ if self._use_rstr_ops(space, w_sub):
+ other = self._op_val(space, w_sub)
+ return space.newbool(value.find(other) >= 0)
+
+ buffer = _get_buffer(space, w_sub)
+ res = _search_slowpath(value, buffer, 0, len(value), FAST_FIND)
+ return space.newbool(res >= 0)
def descr_add(self, space, w_other):
- try:
- other = self._op_val(space, w_other)
- except OperationError as e:
- if e.match(space, space.w_TypeError):
- return space.w_NotImplemented
- raise
- return self._new(self._val(space) + other)
+ if self._use_rstr_ops(space, w_other):
+ try:
+ other = self._op_val(space, w_other)
+ except OperationError as e:
+ if e.match(space, space.w_TypeError):
+ return space.w_NotImplemented
+ raise
+ return self._new(self._val(space) + other)
+
+ # Bytearray overrides this method, CPython doesn't support contacting
+ # buffers and strs, and unicodes are always handled above
+ return space.w_NotImplemented
def descr_mul(self, space, w_times):
try:
@@ -128,14 +138,21 @@
def descr_count(self, space, w_sub, w_start=None, w_end=None):
value, start, end = self._convert_idx_params(space, w_start, w_end)
- return space.newint(value.count(self._op_val(space, w_sub), start,
- end))
+
+ if self._use_rstr_ops(space, w_sub):
+ return space.newint(value.count(self._op_val(space, w_sub), start,
+ end))
+
+ buffer = _get_buffer(space, w_sub)
+ res = _search_slowpath(value, buffer, start, end, FAST_COUNT)
+ return space.wrap(max(res, 0))
def descr_decode(self, space, w_encoding=None, w_errors=None):
from pypy.objspace.std.unicodeobject import (
_get_encoding_and_errors, decode_object, unicode_from_string)
encoding, errors = _get_encoding_and_errors(space, w_encoding,
w_errors)
+ # TODO: On CPython calling bytearray.decode with no arguments works.
if encoding is None and errors is None:
return unicode_from_string(space, self)
return decode_object(space, self, encoding, errors)
@@ -192,30 +209,52 @@
def descr_find(self, space, w_sub, w_start=None, w_end=None):
(value, start, end) = self._convert_idx_params(space, w_start, w_end)
- res = value.find(self._op_val(space, w_sub), start, end)
+
+ if self._use_rstr_ops(space, w_sub):
+ res = value.find(self._op_val(space, w_sub), start, end)
+ return space.wrap(res)
+
+ buffer = _get_buffer(space, w_sub)
+ res = _search_slowpath(value, buffer, start, end, FAST_FIND)
return space.wrap(res)
def descr_rfind(self, space, w_sub, w_start=None, w_end=None):
(value, start, end) = self._convert_idx_params(space, w_start, w_end)
- res = value.rfind(self._op_val(space, w_sub), start, end)
+
+ if self._use_rstr_ops(space, w_sub):
+ res = value.rfind(self._op_val(space, w_sub), start, end)
+ return space.wrap(res)
+
+ buffer = _get_buffer(space, w_sub)
+ res = _search_slowpath(value, buffer, start, end, FAST_RFIND)
return space.wrap(res)
def descr_index(self, space, w_sub, w_start=None, w_end=None):
(value, start, end) = self._convert_idx_params(space, w_start, w_end)
- res = value.find(self._op_val(space, w_sub), start, end)
+
+ if self._use_rstr_ops(space, w_sub):
+ res = value.find(self._op_val(space, w_sub), start, end)
+ else:
+ buffer = _get_buffer(space, w_sub)
+ res = _search_slowpath(value, buffer, start, end, FAST_FIND)
+
if res < 0:
raise oefmt(space.w_ValueError,
"substring not found in string.index")
-
return space.wrap(res)
def descr_rindex(self, space, w_sub, w_start=None, w_end=None):
(value, start, end) = self._convert_idx_params(space, w_start, w_end)
- res = value.rfind(self._op_val(space, w_sub), start, end)
+
+ if self._use_rstr_ops(space, w_sub):
+ res = value.rfind(self._op_val(space, w_sub), start, end)
+ else:
+ buffer = _get_buffer(space, w_sub)
+ res = _search_slowpath(value, buffer, start, end, FAST_RFIND)
+
if res < 0:
raise oefmt(space.w_ValueError,
"substring not found in string.rindex")
-
return space.wrap(res)
@specialize.arg(2)
@@ -328,6 +367,7 @@
value = self._val(space)
prealloc_size = len(value) * (size - 1)
+ unwrapped = newlist_hint(size)
for i in range(size):
w_s = list_w[i]
check_item = self._join_check_item(space, w_s)
@@ -337,13 +377,16 @@
i, w_s)
elif check_item == 2:
return self._join_autoconvert(space, list_w)
- prealloc_size += len(self._op_val(space, w_s))
+ # XXX Maybe the extra copy here is okay? It was basically going to
+ # happen anyway, what with being placed into the builder
+ unwrapped.append(self._op_val(space, w_s))
+ prealloc_size += len(unwrapped[0])
sb = self._builder(prealloc_size)
for i in range(size):
if value and i != 0:
sb.append(value)
- sb.append(self._op_val(space, list_w[i]))
+ sb.append(unwrapped[i])
return self._new(sb.build())
def _join_autoconvert(self, space, list_w):
@@ -386,10 +429,22 @@
def descr_partition(self, space, w_sub):
value = self._val(space)
- sub = self._op_val(space, w_sub)
- if not sub:
+
+ if self._use_rstr_ops(space, w_sub):
+ sub = self._op_val(space, w_sub)
+ sublen = len(sub)
+ else:
+ sub = _get_buffer(space, w_sub)
+ sublen = sub.getlength()
+
+ if sublen == 0:
raise oefmt(space.w_ValueError, "empty separator")
- pos = value.find(sub)
+
+ if self._use_rstr_ops(space, w_sub):
+ pos = value.find(sub)
+ else:
+ pos = _search_slowpath(value, sub, 0, len(value), FAST_FIND)
+
if pos == -1:
from pypy.objspace.std.bytearrayobject import W_BytearrayObject
if isinstance(self, W_BytearrayObject):
@@ -398,17 +453,29 @@
else:
from pypy.objspace.std.bytearrayobject import W_BytearrayObject
if isinstance(self, W_BytearrayObject):
- w_sub = self._new(sub)
+ w_sub = self._new_from_buffer(sub)
return space.newtuple(
[self._sliced(space, value, 0, pos, self), w_sub,
self._sliced(space, value, pos+len(sub), len(value), self)])
def descr_rpartition(self, space, w_sub):
value = self._val(space)
- sub = self._op_val(space, w_sub)
- if not sub:
+
+ if self._use_rstr_ops(space, w_sub):
+ sub = self._op_val(space, w_sub)
+ sublen = len(sub)
+ else:
+ sub = _get_buffer(space, w_sub)
+ sublen = sub.getlength()
+
+ if sublen == 0:
raise oefmt(space.w_ValueError, "empty separator")
- pos = value.rfind(sub)
+
+ if self._use_rstr_ops(space, w_sub):
+ pos = value.rfind(sub)
+ else:
+ pos = _search_slowpath(value, sub, 0, len(value), FAST_RFIND)
+
if pos == -1:
from pypy.objspace.std.bytearrayobject import W_BytearrayObject
if isinstance(self, W_BytearrayObject):
@@ -417,7 +484,7 @@
else:
from pypy.objspace.std.bytearrayobject import W_BytearrayObject
if isinstance(self, W_BytearrayObject):
- w_sub = self._new(sub)
+ w_sub = self._new_from_buffer(sub)
return space.newtuple(
[self._sliced(space, value, 0, pos, self), w_sub,
self._sliced(space, value, pos+len(sub), len(value), self)])
@@ -616,10 +683,11 @@
for char in string:
buf.append(table[ord(char)])
else:
+ # XXX Why not preallocate here too?
buf = self._builder()
deletion_table = [False] * 256
- for c in deletechars:
- deletion_table[ord(c)] = True
+ for i in range(len(deletechars)):
+ deletion_table[ord(deletechars[i])] = True
for char in string:
if not deletion_table[ord(char)]:
buf.append(table[ord(char)])
@@ -662,3 +730,118 @@
@specialize.argtype(0)
def _descr_getslice_slowpath(selfvalue, start, step, sl):
return [selfvalue[start + i*step] for i in range(sl)]
+
+def _get_buffer(space, w_obj):
+ return space.buffer_w(w_obj, space.BUF_SIMPLE)
+
+
+
+# Stolen form rpython.rtyper.lltypesytem.rstr
+# TODO: Ask about what to do with this...
+
+FAST_COUNT = 0
+FAST_FIND = 1
+FAST_RFIND = 2
+
+from rpython.rlib.rarithmetic import LONG_BIT as BLOOM_WIDTH
+
+def bloom_add(mask, c):
+ return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+
+
+def bloom(mask, c):
+ return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+
+ at specialize.argtype(0, 1)
+def _search_slowpath(value, buffer, start, end, mode):
+ if start < 0:
+ start = 0
+ if end > len(value):
+ end = len(value)
+ if start > end:
+ return -1
+
+ count = 0
+ n = end - start
+ m = buffer.getlength()
+
+ if m == 0:
+ if mode == FAST_COUNT:
+ return end - start + 1
+ elif mode == FAST_RFIND:
+ return end
+ else:
+ return start
+
+ w = n - m
+
+ if w < 0:
+ return -1
+
+ mlast = m - 1
+ skip = mlast - 1
+ mask = 0
+
+ if mode != FAST_RFIND:
+ for i in range(mlast):
+ mask = bloom_add(mask, buffer.getitem(i))
+ if buffer.getitem(i) == buffer.getitem(mlast):
+ skip = mlast - i - 1
+ mask = bloom_add(mask, buffer.getitem(mlast))
+
+ i = start - 1
+ while i + 1 <= start + w:
+ i += 1
+ if value[i + m - 1] == buffer.getitem(m - 1):
+ for j in range(mlast):
+ if value[i + j] != buffer.getitem(j):
+ break
+ else:
+ if mode != FAST_COUNT:
+ return i
+ count += 1
+ i += mlast
+ continue
+
+ if i + m < len(value):
+ c = value[i + m]
+ else:
+ c = '\0'
+ if not bloom(mask, c):
+ i += m
+ else:
+ i += skip
+ else:
+ if i + m < len(value):
+ c = value[i + m]
+ else:
+ c = '\0'
+ if not bloom(mask, c):
+ i += m
+ else:
+ mask = bloom_add(mask, buffer.getitem(0))
+ for i in range(mlast, 0, -1):
+ mask = bloom_add(mask, buffer.getitem(i))
+ if buffer.getitem(i) == buffer.getitem(0):
+ skip = i - 1
+
+ i = start + w + 1
+ while i - 1 >= start:
+ i -= 1
+ if value[i] == buffer.getitem(0):
+ for j in xrange(mlast, 0, -1):
+ if value[i + j] != buffer.getitem(j):
+ break
+ else:
+ return i
+ if i - 1 >= 0 and not bloom(mask, value[i - 1]):
+ i -= m
+ else:
+ i -= skip
+ else:
+ if i - 1 >= 0 and not bloom(mask, value[i - 1]):
+ i -= m
+
+ if mode != FAST_COUNT:
+ return -1
+ return count
diff --git a/pypy/objspace/std/test/test_bytearrayobject.py b/pypy/objspace/std/test/test_bytearrayobject.py
--- a/pypy/objspace/std/test/test_bytearrayobject.py
+++ b/pypy/objspace/std/test/test_bytearrayobject.py
@@ -178,8 +178,10 @@
assert bytearray('hello').rindex('l') == 3
assert bytearray('hello').index(bytearray('e')) == 1
assert bytearray('hello').find('l') == 2
+ assert bytearray('hello').find('l', -2) == 3
assert bytearray('hello').rfind('l') == 3
+
# these checks used to not raise in pypy but they should
raises(TypeError, bytearray('hello').index, ord('e'))
raises(TypeError, bytearray('hello').rindex, ord('e'))
diff --git a/pypy/objspace/std/unicodeobject.py b/pypy/objspace/std/unicodeobject.py
--- a/pypy/objspace/std/unicodeobject.py
+++ b/pypy/objspace/std/unicodeobject.py
@@ -103,6 +103,12 @@
_val = unicode_w
@staticmethod
+ def _use_rstr_ops(space, w_other):
+ # Always return true because we always need to copy the other
+ # operand(s) before we can do comparisons
+ return True
+
+ @staticmethod
def _op_val(space, w_other):
if isinstance(w_other, W_UnicodeObject):
return w_other._value
More information about the pypy-commit
mailing list