[Python-checkins] cpython: Issue #23611: Serializing more "lookupable" objects (such as unbound methods

serhiy.storchaka python-checkins at python.org
Tue Mar 31 13:08:33 CEST 2015


https://hg.python.org/cpython/rev/ca12465418bd
changeset:   95320:ca12465418bd
user:        Serhiy Storchaka <storchaka at gmail.com>
date:        Tue Mar 31 14:07:24 2015 +0300
summary:
  Issue #23611: Serializing more "lookupable" objects (such as unbound methods
or nested classes) now are supported with pickle protocols < 4.

files:
  Doc/whatsnew/3.5.rst     |    7 +
  Lib/pickle.py            |   33 ++++---
  Lib/test/pickletester.py |   19 +++-
  Misc/NEWS                |    3 +
  Modules/_pickle.c        |  120 ++++++++++++++++----------
  5 files changed, 115 insertions(+), 67 deletions(-)


diff --git a/Doc/whatsnew/3.5.rst b/Doc/whatsnew/3.5.rst
--- a/Doc/whatsnew/3.5.rst
+++ b/Doc/whatsnew/3.5.rst
@@ -370,6 +370,13 @@
 * :class:`os.stat_result` now has a :attr:`~os.stat_result.st_file_attributes`
   attribute on Windows.  (Contributed by Ben Hoyt in :issue:`21719`.)
 
+pickle
+------
+
+* Serializing more "lookupable" objects (such as unbound methods or nested
+  classes) now are supported with pickle protocols < 4.
+  (Contributed by Serhiy Storchaka in :issue:`23611`.)
+
 re
 --
 
diff --git a/Lib/pickle.py b/Lib/pickle.py
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -258,24 +258,20 @@
 
 # Tools used for pickling.
 
-def _getattribute(obj, name, allow_qualname=False):
-    dotted_path = name.split(".")
-    if not allow_qualname and len(dotted_path) > 1:
-        raise AttributeError("Can't get qualified attribute {!r} on {!r}; " +
-                             "use protocols >= 4 to enable support"
-                             .format(name, obj))
-    for subpath in dotted_path:
+def _getattribute(obj, name):
+    for subpath in name.split('.'):
         if subpath == '<locals>':
             raise AttributeError("Can't get local attribute {!r} on {!r}"
                                  .format(name, obj))
         try:
+            parent = obj
             obj = getattr(obj, subpath)
         except AttributeError:
             raise AttributeError("Can't get attribute {!r} on {!r}"
                                  .format(name, obj))
-    return obj
+    return obj, parent
 
-def whichmodule(obj, name, allow_qualname=False):
+def whichmodule(obj, name):
     """Find the module an object belong to."""
     module_name = getattr(obj, '__module__', None)
     if module_name is not None:
@@ -286,7 +282,7 @@
         if module_name == '__main__' or module is None:
             continue
         try:
-            if _getattribute(module, name, allow_qualname) is obj:
+            if _getattribute(module, name)[0] is obj:
                 return module_name
         except AttributeError:
             pass
@@ -899,16 +895,16 @@
         write = self.write
         memo = self.memo
 
-        if name is None and self.proto >= 4:
+        if name is None:
             name = getattr(obj, '__qualname__', None)
         if name is None:
             name = obj.__name__
 
-        module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4)
+        module_name = whichmodule(obj, name)
         try:
             __import__(module_name, level=0)
             module = sys.modules[module_name]
-            obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4)
+            obj2, parent = _getattribute(module, name)
         except (ImportError, KeyError, AttributeError):
             raise PicklingError(
                 "Can't pickle %r: it's not found as %s.%s" %
@@ -930,11 +926,16 @@
                 else:
                     write(EXT4 + pack("<i", code))
                 return
+        lastname = name.rpartition('.')[2]
+        if parent is module:
+            name = lastname
         # Non-ASCII identifiers are supported only with protocols >= 3.
         if self.proto >= 4:
             self.save(module_name)
             self.save(name)
             write(STACK_GLOBAL)
+        elif parent is not module:
+            self.save_reduce(getattr, (parent, lastname))
         elif self.proto >= 3:
             write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
                   bytes(name, "utf-8") + b'\n')
@@ -1373,8 +1374,10 @@
             elif module in _compat_pickle.IMPORT_MAPPING:
                 module = _compat_pickle.IMPORT_MAPPING[module]
         __import__(module, level=0)
-        return _getattribute(sys.modules[module], name,
-                             allow_qualname=self.proto >= 4)
+        if self.proto >= 4:
+            return _getattribute(sys.modules[module], name)[0]
+        else:
+            return getattr(sys.modules[module], name)
 
     def load_reduce(self):
         stack = self.stack
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -1602,13 +1602,24 @@
                 class B:
                     class C:
                         pass
-
-        for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
             for obj in [Nested.A, Nested.A.B, Nested.A.B.C]:
                 with self.subTest(proto=proto, obj=obj):
                     unpickled = self.loads(self.dumps(obj, proto))
                     self.assertIs(obj, unpickled)
 
+    def test_recursive_nested_names(self):
+        global Recursive
+        class Recursive:
+            pass
+        Recursive.mod = sys.modules[Recursive.__module__]
+        Recursive.__qualname__ = 'Recursive.mod.Recursive'
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            with self.subTest(proto=proto):
+                unpickled = self.loads(self.dumps(Recursive, proto))
+                self.assertIs(unpickled, Recursive)
+        del Recursive.mod # break reference loop
+
     def test_py_methods(self):
         global PyMethodsTest
         class PyMethodsTest:
@@ -1647,7 +1658,7 @@
             (PyMethodsTest.biscuits, PyMethodsTest),
             (PyMethodsTest.Nested.pie, PyMethodsTest.Nested)
         )
-        for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
             for method in py_methods:
                 with self.subTest(proto=proto, method=method):
                     unpickled = self.loads(self.dumps(method, proto))
@@ -1687,7 +1698,7 @@
             (Subclass.Nested("sweet").count, ("e",)),
             (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")),
         )
-        for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
             for method, args in c_methods:
                 with self.subTest(proto=proto, method=method):
                     unpickled = self.loads(self.dumps(method, proto))
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -13,6 +13,9 @@
 Library
 -------
 
+- Issue #23611: Serializing more "lookupable" objects (such as unbound methods
+  or nested classes) now are supported with pickle protocols < 4.
+
 - Issue #13583: sqlite3.Row now supports slice indexing.
 
 - Issue #18473: Fixed 2to3 and 3to2 compatible pickle mappings.  Fixed
diff --git a/Modules/_pickle.c b/Modules/_pickle.c
--- a/Modules/_pickle.c
+++ b/Modules/_pickle.c
@@ -152,6 +152,8 @@
 
     /* codecs.encode, used for saving bytes in older protocols */
     PyObject *codecs_encode;
+    /* builtins.getattr, used for saving nested names with protocol < 4 */
+    PyObject *getattr;
 } PickleState;
 
 /* Forward declaration of the _pickle module definition. */
@@ -188,16 +190,26 @@
     Py_CLEAR(st->name_mapping_3to2);
     Py_CLEAR(st->import_mapping_3to2);
     Py_CLEAR(st->codecs_encode);
+    Py_CLEAR(st->getattr);
 }
 
 /* Initialize the given pickle module state. */
 static int
 _Pickle_InitState(PickleState *st)
 {
+    PyObject *builtins;
     PyObject *copyreg = NULL;
     PyObject *compat_pickle = NULL;
     PyObject *codecs = NULL;
 
+    builtins = PyEval_GetBuiltins();
+    if (builtins == NULL)
+        goto error;
+    st->getattr = PyDict_GetItemString(builtins, "getattr");
+    if (st->getattr == NULL)
+        goto error;
+    Py_INCREF(st->getattr);
+
     copyreg = PyImport_ImportModule("copyreg");
     if (!copyreg)
         goto error;
@@ -1535,7 +1547,7 @@
 }
 
 static PyObject *
-get_dotted_path(PyObject *obj, PyObject *name, int allow_qualname) {
+get_dotted_path(PyObject *obj, PyObject *name) {
     _Py_static_string(PyId_dot, ".");
     _Py_static_string(PyId_locals, "<locals>");
     PyObject *dotted_path;
@@ -1546,20 +1558,6 @@
         return NULL;
     n = PyList_GET_SIZE(dotted_path);
     assert(n >= 1);
-    if (!allow_qualname && n > 1) {
-        if (obj == NULL)
-            PyErr_Format(PyExc_AttributeError,
-                         "Can't pickle qualified object %R; "
-                         "use protocols >= 4 to enable support",
-                         name);
-        else
-            PyErr_Format(PyExc_AttributeError,
-                         "Can't pickle qualified attribute %R on %R; "
-                         "use protocols >= 4 to enable support",
-                         name, obj);
-        Py_DECREF(dotted_path);
-        return NULL;
-    }
     for (i = 0; i < n; i++) {
         PyObject *subpath = PyList_GET_ITEM(dotted_path, i);
         PyObject *result = PyUnicode_RichCompare(
@@ -1582,22 +1580,28 @@
 }
 
 static PyObject *
-get_deep_attribute(PyObject *obj, PyObject *names)
+get_deep_attribute(PyObject *obj, PyObject *names, PyObject **pparent)
 {
     Py_ssize_t i, n;
+    PyObject *parent = NULL;
 
     assert(PyList_CheckExact(names));
     Py_INCREF(obj);
     n = PyList_GET_SIZE(names);
     for (i = 0; i < n; i++) {
         PyObject *name = PyList_GET_ITEM(names, i);
-        PyObject *tmp;
-        tmp = PyObject_GetAttr(obj, name);
-        Py_DECREF(obj);
-        if (tmp == NULL)
+        Py_XDECREF(parent);
+        parent = obj;
+        obj = PyObject_GetAttr(parent, name);
+        if (obj == NULL) {
+            Py_DECREF(parent);
             return NULL;
-        obj = tmp;
-    }
+        }
+    }
+    if (pparent != NULL)
+        *pparent = parent;
+    else
+        Py_XDECREF(parent);
     return obj;
 }
 
@@ -1617,18 +1621,22 @@
 {
     PyObject *dotted_path, *attr;
 
-    dotted_path = get_dotted_path(obj, name, allow_qualname);
-    if (dotted_path == NULL)
-        return NULL;
-    attr = get_deep_attribute(obj, dotted_path);
-    Py_DECREF(dotted_path);
+    if (allow_qualname) {
+        dotted_path = get_dotted_path(obj, name);
+        if (dotted_path == NULL)
+            return NULL;
+        attr = get_deep_attribute(obj, dotted_path, NULL);
+        Py_DECREF(dotted_path);
+    }
+    else
+        attr = PyObject_GetAttr(obj, name);
     if (attr == NULL)
         reformat_attribute_error(obj, name);
     return attr;
 }
 
 static PyObject *
-whichmodule(PyObject *global, PyObject *global_name, int allow_qualname)
+whichmodule(PyObject *global, PyObject *dotted_path)
 {
     PyObject *module_name;
     PyObject *modules_dict;
@@ -1637,7 +1645,6 @@
     _Py_IDENTIFIER(__module__);
     _Py_IDENTIFIER(modules);
     _Py_IDENTIFIER(__main__);
-    PyObject *dotted_path;
 
     module_name = _PyObject_GetAttrId(global, &PyId___module__);
 
@@ -1663,10 +1670,6 @@
         return NULL;
     }
 
-    dotted_path = get_dotted_path(NULL, global_name, allow_qualname);
-    if (dotted_path == NULL)
-        return NULL;
-
     i = 0;
     while (PyDict_Next(modules_dict, &i, &module_name, &module)) {
         PyObject *candidate;
@@ -1676,19 +1679,16 @@
         if (module == Py_None)
             continue;
 
-        candidate = get_deep_attribute(module, dotted_path);
+        candidate = get_deep_attribute(module, dotted_path, NULL);
         if (candidate == NULL) {
-            if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
-                Py_DECREF(dotted_path);
+            if (!PyErr_ExceptionMatches(PyExc_AttributeError))
                 return NULL;
-            }
             PyErr_Clear();
             continue;
         }
 
         if (candidate == global) {
             Py_INCREF(module_name);
-            Py_DECREF(dotted_path);
             Py_DECREF(candidate);
             return module_name;
         }
@@ -1698,7 +1698,6 @@
     /* If no module is found, use __main__. */
     module_name = _PyUnicode_FromId(&PyId___main__);
     Py_INCREF(module_name);
-    Py_DECREF(dotted_path);
     return module_name;
 }
 
@@ -3105,6 +3104,9 @@
     PyObject *global_name = NULL;
     PyObject *module_name = NULL;
     PyObject *module = NULL;
+    PyObject *parent = NULL;
+    PyObject *dotted_path = NULL;
+    PyObject *lastname = NULL;
     PyObject *cls;
     PickleState *st = _Pickle_GetGlobalState();
     int status = 0;
@@ -3118,13 +3120,11 @@
         global_name = name;
     }
     else {
-        if (self->proto >= 4) {
-            global_name = _PyObject_GetAttrId(obj, &PyId___qualname__);
-            if (global_name == NULL) {
-                if (!PyErr_ExceptionMatches(PyExc_AttributeError))
-                    goto error;
-                PyErr_Clear();
-            }
+        global_name = _PyObject_GetAttrId(obj, &PyId___qualname__);
+        if (global_name == NULL) {
+            if (!PyErr_ExceptionMatches(PyExc_AttributeError))
+                goto error;
+            PyErr_Clear();
         }
         if (global_name == NULL) {
             global_name = _PyObject_GetAttrId(obj, &PyId___name__);
@@ -3133,7 +3133,10 @@
         }
     }
 
-    module_name = whichmodule(obj, global_name, self->proto >= 4);
+    dotted_path = get_dotted_path(module, global_name);
+    if (dotted_path == NULL)
+        goto error;
+    module_name = whichmodule(obj, dotted_path);
     if (module_name == NULL)
         goto error;
 
@@ -3152,7 +3155,10 @@
                      obj, module_name);
         goto error;
     }
-    cls = getattribute(module, global_name, self->proto >= 4);
+    lastname = PyList_GET_ITEM(dotted_path, PyList_GET_SIZE(dotted_path)-1);
+    Py_INCREF(lastname);
+    cls = get_deep_attribute(module, dotted_path, &parent);
+    Py_CLEAR(dotted_path);
     if (cls == NULL) {
         PyErr_Format(st->PicklingError,
                      "Can't pickle %R: attribute lookup %S on %S failed",
@@ -3239,6 +3245,11 @@
     }
     else {
   gen_global:
+        if (parent == module) {
+            Py_INCREF(lastname);
+            Py_DECREF(global_name);
+            global_name = lastname;
+        }
         if (self->proto >= 4) {
             const char stack_global_op = STACK_GLOBAL;
 
@@ -3250,6 +3261,15 @@
             if (_Pickler_Write(self, &stack_global_op, 1) < 0)
                 goto error;
         }
+        else if (parent != module) {
+            PickleState *st = _Pickle_GetGlobalState();
+            PyObject *reduce_value = Py_BuildValue("(O(OO))",
+                                        st->getattr, parent, lastname);
+            status = save_reduce(self, reduce_value, NULL);
+            Py_DECREF(reduce_value);
+            if (status < 0)
+                goto error;
+        }
         else {
             /* Generate a normal global opcode if we are using a pickle
                protocol < 4, or if the object is not registered in the
@@ -3328,6 +3348,9 @@
     Py_XDECREF(module_name);
     Py_XDECREF(global_name);
     Py_XDECREF(module);
+    Py_XDECREF(parent);
+    Py_XDECREF(dotted_path);
+    Py_XDECREF(lastname);
 
     return status;
 }
@@ -7150,6 +7173,7 @@
     Py_VISIT(st->name_mapping_3to2);
     Py_VISIT(st->import_mapping_3to2);
     Py_VISIT(st->codecs_encode);
+    Py_VISIT(st->getattr);
     return 0;
 }
 

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list