[pypy-commit] pypy py3.5: merge default
cfbolz
pypy.commits at gmail.com
Thu Mar 29 06:03:38 EDT 2018
Author: Carl Friedrich Bolz-Tereick <cfbolz at gmx.de>
Branch: py3.5
Changeset: r94170:ffbcc29df485
Date: 2018-03-29 12:03 +0200
http://bitbucket.org/pypy/pypy/changeset/ffbcc29df485/
Log: merge default
diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst
--- a/pypy/doc/whatsnew-head.rst
+++ b/pypy/doc/whatsnew-head.rst
@@ -72,3 +72,14 @@
Optimize `Py*_Check` for `Bool`, `Float`, `Set`. Also refactor and simplify
`W_PyCWrapperObject` which is used to call slots from the C-API, greatly
improving microbenchmarks in https://github.com/antocuni/cpyext-benchmarks
+
+
+.. branch: fix-sre-problems
+
+Fix two (unrelated) JIT bugs manifesting in the re module:
+
+- green fields are broken and were thus disabled, plus their usage removed from
+ the _sre implementation
+
+- in rare "trace is too long" situations, the JIT could break behaviour
+ arbitrarily.
diff --git a/pypy/module/_cffi_backend/ccallback.py b/pypy/module/_cffi_backend/ccallback.py
--- a/pypy/module/_cffi_backend/ccallback.py
+++ b/pypy/module/_cffi_backend/ccallback.py
@@ -232,7 +232,9 @@
"different from the 'ffi.h' file seen at compile-time)")
def py_invoke(self, ll_res, ll_args):
+ key_pycode = self.key_pycode
jitdriver1.jit_merge_point(callback=self,
+ key_pycode=key_pycode,
ll_res=ll_res,
ll_args=ll_args)
self.do_invoke(ll_res, ll_args)
@@ -294,7 +296,7 @@
return 'cffi_callback ' + key_pycode.get_repr()
jitdriver1 = jit.JitDriver(name='cffi_callback',
- greens=['callback.key_pycode'],
+ greens=['key_pycode'],
reds=['ll_res', 'll_args', 'callback'],
get_printable_location=get_printable_location1)
diff --git a/pypy/module/_sre/interp_sre.py b/pypy/module/_sre/interp_sre.py
--- a/pypy/module/_sre/interp_sre.py
+++ b/pypy/module/_sre/interp_sre.py
@@ -76,15 +76,15 @@
w_import = space.getattr(space.builtin, space.newtext("__import__"))
return space.call_function(w_import, space.newtext("re"))
-def matchcontext(space, ctx):
+def matchcontext(space, ctx, pattern):
try:
- return rsre_core.match_context(ctx)
+ return rsre_core.match_context(ctx, pattern)
except rsre_core.Error as e:
raise OperationError(space.w_RuntimeError, space.newtext(e.msg))
-def searchcontext(space, ctx):
+def searchcontext(space, ctx, pattern):
try:
- return rsre_core.search_context(ctx)
+ return rsre_core.search_context(ctx, pattern)
except rsre_core.Error as e:
raise OperationError(space.w_RuntimeError, space.newtext(e.msg))
@@ -189,7 +189,7 @@
raise oefmt(space.w_TypeError,
"can't use a bytes pattern on a string-like "
"object")
- return rsre_core.UnicodeMatchContext(self.code, unicodestr,
+ return rsre_core.UnicodeMatchContext(unicodestr,
pos, endpos, flags)
else:
if self.is_known_unicode():
@@ -197,10 +197,10 @@
"can't use a string pattern on a bytes-like "
"object")
if string is not None:
- return rsre_core.StrMatchContext(self.code, string,
+ return rsre_core.StrMatchContext(string,
pos, endpos, flags)
else:
- return rsre_core.BufMatchContext(self.code, buf,
+ return rsre_core.BufMatchContext(buf,
pos, endpos, flags)
def getmatch(self, ctx, found):
@@ -212,18 +212,18 @@
@unwrap_spec(pos=int, endpos=int)
def match_w(self, w_string, pos=0, endpos=sys.maxint):
ctx = self.make_ctx(w_string, pos, endpos)
- return self.getmatch(ctx, matchcontext(self.space, ctx))
+ return self.getmatch(ctx, matchcontext(self.space, ctx, self.code))
@unwrap_spec(pos=int, endpos=int)
def fullmatch_w(self, w_string, pos=0, endpos=sys.maxint):
ctx = self.make_ctx(w_string, pos, endpos)
ctx.fullmatch_only = True
- return self.getmatch(ctx, matchcontext(self.space, ctx))
+ return self.getmatch(ctx, matchcontext(self.space, ctx, self.code))
@unwrap_spec(pos=int, endpos=int)
def search_w(self, w_string, pos=0, endpos=sys.maxint):
ctx = self.make_ctx(w_string, pos, endpos)
- return self.getmatch(ctx, searchcontext(self.space, ctx))
+ return self.getmatch(ctx, searchcontext(self.space, ctx, self.code))
@unwrap_spec(pos=int, endpos=int)
def findall_w(self, w_string, pos=0, endpos=sys.maxint):
@@ -231,7 +231,7 @@
matchlist_w = []
ctx = self.make_ctx(w_string, pos, endpos)
while ctx.match_start <= ctx.end:
- if not searchcontext(space, ctx):
+ if not searchcontext(space, ctx, self.code):
break
num_groups = self.num_groups
w_emptystr = space.newtext("")
@@ -256,14 +256,15 @@
# this also works as the implementation of the undocumented
# scanner() method.
ctx = self.make_ctx(w_string, pos, endpos)
- scanner = W_SRE_Scanner(self, ctx)
+ scanner = W_SRE_Scanner(self, ctx, self.code)
return scanner
@unwrap_spec(maxsplit=int)
def split_w(self, w_string, maxsplit=0):
space = self.space
- if self.code[0] != rsre_core.OPCODE_INFO or self.code[3] == 0:
- if self.code[0] == rsre_core.OPCODE_INFO and self.code[4] == 0:
+
+ if self.code.pattern[0] != rsre_core.OPCODE_INFO or self.code.pattern[3] == 0:
+ if self.code.pattern[0] == rsre_core.OPCODE_INFO and self.code.pattern[4] == 0:
raise oefmt(space.w_ValueError,
"split() requires a non-empty pattern match.")
space.warn(
@@ -275,7 +276,7 @@
last = 0
ctx = self.make_ctx(w_string)
while not maxsplit or n < maxsplit:
- if not searchcontext(space, ctx):
+ if not searchcontext(space, ctx, self.code):
break
if ctx.match_start == ctx.match_end: # zero-width match
if ctx.match_start == ctx.end: # or end of string
@@ -356,8 +357,8 @@
else:
sublist_w = []
n = last_pos = 0
+ pattern = self.code
while not count or n < count:
- pattern = ctx.pattern
sub_jitdriver.jit_merge_point(
self=self,
use_builder=use_builder,
@@ -374,7 +375,7 @@
n=n, last_pos=last_pos, sublist_w=sublist_w
)
space = self.space
- if not searchcontext(space, ctx):
+ if not searchcontext(space, ctx, pattern):
break
if last_pos < ctx.match_start:
_sub_append_slice(
@@ -474,7 +475,11 @@
space.readbuf_w(w_pattern)
srepat.w_pattern = w_pattern # the original uncompiled pattern
srepat.flags = flags
- srepat.code = code
+ # note: we assume that the app-level is caching SRE_Pattern objects,
+ # so that we don't need to do it here. Creating new SRE_Pattern
+ # objects all the time would be bad for the JIT, which relies on the
+ # identity of the CompiledPattern() object.
+ srepat.code = rsre_core.CompiledPattern(code)
srepat.num_groups = groups
srepat.w_groupindex = w_groupindex
srepat.w_indexgroup = w_indexgroup
@@ -711,10 +716,11 @@
# Our version is also directly iterable, to make finditer() easier.
class W_SRE_Scanner(W_Root):
- def __init__(self, pattern, ctx):
+ def __init__(self, pattern, ctx, code):
self.space = pattern.space
self.srepat = pattern
self.ctx = ctx
+ self.code = code
# 'self.ctx' is always a fresh context in which no searching
# or matching succeeded so far.
@@ -724,19 +730,19 @@
def next_w(self):
if self.ctx.match_start > self.ctx.end:
raise OperationError(self.space.w_StopIteration, self.space.w_None)
- if not searchcontext(self.space, self.ctx):
+ if not searchcontext(self.space, self.ctx, self.code):
raise OperationError(self.space.w_StopIteration, self.space.w_None)
return self.getmatch(True)
def match_w(self):
if self.ctx.match_start > self.ctx.end:
return self.space.w_None
- return self.getmatch(matchcontext(self.space, self.ctx))
+ return self.getmatch(matchcontext(self.space, self.ctx, self.code))
def search_w(self):
if self.ctx.match_start > self.ctx.end:
return self.space.w_None
- return self.getmatch(searchcontext(self.space, self.ctx))
+ return self.getmatch(searchcontext(self.space, self.ctx, self.code))
def getmatch(self, found):
if found:
diff --git a/rpython/jit/metainterp/history.py b/rpython/jit/metainterp/history.py
--- a/rpython/jit/metainterp/history.py
+++ b/rpython/jit/metainterp/history.py
@@ -701,6 +701,9 @@
def length(self):
return self.trace._count - len(self.trace.inputargs)
+ def trace_tag_overflow(self):
+ return self.trace.tag_overflow
+
def get_trace_position(self):
return self.trace.cut_point()
diff --git a/rpython/jit/metainterp/opencoder.py b/rpython/jit/metainterp/opencoder.py
--- a/rpython/jit/metainterp/opencoder.py
+++ b/rpython/jit/metainterp/opencoder.py
@@ -49,13 +49,6 @@
way up to lltype.Signed for indexes everywhere
"""
-def frontend_tag_overflow():
- # Minor abstraction leak: raise directly the right exception
- # expected by the rest of the machinery
- from rpython.jit.metainterp import history
- from rpython.rlib.jit import Counters
- raise history.SwitchToBlackhole(Counters.ABORT_TOO_LONG)
-
class BaseTrace(object):
pass
@@ -293,6 +286,7 @@
self._start = len(inputargs)
self._pos = self._start
self.inputargs = inputargs
+ self.tag_overflow = False
def append(self, v):
model = get_model(self)
@@ -300,12 +294,14 @@
# grow by 2X
self._ops = self._ops + [rffi.cast(model.STORAGE_TP, 0)] * len(self._ops)
if not model.MIN_VALUE <= v <= model.MAX_VALUE:
- raise frontend_tag_overflow()
+ v = 0 # broken value, but that's fine, tracing will stop soon
+ self.tag_overflow = True
self._ops[self._pos] = rffi.cast(model.STORAGE_TP, v)
self._pos += 1
- def done(self):
+ def tracing_done(self):
from rpython.rlib.debug import debug_start, debug_stop, debug_print
+ assert not self.tag_overflow
self._bigints_dict = {}
self._refs_dict = llhelper.new_ref_dict_3()
@@ -317,8 +313,6 @@
debug_print(" ref consts: " + str(self._consts_ptr) + " " + str(len(self._refs)))
debug_print(" descrs: " + str(len(self._descrs)))
debug_stop("jit-trace-done")
- return 0 # completely different than TraceIter.done, but we have to
- # share the base class
def length(self):
return self._pos
@@ -379,6 +373,7 @@
def record_op(self, opnum, argboxes, descr=None):
pos = self._index
+ old_pos = self._pos
self.append(opnum)
expected_arity = oparity[opnum]
if expected_arity == -1:
@@ -397,6 +392,10 @@
self._count += 1
if opclasses[opnum].type != 'v':
self._index += 1
+ if self.tag_overflow:
+ # potentially a broken op is left behind
+ # clean it up
+ self._pos = old_pos
return pos
def _encode_descr(self, descr):
@@ -424,10 +423,11 @@
vref_array = self._list_of_boxes(vref_boxes)
s = TopSnapshot(combine_uint(jitcode.index, pc), array, vable_array,
vref_array)
- assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0
# guards have no descr
self._snapshots.append(s)
- self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1)
+ if not self.tag_overflow: # otherwise we're broken anyway
+ assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0
+ self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1)
return s
def create_empty_top_snapshot(self, vable_boxes, vref_boxes):
@@ -436,10 +436,11 @@
vref_array = self._list_of_boxes(vref_boxes)
s = TopSnapshot(combine_uint(2**16 - 1, 0), [], vable_array,
vref_array)
- assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0
# guards have no descr
self._snapshots.append(s)
- self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1)
+ if not self.tag_overflow: # otherwise we're broken anyway
+ assert rffi.cast(lltype.Signed, self._ops[self._pos - 1]) == 0
+ self._ops[self._pos - 1] = rffi.cast(get_model(self).STORAGE_TP, len(self._snapshots) - 1)
return s
def create_snapshot(self, jitcode, pc, frame, flag):
diff --git a/rpython/jit/metainterp/pyjitpl.py b/rpython/jit/metainterp/pyjitpl.py
--- a/rpython/jit/metainterp/pyjitpl.py
+++ b/rpython/jit/metainterp/pyjitpl.py
@@ -2384,9 +2384,9 @@
def blackhole_if_trace_too_long(self):
warmrunnerstate = self.jitdriver_sd.warmstate
- if self.history.length() > warmrunnerstate.trace_limit:
+ if (self.history.length() > warmrunnerstate.trace_limit or
+ self.history.trace_tag_overflow()):
jd_sd, greenkey_of_huge_function = self.find_biggest_function()
- self.history.trace.done()
self.staticdata.stats.record_aborted(greenkey_of_huge_function)
self.portal_trace_positions = None
if greenkey_of_huge_function is not None:
@@ -2689,7 +2689,9 @@
try_disabling_unroll=False, exported_state=None):
num_green_args = self.jitdriver_sd.num_green_args
greenkey = original_boxes[:num_green_args]
- self.history.trace.done()
+ if self.history.trace_tag_overflow():
+ raise SwitchToBlackhole(Counters.ABORT_TOO_LONG)
+ self.history.trace.tracing_done()
if not self.partial_trace:
ptoken = self.get_procedure_token(greenkey)
if ptoken is not None and ptoken.target_tokens is not None:
@@ -2742,7 +2744,9 @@
self.history.record(rop.JUMP, live_arg_boxes[num_green_args:], None,
descr=target_jitcell_token)
self.history.ends_with_jump = True
- self.history.trace.done()
+ if self.history.trace_tag_overflow():
+ raise SwitchToBlackhole(Counters.ABORT_TOO_LONG)
+ self.history.trace.tracing_done()
try:
target_token = compile.compile_trace(self, self.resumekey,
live_arg_boxes[num_green_args:])
@@ -2776,7 +2780,9 @@
assert False
# FIXME: can we call compile_trace?
self.history.record(rop.FINISH, exits, None, descr=token)
- self.history.trace.done()
+ if self.history.trace_tag_overflow():
+ raise SwitchToBlackhole(Counters.ABORT_TOO_LONG)
+ self.history.trace.tracing_done()
target_token = compile.compile_trace(self, self.resumekey, exits)
if target_token is not token:
compile.giveup()
@@ -2802,7 +2808,9 @@
sd = self.staticdata
token = sd.exit_frame_with_exception_descr_ref
self.history.record(rop.FINISH, [valuebox], None, descr=token)
- self.history.trace.done()
+ if self.history.trace_tag_overflow():
+ raise SwitchToBlackhole(Counters.ABORT_TOO_LONG)
+ self.history.trace.tracing_done()
target_token = compile.compile_trace(self, self.resumekey, [valuebox])
if target_token is not token:
compile.giveup()
diff --git a/rpython/jit/metainterp/test/test_ajit.py b/rpython/jit/metainterp/test/test_ajit.py
--- a/rpython/jit/metainterp/test/test_ajit.py
+++ b/rpython/jit/metainterp/test/test_ajit.py
@@ -4661,3 +4661,36 @@
f() # finishes
self.meta_interp(f, [])
+
+ def test_trace_too_long_bug(self):
+ driver = JitDriver(greens=[], reds=['i'])
+ @unroll_safe
+ def match(s):
+ l = len(s)
+ p = 0
+ for i in range(2500): # produces too long trace
+ c = s[p]
+ if c != 'a':
+ return False
+ p += 1
+ if p >= l:
+ return True
+ c = s[p]
+ if c != '\n':
+ p += 1
+ if p >= l:
+ return True
+ else:
+ return False
+ return True
+
+ def f(i):
+ while i > 0:
+ driver.jit_merge_point(i=i)
+ match('a' * (500 * i))
+ i -= 1
+ return i
+
+ res = self.meta_interp(f, [10])
+ assert res == f(10)
+
diff --git a/rpython/jit/metainterp/test/test_greenfield.py b/rpython/jit/metainterp/test/test_greenfield.py
--- a/rpython/jit/metainterp/test/test_greenfield.py
+++ b/rpython/jit/metainterp/test/test_greenfield.py
@@ -1,6 +1,17 @@
+import pytest
from rpython.jit.metainterp.test.support import LLJitMixin
from rpython.rlib.jit import JitDriver, assert_green
+pytest.skip("this feature is disabled at the moment!")
+
+# note why it is disabled: before d721da4573ad
+# there was a failing assert when inlining python -> sre -> python:
+# https://bitbucket.org/pypy/pypy/issues/2775/
+# this shows, that the interaction of greenfields and virtualizables is broken,
+# because greenfields use MetaInterp.virtualizable_boxes, which confuses
+# MetaInterp._nonstandard_virtualizable somehow (and makes no sense
+# conceptually anyway). to fix greenfields, the two mechanisms would have to be
+# disentangled.
class GreenFieldsTests:
diff --git a/rpython/jit/metainterp/test/test_opencoder.py b/rpython/jit/metainterp/test/test_opencoder.py
--- a/rpython/jit/metainterp/test/test_opencoder.py
+++ b/rpython/jit/metainterp/test/test_opencoder.py
@@ -209,5 +209,8 @@
def test_tag_overflow(self):
t = Trace([], metainterp_sd)
i0 = FakeOp(100000)
- py.test.raises(SwitchToBlackhole, t.record_op, rop.FINISH, [i0])
- assert t.unpack() == ([], [])
+ # if we overflow, we can keep recording
+ for i in range(10):
+ t.record_op(rop.FINISH, [i0])
+ assert t.unpack() == ([], [])
+ assert t.tag_overflow
diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -653,6 +653,9 @@
self._make_extregistryentries()
assert get_jitcell_at is None, "get_jitcell_at no longer used"
assert set_jitcell_at is None, "set_jitcell_at no longer used"
+ for green in self.greens:
+ if "." in green:
+ raise ValueError("green fields are buggy! if you need them fixed, please talk to us")
self.get_printable_location = get_printable_location
self.get_location = get_location
self.has_unique_id = (get_unique_id is not None)
diff --git a/rpython/rlib/rsre/rpy/_sre.py b/rpython/rlib/rsre/rpy/_sre.py
--- a/rpython/rlib/rsre/rpy/_sre.py
+++ b/rpython/rlib/rsre/rpy/_sre.py
@@ -1,4 +1,4 @@
-from rpython.rlib.rsre import rsre_char
+from rpython.rlib.rsre import rsre_char, rsre_core
from rpython.rlib.rarithmetic import intmask
VERSION = "2.7.6"
@@ -12,7 +12,7 @@
pass
def compile(pattern, flags, code, *args):
- raise GotIt([intmask(i) for i in code], flags, args)
+ raise GotIt(rsre_core.CompiledPattern([intmask(i) for i in code]), flags, args)
def get_code(regexp, flags=0, allargs=False):
diff --git a/rpython/rlib/rsre/rsre_char.py b/rpython/rlib/rsre/rsre_char.py
--- a/rpython/rlib/rsre/rsre_char.py
+++ b/rpython/rlib/rsre/rsre_char.py
@@ -152,17 +152,16 @@
##### Charset evaluation
@jit.unroll_safe
-def check_charset(ctx, ppos, char_code):
+def check_charset(ctx, pattern, ppos, char_code):
"""Checks whether a character matches set of arbitrary length.
The set starts at pattern[ppos]."""
negated = False
result = False
- pattern = ctx.pattern
while True:
- opcode = pattern[ppos]
+ opcode = pattern.pattern[ppos]
for i, function in set_dispatch_unroll:
if opcode == i:
- newresult, ppos = function(ctx, ppos, char_code)
+ newresult, ppos = function(ctx, pattern, ppos, char_code)
result |= newresult
break
else:
@@ -177,50 +176,44 @@
return not result
return result
-def set_literal(ctx, index, char_code):
+def set_literal(ctx, pattern, index, char_code):
# <LITERAL> <code>
- pat = ctx.pattern
- match = pat[index+1] == char_code
+ match = pattern.pattern[index+1] == char_code
return match, index + 2
-def set_category(ctx, index, char_code):
+def set_category(ctx, pattern, index, char_code):
# <CATEGORY> <code>
- pat = ctx.pattern
- match = category_dispatch(pat[index+1], char_code)
+ match = category_dispatch(pattern.pattern[index+1], char_code)
return match, index + 2
-def set_charset(ctx, index, char_code):
+def set_charset(ctx, pattern, index, char_code):
# <CHARSET> <bitmap> (16 bits per code word)
- pat = ctx.pattern
if CODESIZE == 2:
match = char_code < 256 and \
- (pat[index+1+(char_code >> 4)] & (1 << (char_code & 15)))
+ (pattern.pattern[index+1+(char_code >> 4)] & (1 << (char_code & 15)))
return match, index + 17 # skip bitmap
else:
match = char_code < 256 and \
- (pat[index+1+(char_code >> 5)] & (1 << (char_code & 31)))
+ (pattern.pattern[index+1+(char_code >> 5)] & (1 << (char_code & 31)))
return match, index + 9 # skip bitmap
-def set_range(ctx, index, char_code):
+def set_range(ctx, pattern, index, char_code):
# <RANGE> <lower> <upper>
- pat = ctx.pattern
- match = int_between(pat[index+1], char_code, pat[index+2] + 1)
+ match = int_between(pattern.pattern[index+1], char_code, pattern.pattern[index+2] + 1)
return match, index + 3
-def set_range_ignore(ctx, index, char_code):
+def set_range_ignore(ctx, pattern, index, char_code):
# <RANGE_IGNORE> <lower> <upper>
# the char_code is already lower cased
- pat = ctx.pattern
- lower = pat[index + 1]
- upper = pat[index + 2]
+ lower = pattern.pattern[index + 1]
+ upper = pattern.pattern[index + 2]
match1 = int_between(lower, char_code, upper + 1)
match2 = int_between(lower, getupper(char_code, ctx.flags), upper + 1)
return match1 | match2, index + 3
-def set_bigcharset(ctx, index, char_code):
+def set_bigcharset(ctx, pattern, index, char_code):
# <BIGCHARSET> <blockcount> <256 blockindices> <blocks>
- pat = ctx.pattern
- count = pat[index+1]
+ count = pattern.pattern[index+1]
index += 2
if CODESIZE == 2:
@@ -238,7 +231,7 @@
return False, index
shift = 5
- block = pat[index + (char_code >> (shift + 5))]
+ block = pattern.pattern[index + (char_code >> (shift + 5))]
block_shift = char_code >> 5
if BIG_ENDIAN:
@@ -247,23 +240,22 @@
block = (block >> block_shift) & 0xFF
index += 256 / CODESIZE
- block_value = pat[index+(block * (32 / CODESIZE)
+ block_value = pattern.pattern[index+(block * (32 / CODESIZE)
+ ((char_code & 255) >> shift))]
match = (block_value & (1 << (char_code & ((8 * CODESIZE) - 1))))
index += count * (32 / CODESIZE) # skip blocks
return match, index
-def set_unicode_general_category(ctx, index, char_code):
+def set_unicode_general_category(ctx, pattern, index, char_code):
# Unicode "General category property code" (not used by Python).
- # A general category is two letters. 'pat[index+1]' contains both
+ # A general category is two letters. 'pattern.pattern[index+1]' contains both
# the first character, and the second character shifted by 8.
# http://en.wikipedia.org/wiki/Unicode_character_property#General_Category
# Also supports single-character categories, if the second character is 0.
# Negative matches are triggered by bit number 7.
assert unicodedb is not None
cat = unicodedb.category(char_code)
- pat = ctx.pattern
- category_code = pat[index + 1]
+ category_code = pattern.pattern[index + 1]
first_character = category_code & 0x7F
second_character = (category_code >> 8) & 0x7F
negative_match = category_code & 0x80
diff --git a/rpython/rlib/rsre/rsre_core.py b/rpython/rlib/rsre/rsre_core.py
--- a/rpython/rlib/rsre/rsre_core.py
+++ b/rpython/rlib/rsre/rsre_core.py
@@ -83,35 +83,19 @@
def __init__(self, msg):
self.msg = msg
-class AbstractMatchContext(object):
- """Abstract base class"""
- _immutable_fields_ = ['pattern[*]', 'flags', 'end']
- match_start = 0
- match_end = 0
- match_marks = None
- match_marks_flat = None
- fullmatch_only = False
- def __init__(self, pattern, match_start, end, flags):
- # 'match_start' and 'end' must be known to be non-negative
- # and they must not be more than len(string).
- check_nonneg(match_start)
- check_nonneg(end)
+class CompiledPattern(object):
+ _immutable_fields_ = ['pattern[*]']
+
+ def __init__(self, pattern):
self.pattern = pattern
- self.match_start = match_start
- self.end = end
- self.flags = flags
# check we don't get the old value of MAXREPEAT
# during the untranslated tests
if not we_are_translated():
assert 65535 not in pattern
- def reset(self, start):
- self.match_start = start
- self.match_marks = None
- self.match_marks_flat = None
-
def pat(self, index):
+ jit.promote(self)
check_nonneg(index)
result = self.pattern[index]
# Check that we only return non-negative integers from this helper.
@@ -121,6 +105,29 @@
assert result >= 0
return result
+class AbstractMatchContext(object):
+ """Abstract base class"""
+ _immutable_fields_ = ['flags', 'end']
+ match_start = 0
+ match_end = 0
+ match_marks = None
+ match_marks_flat = None
+ fullmatch_only = False
+
+ def __init__(self, match_start, end, flags):
+ # 'match_start' and 'end' must be known to be non-negative
+ # and they must not be more than len(string).
+ check_nonneg(match_start)
+ check_nonneg(end)
+ self.match_start = match_start
+ self.end = end
+ self.flags = flags
+
+ def reset(self, start):
+ self.match_start = start
+ self.match_marks = None
+ self.match_marks_flat = None
+
@not_rpython
def str(self, index):
"""Must be overridden in a concrete subclass.
@@ -183,8 +190,8 @@
_immutable_fields_ = ["_buffer"]
- def __init__(self, pattern, buf, match_start, end, flags):
- AbstractMatchContext.__init__(self, pattern, match_start, end, flags)
+ def __init__(self, buf, match_start, end, flags):
+ AbstractMatchContext.__init__(self, match_start, end, flags)
self._buffer = buf
def str(self, index):
@@ -196,7 +203,7 @@
return rsre_char.getlower(c, self.flags)
def fresh_copy(self, start):
- return BufMatchContext(self.pattern, self._buffer, start,
+ return BufMatchContext(self._buffer, start,
self.end, self.flags)
class StrMatchContext(AbstractMatchContext):
@@ -204,8 +211,8 @@
_immutable_fields_ = ["_string"]
- def __init__(self, pattern, string, match_start, end, flags):
- AbstractMatchContext.__init__(self, pattern, match_start, end, flags)
+ def __init__(self, string, match_start, end, flags):
+ AbstractMatchContext.__init__(self, match_start, end, flags)
self._string = string
if not we_are_translated() and isinstance(string, unicode):
self.flags |= rsre_char.SRE_FLAG_UNICODE # for rsre_re.py
@@ -219,7 +226,7 @@
return rsre_char.getlower(c, self.flags)
def fresh_copy(self, start):
- return StrMatchContext(self.pattern, self._string, start,
+ return StrMatchContext(self._string, start,
self.end, self.flags)
class UnicodeMatchContext(AbstractMatchContext):
@@ -227,8 +234,8 @@
_immutable_fields_ = ["_unicodestr"]
- def __init__(self, pattern, unicodestr, match_start, end, flags):
- AbstractMatchContext.__init__(self, pattern, match_start, end, flags)
+ def __init__(self, unicodestr, match_start, end, flags):
+ AbstractMatchContext.__init__(self, match_start, end, flags)
self._unicodestr = unicodestr
def str(self, index):
@@ -240,7 +247,7 @@
return rsre_char.getlower(c, self.flags)
def fresh_copy(self, start):
- return UnicodeMatchContext(self.pattern, self._unicodestr, start,
+ return UnicodeMatchContext(self._unicodestr, start,
self.end, self.flags)
# ____________________________________________________________
@@ -265,16 +272,16 @@
class MatchResult(object):
subresult = None
- def move_to_next_result(self, ctx):
+ def move_to_next_result(self, ctx, pattern):
# returns either 'self' or None
result = self.subresult
if result is None:
return
- if result.move_to_next_result(ctx):
+ if result.move_to_next_result(ctx, pattern):
return self
- return self.find_next_result(ctx)
+ return self.find_next_result(ctx, pattern)
- def find_next_result(self, ctx):
+ def find_next_result(self, ctx, pattern):
raise NotImplementedError
MATCHED_OK = MatchResult()
@@ -287,11 +294,11 @@
self.start_marks = marks
@jit.unroll_safe
- def find_first_result(self, ctx):
+ def find_first_result(self, ctx, pattern):
ppos = jit.hint(self.ppos, promote=True)
- while ctx.pat(ppos):
- result = sre_match(ctx, ppos + 1, self.start_ptr, self.start_marks)
- ppos += ctx.pat(ppos)
+ while pattern.pat(ppos):
+ result = sre_match(ctx, pattern, ppos + 1, self.start_ptr, self.start_marks)
+ ppos += pattern.pat(ppos)
if result is not None:
self.subresult = result
self.ppos = ppos
@@ -300,7 +307,7 @@
class RepeatOneMatchResult(MatchResult):
install_jitdriver('RepeatOne',
- greens=['nextppos', 'ctx.pattern'],
+ greens=['nextppos', 'pattern'],
reds=['ptr', 'self', 'ctx'],
debugprint=(1, 0)) # indices in 'greens'
@@ -310,13 +317,14 @@
self.start_ptr = ptr
self.start_marks = marks
- def find_first_result(self, ctx):
+ def find_first_result(self, ctx, pattern):
ptr = self.start_ptr
nextppos = self.nextppos
while ptr >= self.minptr:
ctx.jitdriver_RepeatOne.jit_merge_point(
- self=self, ptr=ptr, ctx=ctx, nextppos=nextppos)
- result = sre_match(ctx, nextppos, ptr, self.start_marks)
+ self=self, ptr=ptr, ctx=ctx, nextppos=nextppos,
+ pattern=pattern)
+ result = sre_match(ctx, pattern, nextppos, ptr, self.start_marks)
ptr -= 1
if result is not None:
self.subresult = result
@@ -327,7 +335,7 @@
class MinRepeatOneMatchResult(MatchResult):
install_jitdriver('MinRepeatOne',
- greens=['nextppos', 'ppos3', 'ctx.pattern'],
+ greens=['nextppos', 'ppos3', 'pattern'],
reds=['ptr', 'self', 'ctx'],
debugprint=(2, 0)) # indices in 'greens'
@@ -338,39 +346,40 @@
self.start_ptr = ptr
self.start_marks = marks
- def find_first_result(self, ctx):
+ def find_first_result(self, ctx, pattern):
ptr = self.start_ptr
nextppos = self.nextppos
ppos3 = self.ppos3
while ptr <= self.maxptr:
ctx.jitdriver_MinRepeatOne.jit_merge_point(
- self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, ppos3=ppos3)
- result = sre_match(ctx, nextppos, ptr, self.start_marks)
+ self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, ppos3=ppos3,
+ pattern=pattern)
+ result = sre_match(ctx, pattern, nextppos, ptr, self.start_marks)
if result is not None:
self.subresult = result
self.start_ptr = ptr
return self
- if not self.next_char_ok(ctx, ptr, ppos3):
+ if not self.next_char_ok(ctx, pattern, ptr, ppos3):
break
ptr += 1
- def find_next_result(self, ctx):
+ def find_next_result(self, ctx, pattern):
ptr = self.start_ptr
- if not self.next_char_ok(ctx, ptr, self.ppos3):
+ if not self.next_char_ok(ctx, pattern, ptr, self.ppos3):
return
self.start_ptr = ptr + 1
- return self.find_first_result(ctx)
+ return self.find_first_result(ctx, pattern)
- def next_char_ok(self, ctx, ptr, ppos):
+ def next_char_ok(self, ctx, pattern, ptr, ppos):
if ptr == ctx.end:
return False
- op = ctx.pat(ppos)
+ op = pattern.pat(ppos)
for op1, checkerfn in unroll_char_checker:
if op1 == op:
- return checkerfn(ctx, ptr, ppos)
+ return checkerfn(ctx, pattern, ptr, ppos)
# obscure case: it should be a single char pattern, but isn't
# one of the opcodes in unroll_char_checker (see test_ext_opcode)
- return sre_match(ctx, ppos, ptr, self.start_marks) is not None
+ return sre_match(ctx, pattern, ppos, ptr, self.start_marks) is not None
class AbstractUntilMatchResult(MatchResult):
@@ -391,17 +400,17 @@
class MaxUntilMatchResult(AbstractUntilMatchResult):
install_jitdriver('MaxUntil',
- greens=['ppos', 'tailppos', 'match_more', 'ctx.pattern'],
+ greens=['ppos', 'tailppos', 'match_more', 'pattern'],
reds=['ptr', 'marks', 'self', 'ctx'],
debugprint=(3, 0, 2))
- def find_first_result(self, ctx):
- return self.search_next(ctx, match_more=True)
+ def find_first_result(self, ctx, pattern):
+ return self.search_next(ctx, pattern, match_more=True)
- def find_next_result(self, ctx):
- return self.search_next(ctx, match_more=False)
+ def find_next_result(self, ctx, pattern):
+ return self.search_next(ctx, pattern, match_more=False)
- def search_next(self, ctx, match_more):
+ def search_next(self, ctx, pattern, match_more):
ppos = self.ppos
tailppos = self.tailppos
ptr = self.cur_ptr
@@ -409,12 +418,13 @@
while True:
ctx.jitdriver_MaxUntil.jit_merge_point(
ppos=ppos, tailppos=tailppos, match_more=match_more,
- ptr=ptr, marks=marks, self=self, ctx=ctx)
+ ptr=ptr, marks=marks, self=self, ctx=ctx,
+ pattern=pattern)
if match_more:
- max = ctx.pat(ppos+2)
+ max = pattern.pat(ppos+2)
if max == rsre_char.MAXREPEAT or self.num_pending < max:
# try to match one more 'item'
- enum = sre_match(ctx, ppos + 3, ptr, marks)
+ enum = sre_match(ctx, pattern, ppos + 3, ptr, marks)
else:
enum = None # 'max' reached, no more matches
else:
@@ -425,9 +435,9 @@
self.num_pending -= 1
ptr = p.ptr
marks = p.marks
- enum = p.enum.move_to_next_result(ctx)
+ enum = p.enum.move_to_next_result(ctx, pattern)
#
- min = ctx.pat(ppos+1)
+ min = pattern.pat(ppos+1)
if enum is not None:
# matched one more 'item'. record it and continue.
last_match_length = ctx.match_end - ptr
@@ -447,7 +457,7 @@
# 'item' no longer matches.
if self.num_pending >= min:
# try to match 'tail' if we have enough 'item'
- result = sre_match(ctx, tailppos, ptr, marks)
+ result = sre_match(ctx, pattern, tailppos, ptr, marks)
if result is not None:
self.subresult = result
self.cur_ptr = ptr
@@ -457,23 +467,23 @@
class MinUntilMatchResult(AbstractUntilMatchResult):
- def find_first_result(self, ctx):
- return self.search_next(ctx, resume=False)
+ def find_first_result(self, ctx, pattern):
+ return self.search_next(ctx, pattern, resume=False)
- def find_next_result(self, ctx):
- return self.search_next(ctx, resume=True)
+ def find_next_result(self, ctx, pattern):
+ return self.search_next(ctx, pattern, resume=True)
- def search_next(self, ctx, resume):
+ def search_next(self, ctx, pattern, resume):
# XXX missing jit support here
ppos = self.ppos
- min = ctx.pat(ppos+1)
- max = ctx.pat(ppos+2)
+ min = pattern.pat(ppos+1)
+ max = pattern.pat(ppos+2)
ptr = self.cur_ptr
marks = self.cur_marks
while True:
# try to match 'tail' if we have enough 'item'
if not resume and self.num_pending >= min:
- result = sre_match(ctx, self.tailppos, ptr, marks)
+ result = sre_match(ctx, pattern, self.tailppos, ptr, marks)
if result is not None:
self.subresult = result
self.cur_ptr = ptr
@@ -483,12 +493,12 @@
if max == rsre_char.MAXREPEAT or self.num_pending < max:
# try to match one more 'item'
- enum = sre_match(ctx, ppos + 3, ptr, marks)
+ enum = sre_match(ctx, pattern, ppos + 3, ptr, marks)
#
# zero-width match protection
if self.num_pending >= min:
while enum is not None and ptr == ctx.match_end:
- enum = enum.move_to_next_result(ctx)
+ enum = enum.move_to_next_result(ctx, pattern)
else:
enum = None # 'max' reached, no more matches
@@ -502,7 +512,7 @@
self.num_pending -= 1
ptr = p.ptr
marks = p.marks
- enum = p.enum.move_to_next_result(ctx)
+ enum = p.enum.move_to_next_result(ctx, pattern)
# matched one more 'item'. record it and continue
self.pending = Pending(ptr, marks, enum, self.pending)
@@ -514,13 +524,13 @@
@specializectx
@jit.unroll_safe
-def sre_match(ctx, ppos, ptr, marks):
+def sre_match(ctx, pattern, ppos, ptr, marks):
"""Returns either None or a MatchResult object. Usually we only need
the first result, but there is the case of REPEAT...UNTIL where we
need all results; in that case we use the method move_to_next_result()
of the MatchResult."""
while True:
- op = ctx.pat(ppos)
+ op = pattern.pat(ppos)
ppos += 1
#jit.jit_debug("sre_match", op, ppos, ptr)
@@ -563,33 +573,33 @@
elif op == OPCODE_ASSERT:
# assert subpattern
# <ASSERT> <0=skip> <1=back> <pattern>
- ptr1 = ptr - ctx.pat(ppos+1)
+ ptr1 = ptr - pattern.pat(ppos+1)
saved = ctx.fullmatch_only
ctx.fullmatch_only = False
- stop = ptr1 < 0 or sre_match(ctx, ppos + 2, ptr1, marks) is None
+ stop = ptr1 < 0 or sre_match(ctx, pattern, ppos + 2, ptr1, marks) is None
ctx.fullmatch_only = saved
if stop:
return
marks = ctx.match_marks
- ppos += ctx.pat(ppos)
+ ppos += pattern.pat(ppos)
elif op == OPCODE_ASSERT_NOT:
# assert not subpattern
# <ASSERT_NOT> <0=skip> <1=back> <pattern>
- ptr1 = ptr - ctx.pat(ppos+1)
+ ptr1 = ptr - pattern.pat(ppos+1)
saved = ctx.fullmatch_only
ctx.fullmatch_only = False
- stop = (ptr1 >= 0 and sre_match(ctx, ppos + 2, ptr1, marks)
+ stop = (ptr1 >= 0 and sre_match(ctx, pattern, ppos + 2, ptr1, marks)
is not None)
ctx.fullmatch_only = saved
if stop:
return
- ppos += ctx.pat(ppos)
+ ppos += pattern.pat(ppos)
elif op == OPCODE_AT:
# match at given position (e.g. at beginning, at boundary, etc.)
# <AT> <code>
- if not sre_at(ctx, ctx.pat(ppos), ptr):
+ if not sre_at(ctx, pattern.pat(ppos), ptr):
return
ppos += 1
@@ -597,14 +607,14 @@
# alternation
# <BRANCH> <0=skip> code <JUMP> ... <NULL>
result = BranchMatchResult(ppos, ptr, marks)
- return result.find_first_result(ctx)
+ return result.find_first_result(ctx, pattern)
elif op == OPCODE_CATEGORY:
# seems to be never produced, but used by some tests from
# pypy/module/_sre/test
# <CATEGORY> <category>
if (ptr == ctx.end or
- not rsre_char.category_dispatch(ctx.pat(ppos), ctx.str(ptr))):
+ not rsre_char.category_dispatch(pattern.pat(ppos), ctx.str(ptr))):
return
ptr += 1
ppos += 1
@@ -612,7 +622,7 @@
elif op == OPCODE_GROUPREF:
# match backreference
# <GROUPREF> <groupnum>
- startptr, length = get_group_ref(marks, ctx.pat(ppos))
+ startptr, length = get_group_ref(marks, pattern.pat(ppos))
if length < 0:
return # group was not previously defined
if not match_repeated(ctx, ptr, startptr, length):
@@ -623,7 +633,7 @@
elif op == OPCODE_GROUPREF_IGNORE:
# match backreference
# <GROUPREF> <groupnum>
- startptr, length = get_group_ref(marks, ctx.pat(ppos))
+ startptr, length = get_group_ref(marks, pattern.pat(ppos))
if length < 0:
return # group was not previously defined
if not match_repeated_ignore(ctx, ptr, startptr, length):
@@ -634,44 +644,44 @@
elif op == OPCODE_GROUPREF_EXISTS:
# conditional match depending on the existence of a group
# <GROUPREF_EXISTS> <group> <skip> codeyes <JUMP> codeno ...
- _, length = get_group_ref(marks, ctx.pat(ppos))
+ _, length = get_group_ref(marks, pattern.pat(ppos))
if length >= 0:
ppos += 2 # jump to 'codeyes'
else:
- ppos += ctx.pat(ppos+1) # jump to 'codeno'
+ ppos += pattern.pat(ppos+1) # jump to 'codeno'
elif op == OPCODE_IN:
# match set member (or non_member)
# <IN> <skip> <set>
- if ptr >= ctx.end or not rsre_char.check_charset(ctx, ppos+1,
+ if ptr >= ctx.end or not rsre_char.check_charset(ctx, pattern, ppos+1,
ctx.str(ptr)):
return
- ppos += ctx.pat(ppos)
+ ppos += pattern.pat(ppos)
ptr += 1
elif op == OPCODE_IN_IGNORE:
# match set member (or non_member), ignoring case
# <IN> <skip> <set>
- if ptr >= ctx.end or not rsre_char.check_charset(ctx, ppos+1,
+ if ptr >= ctx.end or not rsre_char.check_charset(ctx, pattern, ppos+1,
ctx.lowstr(ptr)):
return
- ppos += ctx.pat(ppos)
+ ppos += pattern.pat(ppos)
ptr += 1
elif op == OPCODE_INFO:
# optimization info block
# <INFO> <0=skip> <1=flags> <2=min> ...
- if (ctx.end - ptr) < ctx.pat(ppos+2):
+ if (ctx.end - ptr) < pattern.pat(ppos+2):
return
- ppos += ctx.pat(ppos)
+ ppos += pattern.pat(ppos)
elif op == OPCODE_JUMP:
- ppos += ctx.pat(ppos)
+ ppos += pattern.pat(ppos)
elif op == OPCODE_LITERAL:
# match literal string
# <LITERAL> <code>
- if ptr >= ctx.end or ctx.str(ptr) != ctx.pat(ppos):
+ if ptr >= ctx.end or ctx.str(ptr) != pattern.pat(ppos):
return
ppos += 1
ptr += 1
@@ -679,7 +689,7 @@
elif op == OPCODE_LITERAL_IGNORE:
# match literal string, ignoring case
# <LITERAL_IGNORE> <code>
- if ptr >= ctx.end or ctx.lowstr(ptr) != ctx.pat(ppos):
+ if ptr >= ctx.end or ctx.lowstr(ptr) != pattern.pat(ppos):
return
ppos += 1
ptr += 1
@@ -687,14 +697,14 @@
elif op == OPCODE_MARK:
# set mark
# <MARK> <gid>
- gid = ctx.pat(ppos)
+ gid = pattern.pat(ppos)
marks = Mark(gid, ptr, marks)
ppos += 1
elif op == OPCODE_NOT_LITERAL:
# match if it's not a literal string
# <NOT_LITERAL> <code>
- if ptr >= ctx.end or ctx.str(ptr) == ctx.pat(ppos):
+ if ptr >= ctx.end or ctx.str(ptr) == pattern.pat(ppos):
return
ppos += 1
ptr += 1
@@ -702,7 +712,7 @@
elif op == OPCODE_NOT_LITERAL_IGNORE:
# match if it's not a literal string, ignoring case
# <NOT_LITERAL> <code>
- if ptr >= ctx.end or ctx.lowstr(ptr) == ctx.pat(ppos):
+ if ptr >= ctx.end or ctx.lowstr(ptr) == pattern.pat(ppos):
return
ppos += 1
ptr += 1
@@ -715,22 +725,22 @@
# decode the later UNTIL operator to see if it is actually
# a MAX_UNTIL or MIN_UNTIL
- untilppos = ppos + ctx.pat(ppos)
+ untilppos = ppos + pattern.pat(ppos)
tailppos = untilppos + 1
- op = ctx.pat(untilppos)
+ op = pattern.pat(untilppos)
if op == OPCODE_MAX_UNTIL:
# the hard case: we have to match as many repetitions as
# possible, followed by the 'tail'. we do this by
# remembering each state for each possible number of
# 'item' matching.
result = MaxUntilMatchResult(ppos, tailppos, ptr, marks)
- return result.find_first_result(ctx)
+ return result.find_first_result(ctx, pattern)
elif op == OPCODE_MIN_UNTIL:
# first try to match the 'tail', and if it fails, try
# to match one more 'item' and try again
result = MinUntilMatchResult(ppos, tailppos, ptr, marks)
- return result.find_first_result(ctx)
+ return result.find_first_result(ctx, pattern)
else:
raise Error("missing UNTIL after REPEAT")
@@ -743,17 +753,18 @@
# use the MAX_REPEAT operator.
# <REPEAT_ONE> <skip> <1=min> <2=max> item <SUCCESS> tail
start = ptr
- minptr = start + ctx.pat(ppos+1)
+ minptr = start + pattern.pat(ppos+1)
if minptr > ctx.end:
return # cannot match
- ptr = find_repetition_end(ctx, ppos+3, start, ctx.pat(ppos+2),
+ ptr = find_repetition_end(ctx, pattern, ppos+3, start,
+ pattern.pat(ppos+2),
marks)
# when we arrive here, ptr points to the tail of the target
# string. check if the rest of the pattern matches,
# and backtrack if not.
- nextppos = ppos + ctx.pat(ppos)
+ nextppos = ppos + pattern.pat(ppos)
result = RepeatOneMatchResult(nextppos, minptr, ptr, marks)
- return result.find_first_result(ctx)
+ return result.find_first_result(ctx, pattern)
elif op == OPCODE_MIN_REPEAT_ONE:
# match repeated sequence (minimizing regexp).
@@ -763,26 +774,26 @@
# use the MIN_REPEAT operator.
# <MIN_REPEAT_ONE> <skip> <1=min> <2=max> item <SUCCESS> tail
start = ptr
- min = ctx.pat(ppos+1)
+ min = pattern.pat(ppos+1)
if min > 0:
minptr = ptr + min
if minptr > ctx.end:
return # cannot match
# count using pattern min as the maximum
- ptr = find_repetition_end(ctx, ppos+3, ptr, min, marks)
+ ptr = find_repetition_end(ctx, pattern, ppos+3, ptr, min, marks)
if ptr < minptr:
return # did not match minimum number of times
maxptr = ctx.end
- max = ctx.pat(ppos+2)
+ max = pattern.pat(ppos+2)
if max != rsre_char.MAXREPEAT:
maxptr1 = start + max
if maxptr1 <= maxptr:
maxptr = maxptr1
- nextppos = ppos + ctx.pat(ppos)
+ nextppos = ppos + pattern.pat(ppos)
result = MinRepeatOneMatchResult(nextppos, ppos+3, maxptr,
ptr, marks)
- return result.find_first_result(ctx)
+ return result.find_first_result(ctx, pattern)
else:
raise Error("bad pattern code %d" % op)
@@ -816,7 +827,7 @@
return True
@specializectx
-def find_repetition_end(ctx, ppos, ptr, maxcount, marks):
+def find_repetition_end(ctx, pattern, ppos, ptr, maxcount, marks):
end = ctx.end
ptrp1 = ptr + 1
# First get rid of the cases where we don't have room for any match.
@@ -826,16 +837,16 @@
# The idea is to be fast for cases like re.search("b+"), where we expect
# the common case to be a non-match. It's much faster with the JIT to
# have the non-match inlined here rather than detect it in the fre() call.
- op = ctx.pat(ppos)
+ op = pattern.pat(ppos)
for op1, checkerfn in unroll_char_checker:
if op1 == op:
- if checkerfn(ctx, ptr, ppos):
+ if checkerfn(ctx, pattern, ptr, ppos):
break
return ptr
else:
# obscure case: it should be a single char pattern, but isn't
# one of the opcodes in unroll_char_checker (see test_ext_opcode)
- return general_find_repetition_end(ctx, ppos, ptr, maxcount, marks)
+ return general_find_repetition_end(ctx, pattern, ppos, ptr, maxcount, marks)
# It matches at least once. If maxcount == 1 (relatively common),
# then we are done.
if maxcount == 1:
@@ -846,14 +857,14 @@
end1 = ptr + maxcount
if end1 <= end:
end = end1
- op = ctx.pat(ppos)
+ op = pattern.pat(ppos)
for op1, fre in unroll_fre_checker:
if op1 == op:
- return fre(ctx, ptrp1, end, ppos)
+ return fre(ctx, pattern, ptrp1, end, ppos)
raise Error("rsre.find_repetition_end[%d]" % op)
@specializectx
-def general_find_repetition_end(ctx, ppos, ptr, maxcount, marks):
+def general_find_repetition_end(ctx, patern, ppos, ptr, maxcount, marks):
# moved into its own JIT-opaque function
end = ctx.end
if maxcount != rsre_char.MAXREPEAT:
@@ -861,63 +872,65 @@
end1 = ptr + maxcount
if end1 <= end:
end = end1
- while ptr < end and sre_match(ctx, ppos, ptr, marks) is not None:
+ while ptr < end and sre_match(ctx, patern, ppos, ptr, marks) is not None:
ptr += 1
return ptr
@specializectx
-def match_ANY(ctx, ptr, ppos): # dot wildcard.
+def match_ANY(ctx, pattern, ptr, ppos): # dot wildcard.
return not rsre_char.is_linebreak(ctx.str(ptr))
-def match_ANY_ALL(ctx, ptr, ppos):
+def match_ANY_ALL(ctx, pattern, ptr, ppos):
return True # match anything (including a newline)
@specializectx
-def match_IN(ctx, ptr, ppos):
- return rsre_char.check_charset(ctx, ppos+2, ctx.str(ptr))
+def match_IN(ctx, pattern, ptr, ppos):
+ return rsre_char.check_charset(ctx, pattern, ppos+2, ctx.str(ptr))
@specializectx
-def match_IN_IGNORE(ctx, ptr, ppos):
- return rsre_char.check_charset(ctx, ppos+2, ctx.lowstr(ptr))
+def match_IN_IGNORE(ctx, pattern, ptr, ppos):
+ return rsre_char.check_charset(ctx, pattern, ppos+2, ctx.lowstr(ptr))
@specializectx
-def match_LITERAL(ctx, ptr, ppos):
- return ctx.str(ptr) == ctx.pat(ppos+1)
+def match_LITERAL(ctx, pattern, ptr, ppos):
+ return ctx.str(ptr) == pattern.pat(ppos+1)
@specializectx
-def match_LITERAL_IGNORE(ctx, ptr, ppos):
- return ctx.lowstr(ptr) == ctx.pat(ppos+1)
+def match_LITERAL_IGNORE(ctx, pattern, ptr, ppos):
+ return ctx.lowstr(ptr) == pattern.pat(ppos+1)
@specializectx
-def match_NOT_LITERAL(ctx, ptr, ppos):
- return ctx.str(ptr) != ctx.pat(ppos+1)
+def match_NOT_LITERAL(ctx, pattern, ptr, ppos):
+ return ctx.str(ptr) != pattern.pat(ppos+1)
@specializectx
-def match_NOT_LITERAL_IGNORE(ctx, ptr, ppos):
- return ctx.lowstr(ptr) != ctx.pat(ppos+1)
+def match_NOT_LITERAL_IGNORE(ctx, pattern, ptr, ppos):
+ return ctx.lowstr(ptr) != pattern.pat(ppos+1)
def _make_fre(checkerfn):
if checkerfn == match_ANY_ALL:
- def fre(ctx, ptr, end, ppos):
+ def fre(ctx, pattern, ptr, end, ppos):
return end
elif checkerfn == match_IN:
install_jitdriver_spec('MatchIn',
- greens=['ppos', 'ctx.pattern'],
+ greens=['ppos', 'pattern'],
reds=['ptr', 'end', 'ctx'],
debugprint=(1, 0))
@specializectx
- def fre(ctx, ptr, end, ppos):
+ def fre(ctx, pattern, ptr, end, ppos):
while True:
ctx.jitdriver_MatchIn.jit_merge_point(ctx=ctx, ptr=ptr,
- end=end, ppos=ppos)
- if ptr < end and checkerfn(ctx, ptr, ppos):
+ end=end, ppos=ppos,
+ pattern=pattern)
+ if ptr < end and checkerfn(ctx, pattern, ptr, ppos):
ptr += 1
else:
return ptr
elif checkerfn == match_IN_IGNORE:
install_jitdriver_spec('MatchInIgnore',
- greens=['ppos', 'ctx.pattern'],
+ greens=['ppos', 'pattern'],
reds=['ptr', 'end', 'ctx'],
debugprint=(1, 0))
@specializectx
- def fre(ctx, ptr, end, ppos):
+ def fre(ctx, pattern, ptr, end, ppos):
while True:
ctx.jitdriver_MatchInIgnore.jit_merge_point(ctx=ctx, ptr=ptr,
- end=end, ppos=ppos)
- if ptr < end and checkerfn(ctx, ptr, ppos):
+ end=end, ppos=ppos,
+ pattern=pattern)
+ if ptr < end and checkerfn(ctx, pattern, ptr, ppos):
ptr += 1
else:
return ptr
@@ -925,8 +938,8 @@
# in the other cases, the fre() function is not JITted at all
# and is present as a residual call.
@specializectx
- def fre(ctx, ptr, end, ppos):
- while ptr < end and checkerfn(ctx, ptr, ppos):
+ def fre(ctx, pattern, ptr, end, ppos):
+ while ptr < end and checkerfn(ctx, pattern, ptr, ppos):
ptr += 1
return ptr
fre = func_with_new_name(fre, 'fre_' + checkerfn.__name__)
@@ -1037,10 +1050,11 @@
return start, end
def match(pattern, string, start=0, end=sys.maxint, flags=0, fullmatch=False):
+ assert isinstance(pattern, CompiledPattern)
start, end = _adjust(start, end, len(string))
- ctx = StrMatchContext(pattern, string, start, end, flags)
+ ctx = StrMatchContext(string, start, end, flags)
ctx.fullmatch_only = fullmatch
- if match_context(ctx):
+ if match_context(ctx, pattern):
return ctx
else:
return None
@@ -1049,105 +1063,106 @@
return match(pattern, string, start, end, flags, fullmatch=True)
def search(pattern, string, start=0, end=sys.maxint, flags=0):
+ assert isinstance(pattern, CompiledPattern)
start, end = _adjust(start, end, len(string))
- ctx = StrMatchContext(pattern, string, start, end, flags)
- if search_context(ctx):
+ ctx = StrMatchContext(string, start, end, flags)
+ if search_context(ctx, pattern):
return ctx
else:
return None
install_jitdriver('Match',
- greens=['ctx.pattern'], reds=['ctx'],
+ greens=['pattern'], reds=['ctx'],
debugprint=(0,))
-def match_context(ctx):
+def match_context(ctx, pattern):
ctx.original_pos = ctx.match_start
if ctx.end < ctx.match_start:
return False
- ctx.jitdriver_Match.jit_merge_point(ctx=ctx)
- return sre_match(ctx, 0, ctx.match_start, None) is not None
+ ctx.jitdriver_Match.jit_merge_point(ctx=ctx, pattern=pattern)
+ return sre_match(ctx, pattern, 0, ctx.match_start, None) is not None
-def search_context(ctx):
+def search_context(ctx, pattern):
ctx.original_pos = ctx.match_start
if ctx.end < ctx.match_start:
return False
base = 0
charset = False
- if ctx.pat(base) == OPCODE_INFO:
- flags = ctx.pat(2)
+ if pattern.pat(base) == OPCODE_INFO:
+ flags = pattern.pat(2)
if flags & rsre_char.SRE_INFO_PREFIX:
- if ctx.pat(5) > 1:
- return fast_search(ctx)
+ if pattern.pat(5) > 1:
+ return fast_search(ctx, pattern)
else:
charset = (flags & rsre_char.SRE_INFO_CHARSET)
- base += 1 + ctx.pat(1)
- if ctx.pat(base) == OPCODE_LITERAL:
- return literal_search(ctx, base)
+ base += 1 + pattern.pat(1)
+ if pattern.pat(base) == OPCODE_LITERAL:
+ return literal_search(ctx, pattern, base)
if charset:
- return charset_search(ctx, base)
- return regular_search(ctx, base)
+ return charset_search(ctx, pattern, base)
+ return regular_search(ctx, pattern, base)
install_jitdriver('RegularSearch',
- greens=['base', 'ctx.pattern'],
+ greens=['base', 'pattern'],
reds=['start', 'ctx'],
debugprint=(1, 0))
-def regular_search(ctx, base):
+def regular_search(ctx, pattern, base):
start = ctx.match_start
while start <= ctx.end:
ctx.jitdriver_RegularSearch.jit_merge_point(ctx=ctx, start=start,
- base=base)
- if sre_match(ctx, base, start, None) is not None:
+ base=base, pattern=pattern)
+ if sre_match(ctx, pattern, base, start, None) is not None:
ctx.match_start = start
return True
start += 1
return False
install_jitdriver_spec("LiteralSearch",
- greens=['base', 'character', 'ctx.pattern'],
+ greens=['base', 'character', 'pattern'],
reds=['start', 'ctx'],
debugprint=(2, 0, 1))
@specializectx
-def literal_search(ctx, base):
+def literal_search(ctx, pattern, base):
# pattern starts with a literal character. this is used
# for short prefixes, and if fast search is disabled
- character = ctx.pat(base + 1)
+ character = pattern.pat(base + 1)
base += 2
start = ctx.match_start
while start < ctx.end:
ctx.jitdriver_LiteralSearch.jit_merge_point(ctx=ctx, start=start,
- base=base, character=character)
+ base=base, character=character, pattern=pattern)
if ctx.str(start) == character:
- if sre_match(ctx, base, start + 1, None) is not None:
+ if sre_match(ctx, pattern, base, start + 1, None) is not None:
ctx.match_start = start
return True
start += 1
return False
install_jitdriver_spec("CharsetSearch",
- greens=['base', 'ctx.pattern'],
+ greens=['base', 'pattern'],
reds=['start', 'ctx'],
debugprint=(1, 0))
@specializectx
-def charset_search(ctx, base):
+def charset_search(ctx, pattern, base):
# pattern starts with a character from a known set
start = ctx.match_start
while start < ctx.end:
ctx.jitdriver_CharsetSearch.jit_merge_point(ctx=ctx, start=start,
- base=base)
- if rsre_char.check_charset(ctx, 5, ctx.str(start)):
- if sre_match(ctx, base, start, None) is not None:
+ base=base, pattern=pattern)
+ if rsre_char.check_charset(ctx, pattern, 5, ctx.str(start)):
+ if sre_match(ctx, pattern, base, start, None) is not None:
ctx.match_start = start
return True
start += 1
return False
install_jitdriver_spec('FastSearch',
- greens=['i', 'prefix_len', 'ctx.pattern'],
+ greens=['i', 'prefix_len', 'pattern'],
reds=['string_position', 'ctx'],
debugprint=(2, 0))
@specializectx
-def fast_search(ctx):
+def fast_search(ctx, pattern):
# skips forward in a string as fast as possible using information from
# an optimization info block
# <INFO> <1=skip> <2=flags> <3=min> <4=...>
@@ -1155,17 +1170,18 @@
string_position = ctx.match_start
if string_position >= ctx.end:
return False
- prefix_len = ctx.pat(5)
+ prefix_len = pattern.pat(5)
assert prefix_len >= 0
i = 0
while True:
ctx.jitdriver_FastSearch.jit_merge_point(ctx=ctx,
- string_position=string_position, i=i, prefix_len=prefix_len)
+ string_position=string_position, i=i, prefix_len=prefix_len,
+ pattern=pattern)
char_ord = ctx.str(string_position)
- if char_ord != ctx.pat(7 + i):
+ if char_ord != pattern.pat(7 + i):
if i > 0:
overlap_offset = prefix_len + (7 - 1)
- i = ctx.pat(overlap_offset + i)
+ i = pattern.pat(overlap_offset + i)
continue
else:
i += 1
@@ -1173,22 +1189,22 @@
# found a potential match
start = string_position + 1 - prefix_len
assert start >= 0
- prefix_skip = ctx.pat(6)
+ prefix_skip = pattern.pat(6)
ptr = start + prefix_skip
- #flags = ctx.pat(2)
+ #flags = pattern.pat(2)
#if flags & rsre_char.SRE_INFO_LITERAL:
# # matched all of pure literal pattern
# ctx.match_start = start
# ctx.match_end = ptr
# ctx.match_marks = None
# return True
- pattern_offset = ctx.pat(1) + 1
+ pattern_offset = pattern.pat(1) + 1
ppos_start = pattern_offset + 2 * prefix_skip
- if sre_match(ctx, ppos_start, ptr, None) is not None:
+ if sre_match(ctx, pattern, ppos_start, ptr, None) is not None:
ctx.match_start = start
return True
overlap_offset = prefix_len + (7 - 1)
- i = ctx.pat(overlap_offset + i)
+ i = pattern.pat(overlap_offset + i)
string_position += 1
if string_position >= ctx.end:
return False
diff --git a/rpython/rlib/rsre/test/test_char.py b/rpython/rlib/rsre/test/test_char.py
--- a/rpython/rlib/rsre/test/test_char.py
+++ b/rpython/rlib/rsre/test/test_char.py
@@ -1,10 +1,16 @@
-from rpython.rlib.rsre import rsre_char
+from rpython.rlib.rsre import rsre_char, rsre_core
from rpython.rlib.rsre.rsre_char import SRE_FLAG_LOCALE, SRE_FLAG_UNICODE
def setup_module(mod):
from rpython.rlib.unicodedata import unicodedb
rsre_char.set_unicode_db(unicodedb)
+
+def check_charset(pattern, idx, char):
+ p = rsre_core.CompiledPattern(pattern)
+ return rsre_char.check_charset(Ctx(p), p, idx, char)
+
+
UPPER_PI = 0x3a0
LOWER_PI = 0x3c0
INDIAN_DIGIT = 0x966
@@ -157,12 +163,12 @@
pat_neg = [70, ord(cat) | 0x80, 0]
for c in positive:
assert unicodedb.category(ord(c)).startswith(cat)
- assert rsre_char.check_charset(Ctx(pat_pos), 0, ord(c))
- assert not rsre_char.check_charset(Ctx(pat_neg), 0, ord(c))
+ assert check_charset(pat_pos, 0, ord(c))
+ assert not check_charset(pat_neg, 0, ord(c))
for c in negative:
assert not unicodedb.category(ord(c)).startswith(cat)
- assert not rsre_char.check_charset(Ctx(pat_pos), 0, ord(c))
- assert rsre_char.check_charset(Ctx(pat_neg), 0, ord(c))
+ assert not check_charset(pat_pos, 0, ord(c))
+ assert check_charset(pat_neg, 0, ord(c))
def cat2num(cat):
return ord(cat[0]) | (ord(cat[1]) << 8)
@@ -173,17 +179,16 @@
pat_neg = [70, cat2num(cat) | 0x80, 0]
for c in positive:
assert unicodedb.category(ord(c)) == cat
- assert rsre_char.check_charset(Ctx(pat_pos), 0, ord(c))
- assert not rsre_char.check_charset(Ctx(pat_neg), 0, ord(c))
+ assert check_charset(pat_pos, 0, ord(c))
+ assert not check_charset(pat_neg, 0, ord(c))
for c in negative:
assert unicodedb.category(ord(c)) != cat
- assert not rsre_char.check_charset(Ctx(pat_pos), 0, ord(c))
- assert rsre_char.check_charset(Ctx(pat_neg), 0, ord(c))
+ assert not check_charset(pat_pos, 0, ord(c))
+ assert check_charset(pat_neg, 0, ord(c))
# test for how the common 'L&' pattern might be compiled
pat = [70, cat2num('Lu'), 70, cat2num('Ll'), 70, cat2num('Lt'), 0]
- assert rsre_char.check_charset(Ctx(pat), 0, 65) # Lu
- assert rsre_char.check_charset(Ctx(pat), 0, 99) # Ll
- assert rsre_char.check_charset(Ctx(pat), 0, 453) # Lt
- assert not rsre_char.check_charset(Ctx(pat), 0, 688) # Lm
- assert not rsre_char.check_charset(Ctx(pat), 0, 5870) # Nl
+ assert check_charset(pat, 0, 65) # Lu
+ assert check_charset(pat, 0, 99) # Lcheck_charset(pat, 0, 453) # Lt
+ assert not check_charset(pat, 0, 688) # Lm
+ assert not check_charset(pat, 0, 5870) # Nl
diff --git a/rpython/rlib/rsre/test/test_ext_opcode.py b/rpython/rlib/rsre/test/test_ext_opcode.py
--- a/rpython/rlib/rsre/test/test_ext_opcode.py
+++ b/rpython/rlib/rsre/test/test_ext_opcode.py
@@ -17,10 +17,10 @@
# it's a valid optimization because \1 is always one character long
r = [MARK, 0, ANY, MARK, 1, REPEAT_ONE, 6, 0, MAXREPEAT,
GROUPREF, 0, SUCCESS, SUCCESS]
- assert rsre_core.match(r, "aaa").match_end == 3
+ assert rsre_core.match(rsre_core.CompiledPattern(r), "aaa").match_end == 3
def test_min_repeat_one_with_backref():
# Python 3.5 compiles "(.)\1*?b" using MIN_REPEAT_ONE
r = [MARK, 0, ANY, MARK, 1, MIN_REPEAT_ONE, 6, 0, MAXREPEAT,
GROUPREF, 0, SUCCESS, LITERAL, 98, SUCCESS]
- assert rsre_core.match(r, "aaab").match_end == 4
+ assert rsre_core.match(rsre_core.CompiledPattern(r), "aaab").match_end == 4
diff --git a/rpython/rlib/rsre/test/test_match.py b/rpython/rlib/rsre/test/test_match.py
--- a/rpython/rlib/rsre/test/test_match.py
+++ b/rpython/rlib/rsre/test/test_match.py
@@ -9,7 +9,7 @@
def test_get_code_repetition():
c1 = get_code(r"a+")
c2 = get_code(r"a+")
- assert c1 == c2
+ assert c1.pattern == c2.pattern
class TestMatch:
@@ -305,6 +305,6 @@
rsre_char.set_unicode_db(unicodedb)
#
r = get_code(u"[\U00010428-\U0001044f]", re.I)
- assert r.count(27) == 1 # OPCODE_RANGE
- r[r.index(27)] = 32 # => OPCODE_RANGE_IGNORE
+ assert r.pattern.count(27) == 1 # OPCODE_RANGE
+ r.pattern[r.pattern.index(27)] = 32 # => OPCODE_RANGE_IGNORE
assert rsre_core.match(r, u"\U00010428")
diff --git a/rpython/rlib/rsre/test/test_re.py b/rpython/rlib/rsre/test/test_re.py
--- a/rpython/rlib/rsre/test/test_re.py
+++ b/rpython/rlib/rsre/test/test_re.py
@@ -426,31 +426,6 @@
assert pat.match(p) is not None
assert pat.match(p).span() == (0,256)
- def test_pickling(self):
- import pickle
- self.pickle_test(pickle)
- import cPickle
- self.pickle_test(cPickle)
- # old pickles expect the _compile() reconstructor in sre module
- import warnings
- original_filters = warnings.filters[:]
- try:
- warnings.filterwarnings("ignore", "The sre module is deprecated",
- DeprecationWarning)
- from sre import _compile
- finally:
- warnings.filters = original_filters
-
- def pickle_test(self, pickle):
- oldpat = re.compile('a(?:b|(c|e){1,2}?|d)+?(.)')
- s = pickle.dumps(oldpat)
- newpat = pickle.loads(s)
- # Not using object identity for _sre.py, since some Python builds do
- # not seem to preserve that in all cases (observed on an UCS-4 build
- # of 2.4.1).
- #self.assertEqual(oldpat, newpat)
- assert oldpat.__dict__ == newpat.__dict__
-
def test_constants(self):
assert re.I == re.IGNORECASE
assert re.L == re.LOCALE
diff --git a/rpython/rlib/rsre/test/test_zinterp.py b/rpython/rlib/rsre/test/test_zinterp.py
--- a/rpython/rlib/rsre/test/test_zinterp.py
+++ b/rpython/rlib/rsre/test/test_zinterp.py
@@ -11,6 +11,7 @@
rsre_core.search(pattern, string)
#
unicodestr = unichr(n) * n
+ pattern = rsre_core.CompiledPattern(pattern)
ctx = rsre_core.UnicodeMatchContext(pattern, unicodestr,
0, len(unicodestr), 0)
rsre_core.search_context(ctx)
diff --git a/rpython/rlib/rsre/test/test_zjit.py b/rpython/rlib/rsre/test/test_zjit.py
--- a/rpython/rlib/rsre/test/test_zjit.py
+++ b/rpython/rlib/rsre/test/test_zjit.py
@@ -6,18 +6,20 @@
from rpython.rtyper.annlowlevel import llstr, hlstr
def entrypoint1(r, string, repeat):
- r = array2list(r)
+ r = rsre_core.CompiledPattern(array2list(r))
string = hlstr(string)
match = None
for i in range(repeat):
match = rsre_core.match(r, string)
+ if match is None:
+ return -1
if match is None:
return -1
else:
return match.match_end
def entrypoint2(r, string, repeat):
- r = array2list(r)
+ r = rsre_core.CompiledPattern(array2list(r))
string = hlstr(string)
match = None
for i in range(repeat):
@@ -48,13 +50,13 @@
def meta_interp_match(self, pattern, string, repeat=1):
r = get_code(pattern)
- return self.meta_interp(entrypoint1, [list2array(r), llstr(string),
+ return self.meta_interp(entrypoint1, [list2array(r.pattern), llstr(string),
repeat],
listcomp=True, backendopt=True)
def meta_interp_search(self, pattern, string, repeat=1):
r = get_code(pattern)
- return self.meta_interp(entrypoint2, [list2array(r), llstr(string),
+ return self.meta_interp(entrypoint2, [list2array(r.pattern), llstr(string),
repeat],
listcomp=True, backendopt=True)
@@ -166,3 +168,9 @@
res = self.meta_interp_search(r"b+", "a"*30 + "b")
assert res == 30
self.check_resops(call=0)
+
+ def test_match_jit_bug(self):
+ pattern = ".a" * 2500
+ text = "a" * 6000
+ res = self.meta_interp_match(pattern, text, repeat=10)
+ assert res != -1
diff --git a/rpython/rlib/test/test_jit.py b/rpython/rlib/test/test_jit.py
--- a/rpython/rlib/test/test_jit.py
+++ b/rpython/rlib/test/test_jit.py
@@ -225,8 +225,10 @@
def test_green_field(self):
def get_printable_location(xfoo):
return str(ord(xfoo)) # xfoo must be annotated as a character
- myjitdriver = JitDriver(greens=['x.foo'], reds=['n', 'x'],
+ # green fields are disabled!
+ pytest.raises(ValueError, JitDriver, greens=['x.foo'], reds=['n', 'x'],
get_printable_location=get_printable_location)
+ return
class A(object):
_immutable_fields_ = ['foo']
def fn(n):
More information about the pypy-commit
mailing list