[Python-checkins] bpo-35900: Enable custom reduction callback registration in _pickle (GH-12499)

Antoine Pitrou webhook-mailer at python.org
Wed May 8 17:08:32 EDT 2019


https://github.com/python/cpython/commit/289f1f80ee87a4baf4567a86b3425fb3bf73291d
commit: 289f1f80ee87a4baf4567a86b3425fb3bf73291d
branch: master
author: Pierre Glaser <pierreglaser at msn.com>
committer: Antoine Pitrou <antoine at python.org>
date: 2019-05-08T23:08:25+02:00
summary:

bpo-35900: Enable custom reduction callback registration in _pickle (GH-12499)

Enable custom reduction callback registration for functions and classes in
_pickle.c, using the new Pickler's attribute ``reducer_override``.

files:
A Misc/NEWS.d/next/Library/2019-03-22-22-40-00.bpo-35900.oiee0o.rst
M Doc/library/pickle.rst
M Lib/pickle.py
M Lib/test/pickletester.py
M Lib/test/test_pickle.py
M Modules/_pickle.c

diff --git a/Doc/library/pickle.rst b/Doc/library/pickle.rst
index 3d89536d7d11..55005f009431 100644
--- a/Doc/library/pickle.rst
+++ b/Doc/library/pickle.rst
@@ -356,6 +356,18 @@ The :mod:`pickle` module exports two classes, :class:`Pickler` and
 
       .. versionadded:: 3.3
 
+   .. method:: reducer_override(self, obj)
+
+      Special reducer that can be defined in :class:`Pickler` subclasses. This
+      method has priority over any reducer in the :attr:`dispatch_table`.  It
+      should conform to the same interface as a :meth:`__reduce__` method, and
+      can optionally return ``NotImplemented`` to fallback on
+      :attr:`dispatch_table`-registered reducers to pickle ``obj``.
+
+      For a detailed example, see :ref:`reducer_override`.
+
+      .. versionadded:: 3.8
+
    .. attribute:: fast
 
       Deprecated. Enable fast mode if set to a true value.  The fast mode
@@ -791,6 +803,65 @@ A sample usage might be something like this::
    >>> new_reader.readline()
    '3: Goodbye!'
 
+.. _reducer_override:
+
+Custom Reduction for Types, Functions, and Other Objects
+--------------------------------------------------------
+
+.. versionadded:: 3.8
+
+Sometimes, :attr:`~Pickler.dispatch_table` may not be flexible enough.
+In particular we may want to customize pickling based on another criterion
+than the object's type, or we may want to customize the pickling of
+functions and classes.
+
+For those cases, it is possible to subclass from the :class:`Pickler` class and
+implement a :meth:`~Pickler.reducer_override` method. This method can return an
+arbitrary reduction tuple (see :meth:`__reduce__`). It can alternatively return
+``NotImplemented`` to fallback to the traditional behavior.
+
+If both the :attr:`~Pickler.dispatch_table` and
+:meth:`~Pickler.reducer_override` are defined, then
+:meth:`~Pickler.reducer_override` method takes priority.
+
+.. Note::
+   For performance reasons, :meth:`~Pickler.reducer_override` may not be
+   called for the following objects: ``None``, ``True``, ``False``, and
+   exact instances of :class:`int`, :class:`float`, :class:`bytes`,
+   :class:`str`, :class:`dict`, :class:`set`, :class:`frozenset`, :class:`list`
+   and :class:`tuple`.
+
+Here is a simple example where we allow pickling and reconstructing
+a given class::
+
+   import io
+   import pickle
+
+   class MyClass:
+       my_attribute = 1
+
+   class MyPickler(pickle.Pickler):
+       def reducer_override(self, obj):
+           """Custom reducer for MyClass."""
+           if getattr(obj, "__name__", None) == "MyClass":
+               return type, (obj.__name__, obj.__bases__,
+                             {'my_attribute': obj.my_attribute})
+           else:
+               # For any other object, fallback to usual reduction
+               return NotImplemented
+
+   f = io.BytesIO()
+   p = MyPickler(f)
+   p.dump(MyClass)
+
+   del MyClass
+
+   unpickled_class = pickle.loads(f.getvalue())
+
+   assert isinstance(unpickled_class, type)
+   assert unpickled_class.__name__ == "MyClass"
+   assert unpickled_class.my_attribute == 1
+
 
 .. _pickle-restrict:
 
diff --git a/Lib/pickle.py b/Lib/pickle.py
index 47f0d280efc9..595beda4765a 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -497,34 +497,42 @@ def save(self, obj, save_persistent_id=True):
             self.write(self.get(x[0]))
             return
 
-        # Check the type dispatch table
-        t = type(obj)
-        f = self.dispatch.get(t)
-        if f is not None:
-            f(self, obj) # Call unbound method with explicit self
-            return
-
-        # Check private dispatch table if any, or else copyreg.dispatch_table
-        reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
+        rv = NotImplemented
+        reduce = getattr(self, "reducer_override", None)
         if reduce is not None:
             rv = reduce(obj)
-        else:
-            # Check for a class with a custom metaclass; treat as regular class
-            if issubclass(t, type):
-                self.save_global(obj)
+
+        if rv is NotImplemented:
+            # Check the type dispatch table
+            t = type(obj)
+            f = self.dispatch.get(t)
+            if f is not None:
+                f(self, obj)  # Call unbound method with explicit self
                 return
 
-            # Check for a __reduce_ex__ method, fall back to __reduce__
-            reduce = getattr(obj, "__reduce_ex__", None)
+            # Check private dispatch table if any, or else
+            # copyreg.dispatch_table
+            reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
             if reduce is not None:
-                rv = reduce(self.proto)
+                rv = reduce(obj)
             else:
-                reduce = getattr(obj, "__reduce__", None)
+                # Check for a class with a custom metaclass; treat as regular
+                # class
+                if issubclass(t, type):
+                    self.save_global(obj)
+                    return
+
+                # Check for a __reduce_ex__ method, fall back to __reduce__
+                reduce = getattr(obj, "__reduce_ex__", None)
                 if reduce is not None:
-                    rv = reduce()
+                    rv = reduce(self.proto)
                 else:
-                    raise PicklingError("Can't pickle %r object: %r" %
-                                        (t.__name__, obj))
+                    reduce = getattr(obj, "__reduce__", None)
+                    if reduce is not None:
+                        rv = reduce()
+                    else:
+                        raise PicklingError("Can't pickle %r object: %r" %
+                                            (t.__name__, obj))
 
         # Check for string returned by reduce(), meaning "save as global"
         if isinstance(rv, str):
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index 19e8823a7310..4f8c2942df93 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -4,6 +4,7 @@
 import io
 import functools
 import os
+import math
 import pickle
 import pickletools
 import shutil
@@ -3013,6 +3014,73 @@ def setstate_bbb(obj, state):
     obj.a = "custom state_setter"
 
 
+
+class AbstractCustomPicklerClass:
+    """Pickler implementing a reducing hook using reducer_override."""
+    def reducer_override(self, obj):
+        obj_name = getattr(obj, "__name__", None)
+
+        if obj_name == 'f':
+            # asking the pickler to save f as 5
+            return int, (5, )
+
+        if obj_name == 'MyClass':
+            return str, ('some str',)
+
+        elif obj_name == 'g':
+            # in this case, the callback returns an invalid result (not a 2-5
+            # tuple or a string), the pickler should raise a proper error.
+            return False
+
+        elif obj_name == 'h':
+            # Simulate a case when the reducer fails. The error should
+            # be propagated to the original ``dump`` call.
+            raise ValueError('The reducer just failed')
+
+        return NotImplemented
+
+class AbstractHookTests(unittest.TestCase):
+    def test_pickler_hook(self):
+        # test the ability of a custom, user-defined CPickler subclass to
+        # override the default reducing routines of any type using the method
+        # reducer_override
+
+        def f():
+            pass
+
+        def g():
+            pass
+
+        def h():
+            pass
+
+        class MyClass:
+            pass
+
+        for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+            with self.subTest(proto=proto):
+                bio = io.BytesIO()
+                p = self.pickler_class(bio, proto)
+
+                p.dump([f, MyClass, math.log])
+                new_f, some_str, math_log = pickle.loads(bio.getvalue())
+
+                self.assertEqual(new_f, 5)
+                self.assertEqual(some_str, 'some str')
+                # math.log does not have its usual reducer overriden, so the
+                # custom reduction callback should silently direct the pickler
+                # to the default pickling by attribute, by returning
+                # NotImplemented
+                self.assertIs(math_log, math.log)
+
+                with self.assertRaises(pickle.PicklingError):
+                    p.dump(g)
+
+                with self.assertRaisesRegex(
+                        ValueError, 'The reducer just failed'):
+                    p.dump(h)
+
+
 class AbstractDispatchTableTests(unittest.TestCase):
 
     def test_default_dispatch_table(self):
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index b4bce7e6aceb..435c248802d3 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -11,6 +11,7 @@
 import unittest
 from test import support
 
+from test.pickletester import AbstractHookTests
 from test.pickletester import AbstractUnpickleTests
 from test.pickletester import AbstractPickleTests
 from test.pickletester import AbstractPickleModuleTests
@@ -18,6 +19,7 @@
 from test.pickletester import AbstractIdentityPersistentPicklerTests
 from test.pickletester import AbstractPicklerUnpicklerObjectTests
 from test.pickletester import AbstractDispatchTableTests
+from test.pickletester import AbstractCustomPicklerClass
 from test.pickletester import BigmemPickleTests
 
 try:
@@ -253,12 +255,23 @@ class CChainDispatchTableTests(AbstractDispatchTableTests):
         def get_dispatch_table(self):
             return collections.ChainMap({}, pickle.dispatch_table)
 
+    class PyPicklerHookTests(AbstractHookTests):
+        class CustomPyPicklerClass(pickle._Pickler,
+                                   AbstractCustomPicklerClass):
+            pass
+        pickler_class = CustomPyPicklerClass
+
+    class CPicklerHookTests(AbstractHookTests):
+        class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
+            pass
+        pickler_class = CustomCPicklerClass
+
     @support.cpython_only
     class SizeofTests(unittest.TestCase):
         check_sizeof = support.check_sizeof
 
         def test_pickler(self):
-            basesize = support.calcobjsize('6P2n3i2n3iP')
+            basesize = support.calcobjsize('6P2n3i2n3i2P')
             p = _pickle.Pickler(io.BytesIO())
             self.assertEqual(object.__sizeof__(p), basesize)
             MT_size = struct.calcsize('3nP0n')
@@ -498,7 +511,7 @@ def test_main():
     tests = [PyPickleTests, PyUnpicklerTests, PyPicklerTests,
              PyPersPicklerTests, PyIdPersPicklerTests,
              PyDispatchTableTests, PyChainDispatchTableTests,
-             CompatPickleTests]
+             CompatPickleTests, PyPicklerHookTests]
     if has_c_implementation:
         tests.extend([CPickleTests, CUnpicklerTests, CPicklerTests,
                       CPersPicklerTests, CIdPersPicklerTests,
@@ -506,6 +519,7 @@ def test_main():
                       PyPicklerUnpicklerObjectTests,
                       CPicklerUnpicklerObjectTests,
                       CDispatchTableTests, CChainDispatchTableTests,
+                      CPicklerHookTests,
                       InMemoryPickleTests, SizeofTests])
     support.run_unittest(*tests)
     support.run_doctest(pickle)
diff --git a/Misc/NEWS.d/next/Library/2019-03-22-22-40-00.bpo-35900.oiee0o.rst b/Misc/NEWS.d/next/Library/2019-03-22-22-40-00.bpo-35900.oiee0o.rst
new file mode 100644
index 000000000000..641572649694
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-03-22-22-40-00.bpo-35900.oiee0o.rst
@@ -0,0 +1,2 @@
+enable custom reduction callback registration for functions and classes in
+_pickle.c, using the new Pickler's attribute ``reducer_override``
diff --git a/Modules/_pickle.c b/Modules/_pickle.c
index 897bbe1f24e4..87f3cf7b614a 100644
--- a/Modules/_pickle.c
+++ b/Modules/_pickle.c
@@ -616,6 +616,9 @@ typedef struct PicklerObject {
     PyObject *pers_func_self;   /* borrowed reference to self if pers_func
                                    is an unbound method, NULL otherwise */
     PyObject *dispatch_table;   /* private dispatch_table, can be NULL */
+    PyObject *reducer_override; /* hook for invoking user-defined callbacks
+                                   instead of save_global when pickling
+                                   functions and classes*/
 
     PyObject *write;            /* write() method of the output stream. */
     PyObject *output_buffer;    /* Write into a local bytearray buffer before
@@ -1110,6 +1113,7 @@ _Pickler_New(void)
     self->fast_memo = NULL;
     self->max_output_len = WRITE_BUF_SIZE;
     self->output_len = 0;
+    self->reducer_override = NULL;
 
     self->memo = PyMemoTable_New();
     self->output_buffer = PyBytes_FromStringAndSize(NULL,
@@ -2220,7 +2224,7 @@ save_bytes(PicklerObject *self, PyObject *obj)
            Python 2 *and* the appropriate 'bytes' object when unpickled
            using Python 3. Again this is a hack and we don't need to do this
            with newer protocols. */
-        PyObject *reduce_value = NULL;
+        PyObject *reduce_value;
         int status;
 
         if (PyBytes_GET_SIZE(obj) == 0) {
@@ -4058,7 +4062,25 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
         status = save_tuple(self, obj);
         goto done;
     }
-    else if (type == &PyType_Type) {
+
+    /* Now, check reducer_override.  If it returns NotImplemented,
+     * fallback to save_type or save_global, and then perhaps to the
+     * regular reduction mechanism.
+     */
+    if (self->reducer_override != NULL) {
+        reduce_value = PyObject_CallFunctionObjArgs(self->reducer_override,
+                                                    obj, NULL);
+        if (reduce_value == NULL) {
+            goto error;
+        }
+        if (reduce_value != Py_NotImplemented) {
+            goto reduce;
+        }
+        Py_DECREF(reduce_value);
+        reduce_value = NULL;
+    }
+
+    if (type == &PyType_Type) {
         status = save_type(self, obj);
         goto done;
     }
@@ -4149,6 +4171,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
     if (reduce_value == NULL)
         goto error;
 
+  reduce:
     if (PyUnicode_Check(reduce_value)) {
         status = save_global(self, obj, reduce_value);
         goto done;
@@ -4180,6 +4203,20 @@ static int
 dump(PicklerObject *self, PyObject *obj)
 {
     const char stop_op = STOP;
+    PyObject *tmp;
+    _Py_IDENTIFIER(reducer_override);
+
+    if (_PyObject_LookupAttrId((PyObject *)self, &PyId_reducer_override,
+                               &tmp) < 0) {
+        return -1;
+    }
+    /* Cache the reducer_override method, if it exists. */
+    if (tmp != NULL) {
+        Py_XSETREF(self->reducer_override, tmp);
+    }
+    else {
+        Py_CLEAR(self->reducer_override);
+    }
 
     if (self->proto >= 2) {
         char header[2];
@@ -4304,6 +4341,7 @@ Pickler_dealloc(PicklerObject *self)
     Py_XDECREF(self->pers_func);
     Py_XDECREF(self->dispatch_table);
     Py_XDECREF(self->fast_memo);
+    Py_XDECREF(self->reducer_override);
 
     PyMemoTable_Del(self->memo);
 
@@ -4317,6 +4355,7 @@ Pickler_traverse(PicklerObject *self, visitproc visit, void *arg)
     Py_VISIT(self->pers_func);
     Py_VISIT(self->dispatch_table);
     Py_VISIT(self->fast_memo);
+    Py_VISIT(self->reducer_override);
     return 0;
 }
 
@@ -4328,6 +4367,7 @@ Pickler_clear(PicklerObject *self)
     Py_CLEAR(self->pers_func);
     Py_CLEAR(self->dispatch_table);
     Py_CLEAR(self->fast_memo);
+    Py_CLEAR(self->reducer_override);
 
     if (self->memo != NULL) {
         PyMemoTable *memo = self->memo;



More information about the Python-checkins mailing list