[pypy-commit] pypy unicode-utf8: Rewrite unicode.index/find to use rutf8.codepoints_in_utf8(). It should
arigo
pypy.commits at gmail.com
Sat Oct 14 05:43:30 EDT 2017
Author: Armin Rigo <arigo at tunes.org>
Branch: unicode-utf8
Changeset: r92752:6bc169e4d079
Date: 2017-10-14 10:21 +0200
http://bitbucket.org/pypy/pypy/changeset/6bc169e4d079/
Log: Rewrite unicode.index/find to use rutf8.codepoints_in_utf8(). It
should ensure that the complexity of these methods is always
correct. Previously it was O(result) instead of O(chars-searched),
which can get very bad if .find() is called with a non-zero starting
position.
diff --git a/pypy/objspace/std/test/test_unicodeobject.py b/pypy/objspace/std/test/test_unicodeobject.py
--- a/pypy/objspace/std/test/test_unicodeobject.py
+++ b/pypy/objspace/std/test/test_unicodeobject.py
@@ -2,6 +2,8 @@
# -*- encoding: utf-8 -*-
import py
import sys
+from hypothesis import given, strategies, settings, example
+from pypy.interpreter.error import OperationError
class TestUnicodeObject:
@@ -35,6 +37,49 @@
space.w_unicode, "__new__", space.w_unicode, w_uni)
assert w_new is w_uni
+ @given(strategies.text(), strategies.integers(min_value=0, max_value=10),
+ strategies.integers(min_value=-1, max_value=10))
+ def test_hypo_index_find(self, u, start, len1):
+ if start + len1 < 0:
+ return # skip this case
+ v = u[start : start + len1]
+ space = self.space
+ w_u = space.wrap(u)
+ w_v = space.wrap(v)
+ expected = u.find(v, start, start + len1)
+ try:
+ w_index = space.call_method(w_u, 'index', w_v,
+ space.newint(start),
+ space.newint(start + len1))
+ except OperationError as e:
+ if not e.match(space, space.w_ValueError):
+ raise
+ assert expected == -1
+ else:
+ assert space.int_w(w_index) == expected >= 0
+
+ w_index = space.call_method(w_u, 'find', w_v,
+ space.newint(start),
+ space.newint(start + len1))
+ assert space.int_w(w_index) == expected
+
+ rexpected = u.rfind(v, start, start + len1)
+ try:
+ w_index = space.call_method(w_u, 'rindex', w_v,
+ space.newint(start),
+ space.newint(start + len1))
+ except OperationError as e:
+ if not e.match(space, space.w_ValueError):
+ raise
+ assert rexpected == -1
+ else:
+ assert space.int_w(w_index) == rexpected >= 0
+
+ w_index = space.call_method(w_u, 'rfind', w_v,
+ space.newint(start),
+ space.newint(start + len1))
+ assert space.int_w(w_index) == rexpected
+
class AppTestUnicodeStringStdOnly:
def test_compares(self):
@@ -698,6 +743,7 @@
def test_index(self):
assert u"rrarrrrrrrrra".index(u'a', 4, None) == 12
assert u"rrarrrrrrrrra".index(u'a', None, 6) == 2
+ assert u"\u1234\u4321\u5678".index(u'\u5678', 1) == 2
def test_rindex(self):
from sys import maxint
@@ -707,6 +753,7 @@
assert u'abcdefghiabc'.rindex(u'abc', 0, -1) == 0
assert u'abcdefghiabc'.rindex(u'abc', -4*maxint, 4*maxint) == 9
assert u'rrarrrrrrrrra'.rindex(u'a', 4, None) == 12
+ assert u"\u1234\u5678".rindex(u'\u5678') == 1
raises(ValueError, u'abcdefghiabc'.rindex, u'hib')
raises(ValueError, u'defghiabc'.rindex, u'def', 1)
@@ -721,6 +768,7 @@
assert u'abcdefghiabc'.rfind(u'') == 12
assert u'abcdefghiabc'.rfind(u'abcd') == 0
assert u'abcdefghiabc'.rfind(u'abcz') == -1
+ assert u"\u1234\u5678".rfind(u'\u5678') == 1
def test_rfind_corner_case(self):
assert u'abc'.rfind('', 4) == -1
@@ -736,6 +784,7 @@
assert 'abcdefghiabc'.rindex(u'abc') == 9
raises(UnicodeDecodeError, '\x80'.index, u'')
raises(UnicodeDecodeError, '\x80'.rindex, u'')
+ assert u"\u1234\u5678".find(u'\u5678') == 1
def test_count(self):
assert u"".count(u"x") ==0
diff --git a/pypy/objspace/std/unicodeobject.py b/pypy/objspace/std/unicodeobject.py
--- a/pypy/objspace/std/unicodeobject.py
+++ b/pypy/objspace/std/unicodeobject.py
@@ -1,7 +1,7 @@
"""The builtin unicode implementation"""
from rpython.rlib.objectmodel import (
- compute_hash, compute_unique_id, import_from_mixin,
+ compute_hash, compute_unique_id, import_from_mixin, always_inline,
enforceargs, newlist_hint, specialize, we_are_translated)
from rpython.rlib.buffer import StringBuffer
from rpython.rlib.mutbuffer import MutableStringBuffer
@@ -427,58 +427,32 @@
return W_UnicodeObject(result.build(), result_length)
def descr_find(self, space, w_sub, w_start=None, w_end=None):
- w_sub = self.convert_arg_to_w_unicode(space, w_sub)
- start_index, end_index = self._unwrap_and_compute_idx_params(
- space, w_start, w_end)
-
- res_index = self._utf8.find(w_sub._utf8, start_index, end_index)
- if res_index == -1:
- return space.newint(-1)
-
- res = rutf8.check_utf8(self._utf8, allow_surrogates=True,
- force_len=res_index) # can't raise
- return space.newint(res)
+ w_result = self._unwrap_and_search(space, w_sub, w_start, w_end)
+ if w_result is None:
+ w_result = space.newint(-1)
+ return w_result
def descr_rfind(self, space, w_sub, w_start=None, w_end=None):
- w_sub = self.convert_arg_to_w_unicode(space, w_sub)
- start_index, end_index = self._unwrap_and_compute_idx_params(
- space, w_start, w_end)
-
- res_index = self._utf8.rfind(w_sub._utf8, start_index, end_index)
- if res_index == -1:
- return space.newint(-1)
-
- res = rutf8.check_utf8(self._utf8, allow_surrogates=True,
- force_len=res_index) # can't raise
- return space.newint(res)
+ w_result = self._unwrap_and_search(space, w_sub, w_start, w_end,
+ forward=False)
+ if w_result is None:
+ w_result = space.newint(-1)
+ return w_result
def descr_index(self, space, w_sub, w_start=None, w_end=None):
- w_sub = self.convert_arg_to_w_unicode(space, w_sub)
- start_index, end_index = self._unwrap_and_compute_idx_params(
- space, w_start, w_end)
-
- res_index = self._utf8.find(w_sub._utf8, start_index, end_index)
- if res_index == -1:
+ w_result = self._unwrap_and_search(space, w_sub, w_start, w_end)
+ if w_result is None:
raise oefmt(space.w_ValueError,
"substring not found in string.index")
-
- res = rutf8.check_utf8(self._utf8, allow_surrogates=True,
- force_len=res_index) # can't raise
- return space.newint(res)
+ return w_result
def descr_rindex(self, space, w_sub, w_start=None, w_end=None):
- w_sub = self.convert_arg_to_w_unicode(space, w_sub)
- start_index, end_index = self._unwrap_and_compute_idx_params(
- space, w_start, w_end)
-
- res_index = self._utf8.rfind(w_sub._utf8, start_index, end_index)
- if res_index == -1:
+ w_result = self._unwrap_and_search(space, w_sub, w_start, w_end,
+ forward=False)
+ if w_result is None:
raise oefmt(space.w_ValueError,
"substring not found in string.rindex")
-
- res = rutf8.check_utf8(self._utf8, allow_surrogates=True,
- force_len=res_index) # can't raise
- return space.newint(res)
+ return w_result
@specialize.arg(2)
def _is_generic(self, space, func_name):
@@ -908,11 +882,48 @@
index += self._length
if index < 0 or index >= self._length:
raise oefmt(space.w_IndexError, "string index out of range")
- storage = self._get_index_storage()
- start = rutf8.codepoint_position_at_index(self._utf8, storage, index)
+ start = self._index_to_byte(index)
end = rutf8.next_codepoint_pos(self._utf8, start)
return W_UnicodeObject(self._utf8[start:end], 1)
+ def _index_to_byte(self, index):
+ return rutf8.codepoint_position_at_index(
+ self._utf8, self._get_index_storage(), index)
+
+ @always_inline
+ def _unwrap_and_search(self, space, w_sub, w_start, w_end, forward=True):
+ w_sub = self.convert_arg_to_w_unicode(space, w_sub)
+ start, end = unwrap_start_stop(space, self._length, w_start, w_end)
+ if start == 0:
+ start_index = 0
+ elif start > self._length:
+ return None
+ else:
+ start_index = self._index_to_byte(start)
+
+ if end >= self._length:
+ end = self._length
+ end_index = len(self._utf8)
+ else:
+ end_index = self._index_to_byte(end)
+
+ if forward:
+ res_index = self._utf8.find(w_sub._utf8, start_index, end_index)
+ if res_index < 0:
+ return None
+ skip = rutf8.codepoints_in_utf8(self._utf8, start_index, res_index)
+ res = start + skip
+ assert res >= 0
+ return space.newint(res)
+ else:
+ res_index = self._utf8.rfind(w_sub._utf8, start_index, end_index)
+ if res_index < 0:
+ return None
+ skip = rutf8.codepoints_in_utf8(self._utf8, res_index, end_index)
+ res = end - skip
+ assert res >= 0
+ return space.newint(res)
+
def _unwrap_and_compute_idx_params(self, space, w_start, w_end):
start, end = unwrap_start_stop(space, self._length, w_start, w_end)
# XXX for now just create index
diff --git a/rpython/rlib/rutf8.py b/rpython/rlib/rutf8.py
--- a/rpython/rlib/rutf8.py
+++ b/rpython/rlib/rutf8.py
@@ -15,11 +15,12 @@
extra code in the middle for error handlers and so on.
"""
+import sys
from rpython.rlib.objectmodel import enforceargs
from rpython.rlib.rstring import StringBuilder
from rpython.rlib import jit
from rpython.rlib.rarithmetic import r_uint, intmask
-from rpython.rtyper.lltypesystem import lltype
+from rpython.rtyper.lltypesystem import lltype, rffi
def unichr_as_utf8(code, allow_surrogates=False):
@@ -290,17 +291,16 @@
#@jit.elidable
-def check_utf8(s, allow_surrogates, force_len=-1):
+def check_utf8(s, allow_surrogates):
"""Check that 's' is a utf-8-encoded byte string.
Returns the length (number of chars) or raise CheckError.
- Note that surrogates are not handled specially here.
+ If allow_surrogates is False, then also raise if we see any.
+ Note also codepoints_in_utf8(), which also computes the length
+ faster by assuming that 's' is valid utf-8.
"""
pos = 0
continuation_bytes = 0
- if force_len == -1:
- end = len(s)
- else:
- end = force_len
+ end = len(s)
while pos < end:
ordch1 = ord(s[pos])
pos += 1
@@ -359,6 +359,23 @@
return pos - continuation_bytes
@jit.elidable
+def codepoints_in_utf8(value, start=0, end=sys.maxint):
+ """Return the number of codepoints in the UTF-8 byte string
+ 'value[start:end]'. Assumes 0 <= start <= len(value) and start <= end.
+ """
+ if end > len(value):
+ end = len(value)
+ assert 0 <= start <= end
+ length = 0
+ for i in range(start, end):
+ # we want to count the number of chars not between 0x80 and 0xBF;
+ # we do that by casting the char to a signed integer
+ signedchar = rffi.cast(rffi.SIGNEDCHAR, ord(value[i]))
+ if rffi.cast(lltype.Signed, signedchar) >= -0x40:
+ length += 1
+ return length
+
+ at jit.elidable
def surrogate_in_utf8(value):
"""Check if the UTF-8 byte string 'value' contains a surrogate.
The 'value' argument must be otherwise correctly formed for UTF-8.
diff --git a/rpython/rlib/test/test_rutf8.py b/rpython/rlib/test/test_rutf8.py
--- a/rpython/rlib/test/test_rutf8.py
+++ b/rpython/rlib/test/test_rutf8.py
@@ -85,6 +85,19 @@
r = (ch in txt)
assert r == response
+ at given(strategies.text(), strategies.integers(min_value=0),
+ strategies.integers(min_value=0))
+def test_codepoints_in_utf8(u, start, len1):
+ end = start + len1
+ if end > len(u):
+ extra = end - len(u)
+ else:
+ extra = 0
+ count = rutf8.codepoints_in_utf8(u.encode('utf8'),
+ len(u[:start].encode('utf8')),
+ len(u[:end].encode('utf8')) + extra)
+ assert count == len(u[start:end])
+
@given(strategies.text())
def test_utf8_index_storage(u):
index = rutf8.create_utf8_index_storage(u.encode('utf8'), len(u))
More information about the pypy-commit
mailing list