[pypy-commit] pypy default: hg merge improve-gc-tracing-hooks

arigo noreply at buildbot.pypy.org
Tue Nov 11 18:52:05 CET 2014


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r74457:9069f9c784ba
Date: 2014-11-11 18:51 +0100
http://bitbucket.org/pypy/pypy/changeset/9069f9c784ba/

Log:	hg merge improve-gc-tracing-hooks

	A small branch (fijal, arigo) that cleans up the custom tracers
	called by the GC.

diff --git a/rpython/jit/backend/llsupport/jitframe.py b/rpython/jit/backend/llsupport/jitframe.py
--- a/rpython/jit/backend/llsupport/jitframe.py
+++ b/rpython/jit/backend/llsupport/jitframe.py
@@ -3,6 +3,7 @@
 from rpython.rlib.objectmodel import specialize
 from rpython.rlib.debug import ll_assert
 from rpython.rlib.objectmodel import enforceargs
+from rpython.rlib import rgc
 
 SIZEOFSIGNED = rffi.sizeof(lltype.Signed)
 IS_32BIT = (SIZEOFSIGNED == 4)
@@ -45,6 +46,7 @@
 # detailed explanation how it is on your architecture
 
 def jitframe_allocate(frame_info):
+    rgc.register_custom_trace_hook(JITFRAME, lambda_jitframe_trace)
     frame = lltype.malloc(JITFRAME, frame_info.jfi_frame_depth)
     frame.jf_frame_info = frame_info
     frame.jf_extra_stack_depth = 0
@@ -80,8 +82,6 @@
     ('jf_guard_exc', llmemory.GCREF),
     # in case the frame got reallocated, we have to forward it somewhere
     ('jf_forward', lltype.Ptr(JITFRAME)),
-    # absolutely useless field used to make up for tracing hooks inflexibilities
-    ('jf_gc_trace_state', lltype.Signed),
     # the actual frame
     ('jf_frame', lltype.Array(lltype.Signed)),
     # note that we keep length field, because it's crucial to have the data
@@ -105,75 +105,38 @@
 UNSIGN_SIZE = llmemory.sizeof(lltype.Unsigned)
 STACK_DEPTH_OFS = getofs('jf_extra_stack_depth')
 
-def jitframe_trace(obj_addr, prev):
-    if prev == llmemory.NULL:
-        (obj_addr + getofs('jf_gc_trace_state')).signed[0] = -1
-        return obj_addr + getofs('jf_descr')
-    fld = (obj_addr + getofs('jf_gc_trace_state')).signed[0]
-    if fld < 0:
-        if fld == -1:
-            (obj_addr + getofs('jf_gc_trace_state')).signed[0] = -2
-            return obj_addr + getofs('jf_force_descr')
-        elif fld == -2:
-            (obj_addr + getofs('jf_gc_trace_state')).signed[0] = -3
-            return obj_addr + getofs('jf_savedata')
-        elif fld == -3:
-            (obj_addr + getofs('jf_gc_trace_state')).signed[0] = -4
-            return obj_addr + getofs('jf_guard_exc')
-        elif fld == -4:
-            (obj_addr + getofs('jf_gc_trace_state')).signed[0] = -5
-            return obj_addr + getofs('jf_forward')
-        else:
-            if not (obj_addr + getofs('jf_gcmap')).address[0]:
-                return llmemory.NULL    # done
-            else:
-                fld = 0    # fall-through
-    # bit pattern
-    # decode the pattern
+def jitframe_trace(gc, obj_addr, callback, arg):
+    gc._trace_callback(callback, arg, obj_addr + getofs('jf_descr'))
+    gc._trace_callback(callback, arg, obj_addr + getofs('jf_force_descr'))
+    gc._trace_callback(callback, arg, obj_addr + getofs('jf_savedata'))
+    gc._trace_callback(callback, arg, obj_addr + getofs('jf_guard_exc'))
+    gc._trace_callback(callback, arg, obj_addr + getofs('jf_forward'))
+
     if IS_32BIT:
-        # 32 possible bits
-        state = fld & 0x1f
-        no = fld >> 5
         MAX = 32
     else:
-        # 64 possible bits
-        state = fld & 0x3f
-        no = fld >> 6
         MAX = 64
     gcmap = (obj_addr + getofs('jf_gcmap')).address[0]
+    if not gcmap:
+        return      # done
     gcmap_lgt = (gcmap + GCMAPLENGTHOFS).signed[0]
+    no = 0
     while no < gcmap_lgt:
         cur = (gcmap + GCMAPBASEOFS + UNSIGN_SIZE * no).unsigned[0]
-        while not (cur & (1 << state)):
-            state += 1
-            if state == MAX:
-                no += 1
-                state = 0
-                break      # next iteration of the outermost loop
-        else:
-            # found it
-            index = no * SIZEOFSIGNED * 8 + state
-            # save new state
-            state += 1
-            if state == MAX:
-                no += 1
-                state = 0
-            if IS_32BIT:
-                new_state = state | (no << 5)
-            else:
-                new_state = state | (no << 6)
-            (obj_addr + getofs('jf_gc_trace_state')).signed[0] = new_state
-            # sanity check
-            frame_lgt = (obj_addr + getofs('jf_frame') + LENGTHOFS).signed[0]
-            ll_assert(index < frame_lgt, "bogus frame field get")
-            return (obj_addr + getofs('jf_frame') + BASEITEMOFS + SIGN_SIZE *
-                    (index))
-    return llmemory.NULL
-
-CUSTOMTRACEFUNC = lltype.FuncType([llmemory.Address, llmemory.Address],
-                                  llmemory.Address)
-jitframe_trace_ptr = llhelper(lltype.Ptr(CUSTOMTRACEFUNC), jitframe_trace)
-
-lltype.attachRuntimeTypeInfo(JITFRAME, customtraceptr=jitframe_trace_ptr)
+        bitindex = 0
+        while bitindex < MAX:
+            if cur & (1 << bitindex):
+                # the 'bitindex' is set in 'cur'
+                index = no * SIZEOFSIGNED * 8 + bitindex
+                # sanity check
+                frame_lgt = (obj_addr + getofs('jf_frame') + LENGTHOFS) \
+                    .signed[0]
+                ll_assert(index < frame_lgt, "bogus frame field get")
+                gc._trace_callback(callback, arg,
+                                   obj_addr + getofs('jf_frame') +
+                                   BASEITEMOFS + SIGN_SIZE * index)
+            bitindex += 1
+        no += 1
+lambda_jitframe_trace = lambda: jitframe_trace
 
 JITFRAMEPTR = lltype.Ptr(JITFRAME)
diff --git a/rpython/jit/backend/llsupport/test/test_gc.py b/rpython/jit/backend/llsupport/test/test_gc.py
--- a/rpython/jit/backend/llsupport/test/test_gc.py
+++ b/rpython/jit/backend/llsupport/test/test_gc.py
@@ -254,11 +254,15 @@
     frame.jf_gcmap[2] = r_uint(2 | 16 | 32 | 128)
     frame.jf_gcmap[3] = r_uint(0)
     frame_adr = llmemory.cast_ptr_to_adr(frame)
+    #
     all_addrs = []
-    next = jitframe.jitframe_trace(frame_adr, llmemory.NULL)
-    while next:
-        all_addrs.append(next)
-        next = jitframe.jitframe_trace(frame_adr, next)
+    class FakeGC:
+        def _trace_callback(self, callback, arg, addr):
+            assert callback == "hello"
+            assert arg == "world"
+            all_addrs.append(addr)
+    jitframe.jitframe_trace(FakeGC(), frame_adr, "hello", "world")
+    #
     counter = 0
     for name in jitframe.JITFRAME._names:
         TP = getattr(jitframe.JITFRAME, name)
@@ -297,12 +301,12 @@
     frame.jf_gcmap[0] = r_uint(18446744073441116160)
     frame.jf_gcmap[1] = r_uint(18446740775107559407)
     frame.jf_gcmap[2] = r_uint(3)
-    all_addrs = []
     frame_adr = llmemory.cast_ptr_to_adr(frame)
-    next = jitframe.jitframe_trace(frame_adr, llmemory.NULL)
-    while next:
-        all_addrs.append(next)
-        next = jitframe.jitframe_trace(frame_adr, next)
+    class FakeGC:
+        def _trace_callback(self, callback, arg, addr):
+            assert callback == "hello"
+            assert arg == "world"
+    jitframe.jitframe_trace(FakeGC(), frame_adr, "hello", "world")
     # assert did not hang
 
     lltype.free(frame_info, flavor='raw')
diff --git a/rpython/memory/gc/base.py b/rpython/memory/gc/base.py
--- a/rpython/memory/gc/base.py
+++ b/rpython/memory/gc/base.py
@@ -71,7 +71,6 @@
                             member_index,
                             is_rpython_class,
                             has_custom_trace,
-                            get_custom_trace,
                             fast_path_tracing,
                             has_gcptr,
                             cannot_pin):
@@ -90,7 +89,6 @@
         self.member_index = member_index
         self.is_rpython_class = is_rpython_class
         self.has_custom_trace = has_custom_trace
-        self.get_custom_trace = get_custom_trace
         self.fast_path_tracing = fast_path_tracing
         self.has_gcptr = has_gcptr
         self.cannot_pin = cannot_pin
@@ -235,16 +233,14 @@
                 item += itemlength
                 length -= 1
         if self.has_custom_trace(typeid):
-            generator = self.get_custom_trace(typeid)
-            item = llmemory.NULL
-            while True:
-                item = generator(obj, item)
-                if not item:
-                    break
-                if self.points_to_valid_gc_object(item):
-                    callback(item, arg)
+            self.custom_trace_dispatcher(obj, typeid, callback, arg)
     _trace_slow_path._annspecialcase_ = 'specialize:arg(2)'
 
+    def _trace_callback(self, callback, arg, addr):
+        if self.is_valid_gc_object(addr.address[0]):
+            callback(addr, arg)
+    _trace_callback._annspecialcase_ = 'specialize:arg(1)'
+
     def trace_partial(self, obj, start, stop, callback, arg):
         """Like trace(), but only walk the array part, for indices in
         range(start, stop).  Must only be called if has_gcptr_in_varsize().
diff --git a/rpython/memory/gctransform/framework.py b/rpython/memory/gctransform/framework.py
--- a/rpython/memory/gctransform/framework.py
+++ b/rpython/memory/gctransform/framework.py
@@ -1,9 +1,11 @@
 from rpython.annotator import model as annmodel
 from rpython.rtyper.llannotation import SomeAddress, SomePtr
 from rpython.rlib import rgc
+from rpython.rlib.objectmodel import specialize
+from rpython.rlib.unroll import unrolling_iterable
 from rpython.rtyper import rmodel, annlowlevel
 from rpython.rtyper.lltypesystem import lltype, llmemory, rffi, llgroup
-from rpython.rtyper.lltypesystem.lloperation import LL_OPERATIONS
+from rpython.rtyper.lltypesystem.lloperation import LL_OPERATIONS, llop
 from rpython.memory import gctypelayout
 from rpython.memory.gctransform.log import log
 from rpython.memory.gctransform.support import get_rtti, ll_call_destructor
@@ -239,6 +241,7 @@
             root_walker.need_stacklet_support(self, getfn)
 
         self.layoutbuilder.encode_type_shapes_now()
+        self.create_custom_trace_funcs(gcdata.gc, translator.rtyper)
 
         annhelper.finish()   # at this point, annotate all mix-level helpers
         annhelper.backend_optimize()
@@ -502,6 +505,29 @@
                                                    [SomeAddress()],
                                                    annmodel.s_None)
 
+    def create_custom_trace_funcs(self, gc, rtyper):
+        custom_trace_funcs = tuple(rtyper.custom_trace_funcs)
+        rtyper.custom_trace_funcs = custom_trace_funcs
+        # too late to register new custom trace functions afterwards
+
+        custom_trace_funcs_unrolled = unrolling_iterable(
+            [(self.get_type_id(TP), func) for TP, func in custom_trace_funcs])
+
+        @specialize.arg(2)
+        def custom_trace_dispatcher(obj, typeid, callback, arg):
+            for type_id_exp, func in custom_trace_funcs_unrolled:
+                if (llop.combine_ushort(lltype.Signed, typeid, 0) ==
+                    llop.combine_ushort(lltype.Signed, type_id_exp, 0)):
+                    func(gc, obj, callback, arg)
+                    return
+            else:
+                assert False
+
+        gc.custom_trace_dispatcher = custom_trace_dispatcher
+
+        for TP, func in custom_trace_funcs:
+            self.gcdata._has_got_custom_trace(self.get_type_id(TP))
+            specialize.arg(2)(func)
 
     def consider_constant(self, TYPE, value):
         self.layoutbuilder.consider_constant(TYPE, value, self.gcdata.gc)
diff --git a/rpython/memory/gctransform/shadowstack.py b/rpython/memory/gctransform/shadowstack.py
--- a/rpython/memory/gctransform/shadowstack.py
+++ b/rpython/memory/gctransform/shadowstack.py
@@ -73,16 +73,13 @@
             return top
         self.decr_stack = decr_stack
 
-        root_iterator = get_root_iterator(gctransformer)
         def walk_stack_root(callback, start, end):
-            root_iterator.setcontext(NonConstant(llmemory.NULL))
             gc = self.gc
             addr = end
-            while True:
-                addr = root_iterator.nextleft(gc, start, addr)
-                if addr == llmemory.NULL:
-                    return
-                callback(gc, addr)
+            while addr != start:
+                addr -= sizeofaddr
+                if gc.points_to_valid_gc_object(addr):
+                    callback(gc, addr)
         self.rootstackhook = walk_stack_root
 
         self.shadow_stack_pool = ShadowStackPool(gcdata)
@@ -349,25 +346,6 @@
                 raise MemoryError
 
 
-def get_root_iterator(gctransformer):
-    if hasattr(gctransformer, '_root_iterator'):
-        return gctransformer._root_iterator     # if already built
-    class RootIterator(object):
-        def _freeze_(self):
-            return True
-        def setcontext(self, context):
-            pass
-        def nextleft(self, gc, start, addr):
-            while addr != start:
-                addr -= sizeofaddr
-                if gc.points_to_valid_gc_object(addr):
-                    return addr
-            return llmemory.NULL
-    result = RootIterator()
-    gctransformer._root_iterator = result
-    return result
-
-
 def get_shadowstackref(root_walker, gctransformer):
     if hasattr(gctransformer, '_SHADOWSTACKREF'):
         return gctransformer._SHADOWSTACKREF
@@ -381,19 +359,19 @@
                                      rtti=True)
     SHADOWSTACKREFPTR.TO.become(SHADOWSTACKREF)
 
+    def customtrace(gc, obj, callback, arg):
+        obj = llmemory.cast_adr_to_ptr(obj, SHADOWSTACKREFPTR)
+        addr = obj.top
+        start = obj.base
+        while addr != start:
+            addr -= sizeofaddr
+            gc._trace_callback(callback, arg, addr)
+
     gc = gctransformer.gcdata.gc
-    root_iterator = get_root_iterator(gctransformer)
-
-    def customtrace(obj, prev):
-        obj = llmemory.cast_adr_to_ptr(obj, SHADOWSTACKREFPTR)
-        if not prev:
-            root_iterator.setcontext(obj.context)
-            prev = obj.top
-        return root_iterator.nextleft(gc, obj.base, prev)
-
-    CUSTOMTRACEFUNC = lltype.FuncType([llmemory.Address, llmemory.Address],
-                                      llmemory.Address)
-    customtraceptr = llhelper(lltype.Ptr(CUSTOMTRACEFUNC), customtrace)
+    assert not hasattr(gc, 'custom_trace_dispatcher')
+    # ^^^ create_custom_trace_funcs() must not run before this
+    gctransformer.translator.rtyper.custom_trace_funcs.append(
+        (SHADOWSTACKREF, customtrace))
 
     def shadowstack_destructor(shadowstackref):
         if root_walker.stacklet_support:
@@ -414,8 +392,7 @@
     destrptr = gctransformer.annotate_helper(shadowstack_destructor,
                                              [SHADOWSTACKREFPTR], lltype.Void)
 
-    lltype.attachRuntimeTypeInfo(SHADOWSTACKREF, customtraceptr=customtraceptr,
-                                 destrptr=destrptr)
+    lltype.attachRuntimeTypeInfo(SHADOWSTACKREF, destrptr=destrptr)
 
     gctransformer._SHADOWSTACKREF = SHADOWSTACKREF
     return SHADOWSTACKREF
diff --git a/rpython/memory/gctypelayout.py b/rpython/memory/gctypelayout.py
--- a/rpython/memory/gctypelayout.py
+++ b/rpython/memory/gctypelayout.py
@@ -21,18 +21,12 @@
     # It is called with the object as first argument, and the previous
     # returned address (or NULL the first time) as the second argument.
     FINALIZER_FUNC = lltype.FuncType([llmemory.Address], lltype.Void)
-    CUSTOMTRACER_FUNC = lltype.FuncType([llmemory.Address, llmemory.Address],
-                                        llmemory.Address)
     FINALIZER = lltype.Ptr(FINALIZER_FUNC)
-    CUSTOMTRACER = lltype.Ptr(CUSTOMTRACER_FUNC)
-    EXTRA = lltype.Struct("type_info_extra",
-                          ('finalizer', FINALIZER),
-                          ('customtracer', CUSTOMTRACER))
 
     # structure describing the layout of a typeid
     TYPE_INFO = lltype.Struct("type_info",
         ("infobits",       lltype.Signed),    # combination of the T_xxx consts
-        ("extra",          lltype.Ptr(EXTRA)),
+        ("finalizer",      FINALIZER),
         ("fixedsize",      lltype.Signed),
         ("ofstoptrs",      lltype.Ptr(OFFSETS_TO_GC_PTR)),
         hints={'immutable': True},
@@ -92,18 +86,13 @@
         return (infobits & ANY) != 0
 
     def q_finalizer(self, typeid):
-        typeinfo = self.get(typeid)
-        if typeinfo.infobits & T_HAS_FINALIZER:
-            return typeinfo.extra.finalizer
-        else:
-            return lltype.nullptr(GCData.FINALIZER_FUNC)
+        return self.get(typeid).finalizer
 
     def q_light_finalizer(self, typeid):
         typeinfo = self.get(typeid)
         if typeinfo.infobits & T_HAS_LIGHTWEIGHT_FINALIZER:
-            return typeinfo.extra.finalizer
-        else:
-            return lltype.nullptr(GCData.FINALIZER_FUNC)
+            return typeinfo.finalizer
+        return lltype.nullptr(GCData.FINALIZER_FUNC)
 
     def q_offsets_to_gc_pointers(self, typeid):
         return self.get(typeid).ofstoptrs
@@ -141,12 +130,6 @@
         infobits = self.get(typeid).infobits
         return infobits & T_HAS_CUSTOM_TRACE != 0
 
-    def q_get_custom_trace(self, typeid):
-        ll_assert(self.q_has_custom_trace(typeid),
-                  "T_HAS_CUSTOM_TRACE missing")
-        typeinfo = self.get(typeid)
-        return typeinfo.extra.customtracer
-
     def q_fast_path_tracing(self, typeid):
         # return True if none of the flags T_HAS_GCPTR_IN_VARSIZE,
         # T_IS_GCARRAY_OF_GCPTR or T_HAS_CUSTOM_TRACE is set
@@ -173,11 +156,14 @@
             self.q_member_index,
             self.q_is_rpython_class,
             self.q_has_custom_trace,
-            self.q_get_custom_trace,
             self.q_fast_path_tracing,
             self.q_has_gcptr,
             self.q_cannot_pin)
 
+    def _has_got_custom_trace(self, typeid):
+        type_info = self.get(typeid)
+        type_info.infobits |= (T_HAS_CUSTOM_TRACE | T_HAS_GCPTR)
+
 
 # the lowest 16bits are used to store group member index
 T_MEMBER_INDEX              =   0xffff
@@ -186,9 +172,8 @@
 T_IS_GCARRAY_OF_GCPTR       = 0x040000
 T_IS_WEAKREF                = 0x080000
 T_IS_RPYTHON_INSTANCE       = 0x100000 # the type is a subclass of OBJECT
-T_HAS_FINALIZER             = 0x200000
-T_HAS_CUSTOM_TRACE          = 0x400000
-T_HAS_LIGHTWEIGHT_FINALIZER = 0x800000
+T_HAS_CUSTOM_TRACE          = 0x200000
+T_HAS_LIGHTWEIGHT_FINALIZER = 0x400000
 T_HAS_GCPTR                 = 0x1000000
 T_KEY_MASK                  = intmask(0xFE000000) # bug detection only
 T_KEY_VALUE                 = intmask(0x5A000000) # bug detection only
@@ -217,18 +202,11 @@
     #
     fptrs = builder.special_funcptr_for_type(TYPE)
     if fptrs:
-        extra = lltype.malloc(GCData.EXTRA, zero=True, immortal=True,
-                              flavor='raw')
         if "finalizer" in fptrs:
-            extra.finalizer = fptrs["finalizer"]
-            infobits |= T_HAS_FINALIZER
+            info.finalizer = fptrs["finalizer"]
         if "light_finalizer" in fptrs:
-            extra.finalizer = fptrs["light_finalizer"]
-            infobits |= T_HAS_FINALIZER | T_HAS_LIGHTWEIGHT_FINALIZER
-        if "custom_trace" in fptrs:
-            extra.customtracer = fptrs["custom_trace"]
-            infobits |= T_HAS_CUSTOM_TRACE | T_HAS_GCPTR
-        info.extra = extra
+            info.finalizer = fptrs["light_finalizer"]
+            infobits |= T_HAS_LIGHTWEIGHT_FINALIZER
     #
     if not TYPE._is_varsize():
         info.fixedsize = llarena.round_up_for_allocation(
@@ -420,7 +398,9 @@
         return None
 
     def initialize_gc_query_function(self, gc):
-        return GCData(self.type_info_group).set_query_functions(gc)
+        gcdata = GCData(self.type_info_group)
+        gcdata.set_query_functions(gc)
+        return gcdata
 
     def consider_constant(self, TYPE, value, gc):
         if value is not lltype.top_container(value):
diff --git a/rpython/memory/gcwrapper.py b/rpython/memory/gcwrapper.py
--- a/rpython/memory/gcwrapper.py
+++ b/rpython/memory/gcwrapper.py
@@ -29,7 +29,7 @@
                                                lltype2vtable,
                                                self.llinterp)
         self.get_type_id = layoutbuilder.get_type_id
-        layoutbuilder.initialize_gc_query_function(self.gc)
+        gcdata = layoutbuilder.initialize_gc_query_function(self.gc)
 
         constants = collect_constants(flowgraphs)
         for obj in constants:
@@ -38,8 +38,25 @@
 
         self.constantroots = layoutbuilder.addresses_of_static_ptrs
         self.constantrootsnongc = layoutbuilder.addresses_of_static_ptrs_in_nongc
+        self.prepare_custom_trace_funcs(gcdata)
         self._all_prebuilt_gc = layoutbuilder.all_prebuilt_gc
 
+    def prepare_custom_trace_funcs(self, gcdata):
+        custom_trace_funcs = self.llinterp.typer.custom_trace_funcs
+
+        def custom_trace(obj, typeid, callback, arg):
+            for TP, func in custom_trace_funcs:
+                if typeid == self.get_type_id(TP):
+                    func(self.gc, obj, callback, arg)
+                    return
+            else:
+                assert False
+        
+        for TP, func in custom_trace_funcs:
+            gcdata._has_got_custom_trace(self.get_type_id(TP))
+
+        self.gc.custom_trace_dispatcher = custom_trace
+
     # ____________________________________________________________
     #
     # Interface for the llinterp
diff --git a/rpython/memory/test/gc_test_base.py b/rpython/memory/test/gc_test_base.py
--- a/rpython/memory/test/gc_test_base.py
+++ b/rpython/memory/test/gc_test_base.py
@@ -6,7 +6,7 @@
 from rpython.rtyper.test.test_llinterp import get_interpreter
 from rpython.rtyper.lltypesystem import lltype
 from rpython.rtyper.lltypesystem.lloperation import llop
-from rpython.rlib.objectmodel import we_are_translated
+from rpython.rlib.objectmodel import we_are_translated, keepalive_until_here
 from rpython.rlib.objectmodel import compute_unique_id
 from rpython.rlib import rgc
 from rpython.rlib.rstring import StringBuilder
@@ -237,26 +237,20 @@
         assert 160 <= res <= 165
 
     def test_custom_trace(self):
-        from rpython.rtyper.annlowlevel import llhelper
         from rpython.rtyper.lltypesystem import llmemory
         from rpython.rtyper.lltypesystem.llarena import ArenaError
         #
         S = lltype.GcStruct('S', ('x', llmemory.Address),
-                                 ('y', llmemory.Address), rtti=True)
+                                 ('y', llmemory.Address))
         T = lltype.GcStruct('T', ('z', lltype.Signed))
         offset_of_x = llmemory.offsetof(S, 'x')
-        def customtrace(obj, prev):
-            if not prev:
-                return obj + offset_of_x
-            else:
-                return llmemory.NULL
-        CUSTOMTRACEFUNC = lltype.FuncType([llmemory.Address, llmemory.Address],
-                                          llmemory.Address)
-        customtraceptr = llhelper(lltype.Ptr(CUSTOMTRACEFUNC), customtrace)
-        lltype.attachRuntimeTypeInfo(S, customtraceptr=customtraceptr)
+        def customtrace(gc, obj, callback, arg):
+            gc._trace_callback(callback, arg, obj + offset_of_x)
+        lambda_customtrace = lambda: customtrace
         #
         for attrname in ['x', 'y']:
             def setup():
+                rgc.register_custom_trace_hook(S, lambda_customtrace)
                 s1 = lltype.malloc(S)
                 tx = lltype.malloc(T)
                 tx.z = 42
@@ -762,6 +756,23 @@
             assert rgc.get_gcflag_extra(a1) == False
             assert rgc.get_gcflag_extra(a2) == False
         self.interpret(fn, [])
+    
+    def test_register_custom_trace_hook(self):
+        S = lltype.GcStruct('S', ('x', lltype.Signed))
+        called = []
+
+        def trace_hook(gc, obj, callback, arg):
+            called.append("called")
+        lambda_trace_hook = lambda: trace_hook
+
+        def f():
+            rgc.register_custom_trace_hook(S, lambda_trace_hook)
+            s = lltype.malloc(S)
+            rgc.collect()
+            keepalive_until_here(s)
+
+        self.interpret(f, [])
+        assert called # not empty, can contain more than one item
 
     def test_pinning(self):
         def fn(n):
diff --git a/rpython/memory/test/test_transformed_gc.py b/rpython/memory/test/test_transformed_gc.py
--- a/rpython/memory/test/test_transformed_gc.py
+++ b/rpython/memory/test/test_transformed_gc.py
@@ -14,7 +14,7 @@
 from rpython.conftest import option
 from rpython.rlib.rstring import StringBuilder
 from rpython.rlib.rarithmetic import LONG_BIT
-import pdb
+
 
 WORD = LONG_BIT // 8
 
@@ -385,26 +385,20 @@
         assert 160 <= res <= 165
 
     def define_custom_trace(cls):
-        from rpython.rtyper.annlowlevel import llhelper
-        from rpython.rtyper.lltypesystem import llmemory
         #
-        S = lltype.GcStruct('S', ('x', llmemory.Address), rtti=True)
+        S = lltype.GcStruct('S', ('x', llmemory.Address))
         T = lltype.GcStruct('T', ('z', lltype.Signed))
         offset_of_x = llmemory.offsetof(S, 'x')
-        def customtrace(obj, prev):
-            if not prev:
-                return obj + offset_of_x
-            else:
-                return llmemory.NULL
-        CUSTOMTRACEFUNC = lltype.FuncType([llmemory.Address, llmemory.Address],
-                                          llmemory.Address)
-        customtraceptr = llhelper(lltype.Ptr(CUSTOMTRACEFUNC), customtrace)
-        lltype.attachRuntimeTypeInfo(S, customtraceptr=customtraceptr)
+        def customtrace(gc, obj, callback, arg):
+            gc._trace_callback(callback, arg, obj + offset_of_x)
+        lambda_customtrace = lambda: customtrace
+
         #
         def setup():
-            s1 = lltype.malloc(S)
+            rgc.register_custom_trace_hook(S, lambda_customtrace)
             tx = lltype.malloc(T)
             tx.z = 4243
+            s1 = lltype.malloc(S)
             s1.x = llmemory.cast_ptr_to_adr(tx)
             return s1
         def f():
diff --git a/rpython/rlib/_stacklet_asmgcc.py b/rpython/rlib/_stacklet_asmgcc.py
--- a/rpython/rlib/_stacklet_asmgcc.py
+++ b/rpython/rlib/_stacklet_asmgcc.py
@@ -1,4 +1,6 @@
 from rpython.rlib.debug import ll_assert
+from rpython.rlib import rgc
+from rpython.rlib.objectmodel import specialize
 from rpython.rtyper.lltypesystem import lltype, llmemory, rffi
 from rpython.rtyper.lltypesystem.lloperation import llop
 from rpython.rtyper.annlowlevel import llhelper, MixLevelHelperAnnotator
@@ -11,6 +13,10 @@
 _stackletrootwalker = None
 
 def get_stackletrootwalker():
+    # XXX this is too complicated now; we don't need a StackletRootWalker
+    # instance to store global state.  We could rewrite it all in one big
+    # function.  We don't care enough for now.
+
     # lazily called, to make the following imports lazy
     global _stackletrootwalker
     if _stackletrootwalker is not None:
@@ -25,8 +31,6 @@
     class StackletRootWalker(object):
         _alloc_flavor_ = "raw"
 
-        enumerating = False
-
         def setup(self, obj):
             # initialization: read the SUSPSTACK object
             p = llmemory.cast_adr_to_ptr(obj, lltype.Ptr(SUSPSTACK))
@@ -66,7 +70,8 @@
                 self.fill_initial_frame(self.curframe, anchor)
                 return True
 
-        def next(self, obj, prev):
+        @specialize.arg(3)
+        def customtrace(self, gc, obj, callback, arg):
             #
             # Pointers to the stack can be "translated" or not:
             #
@@ -79,29 +84,20 @@
             # Note that 'curframe' contains non-translated pointers, and
             # of course the stack itself is full of non-translated pointers.
             #
+            if not self.setup(obj):
+                return
+
             while True:
-                if not self.enumerating:
-                    if not prev:
-                        if not self.setup(obj):      # one-time initialization
-                            return llmemory.NULL
-                        prev = obj   # random value, but non-NULL
-                    callee = self.curframe
-                    retaddraddr = self.translateptr(callee.frame_address)
-                    retaddr = retaddraddr.address[0]
-                    ebp_in_caller = callee.regs_stored_at[INDEX_OF_EBP]
-                    ebp_in_caller = self.translateptr(ebp_in_caller)
-                    ebp_in_caller = ebp_in_caller.address[0]
-                    basewalker.locate_caller_based_on_retaddr(retaddr,
-                                                              ebp_in_caller)
-                    self.enumerating = True
-                else:
-                    callee = self.curframe
-                    ebp_in_caller = callee.regs_stored_at[INDEX_OF_EBP]
-                    ebp_in_caller = self.translateptr(ebp_in_caller)
-                    ebp_in_caller = ebp_in_caller.address[0]
-                #
-                # not really a loop, but kept this way for similarity
-                # with asmgcroot:
+                callee = self.curframe
+                retaddraddr = self.translateptr(callee.frame_address)
+                retaddr = retaddraddr.address[0]
+                ebp_in_caller = callee.regs_stored_at[INDEX_OF_EBP]
+                ebp_in_caller = self.translateptr(ebp_in_caller)
+                ebp_in_caller = ebp_in_caller.address[0]
+                basewalker.locate_caller_based_on_retaddr(retaddr,
+                                                          ebp_in_caller)
+
+                # see asmgcroot for similarity:
                 while True:
                     location = basewalker._shape_decompressor.next()
                     if location == 0:
@@ -109,9 +105,9 @@
                     addr = basewalker.getlocation(callee, ebp_in_caller,
                                                   location)
                     # yield the translated addr of the next GCREF in the stack
-                    return self.translateptr(addr)
-                #
-                self.enumerating = False
+                    addr = self.translateptr(addr)
+                    gc._trace_callback(callback, arg, addr)
+
                 caller = self.otherframe
                 reg = CALLEE_SAVED_REGS - 1
                 while reg >= 0:
@@ -129,7 +125,7 @@
                 if caller.frame_address == llmemory.NULL:
                     # completely done with this piece of stack
                     if not self.fetch_next_stack_piece():
-                        return llmemory.NULL
+                        return
                     continue
                 #
                 self.otherframe = callee
@@ -154,9 +150,10 @@
     lltype.attachRuntimeTypeInfo(SUSPSTACK, destrptr=destrptr)
 
 
-def customtrace(obj, prev):
+def customtrace(gc, obj, callback, arg):
     stackletrootwalker = get_stackletrootwalker()
-    return stackletrootwalker.next(obj, prev)
+    stackletrootwalker.customtrace(gc, obj, callback, arg)
+lambda_customtrace = lambda: customtrace
 
 def suspstack_destructor(suspstack):
     h = suspstack.handle
@@ -170,10 +167,6 @@
                             ('callback_pieces', llmemory.Address),
                             rtti=True)
 NULL_SUSPSTACK = lltype.nullptr(SUSPSTACK)
-CUSTOMTRACEFUNC = lltype.FuncType([llmemory.Address, llmemory.Address],
-                                  llmemory.Address)
-customtraceptr = llhelper(lltype.Ptr(CUSTOMTRACEFUNC), customtrace)
-lltype.attachRuntimeTypeInfo(SUSPSTACK, customtraceptr=customtraceptr)
 
 ASM_FRAMEDATA_HEAD_PTR = lltype.Ptr(lltype.ForwardReference())
 ASM_FRAMEDATA_HEAD_PTR.TO.become(lltype.Struct('ASM_FRAMEDATA_HEAD',
@@ -263,6 +256,7 @@
         self.runfn = callback
         self.arg = arg
         # make a fresh new clean SUSPSTACK
+        rgc.register_custom_trace_hook(SUSPSTACK, lambda_customtrace)
         newsuspstack = lltype.malloc(SUSPSTACK)
         newsuspstack.handle = _c.null_handle
         self.suspstack = newsuspstack
diff --git a/rpython/rlib/rgc.py b/rpython/rlib/rgc.py
--- a/rpython/rlib/rgc.py
+++ b/rpython/rlib/rgc.py
@@ -643,3 +643,22 @@
 
 def lltype_is_gc(TP):
     return getattr(getattr(TP, "TO", None), "_gckind", "?") == 'gc'
+
+def register_custom_trace_hook(TP, lambda_func):
+    """ This function does not do anything, but called from any annotated
+    place, will tell that "func" is used to trace GC roots inside any instance
+    of the type TP.  The func must be specified as "lambda: func" in this
+    call, for internal reasons.
+    """
+
+class RegisterGcTraceEntry(ExtRegistryEntry):
+    _about_ = register_custom_trace_hook
+
+    def compute_result_annotation(self, *args_s):
+        pass
+
+    def specialize_call(self, hop):
+        TP = hop.args_s[0].const
+        lambda_func = hop.args_s[1].const
+        hop.exception_cannot_occur()
+        hop.rtyper.custom_trace_funcs.append((TP, lambda_func()))
diff --git a/rpython/rlib/test/test_rgc.py b/rpython/rlib/test/test_rgc.py
--- a/rpython/rlib/test/test_rgc.py
+++ b/rpython/rlib/test/test_rgc.py
@@ -228,3 +228,17 @@
     x1 = X()
     n = rgc.get_rpy_memory_usage(rgc.cast_instance_to_gcref(x1))
     assert n >= 8 and n <= 64
+
+def test_register_custom_trace_hook():
+    TP = lltype.GcStruct('X')
+
+    def trace_func():
+        xxx # should not be annotated here
+    lambda_trace_func = lambda: trace_func
+    
+    def f():
+        rgc.register_custom_trace_hook(TP, lambda_trace_func)
+    
+    t, typer, graph = gengraph(f, [])
+
+    assert typer.custom_trace_funcs == [(TP, trace_func)]
diff --git a/rpython/rtyper/lltypesystem/lltype.py b/rpython/rtyper/lltypesystem/lltype.py
--- a/rpython/rtyper/lltypesystem/lltype.py
+++ b/rpython/rtyper/lltypesystem/lltype.py
@@ -383,8 +383,7 @@
                                                 about=self)._obj
         Struct._install_extras(self, **kwds)
 
-    def _attach_runtime_type_info_funcptr(self, funcptr, destrptr,
-                                          customtraceptr):
+    def _attach_runtime_type_info_funcptr(self, funcptr, destrptr):
         if self._runtime_type_info is None:
             raise TypeError("attachRuntimeTypeInfo: %r must have been built "
                             "with the rtti=True argument" % (self,))
@@ -408,18 +407,6 @@
                 raise TypeError("expected a destructor function "
                                 "implementation, got: %s" % destrptr)
             self._runtime_type_info.destructor_funcptr = destrptr
-        if customtraceptr is not None:
-            from rpython.rtyper.lltypesystem import llmemory
-            T = typeOf(customtraceptr)
-            if (not isinstance(T, Ptr) or
-                not isinstance(T.TO, FuncType) or
-                len(T.TO.ARGS) != 2 or
-                T.TO.RESULT != llmemory.Address or
-                T.TO.ARGS[0] != llmemory.Address or
-                T.TO.ARGS[1] != llmemory.Address):
-                raise TypeError("expected a custom trace function "
-                                "implementation, got: %s" % customtraceptr)
-            self._runtime_type_info.custom_trace_funcptr = customtraceptr
 
 class GcStruct(RttiStruct):
     _gckind = 'gc'
@@ -2288,12 +2275,10 @@
     return SomePtr(ll_ptrtype=PtrT.const)
 
 
-def attachRuntimeTypeInfo(GCSTRUCT, funcptr=None, destrptr=None,
-                          customtraceptr=None):
+def attachRuntimeTypeInfo(GCSTRUCT, funcptr=None, destrptr=None):
     if not isinstance(GCSTRUCT, RttiStruct):
         raise TypeError("expected a RttiStruct: %s" % GCSTRUCT)
-    GCSTRUCT._attach_runtime_type_info_funcptr(funcptr, destrptr,
-                                               customtraceptr)
+    GCSTRUCT._attach_runtime_type_info_funcptr(funcptr, destrptr)
     return _ptr(Ptr(RuntimeTypeInfo), GCSTRUCT._runtime_type_info)
 
 def getRuntimeTypeInfo(GCSTRUCT):
diff --git a/rpython/rtyper/lltypesystem/opimpl.py b/rpython/rtyper/lltypesystem/opimpl.py
--- a/rpython/rtyper/lltypesystem/opimpl.py
+++ b/rpython/rtyper/lltypesystem/opimpl.py
@@ -82,13 +82,11 @@
         else:
             def op_function(x, y):
                 if not isinstance(x, argtype):
-                    if not (isinstance(x, AddressAsInt) and argtype is int):
-                        raise TypeError("%r arg 1 must be %s, got %r instead"% (
-                            fullopname, typname, type(x).__name__))
+                    raise TypeError("%r arg 1 must be %s, got %r instead"% (
+                        fullopname, typname, type(x).__name__))
                 if not isinstance(y, argtype):
-                    if not (isinstance(y, AddressAsInt) and argtype is int):
-                        raise TypeError("%r arg 2 must be %s, got %r instead"% (
-                            fullopname, typname, type(y).__name__))
+                    raise TypeError("%r arg 2 must be %s, got %r instead"% (
+                        fullopname, typname, type(y).__name__))
                 return adjust_result(func(x, y))
 
     return func_with_new_name(op_function, 'op_' + fullopname)
@@ -104,6 +102,19 @@
             lltype.typeOf(adr),))
 
 
+def op_int_eq(x, y):
+    if not isinstance(x, (int, long)):
+        from rpython.rtyper.lltypesystem import llgroup
+        assert isinstance(x, llgroup.CombinedSymbolic), (
+            "'int_eq' arg 1 must be int-like, got %r instead" % (
+                type(x).__name__,))
+    if not isinstance(y, (int, long)):
+        from rpython.rtyper.lltypesystem import llgroup
+        assert isinstance(y, llgroup.CombinedSymbolic), (
+            "'int_eq' arg 2 must be int-like, got %r instead" % (
+                type(y).__name__,))
+    return x == y
+
 def op_ptr_eq(ptr1, ptr2):
     checkptr(ptr1)
     checkptr(ptr2)
diff --git a/rpython/rtyper/rtyper.py b/rpython/rtyper/rtyper.py
--- a/rpython/rtyper/rtyper.py
+++ b/rpython/rtyper/rtyper.py
@@ -60,6 +60,7 @@
         # make the primitive_to_repr constant mapping
         self.primitive_to_repr = {}
         self.exceptiondata = ExceptionData(self)
+        self.custom_trace_funcs = []
 
         try:
             self.seed = int(os.getenv('RTYPERSEED'))
@@ -645,7 +646,7 @@
             raise TyperError("runtime type info function %r returns %r, "
                              "excepted Ptr(RuntimeTypeInfo)" % (func, s))
         funcptr = self.getcallable(graph)
-        attachRuntimeTypeInfo(GCSTRUCT, funcptr, destrptr, None)
+        attachRuntimeTypeInfo(GCSTRUCT, funcptr, destrptr)
 
 # register operations from annotation model
 RPythonTyper._registeroperations(unaryop.UNARY_OPERATIONS, binaryop.BINARY_OPERATIONS)
diff --git a/rpython/translator/c/test/test_newgc.py b/rpython/translator/c/test/test_newgc.py
--- a/rpython/translator/c/test/test_newgc.py
+++ b/rpython/translator/c/test/test_newgc.py
@@ -443,19 +443,14 @@
     def define_custom_trace(cls):
         from rpython.rtyper.annlowlevel import llhelper
         #
-        S = lltype.GcStruct('S', ('x', llmemory.Address), rtti=True)
+        S = lltype.GcStruct('S', ('x', llmemory.Address))
         offset_of_x = llmemory.offsetof(S, 'x')
-        def customtrace(obj, prev):
-            if not prev:
-                return obj + offset_of_x
-            else:
-                return llmemory.NULL
-        CUSTOMTRACEFUNC = lltype.FuncType([llmemory.Address, llmemory.Address],
-                                          llmemory.Address)
-        customtraceptr = llhelper(lltype.Ptr(CUSTOMTRACEFUNC), customtrace)
-        lltype.attachRuntimeTypeInfo(S, customtraceptr=customtraceptr)
+        def customtrace(gc, obj, callback, arg):
+            gc._trace_callback(callback, arg, obj + offset_of_x)
+        lambda_customtrace = lambda: customtrace
         #
         def setup():
+            rgc.register_custom_trace_hook(S, lambda_customtrace)
             s = lltype.nullptr(S)
             for i in range(10000):
                 t = lltype.malloc(S)


More information about the pypy-commit mailing list