[pypy-commit] pypy libgccjit-backend: Begin implementing code patching

dmalcolm noreply at buildbot.pypy.org
Thu Dec 18 19:01:30 CET 2014


Author: David Malcolm <dmalcolm at redhat.com>
Branch: libgccjit-backend
Changeset: r75019:ee6c48dc603d
Date: 2014-12-18 12:10 -0500
http://bitbucket.org/pypy/pypy/changeset/ee6c48dc603d/

Log:	Begin implementing code patching

	When generating guard failure handling, replace direct call to
	default handler with a jump through a function pointer, so that (in
	theory) we can write to this function pointer to use a different
	handler.

	This requires a couple of extra changes to libgccjit that are
	currently only in my local repo:

	 * a new API entrypoint "gcc_jit_result_get_global"
	   * an extra param to gcc_jit_context_new_global (enum
	gcc_jit_global_kind)

diff --git a/rpython/jit/backend/libgccjit/assembler.py b/rpython/jit/backend/libgccjit/assembler.py
--- a/rpython/jit/backend/libgccjit/assembler.py
+++ b/rpython/jit/backend/libgccjit/assembler.py
@@ -20,6 +20,93 @@
         self.param_addr = assembler.ctxt.new_param(assembler.t_void_ptr,
                                                    "addr")
         self.paramlist.append(self.param_addr)
+        self.paramtypes = [assembler.t_jit_frame_ptr,
+                           assembler.t_void_ptr]
+
+class Patchpoint:
+    """
+    We need to be able to patch out the generated code that runs when a
+    guard fails; this class handles that.
+
+    We handle guard failure with a tail-call through a function pointer,
+    equivalent to this C code:
+
+      struct JITFRAME * (*guard_failure_fn_ptr_0) (struct JITFRAME *, void *);
+
+      extern struct JITFRAME *
+      anonloop_0 (struct JITFRAME * jitframe, void * addr)
+      {
+        ...various operations...
+
+        if (!guard)
+         goto on_guard_failure;
+
+        ...various operations...
+
+       on_guard_failure:
+          return guard_failure_fn_ptr_0 (frame, );
+      }
+
+    This can hopefully be optimized to a jump through a ptr, since it's
+    a tail call:
+
+      0x00007fffeb7086d0 <+16>:  jle    0x7fffeb7086c8 <anonloop_0+8>
+      0x00007fffeb7086d2 <+18>:  mov    %rax,0x48(%rdi)
+      0x00007fffeb7086d6 <+22>:  mov    0x2008fb(%rip),%rax        # 0x7fffeb908fd8
+      0x00007fffeb7086dd <+29>:  mov    (%rax),%rax
+      0x00007fffeb7086e0 <+32>:  jmpq   *%rax
+    """
+    def __init__(self, assembler):
+        self.failure_params = Params(assembler)
+        self.serial = assembler.num_guard_failure_fns
+        assembler.num_guard_failure_fns += 1
+        self.t_fn_ptr_type = (
+            assembler.ctxt.new_function_ptr_type (assembler.t_jit_frame_ptr,
+                                                  self.failure_params.paramtypes,
+                                                  r_int(0)))
+        # Create the function ptr
+        # Globals are zero-initialized, so we'll need to
+        # write in the ptr to the initial handler before the loop is called,
+        # or we'll have a jump-through-NULL segfault.
+        self.fn_ptr_name = "guard_failure_fn_ptr_%i" % self.serial
+        self.failure_fn_ptr = (
+            assembler.ctxt.new_global(assembler.lib.GCC_JIT_GLOBAL_EXPORTED,
+                                      self.t_fn_ptr_type,
+                                      self.fn_ptr_name))
+        self.handler_name = "on_guard_failure_%i" % self.serial
+        self.failure_fn = (
+            assembler.ctxt.new_function(assembler.lib.GCC_JIT_FUNCTION_EXPORTED,
+                                        assembler.t_jit_frame_ptr,
+                                        self.handler_name,
+                                        self.failure_params.paramlist,
+                                        r_int(0)))
+
+    def write_initial_handler(self, result):
+        # Get the address of the machine code for the handler;
+        # this is a:
+        #   struct JITFRAME * (*guard_failure_fn) (struct JITFRAME *, void *)
+        handler = result.get_code(self.handler_name)
+
+        # Get the address of the function ptr to be written to;
+        # this is a:
+        #   struct JITFRAME * (**guard_failure_fn) (struct JITFRAME *, void *)
+        # i.e. one extra level of indirection.
+        fn_ptr_ptr = result.get_global(self.fn_ptr_name)
+
+        # We want to write the equivalent of:
+        #    (*fn_ptr_ptr) = handler;
+        #
+        # "fn_ptr_ptr" and "handler" are both currently (void *).
+        # so we need to cast them to a form where we can express
+        # the above.
+
+        # We can't directly express the function ptr ptr in lltype,
+        # so instead pretend we have a (const char **) and a (const char *):
+        fn_ptr_ptr = rffi.cast(rffi.CCHARPP, fn_ptr_ptr)
+        handler = rffi.cast(rffi.CCHARP, handler)
+
+        # ...and write through the ptr:
+        fn_ptr_ptr[0] = handler
 
 class AssemblerLibgccjit(BaseAssembler):
     _regalloc = None
@@ -48,6 +135,7 @@
         self.num_guard_failure_fns = 0
 
         self.sizeof_signed = rffi.sizeof(lltype.Signed)
+        self.patchpoints = []
 
     def make_context(self):
         eci = make_eci()
@@ -65,10 +153,10 @@
         if 1:
             self.ctxt.set_int_option(self.lib.GCC_JIT_INT_OPTION_OPTIMIZATION_LEVEL,
                                      r_int(2))
-        if 0:
+        if 1:
             self.ctxt.set_bool_option(self.lib.GCC_JIT_BOOL_OPTION_KEEP_INTERMEDIATES,
                                       r_int(1))
-        if 0:
+        if 1:
             self.ctxt.set_bool_option(self.lib.GCC_JIT_BOOL_OPTION_DUMP_EVERYTHING,
                                       r_int(1))
         if 0:
@@ -258,6 +346,10 @@
         jit_result = self.ctxt.compile()
         self.ctxt.release()
 
+        # Patch all patchpoints to their initial handlers:
+        for pp in self.patchpoints:
+            pp.write_initial_handler(jit_result)
+
         fn_ptr = jit_result.get_code(loopname)
 
         looptoken._ll_function_addr = fn_ptr
@@ -416,27 +508,21 @@
         # Write out guard failure impl:
         self.b_current = b_guard_failure
 
-        # Implement it as a tail-call to a handler function
-        # This will eventually become a function ptr, allowing
-        # patchability
-        failure_params = Params(self)
-        failure_fn = (
-            self.ctxt.new_function(self.lib.GCC_JIT_FUNCTION_INTERNAL,
-                                   self.t_jit_frame_ptr,
-                                   "on_guard_failure_%i" % self.num_guard_failure_fns,
-                                   failure_params.paramlist,
-                                   r_int(0)))
-        self.num_guard_failure_fns += 1
-        call = self.ctxt.new_call (failure_fn,
-                                   [param.as_rvalue()
-                                    for param in self.loop_params.paramlist])
+        # Implement it as a tail-call through a function ptr to
+        # a handler function, allowing patchability
+        pp = Patchpoint(self)
+        self.patchpoints.append(pp)
+        args = [param.as_rvalue()
+                for param in self.loop_params.paramlist]
+        call = self.ctxt.new_call_through_ptr (pp.failure_fn_ptr.as_rvalue(),
+                                               args)
         self._impl_write_output_args(self.loop_params, resop._fail_args)
         self.b_current.end_with_return(call)
 
-        b_within_failure_fn = failure_fn.new_block("initial")
+        b_within_failure_fn = pp.failure_fn.new_block("initial")
         self.b_current = b_within_failure_fn
-        self._impl_write_jf_descr(failure_params, resop)
-        self.b_current.end_with_return(failure_params.param_frame.as_rvalue ())
+        self._impl_write_jf_descr(pp.failure_params, resop)
+        self.b_current.end_with_return(pp.failure_params.param_frame.as_rvalue ())
         rd_locs = []
         for idx, arg in enumerate(resop._fail_args):
             rd_locs.append(idx * self.sizeof_signed)
diff --git a/rpython/jit/backend/libgccjit/rffi_bindings.py b/rpython/jit/backend/libgccjit/rffi_bindings.py
--- a/rpython/jit/backend/libgccjit/rffi_bindings.py
+++ b/rpython/jit/backend/libgccjit/rffi_bindings.py
@@ -89,6 +89,8 @@
                                                  hints={'nolength': True}))
         self.RVALUE_P_P = lltype.Ptr(lltype.Array(self.GCC_JIT_RVALUE_P,
                                                   hints={'nolength': True}))
+        self.TYPE_P_P = lltype.Ptr(lltype.Array(self.GCC_JIT_TYPE_P,
+                                                hints={'nolength': True}))
 
         # Entrypoints:
         for returntype, name, paramtypes in [
@@ -120,6 +122,10 @@
                  'gcc_jit_result_get_code', [self.GCC_JIT_RESULT_P,
                                              CCHARP]),
 
+                (VOIDP,
+                 'gcc_jit_result_get_global', [self.GCC_JIT_RESULT_P,
+                                               CCHARP]),
+
                 (lltype.Void,
                  'gcc_jit_result_release', [self.GCC_JIT_RESULT_P]),
 
@@ -170,6 +176,14 @@
                                                     INT,
                                                     self.FIELD_P_P]),
 
+                (self.GCC_JIT_TYPE_P,
+                 'gcc_jit_context_new_function_ptr_type', [self.GCC_JIT_CONTEXT_P,
+                                                           self.GCC_JIT_LOCATION_P,
+                                                           self.GCC_JIT_TYPE_P,
+                                                           INT,
+                                                           self.TYPE_P_P,
+                                                           INT]),
+
                 ############################################################
                 # Constructing functions.
                 ############################################################
@@ -205,6 +219,13 @@
                 ############################################################
                 # lvalues, rvalues and expressions.
                 ############################################################
+                (self.GCC_JIT_LVALUE_P,
+                 'gcc_jit_context_new_global', [self.GCC_JIT_CONTEXT_P,
+                                                self.GCC_JIT_LOCATION_P,
+                                                INT, # enum gcc_jit_global_kind
+                                                self.GCC_JIT_TYPE_P,
+                                                CCHARP]),
+
                 (self.GCC_JIT_RVALUE_P,
                  'gcc_jit_lvalue_as_rvalue', [self.GCC_JIT_LVALUE_P]),
 
@@ -262,6 +283,13 @@
                                               self.RVALUE_P_P]),
 
                 (self.GCC_JIT_RVALUE_P,
+                 'gcc_jit_context_new_call_through_ptr',[self.GCC_JIT_CONTEXT_P,
+                                                         self.GCC_JIT_LOCATION_P,
+                                                         self.GCC_JIT_RVALUE_P,
+                                                         INT,
+                                                         self.RVALUE_P_P]),
+
+                (self.GCC_JIT_RVALUE_P,
                  'gcc_jit_context_new_cast', [self.GCC_JIT_CONTEXT_P,
                                               self.GCC_JIT_LOCATION_P,
                                               self.GCC_JIT_RVALUE_P,
@@ -352,6 +380,13 @@
 
         self.make_enum_values(
             """
+            GCC_JIT_GLOBAL_EXPORTED,
+            GCC_JIT_GLOBAL_INTERNAL,
+            GCC_JIT_GLOBAL_IMPORTED
+            """)
+
+        self.make_enum_values(
+            """
             GCC_JIT_UNARY_OP_MINUS,
             GCC_JIT_UNARY_OP_BITWISE_NEGATE,
             GCC_JIT_UNARY_OP_LOGICAL_NEGATE,
@@ -502,6 +537,23 @@
         free_charp(name_charp)
         return Type(self.lib, inner_type)
 
+    def new_function_ptr_type(self, returntype, param_types, is_variadic):
+        raw_type_array = lltype.malloc(self.lib.TYPE_P_P.TO,
+                                       len(param_types),
+                                       flavor='raw') # of maybe gc?
+        for i in range(len(param_types)):
+            raw_type_array[i] = param_types[i].inner_type
+
+        type_ = self.lib.gcc_jit_context_new_function_ptr_type(self.inner_ctxt,
+                                                               self.lib.null_location_ptr,
+                                                               returntype.inner_type,
+                                                               r_int(len(param_types)),
+                                                               raw_type_array,
+                                                               is_variadic)
+        lltype.free(raw_type_array, flavor='raw')
+
+        return Type(self.lib, type_)
+
     def new_rvalue_from_int(self, type_, llvalue):
         return RValue(self.lib,
                       self.lib.gcc_jit_context_new_rvalue_from_int(self.inner_ctxt,
@@ -563,6 +615,20 @@
         lltype.free(raw_arg_array, flavor='raw')
         return RValue(self.lib, rvalue)
 
+    def new_call_through_ptr(self, fn_ptr, args):
+        raw_arg_array = lltype.malloc(self.lib.RVALUE_P_P.TO,
+                                      len(args),
+                                      flavor='raw') # of maybe gc?
+        for i in range(len(args)):
+            raw_arg_array[i] = args[i].inner_rvalue
+        rvalue = self.lib.gcc_jit_context_new_call_through_ptr(self.inner_ctxt,
+                                                               self.lib.null_location_ptr,
+                                                               fn_ptr.inner_rvalue,
+                                                               r_int(len(args)),
+                                                               raw_arg_array)
+        lltype.free(raw_arg_array, flavor='raw')
+        return RValue(self.lib, rvalue)
+
     def new_param(self, type_, name):
         name_charp = str2charp(name)
         param = self.lib.gcc_jit_context_new_param(self.inner_ctxt,
@@ -593,6 +659,16 @@
 
         return Function(self.lib, fn)
 
+    def new_global(self, kind, type_, name):
+        name_charp = str2charp(name)
+        lvalue = self.lib.gcc_jit_context_new_global(self.inner_ctxt,
+                                                     self.lib.null_location_ptr,
+                                                     kind,
+                                                     type_.inner_type,
+                                                     name_charp)
+        free_charp(name_charp)
+        return LValue(self.lib, lvalue)
+
     def new_cast(self, rvalue, type_):
         return RValue(self.lib,
                       self.lib.gcc_jit_context_new_cast(self.inner_ctxt,
@@ -743,5 +819,12 @@
         free_charp(name_charp)
         return fn_ptr
 
+    def get_global(self, name):
+        name_charp = str2charp(name)
+        sym_ptr = self.lib.gcc_jit_result_get_global(self.inner_result,
+                                                     name_charp)
+        free_charp(name_charp)
+        return sym_ptr
+
     def release(self):
         self.lib.gcc_jit_result_release(self.inner_result)


More information about the pypy-commit mailing list