[pypy-commit] pypy fix-bytearray-complexity: All bytearray ops work except for .replace. Some operands are still copied
waedt
noreply at buildbot.pypy.org
Mon Jun 2 19:47:04 CEST 2014
Author: Tyler Wade <wayedt at gmail.com>
Branch: fix-bytearray-complexity
Changeset: r71873:fd4cad71fdab
Date: 2014-05-26 09:04 -0500
http://bitbucket.org/pypy/pypy/changeset/fd4cad71fdab/
Log: All bytearray ops work except for .replace. Some operands are still
copied
diff --git a/pypy/objspace/std/bytearrayobject.py b/pypy/objspace/std/bytearrayobject.py
--- a/pypy/objspace/std/bytearrayobject.py
+++ b/pypy/objspace/std/bytearrayobject.py
@@ -62,6 +62,10 @@
raise oefmt(space.w_IndexError, "bytearray index out of range")
return space.wrap(ord(character))
+ def _fillchar(self, space, w_fillchar):
+ c = self._op_val(space, w_fillchar)
+ return [c], len(c)
+
def _val(self, space):
return self.data
@@ -77,6 +81,9 @@
assert len(char) == 1
return str(char)[0]
+ def _multi_chr(self, char):
+ return [self._chr(char)]
+
@staticmethod
def _builder(size=100):
return BytearrayBuilder(size)
@@ -495,7 +502,6 @@
data[len(self.data) + i] = buffer.getitem(i)
return self._new(data)
-
def descr_reverse(self, space):
self.data.reverse()
@@ -507,6 +513,13 @@
for i in range(len(s)):
self.data.append(s[i])
+ def append_multiple_char(self, c, count):
+ self.data.extend([c] * count)
+
+ def append_slice(self, value, start, end):
+ for i in range(start, end):
+ self.data.append(value[i])
+
def build(self):
return self.data
diff --git a/pypy/objspace/std/stringmethods.py b/pypy/objspace/std/stringmethods.py
--- a/pypy/objspace/std/stringmethods.py
+++ b/pypy/objspace/std/stringmethods.py
@@ -3,7 +3,9 @@
from rpython.rlib import jit
from rpython.rlib.objectmodel import specialize, newlist_hint
from rpython.rlib.rarithmetic import ovfcheck
-from rpython.rlib.rstring import endswith, replace, rsplit, split, startswith
+from rpython.rlib.rstring import (
+ search, SEARCH_FIND, SEARCH_RFIND, SEARCH_COUNT, endswith, replace, rsplit,
+ split, startswith)
from pypy.interpreter.error import OperationError, oefmt
from pypy.interpreter.gateway import WrappedDefault, unwrap_spec
@@ -28,6 +30,9 @@
space, lenself, w_start, w_end, upper_bound=upper_bound)
return (value, start, end)
+ def _multi_chr(self, c):
+ return self._chr(c)
+
def descr_len(self, space):
return space.wrap(self._len())
@@ -41,7 +46,7 @@
return space.newbool(value.find(other) >= 0)
buffer = _get_buffer(space, w_sub)
- res = _search_slowpath(value, buffer, 0, len(value), FAST_FIND)
+ res = search(value, buffer, 0, len(value), SEARCH_FIND)
return space.newbool(res >= 0)
def descr_add(self, space, w_other):
@@ -129,7 +134,7 @@
d = width - len(value)
if d > 0:
offset = d//2 + (d & width & 1)
- fillchar = fillchar[0] # annotator hint: it's a single character
+ fillchar = self._multi_chr(fillchar)
centered = offset * fillchar + value + (d - offset) * fillchar
else:
centered = value
@@ -144,7 +149,7 @@
end))
buffer = _get_buffer(space, w_sub)
- res = _search_slowpath(value, buffer, start, end, FAST_COUNT)
+ res = search(value, buffer, start, end, SEARCH_COUNT)
return space.wrap(max(res, 0))
def descr_decode(self, space, w_encoding=None, w_errors=None):
@@ -152,7 +157,6 @@
_get_encoding_and_errors, decode_object, unicode_from_string)
encoding, errors = _get_encoding_and_errors(space, w_encoding,
w_errors)
- # TODO: On CPython calling bytearray.decode with no arguments works.
if encoding is None and errors is None:
return unicode_from_string(space, self)
return decode_object(space, self, encoding, errors)
@@ -170,7 +174,11 @@
if not value:
return self._empty()
- splitted = value.split(self._chr('\t'))
+ if self._use_rstr_ops(value, self):
+ splitted = value.split(self._chr('\t'))
+ else:
+ splitted = split(value, self._chr('\t'))
+
try:
ovfcheck(len(splitted) * tabsize)
except OverflowError:
@@ -178,7 +186,7 @@
expanded = oldtoken = splitted.pop(0)
for token in splitted:
- expanded += self._chr(' ') * self._tabindent(oldtoken,
+ expanded += self._multi_chr(' ') * self._tabindent(oldtoken,
tabsize) + token
oldtoken = token
@@ -215,7 +223,7 @@
return space.wrap(res)
buffer = _get_buffer(space, w_sub)
- res = _search_slowpath(value, buffer, start, end, FAST_FIND)
+ res = search(value, buffer, start, end, SEARCH_FIND)
return space.wrap(res)
def descr_rfind(self, space, w_sub, w_start=None, w_end=None):
@@ -226,7 +234,7 @@
return space.wrap(res)
buffer = _get_buffer(space, w_sub)
- res = _search_slowpath(value, buffer, start, end, FAST_RFIND)
+ res = search(value, buffer, start, end, SEARCH_RFIND)
return space.wrap(res)
def descr_index(self, space, w_sub, w_start=None, w_end=None):
@@ -236,7 +244,7 @@
res = value.find(self._op_val(space, w_sub), start, end)
else:
buffer = _get_buffer(space, w_sub)
- res = _search_slowpath(value, buffer, start, end, FAST_FIND)
+ res = search(value, buffer, start, end, SEARCH_FIND)
if res < 0:
raise oefmt(space.w_ValueError,
@@ -250,7 +258,7 @@
res = value.rfind(self._op_val(space, w_sub), start, end)
else:
buffer = _get_buffer(space, w_sub)
- res = _search_slowpath(value, buffer, start, end, FAST_RFIND)
+ res = search(value, buffer, start, end, SEARCH_RFIND)
if res < 0:
raise oefmt(space.w_ValueError,
@@ -401,7 +409,7 @@
"ljust() argument 2 must be a single character")
d = width - len(value)
if d > 0:
- fillchar = fillchar[0] # annotator hint: it's a single character
+ fillchar = self._multi_chr(fillchar)
value += d * fillchar
return self._new(value)
@@ -415,7 +423,7 @@
"rjust() argument 2 must be a single character")
d = width - len(value)
if d > 0:
- fillchar = fillchar[0] # annotator hint: it's a single character
+ fillchar = self._multi_chr(fillchar)
value = d * fillchar + value
return self._new(value)
@@ -443,7 +451,7 @@
if self._use_rstr_ops(space, w_sub):
pos = value.find(sub)
else:
- pos = _search_slowpath(value, sub, 0, len(value), FAST_FIND)
+ pos = search(value, sub, 0, len(value), SEARCH_FIND)
if pos == -1:
from pypy.objspace.std.bytearrayobject import W_BytearrayObject
@@ -456,7 +464,7 @@
w_sub = self._new_from_buffer(sub)
return space.newtuple(
[self._sliced(space, value, 0, pos, self), w_sub,
- self._sliced(space, value, pos+len(sub), len(value), self)])
+ self._sliced(space, value, pos + sublen, len(value), self)])
def descr_rpartition(self, space, w_sub):
value = self._val(space)
@@ -474,7 +482,7 @@
if self._use_rstr_ops(space, w_sub):
pos = value.rfind(sub)
else:
- pos = _search_slowpath(value, sub, 0, len(value), FAST_RFIND)
+ pos = search(value, sub, 0, len(value), SEARCH_RFIND)
if pos == -1:
from pypy.objspace.std.bytearrayobject import W_BytearrayObject
@@ -487,7 +495,7 @@
w_sub = self._new_from_buffer(sub)
return space.newtuple(
[self._sliced(space, value, 0, pos, self), w_sub,
- self._sliced(space, value, pos+len(sub), len(value), self)])
+ self._sliced(space, value, pos + sublen, len(value), self)])
@unwrap_spec(count=int)
def descr_replace(self, space, w_old, w_new, count=-1):
@@ -735,113 +743,3 @@
return space.buffer_w(w_obj, space.BUF_SIMPLE)
-
-# Stolen form rpython.rtyper.lltypesytem.rstr
-# TODO: Ask about what to do with this...
-
-FAST_COUNT = 0
-FAST_FIND = 1
-FAST_RFIND = 2
-
-from rpython.rlib.rarithmetic import LONG_BIT as BLOOM_WIDTH
-
-def bloom_add(mask, c):
- return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
-
-
-def bloom(mask, c):
- return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
-
- at specialize.argtype(0, 1)
-def _search_slowpath(value, buffer, start, end, mode):
- if start < 0:
- start = 0
- if end > len(value):
- end = len(value)
- if start > end:
- return -1
-
- count = 0
- n = end - start
- m = buffer.getlength()
-
- if m == 0:
- if mode == FAST_COUNT:
- return end - start + 1
- elif mode == FAST_RFIND:
- return end
- else:
- return start
-
- w = n - m
-
- if w < 0:
- return -1
-
- mlast = m - 1
- skip = mlast - 1
- mask = 0
-
- if mode != FAST_RFIND:
- for i in range(mlast):
- mask = bloom_add(mask, buffer.getitem(i))
- if buffer.getitem(i) == buffer.getitem(mlast):
- skip = mlast - i - 1
- mask = bloom_add(mask, buffer.getitem(mlast))
-
- i = start - 1
- while i + 1 <= start + w:
- i += 1
- if value[i + m - 1] == buffer.getitem(m - 1):
- for j in range(mlast):
- if value[i + j] != buffer.getitem(j):
- break
- else:
- if mode != FAST_COUNT:
- return i
- count += 1
- i += mlast
- continue
-
- if i + m < len(value):
- c = value[i + m]
- else:
- c = '\0'
- if not bloom(mask, c):
- i += m
- else:
- i += skip
- else:
- if i + m < len(value):
- c = value[i + m]
- else:
- c = '\0'
- if not bloom(mask, c):
- i += m
- else:
- mask = bloom_add(mask, buffer.getitem(0))
- for i in range(mlast, 0, -1):
- mask = bloom_add(mask, buffer.getitem(i))
- if buffer.getitem(i) == buffer.getitem(0):
- skip = i - 1
-
- i = start + w + 1
- while i - 1 >= start:
- i -= 1
- if value[i] == buffer.getitem(0):
- for j in xrange(mlast, 0, -1):
- if value[i + j] != buffer.getitem(j):
- break
- else:
- return i
- if i - 1 >= 0 and not bloom(mask, value[i - 1]):
- i -= m
- else:
- i -= skip
- else:
- if i - 1 >= 0 and not bloom(mask, value[i - 1]):
- i -= m
-
- if mode != FAST_COUNT:
- return -1
- return count
diff --git a/pypy/objspace/std/test/test_bytearrayobject.py b/pypy/objspace/std/test/test_bytearrayobject.py
--- a/pypy/objspace/std/test/test_bytearrayobject.py
+++ b/pypy/objspace/std/test/test_bytearrayobject.py
@@ -442,6 +442,7 @@
u = b.decode('utf-8')
assert isinstance(u, unicode)
assert u == u'abcdefghi'
+ assert b.decode()
def test_int(self):
assert int(bytearray('-1234')) == -1234
diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py
--- a/rpython/rlib/rstring.py
+++ b/rpython/rlib/rstring.py
@@ -7,7 +7,8 @@
from rpython.rtyper.llannotation import SomePtr
from rpython.rlib import jit
from rpython.rlib.objectmodel import newlist_hint, specialize
-from rpython.rlib.rarithmetic import ovfcheck
+from rpython.rlib.rarithmetic import ovfcheck, LONG_BIT as BLOOM_WIDTH
+from rpython.rlib.buffer import Buffer
from rpython.rlib.unicodedata import unicodedb_5_2_0 as unicodedb
from rpython.rtyper.extregistry import ExtRegistryEntry
from rpython.tool.pairtype import pairtype
@@ -15,6 +16,35 @@
# -------------- public API for string functions -----------------------
+ at specialize.argtype(0, 1)
+def _get_access_functions(value, other):
+ if isinstance(other, (str, unicode, list)):
+ def getitem(obj, i):
+ return obj[i]
+ def getlength(obj):
+ return len(obj)
+ else:
+ assert isinstance(other, Buffer)
+ def getitem(obj, i):
+ return obj.getitem(i)
+ def getlength(obj):
+ return obj.getlength()
+
+ if isinstance(value, list) or isinstance(other, Buffer):
+ def find(obj, other, start, end):
+ return search(obj, other, start, end, SEARCH_FIND)
+ def rfind(obj, other, start, end):
+ return search(obj, other, start, end, SEARCH_RFIND)
+ else:
+ assert isinstance(value, (str, unicode))
+ assert isinstance(other, (str, unicode))
+ def find(obj, other, start, end):
+ return obj.find(other, start, end)
+ def rfind(obj, other, start, end):
+ return obj.rfind(other, start, end)
+
+ return getitem, getlength, find, rfind
+
@specialize.argtype(0)
def _isspace(char):
if isinstance(char, str):
@@ -55,10 +85,11 @@
i = j + 1
return res
- if isinstance(value, str):
+ if isinstance(value, (list, str)):
assert isinstance(by, str)
else:
assert isinstance(by, unicode)
+ _, _, find, _ = _get_access_functions(value, by)
bylen = len(by)
if bylen == 0:
raise ValueError("empty separator")
@@ -72,7 +103,7 @@
count = maxsplit
res = newlist_hint(count + 1)
while count > 0:
- next = value.find(by, start)
+ next = find(value, by, start, len(value))
assert next >= 0 # cannot fail due to the value.count above
res.append(value[start:next])
start = next + bylen
@@ -86,7 +117,7 @@
res = []
while maxsplit != 0:
- next = value.find(by, start)
+ next = find(value, by, start, len(value))
if next < 0:
break
res.append(value[start:next])
@@ -133,7 +164,7 @@
res.reverse()
return res
- if isinstance(value, str):
+ if isinstance(value, (list, str)):
assert isinstance(by, str)
else:
assert isinstance(by, unicode)
@@ -141,13 +172,14 @@
res = newlist_hint(min(maxsplit + 1, len(value)))
else:
res = []
+ _, _, _, rfind = _get_access_functions(value, by)
end = len(value)
bylen = len(by)
if bylen == 0:
raise ValueError("empty separator")
while maxsplit != 0:
- next = value.rfind(by, 0, end)
+ next = rfind(value, by, 0, end)
if next < 0:
break
res.append(value[next + bylen:end])
@@ -166,13 +198,20 @@
assert isinstance(sub, str)
assert isinstance(by, str)
Builder = StringBuilder
- else:
+ elif isinstance(input, unicode):
assert isinstance(sub, unicode)
assert isinstance(by, unicode)
Builder = UnicodeBuilder
+ elif isinstance(input, list):
+ assert isinstance(sub, str)
+ assert isinstance(by, str)
+ # TODO: ????
+ Builder = StringBuilder
if maxsplit == 0:
return input
+ _, _, find, _ = _get_access_functions(input, sub)
+
if not sub:
upper = len(input)
if maxsplit > 0 and maxsplit < upper + 2:
@@ -210,7 +249,7 @@
sublen = len(sub)
while maxsplit != 0:
- next = input.find(sub, start)
+ next = find(input, sub, start, len(input))
if next < 0:
break
builder.append_slice(input, start, next)
@@ -261,6 +300,114 @@
return False
return True
+# Stolen form rpython.rtyper.lltypesytem.rstr
+# TODO: Ask about what to do with this...
+
+SEARCH_COUNT = 0
+SEARCH_FIND = 1
+SEARCH_RFIND = 2
+
+def bloom_add(mask, c):
+ return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+
+def bloom(mask, c):
+ return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+
+ at specialize.argtype(0, 1)
+def search(value, other, start, end, mode):
+ getitem, getlength, _, _ = _get_access_functions(value, other)
+ if start < 0:
+ start = 0
+ if end > len(value):
+ end = len(value)
+ if start > end:
+ return -1
+
+ count = 0
+ n = end - start
+ m = getlength(other)
+
+ if m == 0:
+ if mode == SEARCH_COUNT:
+ return end - start + 1
+ elif mode == SEARCH_RFIND:
+ return end
+ else:
+ return start
+
+ w = n - m
+
+ if w < 0:
+ return -1
+
+ mlast = m - 1
+ skip = mlast - 1
+ mask = 0
+
+ if mode != SEARCH_RFIND:
+ for i in range(mlast):
+ mask = bloom_add(mask, getitem(other, i))
+ if getitem(other, i) == getitem(other, mlast):
+ skip = mlast - i - 1
+ mask = bloom_add(mask, getitem(other, mlast))
+
+ i = start - 1
+ while i + 1 <= start + w:
+ i += 1
+ if value[i + m - 1] == getitem(other, m - 1):
+ for j in range(mlast):
+ if value[i + j] != getitem(other, j):
+ break
+ else:
+ if mode != SEARCH_COUNT:
+ return i
+ count += 1
+ i += mlast
+ continue
+
+ if i + m < len(value):
+ c = value[i + m]
+ else:
+ c = '\0'
+ if not bloom(mask, c):
+ i += m
+ else:
+ i += skip
+ else:
+ if i + m < len(value):
+ c = value[i + m]
+ else:
+ c = '\0'
+ if not bloom(mask, c):
+ i += m
+ else:
+ mask = bloom_add(mask, getitem(other, 0))
+ for i in range(mlast, 0, -1):
+ mask = bloom_add(mask, getitem(other, i))
+ if getitem(other, i) == getitem(other, 0):
+ skip = i - 1
+
+ i = start + w + 1
+ while i - 1 >= start:
+ i -= 1
+ if value[i] == getitem(other, 0):
+ for j in xrange(mlast, 0, -1):
+ if value[i + j] != getitem(other, j):
+ break
+ else:
+ return i
+ if i - 1 >= 0 and not bloom(mask, value[i - 1]):
+ i -= m
+ else:
+ i -= skip
+ else:
+ if i - 1 >= 0 and not bloom(mask, value[i - 1]):
+ i -= m
+
+ if mode != SEARCH_COUNT:
+ return -1
+ return count
+
# -------------- numeric parsing support --------------------
def strip_spaces(s):
More information about the pypy-commit
mailing list