[pypy-commit] pypy custom-trace: Rewrite the jit+shadowstack root enumerator in an iterator-like

arigo noreply at buildbot.pypy.org
Wed Aug 10 13:51:00 CEST 2011


Author: Armin Rigo <arigo at tunes.org>
Branch: custom-trace
Changeset: r46418:618f8d473cd6
Date: 2011-08-10 13:53 +0200
http://bitbucket.org/pypy/pypy/changeset/618f8d473cd6/

Log:	Rewrite the jit+shadowstack root enumerator in an iterator-like
	style. Goal: support custom_trace in order to do threads or
	stacklets.

diff --git a/pypy/jit/backend/llsupport/gc.py b/pypy/jit/backend/llsupport/gc.py
--- a/pypy/jit/backend/llsupport/gc.py
+++ b/pypy/jit/backend/llsupport/gc.py
@@ -367,44 +367,93 @@
 
     def add_jit2gc_hooks(self, jit2gc):
         #
-        def collect_jit_stack_root(callback, gc, addr):
-            # Note: first check with 'points_to_valid_gc_object' if the
-            # addr.address[0] appears to be a valid pointer.  It returns
-            # False if it's NULL, and may also check for tagged integers.
-            # The important part here is that it will return True for a
-            # pointer to a MARKER (which is word-aligned), even though it's
-            # not pointing to a valid GC object.
-            if gc.points_to_valid_gc_object(addr):
-                if addr.address[0].signed[0] != GcRootMap_shadowstack.MARKER:
-                    # common case
-                    callback(gc, addr)
-                else:
-                    # points to a MARKER
-                    follow_stack_frame_of_assembler(callback, gc, addr)
+        # ---------------
+        # This is used to enumerate the shadowstack in the presence
+        # of the JIT.  It is written as an iterator that can also be
+        # used with a custom_trace.
         #
-        def follow_stack_frame_of_assembler(callback, gc, addr):
-            frame_addr = addr.signed[0] - self.marker_ofs
-            addr = llmemory.cast_int_to_adr(frame_addr + self.force_index_ofs)
-            force_index = addr.signed[0]
-            if force_index < 0:
-                force_index = ~force_index
-            callshape = self._callshapes[force_index]
-            # NB: the previous line reads a still-alive _callshapes,
-            # because we ensure that just before we called this piece of
-            # assembler, we put on the (same) stack a pointer to a
-            # loop_token that keeps the force_index alive.
-            n = 0
-            while True:
-                offset = rffi.cast(lltype.Signed, callshape[n])
-                if offset == 0:
-                    break
-                addr = llmemory.cast_int_to_adr(frame_addr + offset)
-                if gc.points_to_valid_gc_object(addr):
-                    callback(gc, addr)
-                n += 1
+        class RootIterator:
+            _alloc_flavor_ = "raw"
+
+            def setup(iself, gc):
+                iself.gc = gc
+                iself.frame_addr = 0
+
+            def next(iself, prev, range_lowest):
+                # Return the "next" valid GC object' address.  We enumerating
+                # backwards, starting from the high addresses, until we reach
+                # the 'range_lowest'.  The 'prev' argument is the previous
+                # result (or the high end of the shadowstack to start with).
+                #
+                while True:
+                    #
+                    # If we are not iterating right now in a JIT frame
+                    if iself.frame_addr == 0:
+                        #
+                        # Look for the previous shadowstack address that
+                        # contains a valid pointer
+                        while prev != range_lowest:
+                            prev -= llmemory.sizeof(llmemory.Address)
+                            if iself.gc.points_to_valid_gc_object(prev):
+                                break
+                        else:
+                            return llmemory.NULL
+                        #
+                        # Now a "valid" pointer can be either really valid, or
+                        # it can be a pointer to a JIT frame in the stack.  The
+                        # important part here is that points_to_valid_gc_object
+                        # above returns True even for a pointer to a MARKER
+                        # (which is word-aligned).
+                        if prev.address[0].signed[0] != self.MARKER:
+                            return prev
+                        #
+                        # It's a JIT frame.  Save away 'prev' for later, and
+                        # go into JIT-frame-exploring mode.
+                        iself.saved_prev = prev
+                        frame_addr = prev.signed[0] - self.marker_ofs
+                        iself.frame_addr = frame_addr
+                        addr = llmemory.cast_int_to_adr(frame_addr +
+                                                        self.force_index_ofs)
+                        force_index = addr.signed[0]
+                        if force_index < 0:
+                            force_index = ~force_index
+                        # NB: the next line reads a still-alive _callshapes,
+                        # because we ensure that just before we called this
+                        # piece of assembler, we put on the (same) stack a
+                        # pointer to a loop_token that keeps the force_index
+                        # alive.
+                        callshape = self._callshapes[force_index]
+                    else:
+                        # Continuing to explore this JIT frame
+                        callshape = iself.callshape
+                    #
+                    # 'callshape' points to the next INT of the callshape.
+                    # If it's zero we are done with the JIT frame.
+                    while callshape[0] != 0:
+                        #
+                        # Non-zero: it's an offset inside the JIT frame.
+                        # Read it and increment 'callshape'.
+                        offset = callshape[0]
+                        callshape = lltype.direct_ptradd(callshape, 1)
+                        addr = llmemory.cast_int_to_adr(iself.frame_addr +
+                                                        offset)
+                        if iself.gc.points_to_valid_gc_object(addr):
+                            #
+                            # The JIT frame contains a valid GC pointer at
+                            # this address (as opposed to NULL).  Save
+                            # 'callshape' for the next call, and return the
+                            # address.
+                            iself.callshape = callshape
+                            return addr
+                    #
+                    # Restore 'prev' and loop back to the start.
+                    prev = iself.saved_prev
+                    iself.frame_addr = 0
+
+        # ---------------
         #
         jit2gc.update({
-            'rootstackhook': collect_jit_stack_root,
+            'root_iterator': RootIterator(),
             })
 
     def initialize(self):
diff --git a/pypy/rpython/memory/gctransform/framework.py b/pypy/rpython/memory/gctransform/framework.py
--- a/pypy/rpython/memory/gctransform/framework.py
+++ b/pypy/rpython/memory/gctransform/framework.py
@@ -152,13 +152,8 @@
             # for regular translation: pick the GC from the config
             GCClass, GC_PARAMS = choose_gc_from_config(translator.config)
 
-        self.root_stack_jit_hook = None
         if hasattr(translator, '_jit2gc'):
             self.layoutbuilder = translator._jit2gc['layoutbuilder']
-            try:
-                self.root_stack_jit_hook = translator._jit2gc['rootstackhook']
-            except KeyError:
-                pass
         else:
             self.layoutbuilder = TransformerLayoutBuilder(translator, GCClass)
         self.layoutbuilder.transformer = self
diff --git a/pypy/rpython/memory/gctransform/shadowstack.py b/pypy/rpython/memory/gctransform/shadowstack.py
--- a/pypy/rpython/memory/gctransform/shadowstack.py
+++ b/pypy/rpython/memory/gctransform/shadowstack.py
@@ -27,12 +27,27 @@
             return top
         self.decr_stack = decr_stack
 
-        self.rootstackhook = gctransformer.root_stack_jit_hook
-        if self.rootstackhook is None:
-            def collect_stack_root(callback, gc, addr):
-                if gc.points_to_valid_gc_object(addr):
-                    callback(gc, addr)
-            self.rootstackhook = collect_stack_root
+        translator = gctransformer.translator
+        if (hasattr(translator, '_jit2gc') and
+                'root_iterator' in translator._jit2gc):
+            root_iterator = translator._jit2gc['root_iterator']
+            def jit_walk_stack_root(callback, addr, end):
+                gc = self.gc
+                root_iterator.setup(gc)
+                while True:
+                    end = root_iterator.next(end, addr)
+                    if end == llmemory.NULL:
+                        return
+                    callback(gc, end)
+            self.rootstackhook = jit_walk_stack_root
+        else:
+            def default_walk_stack_root(callback, addr, end):
+                gc = self.gc
+                while addr != end:
+                    if gc.points_to_valid_gc_object(addr):
+                        callback(gc, addr)
+                    addr += sizeofaddr
+            self.rootstackhook = default_walk_stack_root
 
     def push_stack(self, addr):
         top = self.incr_stack(1)
@@ -54,17 +69,13 @@
 
     def walk_stack_roots(self, collect_stack_root):
         gcdata = self.gcdata
-        gc = self.gc
-        rootstackhook = self.rootstackhook
-        addr = gcdata.root_stack_base
-        end = gcdata.root_stack_top
-        while addr != end:
-            rootstackhook(collect_stack_root, gc, addr)
-            addr += sizeofaddr
+        self.rootstackhook(collect_stack_root,
+                           gcdata.root_stack_base, gcdata.root_stack_top)
         if self.collect_stacks_from_other_threads is not None:
             self.collect_stacks_from_other_threads(collect_stack_root)
 
     def need_thread_support(self, gctransformer, getfn):
+        XXXXXX   # FIXME
         from pypy.module.thread import ll_thread    # xxx fish
         from pypy.rpython.memory.support import AddressDict
         from pypy.rpython.memory.support import copy_without_null_values


More information about the pypy-commit mailing list