[pypy-svn] r78043 - in pypy/branch/rsre-jit/pypy: jit/codewriter jit/codewriter/test jit/metainterp jit/metainterp/optimizeopt jit/metainterp/test rlib rlib/rsre rlib/rsre/test

arigo at codespeak.net arigo at codespeak.net
Mon Oct 18 15:49:50 CEST 2010


Author: arigo
Date: Mon Oct 18 15:49:48 2010
New Revision: 78043

Added:
   pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_jit.py
Modified:
   pypy/branch/rsre-jit/pypy/jit/codewriter/jtransform.py
   pypy/branch/rsre-jit/pypy/jit/codewriter/test/test_codewriter.py
   pypy/branch/rsre-jit/pypy/jit/codewriter/test/test_jtransform.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/blackhole.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/executor.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/jitdriver.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/optimizeopt/string.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/optimizeopt/virtualize.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/pyjitpl.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/resoperation.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/resume.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_basic.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_greenfield.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_virtualref.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_warmspot.py
   pypy/branch/rsre-jit/pypy/jit/metainterp/warmspot.py
   pypy/branch/rsre-jit/pypy/rlib/jit.py
   pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_char.py
   pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_core.py
   pypy/branch/rsre-jit/pypy/rlib/rsre/test/test_zjit.py
Log:
Main change: allows us to call jitdriver.jit_merge_point()
without ever calling jitdriver.can_enter_jit().  In that case
the jit_merge_point() plays both roles.  The difference with
putting explicitly a can_enter_jit() just before is that such
a can_enter_jit() is not seen unless we are closing a loop;
in particular, it does not work if we have no loop at all.

Tests and fixes for greenfields.

Implement rlib.jit.jit_debug(), which just causes a JIT_DEBUG
resoperation to show up in the trace of the JIT.

Implement rlib.jit.assert_green(), which fails if the JIT considers
that the argument is not a Const.

When not translated, add the 'FORCE' prefix to the name of resoperations
produced by forcing virtualizables.


Modified: pypy/branch/rsre-jit/pypy/jit/codewriter/jtransform.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/codewriter/jtransform.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/codewriter/jtransform.py	Mon Oct 18 15:49:48 2010
@@ -320,6 +320,8 @@
             prepare = self._handle_str2unicode_call
         elif oopspec_name.startswith('virtual_ref'):
             prepare = self._handle_virtual_ref_call
+        elif oopspec_name.startswith('jit.'):
+            prepare = self._handle_jit_call
         else:
             prepare = self.prepare_builtin_call
         try:
@@ -859,6 +861,15 @@
                     (self.graph,))
         return []
 
+    def _handle_jit_call(self, op, oopspec_name, args):
+        if oopspec_name == 'jit.debug':
+            return SpaceOperation('jit_debug', args, None)
+        elif oopspec_name == 'jit.assert_green':
+            kind = getkind(args[0].concretetype)
+            return SpaceOperation('%s_assert_green' % kind, args, None)
+        else:
+            raise AssertionError("missing support for %r" % oopspec_name)
+
     # ----------
     # Lists.
 

Modified: pypy/branch/rsre-jit/pypy/jit/codewriter/test/test_codewriter.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/codewriter/test/test_codewriter.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/codewriter/test/test_codewriter.py	Mon Oct 18 15:49:48 2010
@@ -45,6 +45,7 @@
         self.portal_graph = portal_graph
         self.portal_runner_ptr = "???"
         self.virtualizable_info = None
+        self.greenfield_info = None
 
 
 def test_loop():

Modified: pypy/branch/rsre-jit/pypy/jit/codewriter/test/test_jtransform.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/codewriter/test/test_jtransform.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/codewriter/test/test_jtransform.py	Mon Oct 18 15:49:48 2010
@@ -721,6 +721,45 @@
     assert list(oplist[4].args[4]) == [v3, v4]
     assert oplist[5].opname == '-live-'
 
+def test_getfield_gc():
+    S = lltype.GcStruct('S', ('x', lltype.Char))
+    v1 = varoftype(lltype.Ptr(S))
+    v2 = varoftype(lltype.Char)
+    op = SpaceOperation('getfield', [v1, Constant('x', lltype.Void)], v2)
+    op1 = Transformer(FakeCPU()).rewrite_operation(op)
+    assert op1.opname == 'getfield_gc_i'
+    assert op1.args == [v1, ('fielddescr', S, 'x')]
+    assert op1.result == v2
+
+def test_getfield_gc_pure():
+    S = lltype.GcStruct('S', ('x', lltype.Char),
+                        hints={'immutable': True})
+    v1 = varoftype(lltype.Ptr(S))
+    v2 = varoftype(lltype.Char)
+    op = SpaceOperation('getfield', [v1, Constant('x', lltype.Void)], v2)
+    op1 = Transformer(FakeCPU()).rewrite_operation(op)
+    assert op1.opname == 'getfield_gc_i_pure'
+    assert op1.args == [v1, ('fielddescr', S, 'x')]
+    assert op1.result == v2
+
+def test_getfield_gc_greenfield():
+    class FakeCC:
+        def get_vinfo(self, v):
+            return None
+        def could_be_green_field(self, S1, name1):
+            assert S1 is S
+            assert name1 == 'x'
+            return True
+    S = lltype.GcStruct('S', ('x', lltype.Char),
+                        hints={'immutable': True})
+    v1 = varoftype(lltype.Ptr(S))
+    v2 = varoftype(lltype.Char)
+    op = SpaceOperation('getfield', [v1, Constant('x', lltype.Void)], v2)
+    op1 = Transformer(FakeCPU(), FakeCC()).rewrite_operation(op)
+    assert op1.opname == 'getfield_gc_i_greenfield'
+    assert op1.args == [v1, ('fielddescr', S, 'x')]
+    assert op1.result == v2
+
 def test_int_abs():
     v1 = varoftype(lltype.Signed)
     v2 = varoftype(lltype.Signed)

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/blackhole.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/blackhole.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/blackhole.py	Mon Oct 18 15:49:48 2010
@@ -760,6 +760,20 @@
     def bhimpl_debug_fatalerror(msg):
         llop.debug_fatalerror(lltype.Void, msg)
 
+    @arguments("r", "i", "i", "i", "i")
+    def bhimpl_jit_debug(string, arg1=0, arg2=0, arg3=0, arg4=0):
+        pass
+
+    @arguments("i")
+    def bhimpl_int_assert_green(x):
+        pass
+    @arguments("r")
+    def bhimpl_ref_assert_green(x):
+        pass
+    @arguments("f")
+    def bhimpl_float_assert_green(x):
+        pass
+
     # ----------
     # the main hints and recursive calls
 

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/executor.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/executor.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/executor.py	Mon Oct 18 15:49:48 2010
@@ -304,6 +304,7 @@
                          rop.CALL_ASSEMBLER,
                          rop.COND_CALL_GC_WB,
                          rop.DEBUG_MERGE_POINT,
+                         rop.JIT_DEBUG,
                          rop.SETARRAYITEM_RAW,
                          ):      # list of opcodes never executed by pyjitpl
                 continue

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/jitdriver.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/jitdriver.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/jitdriver.py	Mon Oct 18 15:49:48 2010
@@ -16,6 +16,7 @@
     #    self.greenfield_info   ... pypy.jit.metainterp.warmspot
     #    self.warmstate         ... pypy.jit.metainterp.warmspot
     #    self.handle_jitexc_from_bh pypy.jit.metainterp.warmspot
+    #    self.no_loop_header    ... pypy.jit.metainterp.warmspot
     #    self.portal_finishtoken... pypy.jit.metainterp.pyjitpl
     #    self.index             ... pypy.jit.codewriter.call
     #    self.mainjitcode       ... pypy.jit.codewriter.call

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/optimizeopt/string.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/optimizeopt/string.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/optimizeopt/string.py	Mon Oct 18 15:49:48 2010
@@ -12,7 +12,7 @@
 from pypy.jit.codewriter.effectinfo import EffectInfo, callinfo_for_oopspec
 from pypy.jit.codewriter import heaptracker
 from pypy.rlib.unroll import unrolling_iterable
-from pypy.rlib.objectmodel import specialize
+from pypy.rlib.objectmodel import specialize, we_are_translated
 
 
 class StrOrUnicode(object):
@@ -107,7 +107,10 @@
         self.box = box = self.source_op.result
         newoperations = self.optimizer.newoperations
         lengthbox = self.getstrlen(newoperations, self.mode)
-        newoperations.append(ResOperation(self.mode.NEWSTR, [lengthbox], box))
+        op = ResOperation(self.mode.NEWSTR, [lengthbox], box)
+        if not we_are_translated():
+            op.name = 'FORCE'
+        newoperations.append(op)
         self.string_copy_parts(newoperations, box, CONST_0, self.mode)
 
 

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/optimizeopt/virtualize.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/optimizeopt/virtualize.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/optimizeopt/virtualize.py	Mon Oct 18 15:49:48 2010
@@ -74,6 +74,8 @@
         assert self.source_op is not None
         # ^^^ This case should not occur any more (see test_bug_3).
         #
+        if not we_are_translated():
+            self.source_op.name = 'FORCE ' + self.source_op.name
         newoperations = self.optimizer.newoperations
         newoperations.append(self.source_op)
         self.box = box = self.source_op.result
@@ -165,6 +167,8 @@
 
     def _really_force(self):
         assert self.source_op is not None
+        if not we_are_translated():
+            self.source_op.name = 'FORCE ' + self.source_op.name
         newoperations = self.optimizer.newoperations
         newoperations.append(self.source_op)
         self.box = box = self.source_op.result

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/pyjitpl.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/pyjitpl.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/pyjitpl.py	Mon Oct 18 15:49:48 2010
@@ -1,4 +1,4 @@
-import py, os
+import py, os, sys
 from pypy.rpython.lltypesystem import lltype, llmemory, rclass
 from pypy.rlib.objectmodel import we_are_translated
 from pypy.rlib.unroll import unrolling_iterable
@@ -816,12 +816,16 @@
 
     @arguments("orgpc", "int", "boxes3", "boxes3")
     def opimpl_jit_merge_point(self, orgpc, jdindex, greenboxes, redboxes):
+        any_operation = len(self.metainterp.history.operations) > 0
         jitdriver_sd = self.metainterp.staticdata.jitdrivers_sd[jdindex]
         self.verify_green_args(jitdriver_sd, greenboxes)
         # xxx we may disable the following line in some context later
         self.debug_merge_point(jitdriver_sd, greenboxes)
         if self.metainterp.seen_loop_header_for_jdindex < 0:
-            return
+            if not jitdriver_sd.no_loop_header or not any_operation:
+                return
+            # automatically add a loop_header if there is none
+            self.metainterp.seen_loop_header_for_jdindex = jdindex
         #
         assert self.metainterp.seen_loop_header_for_jdindex == jdindex, (
             "found a loop_header for a JitDriver that does not match "
@@ -910,6 +914,40 @@
         msg = box.getref(lltype.Ptr(rstr.STR))
         lloperation.llop.debug_fatalerror(msg)
 
+    @arguments("box", "box", "box", "box", "box")
+    def opimpl_jit_debug(self, stringbox, arg1box, arg2box, arg3box, arg4box):
+        from pypy.rpython.lltypesystem import rstr
+        from pypy.rpython.annlowlevel import hlstr
+        msg = stringbox.getref(lltype.Ptr(rstr.STR))
+        debug_print('jit_debug:', hlstr(msg),
+                    arg1box.getint(), arg2box.getint(),
+                    arg3box.getint(), arg4box.getint())
+        args = [stringbox, arg1box, arg2box, arg3box, arg4box]
+        i = 4
+        while i > 0 and args[i].getint() == -sys.maxint-1:
+            i -= 1
+        assert i >= 0
+        op = self.metainterp.history.record(rop.JIT_DEBUG, args[:i+1], None)
+        self.metainterp.attach_debug_info(op)
+
+    @arguments("box")
+    def _opimpl_assert_green(self, box):
+        if not isinstance(box, Const):
+            msg = "assert_green failed at %s:%d" % (
+                self.jitcode.name,
+                self.pc)
+            if we_are_translated():
+                from pypy.rpython.annlowlevel import llstr
+                from pypy.rpython.lltypesystem import lloperation
+                lloperation.llop.debug_fatalerror(lltype.Void, llstr(msg))
+            else:
+                from pypy.rlib.jit import AssertGreenFailed
+                raise AssertGreenFailed(msg)
+
+    opimpl_int_assert_green   = _opimpl_assert_green
+    opimpl_ref_assert_green   = _opimpl_assert_green
+    opimpl_float_assert_green = _opimpl_assert_green
+
     @arguments("box")
     def opimpl_virtual_ref(self, box):
         # Details on the content of metainterp.virtualref_boxes:
@@ -1022,8 +1060,7 @@
         if resumepc >= 0:
             self.pc = resumepc
         resume.capture_resumedata(metainterp.framestack, virtualizable_boxes,
-                                  metainterp.virtualref_boxes,
-                                  resumedescr)
+                                  metainterp.virtualref_boxes, resumedescr)
         self.pc = saved_pc
         self.metainterp.staticdata.profiler.count_ops(opnum, GUARDS)
         # count

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/resoperation.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/resoperation.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/resoperation.py	Mon Oct 18 15:49:48 2010
@@ -459,6 +459,7 @@
     #'RUNTIMENEW/1',     # ootype operation    
     'COND_CALL_GC_WB/2d', # [objptr, newvalue]   (for the write barrier)
     'DEBUG_MERGE_POINT/1',      # debugging only
+    'JIT_DEBUG/*',              # debugging only
     'VIRTUAL_REF_FINISH/2',   # removed before it's passed to the backend
     'COPYSTRCONTENT/5',       # src, dst, srcstart, dststart, length
     'COPYUNICODECONTENT/5',

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/resume.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/resume.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/resume.py	Mon Oct 18 15:49:48 2010
@@ -738,15 +738,18 @@
         assert (end & 1) == 0
         return [self.decode_ref(nums[i]) for i in range(end)]
 
-    def consume_vref_and_vable_boxes(self, vinfo):
+    def consume_vref_and_vable_boxes(self, vinfo, ginfo):
         nums = self.cur_numb.nums
         self.cur_numb = self.cur_numb.prev
-        if vinfo is None:
-            virtualizable_boxes = None
-            end = len(nums)
-        else:
+        if vinfo is not None:
             virtualizable_boxes = self.consume_virtualizable_boxes(vinfo, nums)
             end = len(nums) - len(virtualizable_boxes)
+        elif ginfo is not None:
+            virtualizable_boxes = [self.decode_ref(nums[-1])]
+            end = len(nums) - 1
+        else:
+            virtualizable_boxes = None
+            end = len(nums)
         virtualref_boxes = self.consume_virtualref_boxes(nums, end)
         return virtualizable_boxes, virtualref_boxes
 

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_basic.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_basic.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_basic.py	Mon Oct 18 15:49:48 2010
@@ -2,6 +2,7 @@
 import sys
 from pypy.rlib.jit import JitDriver, we_are_jitted, hint, dont_look_inside
 from pypy.rlib.jit import OPTIMIZER_FULL, OPTIMIZER_SIMPLE, loop_invariant
+from pypy.rlib.jit import jit_debug, assert_green, AssertGreenFailed
 from pypy.jit.metainterp.warmspot import ll_meta_interp, get_stats
 from pypy.jit.backend.llgraph import runner
 from pypy.jit.metainterp import pyjitpl, history
@@ -44,6 +45,7 @@
         num_green_args = 0
         portal_graph = graphs[0]
         virtualizable_info = None
+        greenfield_info = None
         result_type = result_kind
         portal_runner_ptr = "???"
 
@@ -1644,6 +1646,33 @@
         res = self.interp_operations(f, [10, 3.5])
         assert res == 3.5
 
+    def test_jit_debug(self):
+        myjitdriver = JitDriver(greens = [], reds = ['x'])
+        class A:
+            pass
+        def f(x):
+            while x > 0:
+                myjitdriver.can_enter_jit(x=x)
+                myjitdriver.jit_merge_point(x=x)
+                jit_debug("hi there:", x)
+                jit_debug("foobar")
+                x -= 1
+            return x
+        res = self.meta_interp(f, [8])
+        assert res == 0
+        self.check_loops(jit_debug=2)
+
+    def test_assert_green(self):
+        def f(x, promote):
+            if promote:
+                x = hint(x, promote=True)
+            assert_green(x)
+            return x
+        res = self.interp_operations(f, [8, 1])
+        assert res == 8
+        py.test.raises(AssertGreenFailed, self.interp_operations, f, [8, 0])
+
+
 class TestOOtype(BasicTests, OOJitMixin):
 
     def test_oohash(self):

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_greenfield.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_greenfield.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_greenfield.py	Mon Oct 18 15:49:48 2010
@@ -27,6 +27,31 @@
         self.check_loop_count(2)
         self.check_loops(guard_value=0)
 
+    def test_green_field_2(self):
+        myjitdriver = JitDriver(greens=['ctx.x'], reds=['ctx'])
+        class Ctx(object):
+            _immutable_fields_ = ['x']
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+        def f(x, y):
+            ctx = Ctx(x, y)
+            while 1:
+                myjitdriver.can_enter_jit(ctx=ctx)
+                myjitdriver.jit_merge_point(ctx=ctx)
+                ctx.y -= 1
+                if ctx.y < 0:
+                    pass     # to just make two paths
+                if ctx.y < -10:
+                    return ctx.y
+        def g(y):
+            return f(5, y) + f(6, y)
+        #
+        res = self.meta_interp(g, [7])
+        assert res == -22
+        self.check_loop_count(4)
+        self.check_loops(guard_value=0)
+
 
 class TestLLtypeGreenFieldsTests(GreenFieldsTests, LLJitMixin):
     pass

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_virtualref.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_virtualref.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_virtualref.py	Mon Oct 18 15:49:48 2010
@@ -93,7 +93,7 @@
         lst = []
         vrefinfo.continue_tracing = lambda vref, virtual: \
                                         lst.append((vref, virtual))
-        resumereader.consume_vref_and_vable(vrefinfo, None)
+        resumereader.consume_vref_and_vable(vrefinfo, None, None)
         del vrefinfo.continue_tracing
         assert len(lst) == 1
         lltype.cast_opaque_ptr(lltype.Ptr(JIT_VIRTUAL_REF),

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_warmspot.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_warmspot.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/test/test_warmspot.py	Mon Oct 18 15:49:48 2010
@@ -296,6 +296,69 @@
         assert res == 1
         self.check_loops(int_add=1)   # I get 13 without the loop_header()
 
+    def test_omit_can_enter_jit(self):
+        # Simple test comparing the effects of always giving a can_enter_jit(),
+        # or not giving any.  Mostly equivalent, except that if given, it is
+        # ignored the first time, and so it ends up taking one extra loop to
+        # start JITting.
+        mydriver = JitDriver(greens=[], reds=['m'])
+        #
+        for i2 in range(10):
+            def f2(m):
+                while m > 0:
+                    mydriver.jit_merge_point(m=m)
+                    m -= 1
+            self.meta_interp(f2, [i2])
+            try:
+                self.check_tree_loop_count(1)
+                break
+            except AssertionError:
+                print "f2: no loop generated for i2==%d" % i2
+        else:
+            raise     # re-raise the AssertionError: check_loop_count never 1
+        #
+        for i1 in range(10):
+            def f1(m):
+                while m > 0:
+                    mydriver.can_enter_jit(m=m)
+                    mydriver.jit_merge_point(m=m)
+                    m -= 1
+            self.meta_interp(f1, [i1])
+            try:
+                self.check_tree_loop_count(1)
+                break
+            except AssertionError:
+                print "f1: no loop generated for i1==%d" % i1
+        else:
+            raise     # re-raise the AssertionError: check_loop_count never 1
+        #
+        assert i1 - 1 == i2
+
+    def test_no_loop_at_all(self):
+        mydriver = JitDriver(greens=[], reds=['m'])
+        def f2(m):
+            mydriver.jit_merge_point(m=m)
+            return m - 1
+        def f1(m):
+            while m > 0:
+                m = f2(m)
+        self.meta_interp(f1, [8])
+        # it should generate one "loop" only, which ends in a FINISH
+        # corresponding to the return from f2.
+        self.check_tree_loop_count(1)
+        self.check_loop_count(0)
+
+    def test_simple_loop(self):
+        mydriver = JitDriver(greens=[], reds=['m'])
+        def f1(m):
+            while m > 0:
+                mydriver.jit_merge_point(m=m)
+                m = m - 1
+        self.meta_interp(f1, [8])
+        self.check_loop_count(1)
+        self.check_loops({'int_sub': 1, 'int_gt': 1, 'guard_true': 1,
+                          'jump': 1})
+
 
 class TestLLWarmspot(WarmspotTests, LLJitMixin):
     CPUClass = runner.LLtypeCPU

Modified: pypy/branch/rsre-jit/pypy/jit/metainterp/warmspot.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/jit/metainterp/warmspot.py	(original)
+++ pypy/branch/rsre-jit/pypy/jit/metainterp/warmspot.py	Mon Oct 18 15:49:48 2010
@@ -115,10 +115,10 @@
     return results
 
 def find_can_enter_jit(graphs):
-    results = _find_jit_marker(graphs, 'can_enter_jit')
-    if not results:
-        raise Exception("no can_enter_jit found!")
-    return results
+    return _find_jit_marker(graphs, 'can_enter_jit')
+
+def find_loop_headers(graphs):
+    return _find_jit_marker(graphs, 'loop_header')
 
 def find_jit_merge_points(graphs):
     results = _find_jit_marker(graphs, 'jit_merge_point')
@@ -483,26 +483,37 @@
             [lltype.Signed, llmemory.GCREF], RESTYPE)
 
     def rewrite_can_enter_jits(self):
-        can_enter_jits = find_can_enter_jit(self.translator.graphs)
         sublists = {}
         for jd in self.jitdrivers_sd:
-            sublists[jd.jitdriver] = []
+            sublists[jd.jitdriver] = jd, []
+            jd.no_loop_header = True
+        #
+        loop_headers = find_loop_headers(self.translator.graphs)
+        for graph, block, index in loop_headers:
+            op = block.operations[index]
+            jitdriver = op.args[1].value
+            assert jitdriver in sublists, \
+                   "loop_header with no matching jit_merge_point"
+            jd, sublist = sublists[jitdriver]
+            jd.no_loop_header = False
+        #
+        can_enter_jits = find_can_enter_jit(self.translator.graphs)
         for graph, block, index in can_enter_jits:
             op = block.operations[index]
             jitdriver = op.args[1].value
             assert jitdriver in sublists, \
                    "can_enter_jit with no matching jit_merge_point"
+            jd, sublist = sublists[jitdriver]
             origportalgraph = jd._jit_merge_point_pos[0]
             if graph is not origportalgraph:
-                sublists[jitdriver].append((graph, block, index))
+                sublist.append((graph, block, index))
+                jd.no_loop_header = False
             else:
                 pass   # a 'can_enter_jit' before the 'jit-merge_point', but
                        # originally in the same function: we ignore it here
                        # see e.g. test_jitdriver.test_simple
         for jd in self.jitdrivers_sd:
-            sublist = sublists[jd.jitdriver]
-            assert len(sublist) > 0, \
-                   "found no can_enter_jit for %r" % (jd.jitdriver,)
+            _, sublist = sublists[jd.jitdriver]
             self.rewrite_can_enter_jit(jd, sublist)
 
     def rewrite_can_enter_jit(self, jd, can_enter_jits):
@@ -510,6 +521,19 @@
         FUNCPTR = jd._PTR_JIT_ENTER_FUNCTYPE
         jit_enter_fnptr = self.helper_func(FUNCPTR, jd._maybe_enter_jit_fn)
 
+        if len(can_enter_jits) == 0:
+            # see test_warmspot.test_no_loop_at_all
+            operations = jd.portal_graph.startblock.operations
+            op1 = operations[0]
+            assert (op1.opname == 'jit_marker' and
+                    op1.args[0].value == 'jit_merge_point')
+            op0 = SpaceOperation(
+                'jit_marker',
+                [Constant('can_enter_jit', lltype.Void)] + op1.args[1:],
+                None)
+            operations.insert(0, op0)
+            can_enter_jits = [(jd.portal_graph, jd.portal_graph.startblock, 0)]
+
         for graph, block, index in can_enter_jits:
             if graph is jd._jit_merge_point_pos[0]:
                 continue

Modified: pypy/branch/rsre-jit/pypy/rlib/jit.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/rlib/jit.py	(original)
+++ pypy/branch/rsre-jit/pypy/rlib/jit.py	Mon Oct 18 15:49:48 2010
@@ -139,6 +139,24 @@
         return hop.inputconst(lltype.Signed, _we_are_jitted)
 
 
+def jit_debug(string, arg1=-sys.maxint-1, arg2=-sys.maxint-1,
+                      arg3=-sys.maxint-1, arg4=-sys.maxint-1):
+    """When JITted, cause an extra operation DEBUG_MERGE_POINT to appear in
+    the graphs.  Should not be left after debugging."""
+    keepalive_until_here(string) # otherwise the whole function call is removed
+jit_debug.oopspec = 'jit.debug(string, arg1, arg2, arg3, arg4)'
+
+def assert_green(value):
+    """Very strong assert: checks that 'value' is a green
+    (a JIT compile-time constant)."""
+    keepalive_until_here(value)
+assert_green._annspecialcase_ = 'specialize:argtype(0)'
+assert_green.oopspec = 'jit.assert_green(value)'
+
+class AssertGreenFailed(Exception):
+    pass
+
+
 ##def force_virtualizable(virtualizable):
 ##    pass
 
@@ -266,7 +284,8 @@
             self.virtualizables = virtualizables
         for v in self.virtualizables:
             assert v in self.reds
-        self._alllivevars = dict.fromkeys(self.greens + self.reds)
+        self._alllivevars = dict.fromkeys(
+            [name for name in self.greens + self.reds if '.' not in name])
         self._make_extregistryentries()
         self.get_jitcell_at = get_jitcell_at
         self.set_jitcell_at = set_jitcell_at

Modified: pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_char.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_char.py	(original)
+++ pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_char.py	Mon Oct 18 15:49:48 2010
@@ -4,6 +4,7 @@
 import sys
 from pypy.rlib.rlocale import tolower, isalnum
 from pypy.rlib.unroll import unrolling_iterable
+from pypy.rlib import jit
 
 # Note: the unicode parts of this module require you to call
 # rsre_char.set_unicode_db() first, to select one of the modules
@@ -43,6 +44,7 @@
 # XXX can we import those safely from sre_constants?
 SRE_INFO_PREFIX = 1
 SRE_INFO_LITERAL = 2
+SRE_INFO_CHARSET = 4
 SRE_FLAG_LOCALE = 4 # honour system locale
 SRE_FLAG_UNICODE = 32 # use unicode locale
 OPCODE_INFO = 17
@@ -64,33 +66,27 @@
 
 #### Category helpers
 
-ascii_char_info = [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 2,
-2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0,
-0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 25, 25, 25, 25, 25, 25, 25,
-25, 25, 0, 0, 0, 0, 0, 0, 0, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
-24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 0, 0,
-0, 0, 16, 0, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
-24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 0, 0, 0, 0, 0 ]
-
+is_a_word = [(chr(i).isalnum() or chr(i) == '_') for i in range(256)]
 linebreak = ord("\n")
 underline = ord("_")
 
 def is_digit(code):
-    return code < 128 and (ascii_char_info[code] & 1 != 0)
+    return code <= 57 and code >= 48
 
 def is_uni_digit(code):
     assert unicodedb is not None
     return unicodedb.isdigit(code)
 
 def is_space(code):
-    return code < 128 and (ascii_char_info[code] & 2 != 0)
+    return code == 32 or (code <= 13 and code >= 9)
 
 def is_uni_space(code):
     assert unicodedb is not None
     return unicodedb.isspace(code)
 
 def is_word(code):
-    return code < 128 and (ascii_char_info[code] & 16 != 0)
+    assert code >= 0
+    return code < 256 and is_a_word[code]
 
 def is_uni_word(code):
     assert unicodedb is not None
@@ -142,6 +138,7 @@
 SET_OK = -1
 SET_NOT_OK = -2
 
+ at jit.unroll_safe
 def check_charset(pattern, ppos, char_code):
     """Checks whether a character matches set of arbitrary length.
     The set starts at pattern[ppos]."""

Modified: pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_core.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_core.py	(original)
+++ pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_core.py	Mon Oct 18 15:49:48 2010
@@ -5,6 +5,7 @@
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.objectmodel import we_are_translated
 from pypy.rlib import jit
+from pypy.rlib.rsre.rsre_jit import install_jitdriver, install_jitdriver_spec
 
 
 OPCODE_FAILURE            = 0
@@ -69,7 +70,7 @@
         return getattr(ctx, specname)(*args)
     dispatch._annspecialcase_ = 'specialize:argtype(0)'
     dispatch._specialized_methods_ = specialized_methods
-    return dispatch
+    return func_with_new_name(dispatch, specname)
 
 # ____________________________________________________________
 
@@ -79,7 +80,7 @@
 
 class AbstractMatchContext(object):
     """Abstract base class"""
-    _immutable_fields_ = ['pattern[*]', 'flags']
+    _immutable_fields_ = ['pattern[*]', 'flags', 'end']
     match_start = 0
     match_end = 0
     match_marks = None
@@ -245,7 +246,7 @@
 
     @jit.unroll_safe
     def find_first_result(self, ctx):
-        ppos = self.ppos
+        ppos = jit.hint(self.ppos, promote=True)
         while ctx.pat(ppos):
             result = sre_match(ctx, ppos + 1, self.start_ptr, self.start_marks)
             ppos += ctx.pat(ppos)
@@ -256,8 +257,10 @@
     find_next_result = find_first_result
 
 class RepeatOneMatchResult(MatchResult):
-    jitdriver = jit.JitDriver(greens=['nextppos', 'pattern'],
-                              reds=['ptr', 'self', 'ctx'])
+    install_jitdriver('RepeatOne',
+                      greens=['nextppos', 'ctx.pattern'],
+                      reds=['ptr', 'self', 'ctx'],
+                      debugprint=(1, 0))   # indices in 'greens'
 
     def __init__(self, nextppos, minptr, ptr, marks):
         self.nextppos = nextppos
@@ -269,15 +272,8 @@
         ptr = self.start_ptr
         nextppos = self.nextppos
         while ptr >= self.minptr:
-            #
-            pattern = ctx.pattern
-            self.jitdriver.can_enter_jit(self=self, ptr=ptr, ctx=ctx,
-                                         nextppos=nextppos, pattern=pattern)
-            self.jitdriver.jit_merge_point(self=self, ptr=ptr, ctx=ctx,
-                                           nextppos=nextppos, pattern=pattern)
-            if jit.we_are_jitted():
-                ctx.pattern = pattern
-            #
+            ctx.jitdriver_RepeatOne.jit_merge_point(
+                self=self, ptr=ptr, ctx=ctx, nextppos=nextppos)
             result = sre_match(ctx, nextppos, ptr, self.start_marks)
             ptr -= 1
             if result is not None:
@@ -288,8 +284,10 @@
 
 
 class MinRepeatOneMatchResult(MatchResult):
-    jitdriver = jit.JitDriver(greens=['nextppos', 'ppos3', 'pattern'],
-                              reds=['ptr', 'self', 'ctx'])
+    install_jitdriver('MinRepeatOne',
+                      greens=['nextppos', 'ppos3', 'ctx.pattern'],
+                      reds=['ptr', 'self', 'ctx'],
+                      debugprint=(2, 0))   # indices in 'greens'
 
     def __init__(self, nextppos, ppos3, maxptr, ptr, marks):
         self.nextppos = nextppos
@@ -303,17 +301,8 @@
         nextppos = self.nextppos
         ppos3 = self.ppos3
         while ptr <= self.maxptr:
-            #
-            pattern = ctx.pattern
-            self.jitdriver.can_enter_jit(self=self, ptr=ptr, ctx=ctx,
-                                         nextppos=nextppos, pattern=pattern,
-                                         ppos3=ppos3)
-            self.jitdriver.jit_merge_point(self=self, ptr=ptr, ctx=ctx,
-                                           nextppos=nextppos, pattern=pattern,
-                                           ppos3=ppos3)
-            if jit.we_are_jitted():
-                ctx.pattern = pattern
-            #
+            ctx.jitdriver_MinRepeatOne.jit_merge_point(
+                self=self, ptr=ptr, ctx=ctx, nextppos=nextppos, ppos3=ppos3)
             result = sre_match(ctx, nextppos, ptr, self.start_marks)
             if result is not None:
                 self.subresult = result
@@ -334,7 +323,7 @@
         if ptr == ctx.end:
             return False
         op = ctx.pat(ppos)
-        for op1, (checkerfn, _) in unroll_char_checker:
+        for op1, checkerfn in unroll_char_checker:
             if op1 == op:
                 return checkerfn(ctx, ptr, ppos)
         raise Error("next_char_ok[%d]" % op)
@@ -357,41 +346,34 @@
         self.next = next     # chained list
 
 class MaxUntilMatchResult(AbstractUntilMatchResult):
+    install_jitdriver('MaxUntil',
+                      greens=['ppos', 'tailppos', 'match_more', 'ctx.pattern'],
+                      reds=['ptr', 'marks', 'self', 'ctx'],
+                      debugprint=(3, 0, 2))
 
     def find_first_result(self, ctx):
-        enum = sre_match(ctx, self.ppos + 3, self.cur_ptr, self.cur_marks)
-        return self.search_next(ctx, enum, resume=False)
+        return self.search_next(ctx, match_more=True)
 
     def find_next_result(self, ctx):
-        return self.search_next(ctx, None, resume=True)
+        return self.search_next(ctx, match_more=False)
 
-    def search_next(self, ctx, enum, resume):
+    def search_next(self, ctx, match_more):
         ppos = self.ppos
-        min = ctx.pat(ppos+1)
-        max = ctx.pat(ppos+2)
+        tailppos = self.tailppos
         ptr = self.cur_ptr
         marks = self.cur_marks
         while True:
-            while True:
-                if (enum is not None and
-                    (ptr != ctx.match_end or self.num_pending < min)):
-                    #               ^^^^^^^^^^ zero-width match protection
-                    # matched one more 'item'.  record it and continue.
-                    self.pending = Pending(ptr, marks, enum, self.pending)
-                    self.num_pending += 1
-                    ptr = ctx.match_end
-                    marks = ctx.match_marks
-                    break
-                # 'item' no longer matches.
-                if not resume and self.num_pending >= min:
-                    # try to match 'tail' if we have enough 'item'
-                    result = sre_match(ctx, self.tailppos, ptr, marks)
-                    if result is not None:
-                        self.subresult = result
-                        self.cur_ptr = ptr
-                        self.cur_marks = marks
-                        return self
-                resume = False
+            ctx.jitdriver_MaxUntil.jit_merge_point(
+                ppos=ppos, tailppos=tailppos, match_more=match_more,
+                ptr=ptr, marks=marks, self=self, ctx=ctx)
+            if match_more:
+                max = ctx.pat(ppos+2)
+                if max == 65535 or self.num_pending < max:
+                    # try to match one more 'item'
+                    enum = sre_match(ctx, ppos + 3, ptr, marks)
+                else:
+                    enum = None    # 'max' reached, no more matches
+            else:
                 p = self.pending
                 if p is None:
                     return
@@ -401,11 +383,27 @@
                 marks = p.marks
                 enum = p.enum.move_to_next_result(ctx)
             #
-            if max == 65535 or self.num_pending < max:
-                # try to match one more 'item'
-                enum = sre_match(ctx, ppos + 3, ptr, marks)
+            min = ctx.pat(ppos+1)
+            if (enum is not None and
+                (ptr != ctx.match_end or self.num_pending < min)):
+                #               ^^^^^^^^^^ zero-width match protection
+                # matched one more 'item'.  record it and continue.
+                self.pending = Pending(ptr, marks, enum, self.pending)
+                self.num_pending += 1
+                ptr = ctx.match_end
+                marks = ctx.match_marks
+                match_more = True
             else:
-                enum = None    # 'max' reached, no more matches
+                # 'item' no longer matches.
+                if self.num_pending >= min:
+                    # try to match 'tail' if we have enough 'item'
+                    result = sre_match(ctx, tailppos, ptr, marks)
+                    if result is not None:
+                        self.subresult = result
+                        self.cur_ptr = ptr
+                        self.cur_marks = marks
+                        return self
+                match_more = False
 
 class MinUntilMatchResult(AbstractUntilMatchResult):
 
@@ -416,6 +414,7 @@
         return self.search_next(ctx, resume=True)
 
     def search_next(self, ctx, resume):
+        # XXX missing jit support here
         ppos = self.ppos
         min = ctx.pat(ppos+1)
         max = ctx.pat(ppos+2)
@@ -471,6 +470,12 @@
         op = ctx.pat(ppos)
         ppos += 1
 
+        #jit.jit_debug("sre_match", op, ppos, ptr)
+        #
+        # When using the JIT, calls to sre_match() must always have a constant
+        # (green) argument for 'ppos'.  If not, the following assert fails.
+        jit.assert_green(op)
+
         if op == OPCODE_FAILURE:
             return
 
@@ -745,13 +750,23 @@
 @specializectx
 def find_repetition_end(ctx, ppos, ptr, maxcount):
     end = ctx.end
-    # adjust end
-    if maxcount != 65535:
+    if maxcount <= 1:
+        if maxcount == 1 and ptr < end:
+            # Relatively common case: maxcount == 1.  If we are not at the
+            # end of the string, it's done by a single direct check.
+            op = ctx.pat(ppos)
+            for op1, checkerfn in unroll_char_checker:
+                if op1 == op:
+                    if checkerfn(ctx, ptr, ppos):
+                        return ptr + 1
+        return ptr
+    elif maxcount != 65535:
+        # adjust end
         end1 = ptr + maxcount
         if end1 <= end:
             end = end1
     op = ctx.pat(ppos)
-    for op1, (_, fre) in unroll_char_checker:
+    for op1, fre in unroll_fre_checker:
         if op1 == op:
             return fre(ctx, ptr, end, ppos)
     raise Error("rsre.find_repetition_end[%d]" % op)
@@ -784,23 +799,60 @@
     if checkerfn == match_ANY_ALL:
         def fre(ctx, ptr, end, ppos):
             return end
+    elif checkerfn == match_IN:
+        install_jitdriver_spec('MatchIn', 
+                               greens=['ppos', 'ctx.pattern'],
+                               reds=['ptr', 'end', 'ctx'],
+                               debugprint=(1, 0))
+        @specializectx
+        def fre(ctx, ptr, end, ppos):
+            while True:
+                ctx.jitdriver_MatchIn.jit_merge_point(ctx=ctx, ptr=ptr,
+                                                      end=end, ppos=ppos)
+                if ptr < end and checkerfn(ctx, ptr, ppos):
+                    ptr += 1
+                else:
+                    return ptr
+    elif checkerfn == match_IN_IGNORE:
+        install_jitdriver_spec('MatchInIgnore', 
+                               greens=['ppos', 'ctx.pattern'],
+                               reds=['ptr', 'end', 'ctx'],
+                               debugprint=(1, 0))
+        @specializectx
+        def fre(ctx, ptr, end, ppos):
+            while True:
+                ctx.jitdriver_MatchInIgnore.jit_merge_point(ctx=ctx, ptr=ptr,
+                                                            end=end, ppos=ppos)
+                if ptr < end and checkerfn(ctx, ptr, ppos):
+                    ptr += 1
+                else:
+                    return ptr
     else:
+        # in the other cases, the fre() function is not JITted at all
+        # and is present as a residual call.
+        @specializectx
         def fre(ctx, ptr, end, ppos):
             while ptr < end and checkerfn(ctx, ptr, ppos):
                 ptr += 1
             return ptr
-    return checkerfn, fre
+    fre = func_with_new_name(fre, 'fre_' + checkerfn.__name__)
+    return fre
+
+unroll_char_checker = [
+    (OPCODE_ANY,                match_ANY),
+    (OPCODE_ANY_ALL,            match_ANY_ALL),
+    (OPCODE_IN,                 match_IN),
+    (OPCODE_IN_IGNORE,          match_IN_IGNORE),
+    (OPCODE_LITERAL,            match_LITERAL),
+    (OPCODE_LITERAL_IGNORE,     match_LITERAL_IGNORE),
+    (OPCODE_NOT_LITERAL,        match_NOT_LITERAL),
+    (OPCODE_NOT_LITERAL_IGNORE, match_NOT_LITERAL_IGNORE),
+    ]
+unroll_fre_checker = [(_op, _make_fre(_fn))
+                      for (_op, _fn) in unroll_char_checker]
 
-unroll_char_checker = unrolling_iterable([
-    (OPCODE_ANY,                _make_fre(match_ANY)),
-    (OPCODE_ANY_ALL,            _make_fre(match_ANY_ALL)),
-    (OPCODE_IN,                 _make_fre(match_IN)),
-    (OPCODE_IN_IGNORE,          _make_fre(match_IN_IGNORE)),
-    (OPCODE_LITERAL,            _make_fre(match_LITERAL)),
-    (OPCODE_LITERAL_IGNORE,     _make_fre(match_LITERAL_IGNORE)),
-    (OPCODE_NOT_LITERAL,        _make_fre(match_NOT_LITERAL)),
-    (OPCODE_NOT_LITERAL_IGNORE, _make_fre(match_NOT_LITERAL_IGNORE)),
-    ])
+unroll_char_checker = unrolling_iterable(unroll_char_checker)
+unroll_fre_checker  = unrolling_iterable(unroll_fre_checker)
 
 ##### At dispatch
 
@@ -906,74 +958,139 @@
     else:
         return None
 
+install_jitdriver('Match',
+                  greens=['ctx.pattern'], reds=['ctx'],
+                  debugprint=(0,))
+
 def match_context(ctx):
     ctx.original_pos = ctx.match_start
     if ctx.end < ctx.match_start:
         return False
+    ctx.jitdriver_Match.jit_merge_point(ctx=ctx)
     return sre_match(ctx, 0, ctx.match_start, None) is not None
 
 def search_context(ctx):
     ctx.original_pos = ctx.match_start
     if ctx.end < ctx.match_start:
         return False
-    if ctx.pat(0) == OPCODE_INFO:
-        if ctx.pat(2) & rsre_char.SRE_INFO_PREFIX and ctx.pat(5) > 1:
-            return fast_search(ctx)
-    return regular_search(ctx)
+    base = 0
+    charset = False
+    if ctx.pat(base) == OPCODE_INFO:
+        flags = ctx.pat(2)
+        if flags & rsre_char.SRE_INFO_PREFIX:
+            if ctx.pat(5) > 1:
+                return fast_search(ctx)
+        else:
+            charset = (flags & rsre_char.SRE_INFO_CHARSET)
+        base += 1 + ctx.pat(1)
+    if ctx.pat(base) == OPCODE_LITERAL:
+        return literal_search(ctx, base)
+    if charset:
+        return charset_search(ctx, base)
+    return regular_search(ctx, base)
+
+install_jitdriver('RegularSearch',
+                  greens=['base', 'ctx.pattern'],
+                  reds=['start', 'ctx'],
+                  debugprint=(1, 0))
 
-def regular_search(ctx):
+def regular_search(ctx, base):
     start = ctx.match_start
     while start <= ctx.end:
-        if sre_match(ctx, 0, start, None) is not None:
+        ctx.jitdriver_RegularSearch.jit_merge_point(ctx=ctx, start=start,
+                                                    base=base)
+        if sre_match(ctx, base, start, None) is not None:
             ctx.match_start = start
             return True
         start += 1
     return False
 
+install_jitdriver_spec("LiteralSearch",
+                       greens=['base', 'character', 'ctx.pattern'],
+                       reds=['start', 'ctx'],
+                       debugprint=(2, 0, 1))
+ at specializectx
+def literal_search(ctx, base):
+    # pattern starts with a literal character.  this is used
+    # for short prefixes, and if fast search is disabled
+    character = ctx.pat(base + 1)
+    base += 2
+    start = ctx.match_start
+    while start < ctx.end:
+        ctx.jitdriver_LiteralSearch.jit_merge_point(ctx=ctx, start=start,
+                                          base=base, character=character)
+        if ctx.str(start) == character:
+            if sre_match(ctx, base, start + 1, None) is not None:
+                ctx.match_start = start
+                return True
+        start += 1
+    return False
+
+install_jitdriver_spec("CharsetSearch",
+                       greens=['base', 'ctx.pattern'],
+                       reds=['start', 'ctx'],
+                       debugprint=(1, 0))
+ at specializectx
+def charset_search(ctx, base):
+    # pattern starts with a character from a known set
+    start = ctx.match_start
+    while start < ctx.end:
+        ctx.jitdriver_CharsetSearch.jit_merge_point(ctx=ctx, start=start,
+                                                    base=base)
+        if rsre_char.check_charset(ctx.pattern, 5, ctx.str(start)):
+            if sre_match(ctx, base, start, None) is not None:
+                ctx.match_start = start
+                return True
+        start += 1
+    return False
+
+install_jitdriver_spec('FastSearch',
+                       greens=['i', 'prefix_len', 'ctx.pattern'],
+                       reds=['string_position', 'ctx'],
+                       debugprint=(2, 0))
 @specializectx
 def fast_search(ctx):
     # skips forward in a string as fast as possible using information from
     # an optimization info block
     # <INFO> <1=skip> <2=flags> <3=min> <4=...>
     #        <5=length> <6=skip> <7=prefix data> <overlap data>
-    flags = ctx.pat(2)
+    string_position = ctx.match_start
+    if string_position >= ctx.end:
+        return False
     prefix_len = ctx.pat(5)
     assert prefix_len >= 0
-    prefix_skip = ctx.pat(6)
-    assert prefix_skip >= 0
-    overlap_offset = 7 + prefix_len - 1
-    assert overlap_offset >= 0
-    pattern_offset = ctx.pat(1) + 1
-    ppos_start = pattern_offset + 2 * prefix_skip
-    assert ppos_start >= 0
     i = 0
-    string_position = ctx.match_start
-    end = ctx.end
-    while string_position < end:
-        while True:
-            char_ord = ctx.str(string_position)
-            if char_ord != ctx.pat(7 + i):
-                if i == 0:
-                    break
-                else:
-                    i = ctx.pat(overlap_offset + i)
-            else:
-                i += 1
-                if i == prefix_len:
-                    # found a potential match
-                    start = string_position + 1 - prefix_len
-                    assert start >= 0
-                    ptr = start + prefix_skip
-                    if flags & rsre_char.SRE_INFO_LITERAL:
-                        # matched all of pure literal pattern
-                        ctx.match_start = start
-                        ctx.match_end = ptr
-                        ctx.match_marks = None
-                        return True
-                    if sre_match(ctx, ppos_start, ptr, None) is not None:
-                        ctx.match_start = start
-                        return True
-                    i = ctx.pat(overlap_offset + i)
-                break
+    while True:
+        ctx.jitdriver_FastSearch.jit_merge_point(ctx=ctx,
+                string_position=string_position, i=i, prefix_len=prefix_len)
+        char_ord = ctx.str(string_position)
+        if char_ord != ctx.pat(7 + i):
+            if i > 0:
+                overlap_offset = prefix_len + (7 - 1)
+                i = ctx.pat(overlap_offset + i)
+                continue
+        else:
+            i += 1
+            if i == prefix_len:
+                # found a potential match
+                start = string_position + 1 - prefix_len
+                assert start >= 0
+                prefix_skip = ctx.pat(6)
+                ptr = start + prefix_skip
+                #flags = ctx.pat(2)
+                #if flags & rsre_char.SRE_INFO_LITERAL:
+                #    # matched all of pure literal pattern
+                #    ctx.match_start = start
+                #    ctx.match_end = ptr
+                #    ctx.match_marks = None
+                #    return True
+                pattern_offset = ctx.pat(1) + 1
+                ppos_start = pattern_offset + 2 * prefix_skip
+                if sre_match(ctx, ppos_start, ptr, None) is not None:
+                    ctx.match_start = start
+                    return True
+                overlap_offset = prefix_len + (7 - 1)
+                i = ctx.pat(overlap_offset + i)
         string_position += 1
-    return False
+        if string_position >= ctx.end:
+            return False

Added: pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_jit.py
==============================================================================
--- (empty file)
+++ pypy/branch/rsre-jit/pypy/rlib/rsre/rsre_jit.py	Mon Oct 18 15:49:48 2010
@@ -0,0 +1,40 @@
+from pypy.rlib.jit import JitDriver
+
+
+class RSreJitDriver(JitDriver):
+
+    def __init__(self, name, debugprint, **kwds):
+        JitDriver.__init__(self, **kwds)
+        #
+        def get_printable_location(*args):
+            # we print based on indices in 'args'.  We first print
+            # 'ctx.pattern' from the arg number debugprint[0].
+            pattern = args[debugprint[0]]
+            s = str(pattern)
+            if len(s) > 120:
+                s = s[:110] + '...'
+            if len(debugprint) > 1:
+                # then we print numbers from the args number
+                # debugprint[1] and possibly debugprint[2]
+                info = ' at %d' % (args[debugprint[1]],)
+                if len(debugprint) > 2:
+                    info = '%s/%d' % (info, args[debugprint[2]])
+            else:
+                info = ''
+            return '%s%s %s' % (name, info, s)
+        #
+        self.get_printable_location = get_printable_location
+
+
+def install_jitdriver(name, **kwds):
+    from pypy.rlib.rsre.rsre_core import AbstractMatchContext
+    jitdriver = RSreJitDriver(name, **kwds)
+    setattr(AbstractMatchContext, 'jitdriver_' + name, jitdriver)
+
+def install_jitdriver_spec(name, **kwds):
+    from pypy.rlib.rsre.rsre_core import StrMatchContext
+    from pypy.rlib.rsre.rsre_core import UnicodeMatchContext
+    for prefix, concreteclass in [('Str', StrMatchContext),
+                                  ('Uni', UnicodeMatchContext)]:
+        jitdriver = RSreJitDriver(prefix + name, **kwds)
+        setattr(concreteclass, 'jitdriver_' + name, jitdriver)

Modified: pypy/branch/rsre-jit/pypy/rlib/rsre/test/test_zjit.py
==============================================================================
--- pypy/branch/rsre-jit/pypy/rlib/rsre/test/test_zjit.py	(original)
+++ pypy/branch/rsre-jit/pypy/rlib/rsre/test/test_zjit.py	Mon Oct 18 15:49:48 2010
@@ -6,16 +6,30 @@
 from pypy.rpython.lltypesystem import lltype
 from pypy.rpython.annlowlevel import llstr, hlstr
 
-def entrypoint1(r, string):
+def entrypoint1(r, string, repeat):
     r = array2list(r)
     string = hlstr(string)
     make_sure_not_modified(r)
-    match = rsre_core.match(r, string)
+    match = None
+    for i in range(repeat):
+        match = rsre_core.match(r, string)
     if match is None:
         return -1
     else:
         return match.match_end
 
+def entrypoint2(r, string, repeat):
+    r = array2list(r)
+    string = hlstr(string)
+    make_sure_not_modified(r)
+    match = None
+    for i in range(repeat):
+        match = rsre_core.search(r, string)
+    if match is None:
+        return -1
+    else:
+        return match.match_start
+
 def list2array(lst):
     a = lltype.malloc(lltype.GcArray(lltype.Signed), len(lst))
     for i, x in enumerate(lst):
@@ -35,9 +49,16 @@
 
 class TestJitRSre(test_basic.LLJitMixin):
 
-    def meta_interp_match(self, pattern, string):
+    def meta_interp_match(self, pattern, string, repeat=1):
         r = get_code(pattern)
-        return self.meta_interp(entrypoint1, [list2array(r), llstr(string)],
+        return self.meta_interp(entrypoint1, [list2array(r), llstr(string),
+                                              repeat],
+                                listcomp=True, backendopt=True)
+
+    def meta_interp_search(self, pattern, string, repeat=1):
+        r = get_code(pattern)
+        return self.meta_interp(entrypoint2, [list2array(r), llstr(string),
+                                              repeat],
                                 listcomp=True, backendopt=True)
 
     def test_simple_match_1(self):
@@ -48,6 +69,11 @@
         res = self.meta_interp_match(r".*abc", "xxabcyyyyyyyyyyyyy")
         assert res == 5
 
+    def test_simple_match_repeated(self):
+        res = self.meta_interp_match(r"abcdef", "abcdef", repeat=10)
+        assert res == 6
+        self.check_tree_loop_count(1)
+
     def test_match_minrepeat_1(self):
         res = self.meta_interp_match(r".*?abc", "xxxxxxxxxxxxxxabc")
         assert res == 17
@@ -67,3 +93,28 @@
              "xxxxxxxxxxabbbbbbbbbbc")
         res = self.meta_interp_match(r".*?ab+?c", s)
         assert res == len(s)
+
+
+    def test_fast_search(self):
+        res = self.meta_interp_search(r"<foo\w+>", "e<f<f<foxd<f<fh<foobar>ua")
+        assert res == 15
+        self.check_loops(guard_value=0)
+
+    def test_regular_search(self):
+        res = self.meta_interp_search(r"<\w+>", "eiofweoxdiwhdoh<foobar>ua")
+        assert res == 15
+
+    def test_regular_search_upcase(self):
+        res = self.meta_interp_search(r"<\w+>", "EIOFWEOXDIWHDOH<FOOBAR>UA")
+        assert res == 15
+
+    def test_max_until_1(self):
+        res = self.meta_interp_match(r"(ab)*abababababc",
+                                     "ababababababababababc")
+        assert res == 21
+
+    def test_example_1(self):
+        res = self.meta_interp_search(
+            r"Active\s+20\d\d-\d\d-\d\d\s+[[]\d+[]]([^[]+)",
+            "Active"*20 + "Active 2010-04-07 [42] Foobar baz boz blah[43]")
+        assert res == 6*20



More information about the Pypy-commit mailing list