[pypy-svn] r76679 - pypy/branch/rsre2/pypy/rlib/rsre

arigo at codespeak.net arigo at codespeak.net
Fri Aug 20 17:03:11 CEST 2010


Author: arigo
Date: Fri Aug 20 17:03:09 2010
New Revision: 76679

Modified:
   pypy/branch/rsre2/pypy/rlib/rsre/rsre_core.py
Log:
A version with seriously less @specializectx around.
Only part of the functions are now specialized; the
calls from non-specialized to specialized code are
an indirect call (so only a bit slower than direct
calls), and they occur at points where the overhead
is not really noticeable (as opposed to e.g. every
time we call ctx.str()).


Modified: pypy/branch/rsre2/pypy/rlib/rsre/rsre_core.py
==============================================================================
--- pypy/branch/rsre2/pypy/rlib/rsre/rsre_core.py	(original)
+++ pypy/branch/rsre2/pypy/rlib/rsre/rsre_core.py	Fri Aug 20 17:03:09 2010
@@ -1,7 +1,7 @@
 import sys
 from pypy.rlib.debug import check_nonneg
+from pypy.rlib.unroll import unrolling_iterable
 from pypy.rlib.rsre import rsre_char
-from pypy.rlib import jit
 from pypy.tool.sourcetools import func_with_new_name
 
 
@@ -41,19 +41,29 @@
 # ____________________________________________________________
 
 def specializectx(func):
-    """A decorator that specializes 'func' for each concrete subclass
-    XyzMatchContext.  It must then be called as func(ctx,...) where
-    ctx is known to be of a specific subclass.
+    """A decorator that specializes 'func(ctx,...)' for each concrete subclass
+    of AbstractMatchContext.  During annotation, if 'ctx' is known to be a
+    specific subclass, calling 'func' is a direct call; if 'ctx' is only known
+    to be of class AbstractMatchContext, calling 'func' is an indirect call.
     """
-    i = list(func.func_code.co_varnames).index('ctx')
-    func._annspecialcase_ = 'specialize:argtype(%d)' % i
-    return func
+    assert func.func_code.co_varnames[0] == 'ctx'
+    specname = '_spec_' + func.func_name
+    # Install a copy of the function under the name '_spec_funcname' in each
+    # concrete subclass
+    for prefix, concreteclass in [('str', StrMatchContext),
+                                  ('uni', UnicodeMatchContext)]:
+        newfunc = func_with_new_name(func, prefix + specname)
+        setattr(concreteclass, specname, newfunc)
+    # Return a dispatcher function, specialized on the exact type of 'ctx'
+    def dispatch(ctx, *args):
+        return getattr(ctx, specname)(*args)
+    dispatch._annspecialcase_ = 'specialize:argtype(0)'
+    return dispatch
 
 # ____________________________________________________________
 
 class AbstractMatchContext(object):
     """Abstract base class"""
-    _must_specialize_ = True
     match_start = 0
     match_end = 0
     match_marks = None
@@ -74,16 +84,15 @@
 
     def str(self, index):
         """NOT_RPYTHON: Must be overridden in a concrete subclass.
-        The line below is used to generate a translation-time crash
+        The tag ^^^ here is used to generate a translation-time crash
         if there is a call to str() that is indirect.  All calls must
         be direct for performance reasons; you need to specialize the
         caller with @specializectx."""
         raise NotImplementedError
 
-    @specializectx
-    def lowstr(ctx, index):
-        c = ctx.str(index)
-        return rsre_char.getlower(c, ctx.flags)
+    def lowstr(self, index):
+        """NOT_RPYTHON: Similar to str()."""
+        raise NotImplementedError
 
     def get_mark(self, gid):
         return find_mark(self.match_marks, gid)
@@ -136,6 +145,10 @@
         check_nonneg(index)
         return ord(self._string[index])
 
+    def lowstr(self, index):
+        c = self.str(index)
+        return rsre_char.getlower(c, self.flags)
+
 class UnicodeMatchContext(AbstractMatchContext):
     """Concrete subclass for matching in a unicode string."""
 
@@ -149,6 +162,10 @@
         check_nonneg(index)
         return ord(self._unicodestr[index])
 
+    def lowstr(self, index):
+        c = self.str(index)
+        return rsre_char.getlower(c, self.flags)
+
 # ____________________________________________________________
 
 class Mark(object):
@@ -171,7 +188,6 @@
 class MatchResult(object):
     subresult = None
 
-    @specializectx
     def move_to_next_result(self, ctx):
         result = self.subresult
         if result is None:
@@ -192,8 +208,6 @@
         self.start_ptr = ptr
         self.start_marks = marks
 
-    @specializectx
-    @jit.unroll_safe     # there are only a few branch alternatives
     def find_first_result(self, ctx):
         ppos = self.ppos
         while ctx.pat(ppos):
@@ -213,7 +227,6 @@
         self.start_ptr = ptr
         self.start_marks = marks
 
-    @specializectx
     def find_first_result(self, ctx):
         ptr = self.start_ptr
         while ptr >= self.minptr:
@@ -235,7 +248,6 @@
         self.start_ptr = ptr
         self.start_marks = marks
 
-    @specializectx
     def find_first_result(self, ctx):
         ptr = self.start_ptr
         while ptr <= self.maxptr:
@@ -244,20 +256,27 @@
                 self.subresult = result
                 self.start_ptr = ptr
                 return self
-            ptr1 = find_repetition_end(ctx, self.ppos3, ptr, 1)
-            if ptr1 == ptr:
+            if not self.next_char_ok(ctx, ptr):
                 break
-            ptr = ptr1
+            ptr += 1
 
-    @specializectx
     def find_next_result(self, ctx):
         ptr = self.start_ptr
-        ptr1 = find_repetition_end(ctx, self.ppos3, ptr, 1)
-        if ptr1 == ptr:
+        if not self.next_char_ok(ctx, ptr):
             return
-        self.start_ptr = ptr1
+        self.start_ptr = ptr + 1
         return self.find_first_result(ctx)
 
+    def next_char_ok(self, ctx, ptr):
+        if ptr == ctx.end:
+            return False
+        ppos = self.ppos3
+        op = ctx.pat(ppos)
+        for op1, (checkerfn, _) in unroll_char_checker:
+            if op1 == op:
+                return checkerfn(ctx, ptr, ppos)
+        raise NotImplementedError("next_char_ok[%d]" % op)
+
 class AbstractUntilMatchResult(MatchResult):
 
     def __init__(self, ppos, tailppos, ptr, marks):
@@ -277,16 +296,13 @@
 
 class MaxUntilMatchResult(AbstractUntilMatchResult):
 
-    @specializectx
     def find_first_result(self, ctx):
         enum = sre_match(ctx, self.ppos + 3, self.cur_ptr, self.cur_marks)
         return self.search_next(ctx, enum, resume=False)
 
-    @specializectx
     def find_next_result(self, ctx):
         return self.search_next(ctx, None, resume=True)
 
-    @specializectx
     def search_next(self, ctx, enum, resume):
         ppos = self.ppos
         min = ctx.pat(ppos+1)
@@ -330,15 +346,12 @@
 
 class MinUntilMatchResult(AbstractUntilMatchResult):
 
-    @specializectx
     def find_first_result(self, ctx):
         return self.search_next(ctx, resume=False)
 
-    @specializectx
     def find_next_result(self, ctx):
         return self.search_next(ctx, resume=True)
 
-    @specializectx
     def search_next(self, ctx, resume):
         ppos = self.ppos
         min = ctx.pat(ppos+1)
@@ -383,8 +396,6 @@
 # ____________________________________________________________
 
 @specializectx
- at jit.unroll_safe      # it's safe to unroll the main 'while' loop:
-                      # 'ppos' is only ever incremented in this function
 def sre_match(ctx, ppos, ptr, marks):
     """Returns either None or a MatchResult object.  Usually we only need
     the first result, but there is the case of REPEAT...UNTIL where we
@@ -666,69 +677,57 @@
         end1 = ptr + maxcount
         if end1 <= end:
             end = end1
-
     op = ctx.pat(ppos)
-    if op == OPCODE_ANY:            return fre_ANY(ctx, ptr, end)
-    if op == OPCODE_ANY_ALL:        return end
-    if op == OPCODE_IN:             return fre_IN(ctx, ptr, end, ppos)
-    if op == OPCODE_IN_IGNORE:      return fre_IN_IGNORE(ctx, ptr, end, ppos)
-    if op == OPCODE_LITERAL:        return fre_LITERAL(ctx, ptr, end, ppos)
-    if op == OPCODE_LITERAL_IGNORE: return fre_LITERAL_IGNORE(ctx,ptr,end,ppos)
-    if op == OPCODE_NOT_LITERAL:    return fre_NOT_LITERAL(ctx, ptr, end, ppos)
-    if op == OPCODE_NOT_LITERAL_IGNORE: return fre_NOT_LITERAL_IGNORE(ctx,ptr,
-                                                                      end,ppos)
+    for op1, (_, fre) in unroll_char_checker:
+        if op1 == op:
+            return fre(ctx, ptr, end, ppos)
     raise NotImplementedError("rsre.find_repetition_end[%d]" % op)
 
 @specializectx
-def fre_ANY(ctx, ptr, end):
-    # repeated dot wildcard.
-    while ptr < end and not rsre_char.is_linebreak(ctx.str(ptr)):
-        ptr += 1
-    return ptr
-
- at specializectx
-def fre_IN(ctx, ptr, end, ppos):
-    # repeated set
-    while ptr < end and rsre_char.check_charset(ctx.pattern, ppos+2,
-                                                ctx.str(ptr)):
-        ptr += 1
-    return ptr
-
- at specializectx
-def fre_IN_IGNORE(ctx, ptr, end, ppos):
-    # repeated set
-    while ptr < end and rsre_char.check_charset(ctx.pattern, ppos+2,
-                                                ctx.lowstr(ptr)):
-        ptr += 1
-    return ptr
-
- at specializectx
-def fre_LITERAL(ctx, ptr, end, ppos):
-    chr = ctx.pat(ppos+1)
-    while ptr < end and ctx.str(ptr) == chr:
-        ptr += 1
-    return ptr
-
- at specializectx
-def fre_LITERAL_IGNORE(ctx, ptr, end, ppos):
-    chr = ctx.pat(ppos+1)
-    while ptr < end and ctx.lowstr(ptr) == chr:
-        ptr += 1
-    return ptr
-
- at specializectx
-def fre_NOT_LITERAL(ctx, ptr, end, ppos):
-    chr = ctx.pat(ppos+1)
-    while ptr < end and ctx.str(ptr) != chr:
-        ptr += 1
-    return ptr
-
- at specializectx
-def fre_NOT_LITERAL_IGNORE(ctx, ptr, end, ppos):
-    chr = ctx.pat(ppos+1)
-    while ptr < end and ctx.lowstr(ptr) != chr:
-        ptr += 1
-    return ptr
+def match_ANY(ctx, ptr, ppos):   # dot wildcard.
+    return not rsre_char.is_linebreak(ctx.str(ptr))
+def match_ANY_ALL(ctx, ptr, ppos):
+    return True    # match anything (including a newline)
+ at specializectx
+def match_IN(ctx, ptr, ppos):
+    return rsre_char.check_charset(ctx.pattern, ppos+2, ctx.str(ptr))
+ at specializectx
+def match_IN_IGNORE(ctx, ptr, ppos):
+    return rsre_char.check_charset(ctx.pattern, ppos+2, ctx.lowstr(ptr))
+ at specializectx
+def match_LITERAL(ctx, ptr, ppos):
+    return ctx.str(ptr) == ctx.pat(ppos+1)
+ at specializectx
+def match_LITERAL_IGNORE(ctx, ptr, ppos):
+    return ctx.lowstr(ptr) == ctx.pat(ppos+1)
+ at specializectx
+def match_NOT_LITERAL(ctx, ptr, ppos):
+    return ctx.str(ptr) != ctx.pat(ppos+1)
+ at specializectx
+def match_NOT_LITERAL_IGNORE(ctx, ptr, ppos):
+    return ctx.lowstr(ptr) != ctx.pat(ppos+1)
+
+def _make_fre(checkerfn):
+    if checkerfn == match_ANY_ALL:
+        def fre(ctx, ptr, end, ppos):
+            return end
+    else:
+        def fre(ctx, ptr, end, ppos):
+            while ptr < end and checkerfn(ctx, ptr, ppos):
+                ptr += 1
+            return ptr
+    return checkerfn, fre
+
+unroll_char_checker = unrolling_iterable([
+    (OPCODE_ANY,                _make_fre(match_ANY)),
+    (OPCODE_ANY_ALL,            _make_fre(match_ANY_ALL)),
+    (OPCODE_IN,                 _make_fre(match_IN)),
+    (OPCODE_IN_IGNORE,          _make_fre(match_IN_IGNORE)),
+    (OPCODE_LITERAL,            _make_fre(match_LITERAL)),
+    (OPCODE_LITERAL_IGNORE,     _make_fre(match_LITERAL_IGNORE)),
+    (OPCODE_NOT_LITERAL,        _make_fre(match_NOT_LITERAL)),
+    (OPCODE_NOT_LITERAL_IGNORE, _make_fre(match_NOT_LITERAL_IGNORE)),
+    ])
 
 ##### At dispatch
 
@@ -813,17 +812,12 @@
 
 def match(pattern, string, start=0, end=sys.maxint, flags=0):
     ctx = StrMatchContext(pattern, string, start, end, flags)
-    if match_context(ctx) is not None:
-        return ctx
-    return None
+    return match_context(ctx)
 
 def search(pattern, string, start=0, end=sys.maxint, flags=0):
     ctx = StrMatchContext(pattern, string, start, end, flags)
-    if search_context(ctx) is not None:
-        return ctx
-    return None
+    return search_context(ctx)
 
- at specializectx
 def match_context(ctx):
     if ctx.end < ctx.match_start:
         return None
@@ -831,7 +825,6 @@
         return ctx
     return None
 
- at specializectx
 def search_context(ctx):
     if ctx.end < ctx.match_start:
         return None
@@ -840,7 +833,6 @@
             return fast_search(ctx)
     return regular_search(ctx)
 
- at specializectx
 def regular_search(ctx):
     start = ctx.match_start
     while start <= ctx.end:
@@ -864,7 +856,8 @@
     overlap_offset = 7 + prefix_len - 1
     assert overlap_offset >= 0
     pattern_offset = ctx.pat(1) + 1
-    assert pattern_offset >= 0
+    ppos_start = pattern_offset + 2 * prefix_skip
+    assert ppos_start >= 0
     i = 0
     string_position = ctx.match_start
     end = ctx.end
@@ -889,8 +882,7 @@
                         ctx.match_end = ptr
                         ctx.match_marks = None
                         return ctx
-                    ppos = pattern_offset + 2 * prefix_skip
-                    if sre_match(ctx, ppos, ptr, None) is not None:
+                    if sre_match(ctx, ppos_start, ptr, None) is not None:
                         ctx.match_start = start
                         return ctx
                     i = ctx.pat(overlap_offset + i)



More information about the Pypy-commit mailing list