Author: marky1991 Branch: py3.3 Changeset: r81705:cdacad9a627a Date: 2015-12-31 21:06 -0500 http://bitbucket.org/pypy/pypy/changeset/cdacad9a627a/
Log: Fix pickling stuff. Also, when (un)pickling functions, pass qualname correctly. diff --git a/lib-python/3/pickle.py b/lib-python/3/pickle.py --- a/lib-python/3/pickle.py +++ b/lib-python/3/pickle.py @@ -23,7 +23,7 @@ """ -from types import FunctionType, BuiltinFunctionType +from types import FunctionType, BuiltinFunctionType, ModuleType from copyreg import dispatch_table from copyreg import _extension_registry, _inverted_registry, _extension_cache import marshal @@ -295,12 +295,10 @@ #Unbound methods no longer exist, but pyframes rely on being #able to pickle unbound methods #This is a pypy-specific requirement, thus the change in the stdlib - is_unbound_method = t == FunctionType and "." in obj.__qualname__ - if not is_unbound_method: - f = self.dispatch.get(t) - if f: - f(self, obj) # Call unbound method with explicit self - return + f = self.dispatch.get(t) + if f: + f(self, obj) # Call unbound method with explicit self + return # Check private dispatch table if any, or else copyreg.dispatch_table reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) @@ -627,6 +625,9 @@ # else tmp is empty, and we're done def save_dict(self, obj): + modict_saver = self._pickle_maybe_moduledict(obj) + if modict_saver is not None: + return self.save_reduce(*modict_saver) write = self.write if self.bin: @@ -677,6 +678,102 @@ write(SETITEM) # else tmp is empty, and we're done + def _pickle_maybe_moduledict(self, obj): + # save module dictionary as "getattr(module, '__dict__')" + try: + name = obj['__name__'] + if type(name) is not str: + return None + themodule = sys.modules[name] + if type(themodule) is not ModuleType: + return None + if themodule.__dict__ is not obj: + return None + except (AttributeError, KeyError, TypeError): + return None + return getattr, (themodule, '__dict__') + + def save_function(self, obj): + try: + return self.save_global(obj) + except PicklingError: + pass + # Check copy_reg.dispatch_table + reduce = dispatch_table.get(type(obj)) + if reduce: + rv = reduce(obj) + else: + # Check for a __reduce_ex__ method, fall back to __reduce__ + reduce = getattr(obj, "__reduce_ex__", None) + if reduce: + rv = reduce(self.proto) + else: + reduce = getattr(obj, "__reduce__", None) + if reduce: + rv = reduce() + else: + raise e + return self.save_reduce(obj=obj, *rv) + dispatch[FunctionType] = save_function + + def save_global(self, obj, name=None, pack=struct.pack): + write = self.write + memo = self.memo + + #This logic is stolen from the protocol 4 logic from 3.5 + #We need it unconditionally as pypy itself relies on it. + if name is None: + name = getattr(obj, '__qualname__', None) + if name is None: + name = obj.__name__ + + module_name = whichmodule(obj, name, allow_qualname=True) + try: + __import__(module_name, level=0) + module = sys.modules[module_name] + obj2 = _getattribute(module, name, allow_qualname=True) + except (ImportError, KeyError, AttributeError): + raise PicklingError( + "Can't pickle %r: it's not found as %s.%s" % + (obj, module_name, name)) + else: + if obj2 is not obj: + raise PicklingError( + "Can't pickle %r: it's not the same object as %s.%s" % + (obj, module_name, name)) + + if self.proto >= 2: + code = _extension_registry.get((module_name, name)) + if code: + assert code > 0 + if code <= 0xff: + write(EXT1 + bytes([code])) + elif code <= 0xffff: + write(EXT2 + bytes([code&0xff, code>>8])) + else: + write(EXT4 + pack("<i", code)) + return + # Non-ASCII identifiers are supported only with protocols >= 3. + if self.proto >= 3: + write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + + bytes(name, "utf-8") + b'\n') + else: + if self.fix_imports: + r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING + r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING + if (module_name, name) in r_name_mapping: + module_name, name = r_name_mapping[(module_name, name)] + if module_name in r_import_mapping: + module_name = r_import_mapping[module_name] + try: + write(GLOBAL + bytes(module_name, "ascii") + b'\n' + + bytes(name, "ascii") + b'\n') + except UnicodeEncodeError: + raise PicklingError( + "can't pickle global identifier '%s.%s' using " + "pickle protocol %i" % (module, name, self.proto)) + + self.memoize(obj) def save_global(self, obj, name=None, pack=struct.pack): write = self.write memo = self.memo @@ -742,7 +839,6 @@ return self.save_reduce(type, (...,), obj=obj) return self.save_global(obj) - dispatch[FunctionType] = save_global dispatch[BuiltinFunctionType] = save_global dispatch[type] = save_type @@ -764,13 +860,30 @@ # aha, this is the first one :-) memo[id(memo)]=[x] +def _getattribute(obj, name, allow_qualname=False): + dotted_path = name.split(".") + if not allow_qualname and len(dotted_path) > 1: + raise AttributeError("Can't get qualified attribute {!r} on {!r}; " + + "use protocols >= 4 to enable support" + .format(name, obj)) + for subpath in dotted_path: + if subpath == '<locals>': + raise AttributeError("Can't get local attribute {!r} on {!r}" + .format(name, obj)) + try: + obj = getattr(obj, subpath) + except AttributeError: + raise AttributeError("Can't get attribute {!r} on {!r}" + .format(name, obj)) + return obj + # A cache for whichmodule(), mapping a function object to the name of # the module in which the function was found. classmap = {} # called classmap for backwards compatibility -def whichmodule(func, funcname): +def whichmodule(obj, name, allow_qualname=False): """Figure out the module in which a function occurs. Search sys.modules for the module. @@ -779,22 +892,23 @@ If the function cannot be found, return "__main__". """ # Python functions should always get an __module__ from their globals. - mod = getattr(func, "__module__", None) + mod = getattr(obj, "__module__", None) if mod is not None: return mod - if func in classmap: - return classmap[func] + if obj in classmap: + return classmap[obj] - for name, module in list(sys.modules.items()): - if module is None: + for module_name, module in list(sys.modules.items()): + if module_name == '__main__' or module is None: continue # skip dummy package entries - if name != '__main__' and getattr(module, funcname, None) is func: - break - else: - name = '__main__' - classmap[func] = name - return name - + try: + if _getattribute(module, name, allow_qualname) is obj: + classmap[obj] = module_name + return module_name + except AttributeError: + pass + classmap[obj] = '__main__' + return '__main__' # Unpickling machinery diff --git a/pypy/interpreter/function.py b/pypy/interpreter/function.py --- a/pypy/interpreter/function.py +++ b/pypy/interpreter/function.py @@ -306,6 +306,7 @@ tup_base = [] tup_state = [ w(self.name), + w(self.qualname), w_doc, w(self.code), w_func_globals, @@ -319,8 +320,8 @@ def descr_function__setstate__(self, space, w_args): args_w = space.unpackiterable(w_args) try: - (w_name, w_doc, w_code, w_func_globals, w_closure, w_defs, - w_func_dict, w_module) = args_w + (w_name, w_qualname, w_doc, w_code, w_func_globals, w_closure, + w_defs, w_func_dict, w_module) = args_w except ValueError: # wrong args raise OperationError(space.w_ValueError, @@ -328,6 +329,7 @@ self.space = space self.name = space.str_w(w_name) + self.qualname = space.str_w(w_qualname) self.code = space.interp_w(Code, w_code) if not space.is_w(w_closure, space.w_None): from pypy.interpreter.nestedscope import Cell diff --git a/pypy/interpreter/test/test_zzpickle_and_slow.py b/pypy/interpreter/test/test_zzpickle_and_slow.py --- a/pypy/interpreter/test/test_zzpickle_and_slow.py +++ b/pypy/interpreter/test/test_zzpickle_and_slow.py @@ -394,8 +394,10 @@ import pickle tdict = {'2':2, '3':3, '5':5} diter = iter(tdict) - next(diter) - raises(TypeError, pickle.dumps, diter) + seen = next(diter) + pckl = pickle.dumps(diter) + result = pickle.loads(pckl) + assert set(result) == (set('235') - set(seen)) def test_pickle_reversed(self): import pickle _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit