[pypy-commit] pypy fix-bytearray-complexity: Reorganize rlib.rstring a bit; add a test

waedt noreply at buildbot.pypy.org
Mon Jun 2 19:47:15 CEST 2014


Author: Tyler Wade <wayedt at gmail.com>
Branch: fix-bytearray-complexity
Changeset: r71881:6cbc2ccc583e
Date: 2014-06-02 11:40 -0500
http://bitbucket.org/pypy/pypy/changeset/6cbc2ccc583e/

Log:	Reorganize rlib.rstring a bit; add a test

diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py
--- a/rpython/rlib/rstring.py
+++ b/rpython/rlib/rstring.py
@@ -16,29 +16,6 @@
 
 # -------------- public API for string functions -----------------------
 
- at specialize.argtype(0, 1)
-def _get_access_functions(value, other):
-    if (not (isinstance(value, str) or isinstance(value, unicode)) or
-        not (isinstance(other, str) or isinstance(other, unicode))):
-
-        def find(obj, other, start, end):
-            return search(obj, other, start, end, SEARCH_FIND)
-        def rfind(obj, other, start, end):
-            return search(obj, other, start, end, SEARCH_RFIND)
-        def count(obj, other, start, end):
-            return search(obj, other, start, end, SEARCH_COUNT)
-    else:
-        assert isinstance(value, str) or isinstance(value, unicode)
-        assert isinstance(other, str) or isinstance(other, unicode)
-        def find(obj, other, start, end):
-            return obj.find(other, start, end)
-        def rfind(obj, other, start, end):
-            return obj.rfind(other, start, end)
-        def count(obj, other, start, end):
-            return obj.count(other, start, end)
-
-    return find, rfind, count
-
 @specialize.argtype(0)
 def _isspace(char):
     if isinstance(char, str):
@@ -79,7 +56,6 @@
             i = j + 1
         return res
 
-    find, _, count = _get_access_functions(value, by)
     bylen = len(by)
     if bylen == 0:
         raise ValueError("empty separator")
@@ -88,16 +64,16 @@
     if bylen == 1:
         # fast path: uses str.rfind(character) and str.count(character)
         by = by[0]    # annotator hack: string -> char
-        count = count(value, by, 0, len(value))
-        if 0 <= maxsplit < count:
-            count = maxsplit
-        res = newlist_hint(count + 1)
-        while count > 0:
+        cnt = count(value, by, 0, len(value))
+        if 0 <= maxsplit < cnt:
+            cnt = maxsplit
+        res = newlist_hint(cnt + 1)
+        while cnt > 0:
             next = find(value, by, start, len(value))
             assert next >= 0 # cannot fail due to the value.count above
             res.append(value[start:next])
             start = next + bylen
-            count -= 1
+            cnt -= 1
         res.append(value[start:len(value)])
         return res
 
@@ -110,6 +86,7 @@
         next = find(value, by, start, len(value))
         if next < 0:
             break
+        assert start >= 0
         res.append(value[start:next])
         start = next + bylen
         maxsplit -= 1   # NB. if it's already < 0, it stays < 0
@@ -158,7 +135,6 @@
         res = newlist_hint(min(maxsplit + 1, len(value)))
     else:
         res = []
-    _, rfind, _ = _get_access_functions(value, by)
     end = len(value)
     bylen = len(by)
     if bylen == 0:
@@ -190,7 +166,6 @@
     if maxsplit == 0:
         return input
 
-    find, _, count = _get_access_functions(input, sub)
 
     if not sub:
         upper = len(input)
@@ -214,12 +189,12 @@
         builder.append_slice(input, upper, len(input))
     else:
         # First compute the exact result size
-        count = count(input, sub, 0, len(input))
-        if count > maxsplit and maxsplit > 0:
-            count = maxsplit
+        cnt = count(input, sub, 0, len(input))
+        if cnt > maxsplit and maxsplit > 0:
+            cnt = maxsplit
         diff_len = len(by) - len(sub)
         try:
-            result_size = ovfcheck(diff_len * count)
+            result_size = ovfcheck(diff_len * cnt)
             result_size = ovfcheck(result_size + len(input))
         except OverflowError:
             raise
@@ -280,8 +255,28 @@
             return False
     return True
 
-# Stolen form rpython.rtyper.lltypesytem.rstr
-# TODO: Ask about what to do with this...
+ at specialize.argtype(0, 1)
+def find(value, other, start, end):
+    if ((isinstance(value, str) or isinstance(value, unicode)) and
+        (isinstance(other, str) or isinstance(other, unicode))):
+        return value.find(other, start, end)
+    return _search(value, other, start, end, SEARCH_FIND)
+
+ at specialize.argtype(0, 1)
+def rfind(value, other, start, end):
+    if ((isinstance(value, str) or isinstance(value, unicode)) and
+        (isinstance(other, str) or isinstance(other, unicode))):
+        return value.rfind(other, start, end)
+    return _search(value, other, start, end, SEARCH_RFIND)
+
+ at specialize.argtype(0, 1)
+def count(value, other, start, end):
+    if ((isinstance(value, str) or isinstance(value, unicode)) and
+        (isinstance(other, str) or isinstance(other, unicode))):
+        return value.count(other, start, end)
+    return _search(value, other, start, end, SEARCH_COUNT)
+
+# -------------- substring searching helper ----------------
 
 SEARCH_COUNT = 0
 SEARCH_FIND = 1
@@ -294,7 +289,7 @@
     return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
 
 @specialize.argtype(0, 1)
-def search(value, other, start, end, mode):
+def _search(value, other, start, end, mode):
     if start < 0:
         start = 0
     if end > len(value):
diff --git a/rpython/rlib/test/test_rstring.py b/rpython/rlib/test/test_rstring.py
--- a/rpython/rlib/test/test_rstring.py
+++ b/rpython/rlib/test/test_rstring.py
@@ -2,7 +2,8 @@
 
 from rpython.rlib.rstring import StringBuilder, UnicodeBuilder, split, rsplit
 from rpython.rlib.rstring import replace, startswith, endswith
-from rpython.rlib.rstring import search, SEARCH_FIND, SEARCH_RFIND, SEARCH_COUNT
+from rpython.rlib.rstring import find, rfind, count
+from rpython.rlib.buffer import StringBuffer
 from rpython.rtyper.test.tool import BaseRtypingTest
 
 def test_split():
@@ -216,22 +217,22 @@
     assert isinstance(s.build(), unicode)
 
 def test_search():
-    def check_search(value, sub, *args, **kwargs):
+    def check_search(func, value, sub, *args, **kwargs):
         result = kwargs['res']
-        assert search(value, sub, *args) == result
-        assert search(list(value), sub, *args) == result
+        assert func(value, sub, *args) == result
+        assert func(list(value), sub, *args) == result
 
-    check_search('one two three', 'ne', 0, 13, SEARCH_FIND, res=1)
-    check_search('one two three', 'ne', 5, 13, SEARCH_FIND, res=-1)
-    check_search('one two three', '', 0, 13, SEARCH_FIND, res=0)
+    check_search(find, 'one two three', 'ne', 0, 13, res=1)
+    check_search(find, 'one two three', 'ne', 5, 13, res=-1)
+    check_search(find, 'one two three', '', 0, 13, res=0)
 
-    check_search('one two three', 'e', 0, 13, SEARCH_RFIND, res=12)
-    check_search('one two three', 'e', 0, 1, SEARCH_RFIND, res=-1)
-    check_search('one two three', '', 0, 13, SEARCH_RFIND, res=13)
+    check_search(rfind, 'one two three', 'e', 0, 13, res=12)
+    check_search(rfind, 'one two three', 'e', 0, 1, res=-1)
+    check_search(rfind, 'one two three', '', 0, 13, res=13)
 
-    check_search('one two three', 'e', 0, 13, SEARCH_COUNT, res=3)
-    check_search('one two three', 'e', 0, 1, SEARCH_COUNT, res=0)
-    check_search('one two three', '', 0, 13, SEARCH_RFIND, res=13)
+    check_search(count, 'one two three', 'e', 0, 13, res=3)
+    check_search(count, 'one two three', 'e', 0, 1, res=0)
+    check_search(count, 'one two three', '', 0, 13, res=14)
 
 
 class TestTranslates(BaseRtypingTest):
@@ -252,6 +253,20 @@
         res = self.interpret(fn, [])
         assert res
 
+    def test_buffer_parameter(self):
+        def fn():
+            res = True
+            res = res and split('a//b//c//d', StringBuffer('//')) == ['a', 'b', 'c', 'd']
+            res = res and split(u'a//b//c//d', StringBuffer('//')) == ['a', 'b', 'c', 'd']
+            res = res and rsplit('a//b//c//d', StringBuffer('//')) == ['a', 'b', 'c', 'd']
+            res = res and find('a//b//c//d', StringBuffer('//'), 0, 10) != -1
+            res = res and rfind('a//b//c//d', StringBuffer('//'), 0, 10) != -1
+            res = res and count('a//b//c//d', StringBuffer('//'), 0, 10) != 0
+            return res
+        res = self.interpret(fn, [])
+        assert res
+
+
     def test_replace(self):
         def fn():
             res = True


More information about the pypy-commit mailing list