[pypy-commit] stmgc default: Save and restore slices of the shadowstack in addition to slices of the

arigo noreply at buildbot.pypy.org
Tue Aug 12 15:58:27 CEST 2014


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r1297:438a6f00fadc
Date: 2014-08-12 15:58 +0200
http://bitbucket.org/pypy/stmgc/changeset/438a6f00fadc/

Log:	Save and restore slices of the shadowstack in addition to slices of
	the C stack.

diff --git a/c7/stm/rewind_setjmp.c b/c7/stm/rewind_setjmp.c
--- a/c7/stm/rewind_setjmp.c
+++ b/c7/stm/rewind_setjmp.c
@@ -7,7 +7,8 @@
 
 struct _rewind_jmp_moved_s {
     struct _rewind_jmp_moved_s *next;
-    size_t size;
+    size_t stack_size;
+    size_t shadowstack_size;
 };
 #define RJM_HEADER  sizeof(struct _rewind_jmp_moved_s)
 
@@ -20,28 +21,41 @@
 #endif
 
 
-static void copy_stack(rewind_jmp_thread *rjthread, char *base)
+static void copy_stack(rewind_jmp_thread *rjthread, char *base, void *ssbase)
 {
+    /* Copy away part of the stack and shadowstack.
+       The stack is copied between 'base' (lower limit, i.e. newest bytes)
+       and 'rjthread->head->frame_base' (upper limit, i.e. oldest bytes).
+       The shadowstack is copied between 'ssbase' (upper limit, newest)
+       and 'rjthread->head->shadowstack_base' (lower limit, oldest).
+    */
     assert(rjthread->head != NULL);
     char *stop = rjthread->head->frame_base;
-    assert(stop > base);
+    assert(stop >= base);
+    void *ssstop = rjthread->head->shadowstack_base;
+    assert(ssstop <= ssbase);
     struct _rewind_jmp_moved_s *next = (struct _rewind_jmp_moved_s *)
-        rj_malloc(RJM_HEADER + (stop - base));
+        rj_malloc(RJM_HEADER + (stop - base) + (ssbase - ssstop));
     assert(next != NULL);    /* XXX out of memory */
     next->next = rjthread->moved_off;
-    next->size = stop - base;
+    next->stack_size = stop - base;
+    next->shadowstack_size = ssbase - ssstop;
     memcpy(((char *)next) + RJM_HEADER, base, stop - base);
+    memcpy(((char *)next) + RJM_HEADER + (stop - base), ssstop,
+           ssbase - ssstop);
 
     rjthread->moved_off_base = stop;
+    rjthread->moved_off_ssbase = ssstop;
     rjthread->moved_off = next;
 }
 
 __attribute__((noinline))
-long rewind_jmp_setjmp(rewind_jmp_thread *rjthread)
+long rewind_jmp_setjmp(rewind_jmp_thread *rjthread, void *ss)
 {
     if (rjthread->moved_off) {
         _rewind_jmp_free_stack_slices(rjthread);
     }
+    void *volatile ss1 = ss;
     rewind_jmp_thread *volatile rjthread1 = rjthread;
     int result;
     if (__builtin_setjmp(rjthread->jmpbuf) == 0) {
@@ -55,7 +69,7 @@
         result = rjthread->repeat_count + 1;
     }
     rjthread->repeat_count = result;
-    copy_stack(rjthread, (char *)&rjthread1);
+    copy_stack(rjthread, (char *)&rjthread1, ss1);
     return result;
 }
 
@@ -67,13 +81,20 @@
     while (rjthread->moved_off) {
         struct _rewind_jmp_moved_s *p = rjthread->moved_off;
         char *target = rjthread->moved_off_base;
-        target -= p->size;
+        target -= p->stack_size;
         if (target < stack_free) {
             /* need more stack space! */
             do_longjmp(rjthread, alloca(stack_free - target));
         }
-        memcpy(target, ((char *)p) + RJM_HEADER, p->size);
+        memcpy(target, ((char *)p) + RJM_HEADER, p->stack_size);
+
+        char *sstarget = rjthread->moved_off_ssbase;
+        char *ssend = sstarget + p->shadowstack_size;
+        memcpy(sstarget, ((char *)p) + RJM_HEADER + p->stack_size,
+               p->shadowstack_size);
+
         rjthread->moved_off_base = target;
+        rjthread->moved_off_ssbase = ssend;
         rjthread->moved_off = p->next;
         rj_free(p);
     }
@@ -95,7 +116,7 @@
         return;
     }
     assert(rjthread->moved_off_base < (char *)rjthread->head);
-    copy_stack(rjthread, rjthread->moved_off_base);
+    copy_stack(rjthread, rjthread->moved_off_base, rjthread->moved_off_ssbase);
 }
 
 void _rewind_jmp_free_stack_slices(rewind_jmp_thread *rjthread)
diff --git a/c7/stm/rewind_setjmp.h b/c7/stm/rewind_setjmp.h
--- a/c7/stm/rewind_setjmp.h
+++ b/c7/stm/rewind_setjmp.h
@@ -41,6 +41,7 @@
 
 typedef struct _rewind_jmp_buf {
     char *frame_base;
+    char *shadowstack_base;
     struct _rewind_jmp_buf *prev;
 } rewind_jmp_buf;
 
@@ -48,30 +49,36 @@
     rewind_jmp_buf *head;
     rewind_jmp_buf *initial_head;
     char *moved_off_base;
+    char *moved_off_ssbase;
     struct _rewind_jmp_moved_s *moved_off;
     void *jmpbuf[5];
     long repeat_count;
 } rewind_jmp_thread;
 
 
-#define rewind_jmp_enterframe(rjthread, rjbuf)   do {   \
-    (rjbuf)->frame_base = __builtin_frame_address(0);   \
-    (rjbuf)->prev = (rjthread)->head;                   \
-    (rjthread)->head = (rjbuf);                         \
+#define rewind_jmp_enterframe(rjthread, rjbuf, ss)   do {  \
+    (rjbuf)->frame_base = __builtin_frame_address(0);      \
+    (rjbuf)->shadowstack_base = (char *)(ss);              \
+    (rjbuf)->prev = (rjthread)->head;                      \
+    (rjthread)->head = (rjbuf);                            \
 } while (0)
 
-#define rewind_jmp_leaveframe(rjthread, rjbuf)   do {   \
-    (rjthread)->head = (rjbuf)->prev;                   \
-    if ((rjbuf)->frame_base == (rjthread)->moved_off_base) \
-        _rewind_jmp_copy_stack_slice(rjthread);         \
+#define rewind_jmp_leaveframe(rjthread, rjbuf, ss)   do {    \
+    assert((rjbuf)->shadowstack_base == (char *)(ss));       \
+    (rjthread)->head = (rjbuf)->prev;                        \
+    if ((rjbuf)->frame_base == (rjthread)->moved_off_base) { \
+        assert((rjthread)->moved_off_ssbase == (char *)(ss));\
+        _rewind_jmp_copy_stack_slice(rjthread);              \
+    }                                                        \
 } while (0)
 
-long rewind_jmp_setjmp(rewind_jmp_thread *rjthread);
+long rewind_jmp_setjmp(rewind_jmp_thread *rjthread, void *ss);
 void rewind_jmp_longjmp(rewind_jmp_thread *rjthread) __attribute__((noreturn));
 
 #define rewind_jmp_forget(rjthread)  do {                               \
     if ((rjthread)->moved_off) _rewind_jmp_free_stack_slices(rjthread); \
     (rjthread)->moved_off_base = 0;                                     \
+    (rjthread)->moved_off_ssbase = 0;                                   \
 } while (0)
 
 void _rewind_jmp_copy_stack_slice(rewind_jmp_thread *);
diff --git a/c7/test/test_rewind.c b/c7/test/test_rewind.c
--- a/c7/test/test_rewind.c
+++ b/c7/test/test_rewind.c
@@ -43,10 +43,10 @@
 void test1(void)
 {
     rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    rewind_jmp_enterframe(&gthread, &buf, NULL);
 
     test1_x = 0;
-    rewind_jmp_setjmp(&gthread);
+    rewind_jmp_setjmp(&gthread, NULL);
 
     test1_x++;
     f1(test1_x);
@@ -59,7 +59,7 @@
     rewind_jmp_forget(&gthread);
     assert(!rewind_jmp_armed(&gthread));
 
-    rewind_jmp_leaveframe(&gthread, &buf);
+    rewind_jmp_leaveframe(&gthread, &buf, NULL);
 }
 
 /************************************************************/
@@ -70,22 +70,22 @@
 int f2(void)
 {
     rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    rewind_jmp_enterframe(&gthread, &buf, NULL);
     test2_x = 0;
-    rewind_jmp_setjmp(&gthread);
-    rewind_jmp_leaveframe(&gthread, &buf);
+    rewind_jmp_setjmp(&gthread, NULL);
+    rewind_jmp_leaveframe(&gthread, &buf, NULL);
     return ++test2_x;
 }
 
 void test2(void)
 {
     rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    rewind_jmp_enterframe(&gthread, &buf, NULL);
     int x = f2();
     gevent(x);
     if (x < 10)
         rewind_jmp_longjmp(&gthread);
-    rewind_jmp_leaveframe(&gthread, &buf);
+    rewind_jmp_leaveframe(&gthread, &buf, NULL);
     int expected[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
     CHECK(expected);
 }
@@ -104,12 +104,12 @@
 void test3(void)
 {
     rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    rewind_jmp_enterframe(&gthread, &buf, NULL);
     int x = f3(50);
     gevent(x);
     if (x < 10)
         rewind_jmp_longjmp(&gthread);
-    rewind_jmp_leaveframe(&gthread, &buf);
+    rewind_jmp_leaveframe(&gthread, &buf, NULL);
     int expected[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
     CHECK(expected);
 }
@@ -120,25 +120,25 @@
 int f4(int rec)
 {
     rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    rewind_jmp_enterframe(&gthread, &buf, NULL);
     int res;
     if (rec > 0)
         res = f4(rec - 1);
     else
         res = f2();
-    rewind_jmp_leaveframe(&gthread, &buf);
+    rewind_jmp_leaveframe(&gthread, &buf, NULL);
     return res;
 }
 
 void test4(void)
 {
     rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    rewind_jmp_enterframe(&gthread, &buf, NULL);
     int x = f4(5);
     gevent(x);
     if (x < 10)
         rewind_jmp_longjmp(&gthread);
-    rewind_jmp_leaveframe(&gthread, &buf);
+    rewind_jmp_leaveframe(&gthread, &buf, NULL);
     int expected[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
     CHECK(expected);
 }
@@ -148,11 +148,11 @@
 void test5(void)
 {
     struct { int a; rewind_jmp_buf buf; int b; } sbuf;
-    rewind_jmp_enterframe(&gthread, &sbuf.buf);
+    rewind_jmp_enterframe(&gthread, &sbuf.buf, NULL);
     sbuf.a = 42;
     sbuf.b = -42;
     test2_x = 0;
-    rewind_jmp_setjmp(&gthread);
+    rewind_jmp_setjmp(&gthread, NULL);
     sbuf.a++;
     sbuf.b--;
     gevent(sbuf.a);
@@ -163,7 +163,7 @@
     }
     int expected[] = {43, -43, 43, -43};
     CHECK(expected);
-    rewind_jmp_leaveframe(&gthread, &sbuf.buf);
+    rewind_jmp_leaveframe(&gthread, &sbuf.buf, NULL);
 }
 
 /************************************************************/
@@ -178,9 +178,9 @@
         int a8, int a9, int a10, int a11, int a12, int a13)
 {
     rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    rewind_jmp_enterframe(&gthread, &buf, NULL);
 
-    rewind_jmp_setjmp(&gthread);
+    rewind_jmp_setjmp(&gthread, NULL);
     gevent(a1); gevent(a2); gevent(a3); gevent(a4);
     gevent(a5); gevent(a6); gevent(a7); gevent(a8);
     gevent(a9); gevent(a10); gevent(a11); gevent(a12);
@@ -201,16 +201,16 @@
         foo(&a13);
         rewind_jmp_longjmp(&gthread);
     }
-    rewind_jmp_leaveframe(&gthread, &buf);
+    rewind_jmp_leaveframe(&gthread, &buf, NULL);
 }
 
 void test6(void)
 {
     rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    rewind_jmp_enterframe(&gthread, &buf, NULL);
     test6_x = 0;
     f6(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
-    rewind_jmp_leaveframe(&gthread, &buf);
+    rewind_jmp_leaveframe(&gthread, &buf, NULL);
     int expected[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
                       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
                       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
@@ -220,45 +220,64 @@
 
 /************************************************************/
 
-typedef struct { char foo; } object_t;
-struct stm_shadowentry_s { object_t *ss; };
-typedef struct {
-    struct stm_shadowentry_s *shadowstack;
-    struct stm_shadowentry_s _inline[99];
-} stm_thread_local_t;
-#define STM_PUSH_ROOT(tl, p)   ((tl).shadowstack++->ss = (object_t *)(p))
-#define STM_POP_ROOT(tl, p)    ((p) = (typeof(p))((--(tl).shadowstack)->ss))
-void stm_register_thread_local(stm_thread_local_t *tl) {
-    tl->shadowstack = tl->_inline;
-}
-void stm_unregister_thread_local(stm_thread_local_t *tl) { }
-static stm_thread_local_t tl;
-
+static void *ssarray[99];
 
 void testTL1(void)
 {
-    object_t *a1, *a2;
-    stm_register_thread_local(&tl);
+    void *a4, *a5;
+    rewind_jmp_buf buf;
+    rewind_jmp_enterframe(&gthread, &buf, ssarray+5);
 
-    rewind_jmp_buf buf;
-    rewind_jmp_enterframe(&gthread, &buf);
+    a4 = (void *)444444;
+    a5 = (void *)555555;
+    ssarray[4] = a4;
+    ssarray[5] = a5;
 
-    a1 = a2 = (object_t *)123456;
-    STM_PUSH_ROOT(tl, a1);
-
-    if (rewind_jmp_setjmp(&gthread) == 0) {
+    if (rewind_jmp_setjmp(&gthread, ssarray+6) == 0) {
         /* first path */
-        STM_POP_ROOT(tl, a2);
-        assert(a1 == a2);
-        STM_PUSH_ROOT(tl, NULL);
+        assert(ssarray[4] == a4);
+        assert(ssarray[5] == a5);
+        ssarray[4] = NULL;
+        ssarray[5] = NULL;
         rewind_jmp_longjmp(&gthread);
     }
     /* second path */
-    STM_POP_ROOT(tl, a2);
-    assert(a1 == a2);
+    assert(ssarray[4] == NULL);   /* was not saved */
+    assert(ssarray[5] == a5);     /* saved and restored */
 
-    rewind_jmp_leaveframe(&gthread, &buf);
-    stm_unregister_thread_local(&tl);
+    rewind_jmp_leaveframe(&gthread, &buf, ssarray+5);
+}
+
+__attribute__((noinline))
+int gtl2(void)
+{
+    rewind_jmp_buf buf;
+    rewind_jmp_enterframe(&gthread, &buf, ssarray+5);
+    ssarray[5] = (void *)555555;
+
+    int result = rewind_jmp_setjmp(&gthread, ssarray+6);
+
+    assert(ssarray[4] == (void *)444444);
+    assert(ssarray[5] == (void *)555555);
+    ssarray[5] = NULL;
+
+    rewind_jmp_leaveframe(&gthread, &buf, ssarray+5);
+    return result;
+}
+
+void testTL2(void)
+{
+    rewind_jmp_buf buf;
+    rewind_jmp_enterframe(&gthread, &buf, ssarray+4);
+
+    ssarray[4] = (void *)444444;
+    int result = gtl2();
+    ssarray[4] = NULL;
+
+    if (result == 0)
+        rewind_jmp_longjmp(&gthread);
+
+    rewind_jmp_leaveframe(&gthread, &buf, ssarray+4);
 }
 
 /************************************************************/
@@ -292,6 +311,7 @@
     else if (!strcmp(argv[1], "5"))  test5();
     else if (!strcmp(argv[1], "6"))  test6();
     else if (!strcmp(argv[1], "TL1")) testTL1();
+    else if (!strcmp(argv[1], "TL2")) testTL2();
     else
         assert(!"bad argv[1]");
     assert(rj_malloc_count == 0);
diff --git a/c7/test/test_rewind.py b/c7/test/test_rewind.py
--- a/c7/test/test_rewind.py
+++ b/c7/test/test_rewind.py
@@ -6,7 +6,7 @@
                     % (opt, opt))
     if err != 0:
         raise OSError("clang failed on test_rewind.c")
-    for testnum in [1, 2, 3, 4, 5, 6, "TL1"]:
+    for testnum in [1, 2, 3, 4, 5, 6, "TL1", "TL2"]:
         print '=== O%s: RUNNING TEST %s ===' % (opt, testnum)
         err = os.system("./test_rewind_O%s %s" % (opt, testnum))
         if err != 0:


More information about the pypy-commit mailing list