[pypy-svn] r76516 - in pypy/branch/rsre2/pypy: annotation rlib/rsre rlib/rsre/test

arigo at codespeak.net arigo at codespeak.net
Fri Aug 6 18:14:55 CEST 2010


Author: arigo
Date: Fri Aug  6 18:14:53 2010
New Revision: 76516

Modified:
   pypy/branch/rsre2/pypy/annotation/specialize.py
   pypy/branch/rsre2/pypy/rlib/rsre/rsre_core.py
   pypy/branch/rsre2/pypy/rlib/rsre/rsre_re.py
   pypy/branch/rsre2/pypy/rlib/rsre/test/test_zinterp.py
Log:
Unicode support, with appropriate specializations
in order to translate different versions of some
of the functions in rsre_core depending on whether
they operate on strings or unicodes.


Modified: pypy/branch/rsre2/pypy/annotation/specialize.py
==============================================================================
--- pypy/branch/rsre2/pypy/annotation/specialize.py	(original)
+++ pypy/branch/rsre2/pypy/annotation/specialize.py	Fri Aug  6 18:14:53 2010
@@ -354,6 +354,12 @@
 
 def specialize_argtype(funcdesc, args_s, *argindices):
     key = tuple([args_s[i].knowntype for i in argindices])
+    for cls in key:
+        try:
+            assert '_must_specialize_' not in cls.classdesc.pyobj.__dict__, (
+                "%s has the tag _must_specialize_" % (cls,))
+        except AttributeError:
+            pass
     return maybe_star_args(funcdesc, key, args_s)
 
 def specialize_arglistitemtype(funcdesc, args_s, i):

Modified: pypy/branch/rsre2/pypy/rlib/rsre/rsre_core.py
==============================================================================
--- pypy/branch/rsre2/pypy/rlib/rsre/rsre_core.py	(original)
+++ pypy/branch/rsre2/pypy/rlib/rsre/rsre_core.py	Fri Aug  6 18:14:53 2010
@@ -2,6 +2,7 @@
 from pypy.rlib.debug import check_nonneg
 from pypy.rlib.rsre import rsre_char
 from pypy.rlib import jit
+from pypy.tool.sourcetools import func_with_new_name
 
 
 OPCODE_FAILURE            = 0
@@ -37,20 +38,50 @@
 #OPCODE_SUBPATTERN        = 30
 OPCODE_MIN_REPEAT_ONE     = 31
 
+# ____________________________________________________________
+
+def specializectx(func):
+    """A decorator that specializes 'func' for each concrete subclass
+    XyzMatchContext.  It must then be called as func(ctx,...) where
+    ctx is known to be of a specific subclass.  Use this for the case
+    of very common and small methods; for large methods where an
+    indirect call is ok, use @specializectxmethod.
+    """
+    func._annspecialcase_ = 'specialize:argtype(0)'
+    return func
+
+def specializectxmethod(func):
+    """A decorator that puts 'func' as a method on each concrete
+    subclass XyzMatchContext.  This is an annotation trick to allow a
+    different version of the function to be called (as a method call)
+    depending on whether it operates on strings or unicodes.  It is ok
+    to do ctx.func(...) even if ctx is a general AbstractMatchContext;
+    it becomes an indirect call in the C version.
+    """
+    name = func.__name__
+    setattr(StrMatchContext, name,
+            func_with_new_name(func, name + '_str'))
+    setattr(UnicodeMatchContext, name,
+            func_with_new_name(func, name + '_unicode'))
+    return NotImplemented    # the original decorated function is not available
 
-class MatchContext(object):
+# ____________________________________________________________
+
+class AbstractMatchContext(object):
+    """Abstract base class"""
+    _must_specialize_ = True
     match_start = 0
     match_end = 0
     match_marks = None
     match_marks_flat = None
 
-    def __init__(self, pattern, string, match_start, end, flags):
+    def __init__(self, pattern, match_start, end, flags):
+        # here, 'end' must be at most len(string)
         self.pattern = pattern
-        self.string = string
-        if end > len(string):
-            end = len(string)
-        self.end = end
+        if match_start < 0:
+            match_start = 0
         self.match_start = match_start
+        self.end = end
         self.flags = flags
 
     def pat(self, index):
@@ -58,9 +89,14 @@
         return self.pattern[index]
 
     def str(self, index):
-        check_nonneg(index)
-        return ord(self.string[index])
+        """NOT_RPYTHON: Must be overridden in a concrete subclass.
+        The line below is used to generate a translation-time crash
+        if there is a call to str() that is indirect.  All calls must
+        be direct for performance reasons; you need to specialize the
+        caller with @specializectx or @specializectxmethod."""
+        raise NotImplementedError
 
+    @specializectx
     def lowstr(self, index):
         c = self.str(index)
         return rsre_char.getlower(c, self.flags)
@@ -88,7 +124,7 @@
         return self.match_marks_flat
 
     def span(self, groupnum=0):
-        # compatibility
+        "NOT_RPYTHON"   # compatibility
         fmarks = self.flatten_marks()
         groupnum *= 2
         if groupnum >= len(fmarks):
@@ -96,31 +132,40 @@
         return (fmarks[groupnum], fmarks[groupnum+1])
 
     def group(self, groupnum=0):
-        # compatibility
+        "NOT_RPYTHON"   # compatibility
         frm, to = self.span(groupnum)
         if 0 <= frm <= to:
-            return self.string[frm:to]
+            return self._string[frm:to]
         else:
             return None
 
-    def at_boundary(self, ptr, word_checker):
-        if self.end == 0:
-            return False
-        prevptr = ptr - 1
-        that = prevptr >= 0 and word_checker(self.str(prevptr))
-        this = ptr < self.end and word_checker(self.str(ptr))
-        return this != that
-    at_boundary._annspecialcase_ = 'specialize:arg(2)'
+class StrMatchContext(AbstractMatchContext):
+    """Concrete subclass for matching in a plain string."""
 
-    def at_non_boundary(self, ptr, word_checker):
-        if self.end == 0:
-            return False
-        prevptr = ptr - 1
-        that = prevptr >= 0 and word_checker(self.str(prevptr))
-        this = ptr < self.end and word_checker(self.str(ptr))
-        return this == that
-    at_non_boundary._annspecialcase_ = 'specialize:arg(2)'
+    def __init__(self, pattern, string, match_start, end, flags):
+        if end > len(string):
+            end = len(string)
+        AbstractMatchContext.__init__(self, pattern, match_start, end, flags)
+        self._string = string
 
+    def str(self, index):
+        check_nonneg(index)
+        return ord(self._string[index])
+
+class UnicodeMatchContext(AbstractMatchContext):
+    """Concrete subclass for matching in a unicode string."""
+
+    def __init__(self, pattern, unicodestr, match_start, end, flags):
+        if end > len(unicodestr):
+            end = len(unicodestr)
+        AbstractMatchContext.__init__(self, pattern, match_start, end, flags)
+        self._unicodestr = unicodestr
+
+    def str(self, index):
+        check_nonneg(index)
+        return ord(self._unicodestr[index])
+
+# ____________________________________________________________
 
 class Mark(object):
     _immutable_ = True
@@ -137,6 +182,7 @@
         mark = mark.prev
     return -1
 
+# ____________________________________________________________
 
 class MatchResult(object):
     subresult = None
@@ -165,7 +211,7 @@
     def find_first_result(self, ctx):
         ppos = self.ppos
         while ctx.pat(ppos):
-            result = sre_match(ctx, ppos + 1, self.start_ptr, self.start_marks)
+            result = ctx.sre_match(ppos + 1, self.start_ptr, self.start_marks)
             ppos += ctx.pat(ppos)
             if result is not None:
                 self.subresult = result
@@ -184,7 +230,7 @@
     def find_first_result(self, ctx):
         ptr = self.start_ptr
         while ptr >= self.minptr:
-            result = sre_match(ctx, self.nextppos, ptr, self.start_marks)
+            result = ctx.sre_match(self.nextppos, ptr, self.start_marks)
             ptr -= 1
             if result is not None:
                 self.subresult = result
@@ -205,19 +251,19 @@
     def find_first_result(self, ctx):
         ptr = self.start_ptr
         while ptr <= self.maxptr:
-            result = sre_match(ctx, self.nextppos, ptr, self.start_marks)
+            result = ctx.sre_match(self.nextppos, ptr, self.start_marks)
             if result is not None:
                 self.subresult = result
                 self.start_ptr = ptr
                 return self
-            ptr1 = find_repetition_end(ctx, self.ppos3, ptr, 1)
+            ptr1 = ctx.find_repetition_end(self.ppos3, ptr, 1)
             if ptr1 == ptr:
                 break
             ptr = ptr1
 
     def find_next_result(self, ctx):
         ptr = self.start_ptr
-        ptr1 = find_repetition_end(ctx, self.ppos3, ptr, 1)
+        ptr1 = ctx.find_repetition_end(self.ppos3, ptr, 1)
         if ptr1 == ptr:
             return
         self.start_ptr = ptr1
@@ -243,7 +289,7 @@
 class MaxUntilMatchResult(AbstractUntilMatchResult):
 
     def find_first_result(self, ctx):
-        enum = sre_match(ctx, self.ppos + 3, self.cur_ptr, self.cur_marks)
+        enum = ctx.sre_match(self.ppos + 3, self.cur_ptr, self.cur_marks)
         return self.search_next(ctx, enum, resume=False)
 
     def find_next_result(self, ctx):
@@ -268,7 +314,7 @@
                     # 'item' no longer matches.
                     if not resume and self.num_pending >= min:
                         # try to match 'tail' if we have enough 'item'
-                        result = sre_match(ctx, self.tailppos, ptr, marks)
+                        result = ctx.sre_match(self.tailppos, ptr, marks)
                         if result is not None:
                             self.subresult = result
                             self.cur_ptr = ptr
@@ -286,7 +332,7 @@
             #
             if max == 65535 or self.num_pending < max:
                 # try to match one more 'item'
-                enum = sre_match(ctx, ppos + 3, ptr, marks)
+                enum = ctx.sre_match(ppos + 3, ptr, marks)
             else:
                 enum = None    # 'max' reached, no more matches
 
@@ -307,7 +353,7 @@
         while True:
             # try to match 'tail' if we have enough 'item'
             if not resume and self.num_pending >= min:
-                result = sre_match(ctx, self.tailppos, ptr, marks)
+                result = ctx.sre_match(self.tailppos, ptr, marks)
                 if result is not None:
                     self.subresult = result
                     self.cur_ptr = ptr
@@ -317,7 +363,7 @@
 
             if max == 65535 or self.num_pending < max:
                 # try to match one more 'item'
-                enum = sre_match(ctx, ppos + 3, ptr, marks)
+                enum = ctx.sre_match(ppos + 3, ptr, marks)
             else:
                 enum = None    # 'max' reached, no more matches
 
@@ -341,6 +387,7 @@
 
 # ____________________________________________________________
 
+ at specializectxmethod
 @jit.unroll_safe      # it's safe to unroll the main 'while' loop:
                       # 'ppos' is only ever incremented in this function
 def sre_match(ctx, ppos, ptr, marks):
@@ -380,7 +427,7 @@
             # assert subpattern
             # <ASSERT> <0=skip> <1=back> <pattern>
             ptr1 = ptr - ctx.pat(ppos+1)
-            if ptr1 < 0 or sre_match(ctx, ppos + 2, ptr1, marks) is None:
+            if ptr1 < 0 or ctx.sre_match(ppos + 2, ptr1, marks) is None:
                 return
             marks = ctx.match_marks
             ppos += ctx.pat(ppos)
@@ -389,7 +436,7 @@
             # assert not subpattern
             # <ASSERT_NOT> <0=skip> <1=back> <pattern>
             ptr1 = ptr - ctx.pat(ppos+1)
-            if ptr1 >= 0 and sre_match(ctx, ppos + 2, ptr1, marks) is not None:
+            if ptr1 >= 0 and ctx.sre_match(ppos + 2, ptr1, marks) is not None:
                 return
             ppos += ctx.pat(ppos)
 
@@ -548,7 +595,7 @@
             minptr = start + ctx.pat(ppos+1)
             if minptr > ctx.end:
                 return    # cannot match
-            ptr = find_repetition_end(ctx, ppos+3, start, ctx.pat(ppos+2))
+            ptr = ctx.find_repetition_end(ppos+3, start, ctx.pat(ppos+2))
             # when we arrive here, ptr points to the tail of the target
             # string.  check if the rest of the pattern matches,
             # and backtrack if not.
@@ -570,7 +617,7 @@
                 if minptr > ctx.end:
                     return   # cannot match
                 # count using pattern min as the maximum
-                ptr = find_repetition_end(ctx, ppos+3, ptr, min)
+                ptr = ctx.find_repetition_end(ppos+3, ptr, min)
                 if ptr < minptr:
                     return   # did not match minimum number of times
 
@@ -598,6 +645,7 @@
     length = endptr - startptr     # < 0 if endptr < startptr (or if endptr=-1)
     return startptr, length
 
+ at specializectx
 def match_repeated(ctx, ptr, oldptr, length):
     if ptr + length > ctx.end:
         return False
@@ -606,6 +654,7 @@
             return False
     return True
 
+ at specializectx
 def match_repeated_ignore(ctx, ptr, oldptr, length):
     if ptr + length > ctx.end:
         return False
@@ -614,6 +663,7 @@
             return False
     return True
 
+ at specializectxmethod
 def find_repetition_end(ctx, ppos, ptr, maxcount):
     end = ctx.end
     # adjust end
@@ -634,12 +684,14 @@
                                                                       end,ppos)
     raise NotImplementedError("rsre.find_repetition_end[%d]" % op)
 
+ at specializectx
 def fre_ANY(ctx, ptr, end):
     # repeated dot wildcard.
     while ptr < end and not rsre_char.is_linebreak(ctx.str(ptr)):
         ptr += 1
     return ptr
 
+ at specializectx
 def fre_IN(ctx, ptr, end, ppos):
     # repeated set
     while ptr < end and rsre_char.check_charset(ctx.pattern, ppos+2,
@@ -647,6 +699,7 @@
         ptr += 1
     return ptr
 
+ at specializectx
 def fre_IN_IGNORE(ctx, ptr, end, ppos):
     # repeated set
     while ptr < end and rsre_char.check_charset(ctx.pattern, ppos+2,
@@ -654,24 +707,28 @@
         ptr += 1
     return ptr
 
+ at specializectx
 def fre_LITERAL(ctx, ptr, end, ppos):
     chr = ctx.pat(ppos+1)
     while ptr < end and ctx.str(ptr) == chr:
         ptr += 1
     return ptr
 
+ at specializectx
 def fre_LITERAL_IGNORE(ctx, ptr, end, ppos):
     chr = ctx.pat(ppos+1)
     while ptr < end and ctx.lowstr(ptr) == chr:
         ptr += 1
     return ptr
 
+ at specializectx
 def fre_NOT_LITERAL(ctx, ptr, end, ppos):
     chr = ctx.pat(ppos+1)
     while ptr < end and ctx.str(ptr) != chr:
         ptr += 1
     return ptr
 
+ at specializectx
 def fre_NOT_LITERAL_IGNORE(ctx, ptr, end, ppos):
     chr = ctx.pat(ppos+1)
     while ptr < end and ctx.lowstr(ptr) != chr:
@@ -693,6 +750,7 @@
 AT_UNI_BOUNDARY = 10
 AT_UNI_NON_BOUNDARY = 11
 
+ at specializectx
 def sre_at(ctx, atcode, ptr):
     if (atcode == AT_BEGINNING or
         atcode == AT_BEGINNING_STRING):
@@ -703,10 +761,10 @@
         return prevptr < 0 or rsre_char.is_linebreak(ctx.str(prevptr))
 
     elif atcode == AT_BOUNDARY:
-        return ctx.at_boundary(ptr, rsre_char.is_word)
+        return at_boundary(ctx, ptr, rsre_char.is_word)
 
     elif atcode == AT_NON_BOUNDARY:
-        return ctx.at_non_boundary(ptr, rsre_char.is_word)
+        return at_non_boundary(ctx, ptr, rsre_char.is_word)
 
     elif atcode == AT_END:
         remaining_chars = ctx.end - ptr
@@ -720,47 +778,79 @@
         return ptr == ctx.end
 
     elif atcode == AT_LOC_BOUNDARY:
-        return ctx.at_boundary(ptr, rsre_char.is_loc_word)
+        return at_boundary(ctx, ptr, rsre_char.is_loc_word)
 
     elif atcode == AT_LOC_NON_BOUNDARY:
-        return ctx.at_non_boundary(ptr, rsre_char.is_loc_word)
+        return at_non_boundary(ctx, ptr, rsre_char.is_loc_word)
 
     elif atcode == AT_UNI_BOUNDARY:
-        return ctx.at_boundary(ptr, rsre_char.is_uni_word)
+        return at_boundary(ctx, ptr, rsre_char.is_uni_word)
 
     elif atcode == AT_UNI_NON_BOUNDARY:
-        return ctx.at_non_boundary(ptr, rsre_char.is_uni_word)
+        return at_non_boundary(ctx, ptr, rsre_char.is_uni_word)
 
     return False
 
+ at specializectx
+def at_boundary(ctx, ptr, word_checker):
+    if ctx.end == 0:
+        return False
+    prevptr = ptr - 1
+    that = prevptr >= 0 and word_checker(ctx.str(prevptr))
+    this = ptr < ctx.end and word_checker(ctx.str(ptr))
+    return this != that
+
+ at specializectx
+def at_non_boundary(ctx, ptr, word_checker):
+    if ctx.end == 0:
+        return False
+    prevptr = ptr - 1
+    that = prevptr >= 0 and word_checker(ctx.str(prevptr))
+    this = ptr < ctx.end and word_checker(ctx.str(ptr))
+    return this == that
+
 # ____________________________________________________________
 
 def match(pattern, string, start=0, end=sys.maxint, flags=0):
-    if start < 0: start = 0
-    if end < start: return None
-    ctx = MatchContext(pattern, string, start, end, flags)
-    if sre_match(ctx, 0, start, None) is not None:
+    ctx = StrMatchContext(pattern, string, start, end, flags)
+    if match_context(ctx) is not None:
         return ctx
     return None
 
 def search(pattern, string, start=0, end=sys.maxint, flags=0):
-    if start < 0: start = 0
-    if end < start: return None
-    ctx = MatchContext(pattern, string, start, end, flags)
+    ctx = StrMatchContext(pattern, string, start, end, flags)
+    if search_context(ctx) is not None:
+        return ctx
+    return None
+
+ at specializectx
+def match_context(ctx):
+    if ctx.end < ctx.match_start:
+        return None
+    if ctx.sre_match(0, ctx.match_start, None) is not None:
+        return ctx
+    return None
+
+ at specializectx
+def search_context(ctx):
+    if ctx.end < ctx.match_start:
+        return None
     if ctx.pat(0) == OPCODE_INFO:
         if ctx.pat(2) & rsre_char.SRE_INFO_PREFIX and ctx.pat(5) > 1:
             return fast_search(ctx)
     return regular_search(ctx)
 
+ at specializectx
 def regular_search(ctx):
     start = ctx.match_start
     while start <= ctx.end:
-        if sre_match(ctx, 0, start, None) is not None:
+        if ctx.sre_match(0, start, None) is not None:
             ctx.match_start = start
             return ctx
         start += 1
     return None
 
+ at specializectx
 def fast_search(ctx):
     # skips forward in a string as fast as possible using information from
     # an optimization info block
@@ -800,7 +890,7 @@
                         ctx.match_marks = None
                         return ctx
                     ppos = pattern_offset + 2 * prefix_skip
-                    if sre_match(ctx, ppos, ptr, None) is not None:
+                    if ctx.sre_match(ppos, ptr, None) is not None:
                         ctx.match_start = start
                         return ctx
                     i = ctx.pat(overlap_offset + i)

Modified: pypy/branch/rsre2/pypy/rlib/rsre/rsre_re.py
==============================================================================
--- pypy/branch/rsre2/pypy/rlib/rsre/rsre_re.py	(original)
+++ pypy/branch/rsre2/pypy/rlib/rsre/rsre_re.py	Fri Aug  6 18:14:53 2010
@@ -110,7 +110,7 @@
         for group in groups:
             frm, to = self.span(group)
             if 0 <= frm <= to:
-                result.append(self._ctx.string[frm:to])
+                result.append(self._ctx._string[frm:to])
             else:
                 result.append(None)
         if len(result) > 1:
@@ -160,7 +160,7 @@
 
     @property
     def string(self):
-        return self._ctx.string
+        return self._ctx._string
 
     @property
     def pos(self):

Modified: pypy/branch/rsre2/pypy/rlib/rsre/test/test_zinterp.py
==============================================================================
--- pypy/branch/rsre2/pypy/rlib/rsre/test/test_zinterp.py	(original)
+++ pypy/branch/rsre2/pypy/rlib/rsre/test/test_zinterp.py	Fri Aug  6 18:14:53 2010
@@ -8,6 +8,12 @@
     pattern = [n] * n
     string = chr(n) * n
     rsre_core.search(pattern, string)
+    #
+    unicodestr = unichr(n) * n
+    ctx = rsre_core.UnicodeMatchContext(pattern, unicodestr,
+                                        0, len(unicodestr), 0)
+    rsre_core.search_context(ctx)
+    #
     return 0
 
 



More information about the Pypy-commit mailing list