[pypy-commit] pypy unicode-utf8: Rewrite unicode.index/find to use rutf8.codepoints_in_utf8(). It should

arigo pypy.commits at gmail.com
Sat Oct 14 05:43:30 EDT 2017


Author: Armin Rigo <arigo at tunes.org>
Branch: unicode-utf8
Changeset: r92752:6bc169e4d079
Date: 2017-10-14 10:21 +0200
http://bitbucket.org/pypy/pypy/changeset/6bc169e4d079/

Log:	Rewrite unicode.index/find to use rutf8.codepoints_in_utf8(). It
	should ensure that the complexity of these methods is always
	correct. Previously it was O(result) instead of O(chars-searched),
	which can get very bad if .find() is called with a non-zero starting
	position.

diff --git a/pypy/objspace/std/test/test_unicodeobject.py b/pypy/objspace/std/test/test_unicodeobject.py
--- a/pypy/objspace/std/test/test_unicodeobject.py
+++ b/pypy/objspace/std/test/test_unicodeobject.py
@@ -2,6 +2,8 @@
 # -*- encoding: utf-8 -*-
 import py
 import sys
+from hypothesis import given, strategies, settings, example
+from pypy.interpreter.error import OperationError
 
 
 class TestUnicodeObject:
@@ -35,6 +37,49 @@
                 space.w_unicode, "__new__", space.w_unicode, w_uni)
         assert w_new is w_uni
 
+    @given(strategies.text(), strategies.integers(min_value=0, max_value=10),
+                              strategies.integers(min_value=-1, max_value=10))
+    def test_hypo_index_find(self, u, start, len1):
+        if start + len1 < 0:
+            return   # skip this case
+        v = u[start : start + len1]
+        space = self.space
+        w_u = space.wrap(u)
+        w_v = space.wrap(v)
+        expected = u.find(v, start, start + len1)
+        try:
+            w_index = space.call_method(w_u, 'index', w_v,
+                                        space.newint(start),
+                                        space.newint(start + len1))
+        except OperationError as e:
+            if not e.match(space, space.w_ValueError):
+                raise
+            assert expected == -1
+        else:
+            assert space.int_w(w_index) == expected >= 0
+
+        w_index = space.call_method(w_u, 'find', w_v,
+                                    space.newint(start),
+                                    space.newint(start + len1))
+        assert space.int_w(w_index) == expected
+
+        rexpected = u.rfind(v, start, start + len1)
+        try:
+            w_index = space.call_method(w_u, 'rindex', w_v,
+                                        space.newint(start),
+                                        space.newint(start + len1))
+        except OperationError as e:
+            if not e.match(space, space.w_ValueError):
+                raise
+            assert rexpected == -1
+        else:
+            assert space.int_w(w_index) == rexpected >= 0
+
+        w_index = space.call_method(w_u, 'rfind', w_v,
+                                    space.newint(start),
+                                    space.newint(start + len1))
+        assert space.int_w(w_index) == rexpected
+
 
 class AppTestUnicodeStringStdOnly:
     def test_compares(self):
@@ -698,6 +743,7 @@
     def test_index(self):
         assert u"rrarrrrrrrrra".index(u'a', 4, None) == 12
         assert u"rrarrrrrrrrra".index(u'a', None, 6) == 2
+        assert u"\u1234\u4321\u5678".index(u'\u5678', 1) == 2
 
     def test_rindex(self):
         from sys import maxint
@@ -707,6 +753,7 @@
         assert u'abcdefghiabc'.rindex(u'abc', 0, -1) == 0
         assert u'abcdefghiabc'.rindex(u'abc', -4*maxint, 4*maxint) == 9
         assert u'rrarrrrrrrrra'.rindex(u'a', 4, None) == 12
+        assert u"\u1234\u5678".rindex(u'\u5678') == 1
 
         raises(ValueError, u'abcdefghiabc'.rindex, u'hib')
         raises(ValueError, u'defghiabc'.rindex, u'def', 1)
@@ -721,6 +768,7 @@
         assert u'abcdefghiabc'.rfind(u'') == 12
         assert u'abcdefghiabc'.rfind(u'abcd') == 0
         assert u'abcdefghiabc'.rfind(u'abcz') == -1
+        assert u"\u1234\u5678".rfind(u'\u5678') == 1
 
     def test_rfind_corner_case(self):
         assert u'abc'.rfind('', 4) == -1
@@ -736,6 +784,7 @@
         assert 'abcdefghiabc'.rindex(u'abc') == 9
         raises(UnicodeDecodeError, '\x80'.index, u'')
         raises(UnicodeDecodeError, '\x80'.rindex, u'')
+        assert u"\u1234\u5678".find(u'\u5678') == 1
 
     def test_count(self):
         assert u"".count(u"x") ==0
diff --git a/pypy/objspace/std/unicodeobject.py b/pypy/objspace/std/unicodeobject.py
--- a/pypy/objspace/std/unicodeobject.py
+++ b/pypy/objspace/std/unicodeobject.py
@@ -1,7 +1,7 @@
 """The builtin unicode implementation"""
 
 from rpython.rlib.objectmodel import (
-    compute_hash, compute_unique_id, import_from_mixin,
+    compute_hash, compute_unique_id, import_from_mixin, always_inline,
     enforceargs, newlist_hint, specialize, we_are_translated)
 from rpython.rlib.buffer import StringBuffer
 from rpython.rlib.mutbuffer import MutableStringBuffer
@@ -427,58 +427,32 @@
         return W_UnicodeObject(result.build(), result_length)
 
     def descr_find(self, space, w_sub, w_start=None, w_end=None):
-        w_sub = self.convert_arg_to_w_unicode(space, w_sub)
-        start_index, end_index = self._unwrap_and_compute_idx_params(
-            space, w_start, w_end)
-
-        res_index = self._utf8.find(w_sub._utf8, start_index, end_index)
-        if res_index == -1:
-            return space.newint(-1)
-
-        res = rutf8.check_utf8(self._utf8, allow_surrogates=True,
-                               force_len=res_index) # can't raise
-        return space.newint(res)
+        w_result = self._unwrap_and_search(space, w_sub, w_start, w_end)
+        if w_result is None:
+            w_result = space.newint(-1)
+        return w_result
 
     def descr_rfind(self, space, w_sub, w_start=None, w_end=None):
-        w_sub = self.convert_arg_to_w_unicode(space, w_sub)
-        start_index, end_index = self._unwrap_and_compute_idx_params(
-            space, w_start, w_end)
-
-        res_index = self._utf8.rfind(w_sub._utf8, start_index, end_index)
-        if res_index == -1:
-            return space.newint(-1)
-
-        res = rutf8.check_utf8(self._utf8, allow_surrogates=True,
-                               force_len=res_index) # can't raise
-        return space.newint(res)
+        w_result = self._unwrap_and_search(space, w_sub, w_start, w_end,
+                                           forward=False)
+        if w_result is None:
+            w_result = space.newint(-1)
+        return w_result
 
     def descr_index(self, space, w_sub, w_start=None, w_end=None):
-        w_sub = self.convert_arg_to_w_unicode(space, w_sub)
-        start_index, end_index = self._unwrap_and_compute_idx_params(
-            space, w_start, w_end)
-
-        res_index = self._utf8.find(w_sub._utf8, start_index, end_index)
-        if res_index == -1:
+        w_result = self._unwrap_and_search(space, w_sub, w_start, w_end)
+        if w_result is None:
             raise oefmt(space.w_ValueError,
                         "substring not found in string.index")
-
-        res = rutf8.check_utf8(self._utf8, allow_surrogates=True,
-                               force_len=res_index) # can't raise
-        return space.newint(res)
+        return w_result
 
     def descr_rindex(self, space, w_sub, w_start=None, w_end=None):
-        w_sub = self.convert_arg_to_w_unicode(space, w_sub)
-        start_index, end_index = self._unwrap_and_compute_idx_params(
-            space, w_start, w_end)
-
-        res_index = self._utf8.rfind(w_sub._utf8, start_index, end_index)
-        if res_index == -1:
+        w_result = self._unwrap_and_search(space, w_sub, w_start, w_end,
+                                           forward=False)
+        if w_result is None:
             raise oefmt(space.w_ValueError,
                         "substring not found in string.rindex")
-
-        res = rutf8.check_utf8(self._utf8, allow_surrogates=True,
-                               force_len=res_index) # can't raise
-        return space.newint(res)
+        return w_result
 
     @specialize.arg(2)
     def _is_generic(self, space, func_name):
@@ -908,11 +882,48 @@
             index += self._length
         if index < 0 or index >= self._length:
             raise oefmt(space.w_IndexError, "string index out of range")
-        storage = self._get_index_storage()
-        start = rutf8.codepoint_position_at_index(self._utf8, storage, index)
+        start = self._index_to_byte(index)
         end = rutf8.next_codepoint_pos(self._utf8, start)
         return W_UnicodeObject(self._utf8[start:end], 1)
 
+    def _index_to_byte(self, index):
+        return rutf8.codepoint_position_at_index(
+            self._utf8, self._get_index_storage(), index)
+
+    @always_inline
+    def _unwrap_and_search(self, space, w_sub, w_start, w_end, forward=True):
+        w_sub = self.convert_arg_to_w_unicode(space, w_sub)
+        start, end = unwrap_start_stop(space, self._length, w_start, w_end)
+        if start == 0:
+            start_index = 0
+        elif start > self._length:
+            return None
+        else:
+            start_index = self._index_to_byte(start)
+
+        if end >= self._length:
+            end = self._length
+            end_index = len(self._utf8)
+        else:
+            end_index = self._index_to_byte(end)
+
+        if forward:
+            res_index = self._utf8.find(w_sub._utf8, start_index, end_index)
+            if res_index < 0:
+                return None
+            skip = rutf8.codepoints_in_utf8(self._utf8, start_index, res_index)
+            res = start + skip
+            assert res >= 0
+            return space.newint(res)
+        else:
+            res_index = self._utf8.rfind(w_sub._utf8, start_index, end_index)
+            if res_index < 0:
+                return None
+            skip = rutf8.codepoints_in_utf8(self._utf8, res_index, end_index)
+            res = end - skip
+            assert res >= 0
+            return space.newint(res)
+
     def _unwrap_and_compute_idx_params(self, space, w_start, w_end):
         start, end = unwrap_start_stop(space, self._length, w_start, w_end)
         # XXX for now just create index
diff --git a/rpython/rlib/rutf8.py b/rpython/rlib/rutf8.py
--- a/rpython/rlib/rutf8.py
+++ b/rpython/rlib/rutf8.py
@@ -15,11 +15,12 @@
 extra code in the middle for error handlers and so on.
 """
 
+import sys
 from rpython.rlib.objectmodel import enforceargs
 from rpython.rlib.rstring import StringBuilder
 from rpython.rlib import jit
 from rpython.rlib.rarithmetic import r_uint, intmask
-from rpython.rtyper.lltypesystem import lltype
+from rpython.rtyper.lltypesystem import lltype, rffi
 
 
 def unichr_as_utf8(code, allow_surrogates=False):
@@ -290,17 +291,16 @@
 
 
 #@jit.elidable
-def check_utf8(s, allow_surrogates, force_len=-1):
+def check_utf8(s, allow_surrogates):
     """Check that 's' is a utf-8-encoded byte string.
     Returns the length (number of chars) or raise CheckError.
-    Note that surrogates are not handled specially here.
+    If allow_surrogates is False, then also raise if we see any.
+    Note also codepoints_in_utf8(), which also computes the length
+    faster by assuming that 's' is valid utf-8.
     """
     pos = 0
     continuation_bytes = 0
-    if force_len == -1:
-        end = len(s)
-    else:
-        end = force_len
+    end = len(s)
     while pos < end:
         ordch1 = ord(s[pos])
         pos += 1
@@ -359,6 +359,23 @@
     return pos - continuation_bytes
 
 @jit.elidable
+def codepoints_in_utf8(value, start=0, end=sys.maxint):
+    """Return the number of codepoints in the UTF-8 byte string
+    'value[start:end]'.  Assumes 0 <= start <= len(value) and start <= end.
+    """
+    if end > len(value):
+        end = len(value)
+    assert 0 <= start <= end
+    length = 0
+    for i in range(start, end):
+        # we want to count the number of chars not between 0x80 and 0xBF;
+        # we do that by casting the char to a signed integer
+        signedchar = rffi.cast(rffi.SIGNEDCHAR, ord(value[i]))
+        if rffi.cast(lltype.Signed, signedchar) >= -0x40:
+            length += 1
+    return length
+
+ at jit.elidable
 def surrogate_in_utf8(value):
     """Check if the UTF-8 byte string 'value' contains a surrogate.
     The 'value' argument must be otherwise correctly formed for UTF-8.
diff --git a/rpython/rlib/test/test_rutf8.py b/rpython/rlib/test/test_rutf8.py
--- a/rpython/rlib/test/test_rutf8.py
+++ b/rpython/rlib/test/test_rutf8.py
@@ -85,6 +85,19 @@
     r = (ch in txt)
     assert r == response
 
+ at given(strategies.text(), strategies.integers(min_value=0),
+                          strategies.integers(min_value=0))
+def test_codepoints_in_utf8(u, start, len1):
+    end = start + len1
+    if end > len(u):
+        extra = end - len(u)
+    else:
+        extra = 0
+    count = rutf8.codepoints_in_utf8(u.encode('utf8'),
+                                     len(u[:start].encode('utf8')),
+                                     len(u[:end].encode('utf8')) + extra)
+    assert count == len(u[start:end])
+
 @given(strategies.text())
 def test_utf8_index_storage(u):
     index = rutf8.create_utf8_index_storage(u.encode('utf8'), len(u))


More information about the pypy-commit mailing list