[pypy-commit] pypy py3.3: Fix pickling stuff. Also, when (un)pickling functions, pass qualname correctly.

marky1991 pypy.commits at gmail.com
Tue Jan 12 22:24:01 EST 2016


Author: marky1991
Branch: py3.3
Changeset: r81705:cdacad9a627a
Date: 2015-12-31 21:06 -0500
http://bitbucket.org/pypy/pypy/changeset/cdacad9a627a/

Log:	Fix pickling stuff. Also, when (un)pickling functions, pass qualname
	correctly.

diff --git a/lib-python/3/pickle.py b/lib-python/3/pickle.py
--- a/lib-python/3/pickle.py
+++ b/lib-python/3/pickle.py
@@ -23,7 +23,7 @@
 
 """
 
-from types import FunctionType, BuiltinFunctionType
+from types import FunctionType, BuiltinFunctionType, ModuleType
 from copyreg import dispatch_table
 from copyreg import _extension_registry, _inverted_registry, _extension_cache
 import marshal
@@ -295,12 +295,10 @@
         #Unbound methods no longer exist, but pyframes rely on being
         #able to pickle unbound methods
         #This is a pypy-specific requirement, thus the change in the stdlib
-        is_unbound_method = t == FunctionType and "." in obj.__qualname__
-        if not is_unbound_method:
-            f = self.dispatch.get(t)
-            if f:
-                f(self, obj) # Call unbound method with explicit self
-                return
+        f = self.dispatch.get(t)
+        if f:
+            f(self, obj) # Call unbound method with explicit self
+            return
 
         # Check private dispatch table if any, or else copyreg.dispatch_table
         reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
@@ -627,6 +625,9 @@
             # else tmp is empty, and we're done
 
     def save_dict(self, obj):
+        modict_saver = self._pickle_maybe_moduledict(obj)
+        if modict_saver is not None:
+            return self.save_reduce(*modict_saver)
         write = self.write
 
         if self.bin:
@@ -677,6 +678,102 @@
                 write(SETITEM)
             # else tmp is empty, and we're done
 
+    def _pickle_maybe_moduledict(self, obj):
+        # save module dictionary as "getattr(module, '__dict__')"
+        try:
+            name = obj['__name__']
+            if type(name) is not str:
+                return None
+            themodule = sys.modules[name]
+            if type(themodule) is not ModuleType:
+                return None
+            if themodule.__dict__ is not obj:
+                return None
+        except (AttributeError, KeyError, TypeError):
+            return None
+        return getattr, (themodule, '__dict__')
+
+    def save_function(self, obj):
+        try:
+            return self.save_global(obj)
+        except PicklingError:
+            pass
+        # Check copy_reg.dispatch_table
+        reduce = dispatch_table.get(type(obj))
+        if reduce:
+            rv = reduce(obj)
+        else:
+            # Check for a __reduce_ex__ method, fall back to __reduce__
+            reduce = getattr(obj, "__reduce_ex__", None)
+            if reduce:
+                rv = reduce(self.proto)
+            else:
+                reduce = getattr(obj, "__reduce__", None)
+                if reduce:
+                    rv = reduce()
+                else:
+                    raise e
+        return self.save_reduce(obj=obj, *rv)
+    dispatch[FunctionType] = save_function
+
+    def save_global(self, obj, name=None, pack=struct.pack):
+        write = self.write
+        memo = self.memo
+
+        #This logic is stolen from the protocol 4 logic from 3.5
+        #We need it unconditionally as pypy itself relies on it.
+        if name is None:
+            name = getattr(obj, '__qualname__', None)
+        if name is None:
+            name = obj.__name__
+
+        module_name = whichmodule(obj, name, allow_qualname=True)
+        try:
+            __import__(module_name, level=0)
+            module = sys.modules[module_name]
+            obj2 = _getattribute(module, name, allow_qualname=True)
+        except (ImportError, KeyError, AttributeError):
+            raise PicklingError(
+                "Can't pickle %r: it's not found as %s.%s" %
+                (obj, module_name, name))
+        else:
+            if obj2 is not obj:
+                raise PicklingError(
+                    "Can't pickle %r: it's not the same object as %s.%s" %
+                    (obj, module_name, name))
+
+        if self.proto >= 2:
+            code = _extension_registry.get((module_name, name))
+            if code:
+                assert code > 0
+                if code <= 0xff:
+                    write(EXT1 + bytes([code]))
+                elif code <= 0xffff:
+                    write(EXT2 + bytes([code&0xff, code>>8]))
+                else:
+                    write(EXT4 + pack("<i", code))
+                return
+        # Non-ASCII identifiers are supported only with protocols >= 3.
+        if self.proto >= 3:
+            write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
+                  bytes(name, "utf-8") + b'\n')
+        else:
+            if self.fix_imports:
+                r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
+                r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
+                if (module_name, name) in r_name_mapping:
+                    module_name, name = r_name_mapping[(module_name, name)]
+                if module_name in r_import_mapping:
+                    module_name = r_import_mapping[module_name]
+            try:
+                write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
+                      bytes(name, "ascii") + b'\n')
+            except UnicodeEncodeError:
+                raise PicklingError(
+                    "can't pickle global identifier '%s.%s' using "
+                    "pickle protocol %i" % (module, name, self.proto))
+
+        self.memoize(obj)
     def save_global(self, obj, name=None, pack=struct.pack):
         write = self.write
         memo = self.memo
@@ -742,7 +839,6 @@
             return self.save_reduce(type, (...,), obj=obj)
         return self.save_global(obj)
 
-    dispatch[FunctionType] = save_global
     dispatch[BuiltinFunctionType] = save_global
     dispatch[type] = save_type
 
@@ -764,13 +860,30 @@
         # aha, this is the first one :-)
         memo[id(memo)]=[x]
 
+def _getattribute(obj, name, allow_qualname=False):
+    dotted_path = name.split(".")
+    if not allow_qualname and len(dotted_path) > 1:
+        raise AttributeError("Can't get qualified attribute {!r} on {!r}; " +
+                             "use protocols >= 4 to enable support"
+                             .format(name, obj))
+    for subpath in dotted_path:
+        if subpath == '<locals>':
+            raise AttributeError("Can't get local attribute {!r} on {!r}"
+                                 .format(name, obj))
+        try:
+            obj = getattr(obj, subpath)
+        except AttributeError:
+            raise AttributeError("Can't get attribute {!r} on {!r}"
+                                 .format(name, obj))
+    return obj
+
 
 # A cache for whichmodule(), mapping a function object to the name of
 # the module in which the function was found.
 
 classmap = {} # called classmap for backwards compatibility
 
-def whichmodule(func, funcname):
+def whichmodule(obj, name, allow_qualname=False):
     """Figure out the module in which a function occurs.
 
     Search sys.modules for the module.
@@ -779,22 +892,23 @@
     If the function cannot be found, return "__main__".
     """
     # Python functions should always get an __module__ from their globals.
-    mod = getattr(func, "__module__", None)
+    mod = getattr(obj, "__module__", None)
     if mod is not None:
         return mod
-    if func in classmap:
-        return classmap[func]
+    if obj in classmap:
+        return classmap[obj]
 
-    for name, module in list(sys.modules.items()):
-        if module is None:
+    for module_name, module in list(sys.modules.items()):
+        if module_name == '__main__' or module is None:
             continue # skip dummy package entries
-        if name != '__main__' and getattr(module, funcname, None) is func:
-            break
-    else:
-        name = '__main__'
-    classmap[func] = name
-    return name
-
+        try:
+            if _getattribute(module, name, allow_qualname) is obj:
+                classmap[obj] = module_name
+                return module_name
+        except AttributeError:
+            pass
+    classmap[obj] = '__main__'
+    return '__main__'
 
 # Unpickling machinery
 
diff --git a/pypy/interpreter/function.py b/pypy/interpreter/function.py
--- a/pypy/interpreter/function.py
+++ b/pypy/interpreter/function.py
@@ -306,6 +306,7 @@
         tup_base = []
         tup_state = [
             w(self.name),
+            w(self.qualname),
             w_doc,
             w(self.code),
             w_func_globals,
@@ -319,8 +320,8 @@
     def descr_function__setstate__(self, space, w_args):
         args_w = space.unpackiterable(w_args)
         try:
-            (w_name, w_doc, w_code, w_func_globals, w_closure, w_defs,
-             w_func_dict, w_module) = args_w
+            (w_name, w_qualname, w_doc, w_code, w_func_globals, w_closure,
+             w_defs, w_func_dict, w_module) = args_w
         except ValueError:
             # wrong args
             raise OperationError(space.w_ValueError,
@@ -328,6 +329,7 @@
 
         self.space = space
         self.name = space.str_w(w_name)
+        self.qualname = space.str_w(w_qualname)
         self.code = space.interp_w(Code, w_code)
         if not space.is_w(w_closure, space.w_None):
             from pypy.interpreter.nestedscope import Cell
diff --git a/pypy/interpreter/test/test_zzpickle_and_slow.py b/pypy/interpreter/test/test_zzpickle_and_slow.py
--- a/pypy/interpreter/test/test_zzpickle_and_slow.py
+++ b/pypy/interpreter/test/test_zzpickle_and_slow.py
@@ -394,8 +394,10 @@
         import pickle
         tdict = {'2':2, '3':3, '5':5}
         diter  = iter(tdict)
-        next(diter)
-        raises(TypeError, pickle.dumps, diter)
+        seen = next(diter)
+        pckl = pickle.dumps(diter)
+        result = pickle.loads(pckl)
+        assert set(result) == (set('235') - set(seen))
 
     def test_pickle_reversed(self):
         import pickle


More information about the pypy-commit mailing list