[Python-checkins] bpo-28416: Break reference cycles in Pickler and Unpickler subclasses (GH-4080) (#4653)

Serhiy Storchaka webhook-mailer at python.org
Thu Nov 30 16:30:41 EST 2017


https://github.com/python/cpython/commit/c91bf742e542dceaf71042a44b5a04fb08bdda70
commit: c91bf742e542dceaf71042a44b5a04fb08bdda70
branch: 3.6
author: Miss Islington (bot) <31488909+miss-islington at users.noreply.github.com>
committer: Serhiy Storchaka <storchaka at gmail.com>
date: 2017-11-30T23:30:39+02:00
summary:

bpo-28416: Break reference cycles in Pickler and Unpickler subclasses (GH-4080) (#4653)

with the persistent_id() and persistent_load() methods.
(cherry picked from commit 986375ebde0dd5ff2b7349e445a06bd28a3a8ee2)

files:
A Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst
M Lib/test/test_pickle.py
M Modules/_pickle.c

diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index ee71c632733..895ed48df1d 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -6,6 +6,7 @@
 import collections
 import struct
 import sys
+import weakref
 
 import unittest
 from test import support
@@ -117,6 +118,66 @@ class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
     pickler = pickle._Pickler
     unpickler = pickle._Unpickler
 
+    @support.cpython_only
+    def test_pickler_reference_cycle(self):
+        def check(Pickler):
+            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                f = io.BytesIO()
+                pickler = Pickler(f, proto)
+                pickler.dump('abc')
+                self.assertEqual(self.loads(f.getvalue()), 'abc')
+            pickler = Pickler(io.BytesIO())
+            self.assertEqual(pickler.persistent_id('def'), 'def')
+            r = weakref.ref(pickler)
+            del pickler
+            self.assertIsNone(r())
+
+        class PersPickler(self.pickler):
+            def persistent_id(subself, obj):
+                return obj
+        check(PersPickler)
+
+        class PersPickler(self.pickler):
+            @classmethod
+            def persistent_id(cls, obj):
+                return obj
+        check(PersPickler)
+
+        class PersPickler(self.pickler):
+            @staticmethod
+            def persistent_id(obj):
+                return obj
+        check(PersPickler)
+
+    @support.cpython_only
+    def test_unpickler_reference_cycle(self):
+        def check(Unpickler):
+            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
+                self.assertEqual(unpickler.load(), 'abc')
+            unpickler = Unpickler(io.BytesIO())
+            self.assertEqual(unpickler.persistent_load('def'), 'def')
+            r = weakref.ref(unpickler)
+            del unpickler
+            self.assertIsNone(r())
+
+        class PersUnpickler(self.unpickler):
+            def persistent_load(subself, pid):
+                return pid
+        check(PersUnpickler)
+
+        class PersUnpickler(self.unpickler):
+            @classmethod
+            def persistent_load(cls, pid):
+                return pid
+        check(PersUnpickler)
+
+        class PersUnpickler(self.unpickler):
+            @staticmethod
+            def persistent_load(pid):
+                return pid
+        check(PersUnpickler)
+
 
 class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
 
@@ -197,7 +258,7 @@ class SizeofTests(unittest.TestCase):
         check_sizeof = support.check_sizeof
 
         def test_pickler(self):
-            basesize = support.calcobjsize('5P2n3i2n3iP')
+            basesize = support.calcobjsize('6P2n3i2n3iP')
             p = _pickle.Pickler(io.BytesIO())
             self.assertEqual(object.__sizeof__(p), basesize)
             MT_size = struct.calcsize('3nP0n')
@@ -214,7 +275,7 @@ def test_pickler(self):
                 0)  # Write buffer is cleared after every dump().
 
         def test_unpickler(self):
-            basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i')
+            basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n6P2n2i')
             unpickler = _pickle.Unpickler
             P = struct.calcsize('P')  # Size of memo table entry.
             n = struct.calcsize('n')  # Size of mark table entry.
diff --git a/Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst b/Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst
new file mode 100644
index 00000000000..b1014827af6
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst
@@ -0,0 +1,3 @@
+Instances of pickle.Pickler subclass with the persistent_id() method and
+pickle.Unpickler subclass with the persistent_load() method no longer create
+reference cycles.
diff --git a/Modules/_pickle.c b/Modules/_pickle.c
index 198474d88bf..7e9bb9895ca 100644
--- a/Modules/_pickle.c
+++ b/Modules/_pickle.c
@@ -353,6 +353,69 @@ _Pickle_FastCall(PyObject *func, PyObject *obj)
 
 /*************************************************************************/
 
+/* Retrieve and deconstruct a method for avoiding a reference cycle
+   (pickler -> bound method of pickler -> pickler) */
+static int
+init_method_ref(PyObject *self, _Py_Identifier *name,
+                PyObject **method_func, PyObject **method_self)
+{
+    PyObject *func, *func2;
+
+    /* *method_func and *method_self should be consistent.  All refcount decrements
+       should be occurred after setting *method_self and *method_func. */
+    func = _PyObject_GetAttrId(self, name);
+    if (func == NULL) {
+        *method_self = NULL;
+        Py_CLEAR(*method_func);
+        if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+            return -1;
+        }
+        PyErr_Clear();
+        return 0;
+    }
+
+    if (PyMethod_Check(func) && PyMethod_GET_SELF(func) == self) {
+        /* Deconstruct a bound Python method */
+        func2 = PyMethod_GET_FUNCTION(func);
+        Py_INCREF(func2);
+        *method_self = self; /* borrowed */
+        Py_XSETREF(*method_func, func2);
+        Py_DECREF(func);
+        return 0;
+    }
+    else {
+        *method_self = NULL;
+        Py_XSETREF(*method_func, func);
+        return 0;
+    }
+}
+
+/* Bind a method if it was deconstructed */
+static PyObject *
+reconstruct_method(PyObject *func, PyObject *self)
+{
+    if (self) {
+        return PyMethod_New(func, self);
+    }
+    else {
+        Py_INCREF(func);
+        return func;
+    }
+}
+
+static PyObject *
+call_method(PyObject *func, PyObject *self, PyObject *obj)
+{
+    if (self) {
+        return PyObject_CallFunctionObjArgs(func, self, obj, NULL);
+    }
+    else {
+        return PyObject_CallFunctionObjArgs(func, obj, NULL);
+    }
+}
+
+/*************************************************************************/
+
 /* Internal data type used as the unpickling stack. */
 typedef struct {
     PyObject_VAR_HEAD
@@ -545,6 +608,8 @@ typedef struct PicklerObject {
                                    objects to support self-referential objects
                                    pickling. */
     PyObject *pers_func;        /* persistent_id() method, can be NULL */
+    PyObject *pers_func_self;   /* borrowed reference to self if pers_func
+                                   is an unbound method, NULL otherwise */
     PyObject *dispatch_table;   /* private dispatch_table, can be NULL */
 
     PyObject *write;            /* write() method of the output stream. */
@@ -583,6 +648,8 @@ typedef struct UnpicklerObject {
     Py_ssize_t memo_len;        /* Number of objects in the memo */
 
     PyObject *pers_func;        /* persistent_load() method, can be NULL. */
+    PyObject *pers_func_self;   /* borrowed reference to self if pers_func
+                                   is an unbound method, NULL otherwise */
 
     Py_buffer buffer;
     char *input_buffer;
@@ -3401,7 +3468,7 @@ save_type(PicklerObject *self, PyObject *obj)
 }
 
 static int
-save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
+save_pers(PicklerObject *self, PyObject *obj)
 {
     PyObject *pid = NULL;
     int status = 0;
@@ -3409,8 +3476,7 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
     const char persid_op = PERSID;
     const char binpersid_op = BINPERSID;
 
-    Py_INCREF(obj);
-    pid = _Pickle_FastCall(func, obj);
+    pid = call_method(self->pers_func, self->pers_func_self, obj);
     if (pid == NULL)
         return -1;
 
@@ -3788,7 +3854,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
              0   if it did nothing successfully;
              1   if a persistent id was saved.
          */
-        if ((status = save_pers(self, obj, self->pers_func)) != 0)
+        if ((status = save_pers(self, obj)) != 0)
             goto done;
     }
 
@@ -4203,13 +4269,10 @@ _pickle_Pickler___init___impl(PicklerObject *self, PyObject *file,
     self->fast_nesting = 0;
     self->fast_memo = NULL;
 
-    self->pers_func = _PyObject_GetAttrId((PyObject *)self,
-                                          &PyId_persistent_id);
-    if (self->pers_func == NULL) {
-        if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
-            return -1;
-        }
-        PyErr_Clear();
+    if (init_method_ref((PyObject *)self, &PyId_persistent_id,
+                        &self->pers_func, &self->pers_func_self) < 0)
+    {
+        return -1;
     }
 
     self->dispatch_table = _PyObject_GetAttrId((PyObject *)self,
@@ -4476,11 +4539,11 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj)
 static PyObject *
 Pickler_get_persid(PicklerObject *self)
 {
-    if (self->pers_func == NULL)
+    if (self->pers_func == NULL) {
         PyErr_SetString(PyExc_AttributeError, "persistent_id");
-    else
-        Py_INCREF(self->pers_func);
-    return self->pers_func;
+        return NULL;
+    }
+    return reconstruct_method(self->pers_func, self->pers_func_self);
 }
 
 static int
@@ -4497,6 +4560,7 @@ Pickler_set_persid(PicklerObject *self, PyObject *value)
         return -1;
     }
 
+    self->pers_func_self = NULL;
     Py_INCREF(value);
     Py_XSETREF(self->pers_func, value);
 
@@ -5446,7 +5510,7 @@ load_stack_global(UnpicklerObject *self)
 static int
 load_persid(UnpicklerObject *self)
 {
-    PyObject *pid;
+    PyObject *pid, *obj;
     Py_ssize_t len;
     char *s;
 
@@ -5466,13 +5530,12 @@ load_persid(UnpicklerObject *self)
             return -1;
         }
 
-        /* This does not leak since _Pickle_FastCall() steals the reference
-           to pid first. */
-        pid = _Pickle_FastCall(self->pers_func, pid);
-        if (pid == NULL)
+        obj = call_method(self->pers_func, self->pers_func_self, pid);
+        Py_DECREF(pid);
+        if (obj == NULL)
             return -1;
 
-        PDATA_PUSH(self->stack, pid, -1);
+        PDATA_PUSH(self->stack, obj, -1);
         return 0;
     }
     else {
@@ -5487,20 +5550,19 @@ load_persid(UnpicklerObject *self)
 static int
 load_binpersid(UnpicklerObject *self)
 {
-    PyObject *pid;
+    PyObject *pid, *obj;
 
     if (self->pers_func) {
         PDATA_POP(self->stack, pid);
         if (pid == NULL)
             return -1;
 
-        /* This does not leak since _Pickle_FastCall() steals the
-           reference to pid first. */
-        pid = _Pickle_FastCall(self->pers_func, pid);
-        if (pid == NULL)
+        obj = call_method(self->pers_func, self->pers_func_self, pid);
+        Py_DECREF(pid);
+        if (obj == NULL)
             return -1;
 
-        PDATA_PUSH(self->stack, pid, -1);
+        PDATA_PUSH(self->stack, obj, -1);
         return 0;
     }
     else {
@@ -6637,13 +6699,10 @@ _pickle_Unpickler___init___impl(UnpicklerObject *self, PyObject *file,
 
     self->fix_imports = fix_imports;
 
-    self->pers_func = _PyObject_GetAttrId((PyObject *)self,
-                                          &PyId_persistent_load);
-    if (self->pers_func == NULL) {
-        if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
-            return -1;
-        }
-        PyErr_Clear();
+    if (init_method_ref((PyObject *)self, &PyId_persistent_load,
+                        &self->pers_func, &self->pers_func_self) < 0)
+    {
+        return -1;
     }
 
     self->stack = (Pdata *)Pdata_New();
@@ -6930,11 +6989,11 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj)
 static PyObject *
 Unpickler_get_persload(UnpicklerObject *self)
 {
-    if (self->pers_func == NULL)
+    if (self->pers_func == NULL) {
         PyErr_SetString(PyExc_AttributeError, "persistent_load");
-    else
-        Py_INCREF(self->pers_func);
-    return self->pers_func;
+        return NULL;
+    }
+    return reconstruct_method(self->pers_func, self->pers_func_self);
 }
 
 static int
@@ -6952,6 +7011,7 @@ Unpickler_set_persload(UnpicklerObject *self, PyObject *value)
         return -1;
     }
 
+    self->pers_func_self = NULL;
     Py_INCREF(value);
     Py_XSETREF(self->pers_func, value);
 



More information about the Python-checkins mailing list