[pypy-svn] r75480 - in pypy/branch/multijit-4/pypy/rlib/rsre: . test

arigo at codespeak.net arigo at codespeak.net
Sun Jun 20 18:53:27 CEST 2010


Author: arigo
Date: Sun Jun 20 18:53:26 2010
New Revision: 75480

Modified:
   pypy/branch/multijit-4/pypy/rlib/rsre/rsre_char.py
   pypy/branch/multijit-4/pypy/rlib/rsre/rsre_core.py
   pypy/branch/multijit-4/pypy/rlib/rsre/test/targetrsre.py
   pypy/branch/multijit-4/pypy/rlib/rsre/test/test_search.py
Log:
Whack whack whack.  Still modest improvements, although in the example
I am trying I managed to remove most mallocs.


Modified: pypy/branch/multijit-4/pypy/rlib/rsre/rsre_char.py
==============================================================================
--- pypy/branch/multijit-4/pypy/rlib/rsre/rsre_char.py	(original)
+++ pypy/branch/multijit-4/pypy/rlib/rsre/rsre_char.py	Sun Jun 20 18:53:26 2010
@@ -155,11 +155,9 @@
 SET_OK = -1
 SET_NOT_OK = -2
 
-def check_charset(char_code, context):
+def check_charset(char_code, pattern_codes, index):
     """Checks whether a character matches set of arbitrary length. Currently
     assumes the set starts at the first member of pattern_codes."""
-    pattern_codes = context.pattern_codes
-    index = context.code_position
     negated = SET_OK
     while index >= 0:
         opcode = pattern_codes[index]

Modified: pypy/branch/multijit-4/pypy/rlib/rsre/rsre_core.py
==============================================================================
--- pypy/branch/multijit-4/pypy/rlib/rsre/rsre_core.py	(original)
+++ pypy/branch/multijit-4/pypy/rlib/rsre/rsre_core.py	Sun Jun 20 18:53:26 2010
@@ -21,7 +21,22 @@
 class StateMixin(object):
 
     def reset(self):
-        self.string_position = self.start
+        pass
+
+    def search(self, pattern_codes):
+        return search(self, pattern_codes)
+
+    def match(self, pattern_codes):
+        return match(self, pattern_codes)
+
+
+class DynamicState(object):
+
+    def __init__(self, staticstate, start, end, string_position):
+        self.staticstate = staticstate
+        self.start = start
+        self.end = end
+        self.string_position = string_position
         self.marks = [0, 0, 0, 0]
         self.lastindex = -1
         self.marks_count = 0
@@ -32,15 +47,9 @@
         # with x and y saved indices to allow pops.
         self.saved_marks = []
         self.saved_marks_top = 0
-        self.context_stack = []
+        self.top_context = None
         self.repeat = None
 
-    def search(self, pattern_codes):
-        return search(self, pattern_codes)
-
-    def match(self, pattern_codes):
-        return match(self, pattern_codes)
-
     def create_regs(self, group_count):
         """Creates a tuple of index pairs representing matched groups, a format
         that's convenient for SRE_Match."""
@@ -57,6 +66,9 @@
             regs.append((start, end))
         return regs
 
+    def lower(self, char):
+        return self.staticstate.lower(char)
+
     def set_mark(self, mark_nr, position):
         assert mark_nr >= 0
         if mark_nr & 1:
@@ -81,17 +93,22 @@
     def marks_push(self):
         # Create in saved_marks: [......, m1,..,mn,lastindex,p, ...........]
         #                                 ^p                    ^newsize
-        p = self.saved_marks_top
-        n = self.marks_count
+        newsize = self._marks_push(self.saved_marks_top, self.marks_count,
+                                   self.saved_marks, self.marks,
+                                   self.lastindex)
+        self.saved_marks_top = newsize
+
+    @staticmethod
+    def _marks_push(p, n, saved_marks, marks, lastindex):
         assert p >= 0
         newsize = p + n + 2
-        while len(self.saved_marks) < newsize:
-            self.saved_marks.append(-1)
+        while len(saved_marks) < newsize:
+            saved_marks.append(-1)
         for i in range(n):
-            self.saved_marks[p+i] = self.marks[i]
-        self.saved_marks[p+n] = self.lastindex
-        self.saved_marks[p+n+1] = p
-        self.saved_marks_top = newsize
+            saved_marks[p+i] = marks[i]
+        saved_marks[p+n] = lastindex
+        saved_marks[p+n+1] = p
+        return newsize
 
     def marks_pop(self):
         p0 = self.marks_pop_keep()
@@ -106,11 +123,15 @@
         n = p1 - p0
         assert p0 >= 0 and n >= 0
         self.lastindex = self.saved_marks[p1]
-        for i in range(n):
-            self.marks[i] = self.saved_marks[p0+i]
+        self._marks_pop_keep(self.marks, self.saved_marks, n, p0)
         self.marks_count = n
         return p0
 
+    @staticmethod
+    def _marks_pop_keep(marks, saved_marks, n, p0):
+        for i in range(n):
+            marks[i] = saved_marks[p0+i]
+
     def marks_pop_discard(self):
         p0 = self.saved_marks[self.saved_marks_top-1]
         self.saved_marks_top = p0
@@ -118,14 +139,16 @@
 
 class MatchContext(rsre_char.MatchContextBase):
 
-    def __init__(self, state, pattern_codes, offset=0):
+    def __init__(self, state, pattern_codes, offset=0, prev_context=None):
         self.state = state
+        self.staticstate = state.staticstate
         self.pattern_codes = pattern_codes
         self.string_position = state.string_position
         self.code_position = offset
         self.has_matched = self.UNDECIDED
         self.backup = []
         self.resume_at_opcode = -1
+        self.prev_context = prev_context
 
     def push_new_context(self, pattern_offset):
         """Creates a new child context of this context and pushes it on the
@@ -133,8 +156,9 @@
         start interpreting from."""
         offset = self.code_position + pattern_offset
         assert offset >= 0
-        child_context = MatchContext(self.state, self.pattern_codes, offset)
-        self.state.context_stack.append(child_context)
+        child_context = MatchContext(self.state, self.pattern_codes, offset,
+                                     self.state.top_context)
+        self.state.top_context = child_context
         self.child_context = child_context
         return child_context
 
@@ -150,7 +174,7 @@
         return values
 
     def peek_char(self, peek=0):
-        return self.state.get_char_ord(self.string_position + peek)
+        return self.staticstate.get_char_ord(self.string_position + peek)
 
     def skip_char(self, skip_count):
         self.string_position = self.string_position + skip_count
@@ -192,18 +216,20 @@
 
 #### Main opcode dispatch loop
 
-def search(state, pattern_codes):
+def search(staticstate, pattern_codes):
     flags = 0
     if pattern_codes[0] == OPCODE_INFO:
         # optimization info block
         # <INFO> <1=skip> <2=flags> <3=min> <4=max> <5=prefix info>
         if pattern_codes[2] & SRE_INFO_PREFIX and pattern_codes[5] > 1:
-            return fast_search(state, pattern_codes)
+            return fast_search(staticstate, pattern_codes)
         flags = pattern_codes[2]
         offset = pattern_codes[1] + 1
         assert offset >= 0
         #pattern_codes = pattern_codes[offset:]
 
+    raise NotImplementedError("XXX")
+
     string_position = state.start
     while string_position <= state.end:
         state.reset()
@@ -215,10 +241,10 @@
 
 from pypy.rlib.jit import JitDriver, unroll_safe
 rsrejitdriver = JitDriver(greens=['i', 'overlap_offset', 'pattern_codes'],
-                          reds=['string_position', 'end', 'state'],
+                          reds=['string_position', 'end', 'staticstate'],
                           can_inline=lambda *args: False)
 
-def fast_search(state, pattern_codes):
+def fast_search(staticstate, pattern_codes):
     """Skips forward in a string as fast as possible using information from
     an optimization info block."""
     # pattern starts with a known prefix
@@ -227,8 +253,8 @@
     overlap_offset = 7 + prefix_len - 1
     assert overlap_offset >= 0
     i = 0
-    string_position = state.string_position
-    end = state.end
+    string_position = staticstate.start
+    end = staticstate.end
     while string_position < end:
         while True:
             rsrejitdriver.can_enter_jit(
@@ -236,16 +262,16 @@
                 i=i,
                 string_position=string_position,
                 end=end,
-                state=state,
+                staticstate=staticstate,
                 overlap_offset=overlap_offset)
             rsrejitdriver.jit_merge_point(
                 pattern_codes=pattern_codes,
                 i=i,
                 string_position=string_position,
                 end=end,
-                state=state,
+                staticstate=staticstate,
                 overlap_offset=overlap_offset)
-            char_ord = state.get_char_ord(string_position)
+            char_ord = staticstate.get_char_ord(string_position)
             if char_ord != pattern_codes[7+i]:
                 if i == 0:
                     break
@@ -257,21 +283,24 @@
                 if i == prefix_len:
                     # found a potential match
                     prefix_skip = pattern_codes[6]
-                    state.start = string_position + 1 - prefix_len
-                    state.string_position = string_position + 1 \
-                                                 - prefix_len + prefix_skip
+                    state = DynamicState(
+                        staticstate,
+                        start = string_position + 1 - prefix_len,
+                        end = staticstate.end,
+                        string_position = (string_position + 1
+                                           - prefix_len + prefix_skip))
                     flags = pattern_codes[2]
                     if flags & SRE_INFO_LITERAL:
-                        return True # matched all of pure literal pattern
+                        return state # matched all of pure literal pattern
                     pattern_offset = pattern_codes[1] + 1
                     start = pattern_offset + 2 * prefix_skip
                     assert start >= 0
                     if match(state, pattern_codes, start):
-                        return True
+                        return state
                     i = pattern_codes[overlap_offset + i]
                 break
         string_position += 1
-    return False
+    return None
 
 @unroll_safe
 def match(state, pattern_codes, pstart=0):
@@ -283,16 +312,16 @@
         if state.end - state.string_position < pattern_codes[pstart+3]:
             return False
         pstart += pattern_codes[pstart+1] + 1
-    state.context_stack.append(MatchContext(state, pattern_codes, pstart))
+    state.top_context = MatchContext(state, pattern_codes, pstart)
     has_matched = MatchContext.UNDECIDED
-    while len(state.context_stack) > 0:
-        context = state.context_stack[-1]
+    while state.top_context is not None:
+        context = state.top_context
         if context.has_matched == context.UNDECIDED:
             has_matched = dispatch_loop(context)
         else:
             has_matched = context.has_matched
         if has_matched != context.UNDECIDED: # don't pop if context isn't done
-            state.context_stack.pop()
+            state.top_context = context.prev_context
     return has_matched == MatchContext.MATCHED
 
 @unroll_safe
@@ -417,6 +446,7 @@
     char_code = ctx.peek_char()
     if ignore:
         char_code = ctx.state.lower(char_code)
+    raise NotImplementedError("XXX")
     if not rsre_char.check_charset(char_code, ctx):
         ctx.has_matched = ctx.NOT_MATCHED
         return
@@ -499,7 +529,19 @@
 
     # Initialize the actual backtracking
     if count >= mincount:
-        count = quickly_skip_unmatchable_positions(ctx, count, mincount)
+        # <optimization>
+        nextidx = ctx.peek_code(1)
+        if ctx.peek_code(nextidx + 1) == OPCODE_LITERAL:
+            # tail starts with a literal. skip positions where
+            # the rest of the pattern cannot possibly match
+            chr = ctx.peek_code(nextidx + 2)
+            sp1 = ctx.string_position
+            sp2 = quickly_skip_unmatchable_positions(chr, sp1,
+                                                     count - mincount,
+                                                     ctx.staticstate)
+            ctx.string_position = sp2
+            count -= (sp1 - sp2)
+        # </optimization>
         if count >= mincount:
             ctx.state.string_position = ctx.string_position
             ctx.push_new_context(ctx.peek_code(1) + 1)
@@ -512,22 +554,17 @@
     ctx.has_matched = ctx.NOT_MATCHED
     return True
 
-def quickly_skip_unmatchable_positions(ctx, count, mincount):
-    # this is only an optimization
-    nextidx = ctx.peek_code(1)
-    if ctx.peek_code(nextidx + 1) == OPCODE_LITERAL:
-        # tail starts with a literal. skip positions where
-        # the rest of the pattern cannot possibly match
-        chr = ctx.peek_code(nextidx + 2)
-        if ctx.at_end():
-            ctx.skip_char(-1)
-            count -= 1
-        while count >= mincount:
-            if ctx.peek_char() == chr:
-                break
-            ctx.skip_char(-1)
-            count -= 1
-    return count
+def quickly_skip_unmatchable_positions(chr, string_position, reduce_up_to,
+                                       staticstate):
+    string_position_min = string_position - reduce_up_to
+    assert string_position_min >= 0
+    if string_position == staticstate.end:
+        string_position -= 1
+    while string_position >= string_position_min:
+        if staticstate.get_char_ord(string_position) == chr:
+            break
+        string_position -= 1
+    return string_position
 
 def op_min_repeat_one(ctx):
     # match repeated sequence (minimizing)
@@ -803,7 +840,7 @@
         return True
     while group_start < group_end:
         new_char = ctx.peek_char()
-        old_char = ctx.state.get_char_ord(group_start)
+        old_char = ctx.staticstate.get_char_ord(group_start)
         if ctx.at_end() or (not ignore and old_char != new_char) \
                 or (ignore and ctx.state.lower(old_char) != ctx.state.lower(new_char)):
             ctx.has_matched = ctx.NOT_MATCHED
@@ -929,27 +966,30 @@
 
 ##### count_repetitions dispatch
 
-def general_cr_in(ctx, maxcount, ignore):
-    code_position = ctx.code_position
+def general_cr_in(maxcount, ignore, pattern_codes, code_position,
+                  staticstate, string_position):
+    code_position += 6    # set op pointer to the set code
     count = 0
     while count < maxcount:
-        ctx.code_position = code_position
-        ctx.skip_code(6) # set op pointer to the set code
-        char_code = ctx.peek_char(count)
+        char_code = staticstate.get_char_ord(string_position + count)
         if ignore:
-            char_code = ctx.state.lower(char_code)
-        if not rsre_char.check_charset(char_code, ctx):
+            char_code = staticstate.lower(char_code)
+        if not rsre_char.check_charset(char_code, pattern_codes,
+                                       code_position):
             break
         count += 1
-    ctx.code_position = code_position
     return count
-general_cr_in._annspecialcase_ = 'specialize:arg(2)'
+general_cr_in._annspecialcase_ = 'specialize:arg(1)'
 
 def cr_in(ctx, maxcount):
-    return general_cr_in(ctx, maxcount, False)
+    return general_cr_in(maxcount, False,
+                         ctx.pattern_codes, ctx.code_position,
+                         ctx.staticstate, ctx.string_position)
 
 def cr_in_ignore(ctx, maxcount):
-    return general_cr_in(ctx, maxcount, True)
+    return general_cr_in(maxcount, True,
+                         ctx.pattern_codes, ctx.code_position,
+                         ctx.staticstate, ctx.string_position)
 
 def cr_any(ctx, maxcount):
     count = 0

Modified: pypy/branch/multijit-4/pypy/rlib/rsre/test/targetrsre.py
==============================================================================
--- pypy/branch/multijit-4/pypy/rlib/rsre/test/targetrsre.py	(original)
+++ pypy/branch/multijit-4/pypy/rlib/rsre/test/targetrsre.py	Sun Jun 20 18:53:26 2010
@@ -32,9 +32,9 @@
     while True:
         state = rsre.SimpleStringState(data, p)
         res = state.search(r)
-        if not res:
+        if res is None:
             break
-        groups = state.create_regs(1)
+        groups = res.create_regs(1)
         matchstart, matchstop = groups[1]
         assert 0 <= matchstart <= matchstop
         print '%s: %s' % (filename, data[matchstart:matchstop])

Modified: pypy/branch/multijit-4/pypy/rlib/rsre/test/test_search.py
==============================================================================
--- pypy/branch/multijit-4/pypy/rlib/rsre/test/test_search.py	(original)
+++ pypy/branch/multijit-4/pypy/rlib/rsre/test/test_search.py	Sun Jun 20 18:53:26 2010
@@ -24,8 +24,8 @@
         r_code1 = self.get_code(r'<item>\s*<title>(.*?)</title>')
         state = rsre.SimpleStringState("foo<item>  <title>abc</title>def")
         res = state.search(r_code1)
-        assert res is True
-        groups = state.create_regs(1)
+        assert res is not None
+        groups = res.create_regs(1)
         assert groups[0] == (3, 29)
         assert groups[1] == (18, 21)
 



More information about the Pypy-commit mailing list