[pypy-commit] pypy speed-up-stringsearch: try to split ll_search into two elidable functions, one of which only depends on the search string

cfbolz pypy.commits at gmail.com
Thu Mar 3 05:53:57 EST 2016


Author: Carl Friedrich Bolz <cfbolz at gmx.de>
Branch: speed-up-stringsearch
Changeset: r82672:37cc19e3dcb4
Date: 2016-02-27 17:01 +0100
http://bitbucket.org/pypy/pypy/changeset/37cc19e3dcb4/

Log:	try to split ll_search into two elidable functions, one of which
	only depends on the search string

diff --git a/rpython/rtyper/lltypesystem/rstr.py b/rpython/rtyper/lltypesystem/rstr.py
--- a/rpython/rtyper/lltypesystem/rstr.py
+++ b/rpython/rtyper/lltypesystem/rstr.py
@@ -721,12 +721,9 @@
         return res
 
     @staticmethod
-    @jit.elidable
     def ll_search(s1, s2, start, end, mode):
-        count = 0
         n = end - start
         m = len(s2.chars)
-
         if m == 0:
             if mode == FAST_COUNT:
                 return end - start + 1
@@ -743,74 +740,112 @@
             return -1
 
         mlast = m - 1
+
+        if mode != FAST_RFIND:
+            skip, mask = LLHelpers._precompute_skip_mask_forward(s2)
+            return LLHelpers._str_search_forward(s1, s2, start, end, mode, skip, mask)
+        else:
+            skip, mask = LLHelpers._precompute_skip_mask_backward(s2)
+            return LLHelpers._str_search_backward(s1, s2, start, end, skip, mask)
+
+    @staticmethod
+    @jit.elidable
+    def _precompute_skip_mask_forward(s2):
+        mlast = len(s2.chars) - 1
         skip = mlast - 1
         mask = 0
+        lastchar = s2.chars[mlast]
+        for i in range(mlast):
+            mask = bloom_add(mask, s2.chars[i])
+            if s2.chars[i] == lastchar:
+                skip = mlast - i - 1
+        mask = bloom_add(mask, lastchar)
+        return skip, mask
 
-        if mode != FAST_RFIND:
-            for i in range(mlast):
-                mask = bloom_add(mask, s2.chars[i])
-                if s2.chars[i] == s2.chars[mlast]:
-                    skip = mlast - i - 1
-            mask = bloom_add(mask, s2.chars[mlast])
+    @staticmethod
+    @jit.elidable
+    def _precompute_skip_mask_backward(s2):
+        mlast = len(s2.chars) - 1
+        skip = mlast - 1
+        firstchar = s2.chars[0]
+        mask = bloom_add(0, firstchar)
+        for i in range(mlast, 0, -1):
+            mask = bloom_add(mask, s2.chars[i])
+            if s2.chars[i] == firstchar:
+                skip = i - 1
+        return skip, mask
 
-            i = start - 1
-            while i + 1 <= start + w:
-                i += 1
-                if s1.chars[i + m - 1] == s2.chars[m - 1]:
-                    for j in range(mlast):
-                        if s1.chars[i + j] != s2.chars[j]:
-                            break
-                    else:
-                        if mode != FAST_COUNT:
-                            return i
-                        count += 1
-                        i += mlast
-                        continue
+    @staticmethod
+    @jit.elidable
+    def _str_search_forward(s1, s2, start, end, mode, skip, mask):
+        count = 0
+        n = end - start
+        m = len(s2.chars)
 
-                    if i + m < len(s1.chars):
-                        c = s1.chars[i + m]
-                    else:
-                        c = '\0'
-                    if not bloom(mask, c):
-                        i += m
-                    else:
-                        i += skip
+        w = n - m
+        mlast = m - 1
+        i = start - 1
+        lastchar = s2.chars[mlast]
+        while i + 1 <= start + w:
+            i += 1
+            if s1.chars[i + m - 1] == lastchar:
+                for j in range(mlast):
+                    if s1.chars[i + j] != s2.chars[j]:
+                        break
                 else:
-                    if i + m < len(s1.chars):
-                        c = s1.chars[i + m]
-                    else:
-                        c = '\0'
-                    if not bloom(mask, c):
-                        i += m
-        else:
-            mask = bloom_add(mask, s2.chars[0])
-            for i in range(mlast, 0, -1):
-                mask = bloom_add(mask, s2.chars[i])
-                if s2.chars[i] == s2.chars[0]:
-                    skip = i - 1
+                    if mode != FAST_COUNT:
+                        return i
+                    count += 1
+                    i += mlast
+                    continue
 
-            i = start + w + 1
-            while i - 1 >= start:
-                i -= 1
-                if s1.chars[i] == s2.chars[0]:
-                    for j in xrange(mlast, 0, -1):
-                        if s1.chars[i + j] != s2.chars[j]:
-                            break
-                    else:
-                        return i
-                    if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]):
-                        i -= m
-                    else:
-                        i -= skip
+                if i + m < len(s1.chars):
+                    c = s1.chars[i + m]
                 else:
-                    if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]):
-                        i -= m
-
+                    c = '\0'
+                if not bloom(mask, c):
+                    i += m
+                else:
+                    i += skip
+            else:
+                if i + m < len(s1.chars):
+                    c = s1.chars[i + m]
+                else:
+                    c = '\0'
+                if not bloom(mask, c):
+                    i += m
         if mode != FAST_COUNT:
             return -1
         return count
 
     @staticmethod
+    @jit.elidable
+    def _str_search_backward(s1, s2, start, end, skip, mask):
+        n = end - start
+        m = len(s2.chars)
+
+        w = n - m
+        mlast = m - 1
+        i = start + w + 1
+        firstchar = s2.chars[0]
+        while i - 1 >= start:
+            i -= 1
+            if s1.chars[i] == firstchar:
+                for j in xrange(mlast, 0, -1):
+                    if s1.chars[i + j] != s2.chars[j]:
+                        break
+                else:
+                    return i
+                if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]):
+                    i -= m
+                else:
+                    i -= skip
+            else:
+                if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]):
+                    i -= m
+        return -1
+
+    @staticmethod
     @signature(types.int(), types.any(), returns=types.any())
     @jit.look_inside_iff(lambda length, items: jit.loop_unrolling_heuristic(
         items, length))


More information about the pypy-commit mailing list