diff options
Diffstat (limited to 'Lib/pickle.py')
-rw-r--r-- | Lib/pickle.py | 71 |
1 files changed, 63 insertions, 8 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py index 863702dc9f..d98bcd5076 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -28,6 +28,7 @@ __version__ = "$Revision$" # Code version from types import * from copy_reg import dispatch_table, _reconstructor +from copy_reg import extension_registry, inverted_registry, extension_cache import marshal import sys import struct @@ -395,16 +396,28 @@ class Pickler: self.memoize(obj) if isinstance(obj, list): - write(MARK) - for x in obj: - save(x) - write(APPENDS) + n = len(obj) + if n > 1: + write(MARK) + for x in obj: + save(x) + write(APPENDS) + elif n == 1: + save(obj[0]) + write(APPEND) elif isinstance(obj, dict): - write(MARK) - for k, v in obj.iteritems(): + n = len(obj) + if n > 1: + write(MARK) + for k, v in obj.iteritems(): + save(k) + save(v) + write(SETITEMS) + elif n == 1: + k, v = obj.items()[0] save(k) save(v) - write(SETITEMS) + write(SETITEM) getstate = getattr(obj, "__getstate__", None) if getstate: @@ -420,6 +433,8 @@ class Pickler: getstate = None if not getstate: state = getattr(obj, "__dict__", None) + if not state: + state = None # If there are slots, the state becomes a tuple of two # items: the first item the regular __dict__ or None, and # the second a dict mapping slot names to slot values @@ -703,7 +718,7 @@ class Pickler: dispatch[InstanceType] = save_inst - def save_global(self, obj, name = None): + def save_global(self, obj, name=None, pack=struct.pack): write = self.write memo = self.memo @@ -729,6 +744,18 @@ class Pickler: "Can't pickle %r: it's not the same object as %s.%s" % (obj, module, name)) + if self.proto >= 2: + code = extension_registry.get((module, name)) + if code: + assert code > 0 + if code <= 0xff: + write(EXT1 + chr(code)) + elif code <= 0xffff: + write(EXT2 + chr(code&0xff) + chr(code>>8)) + else: + write(EXT4 + pack("<i", code)) + return + write(GLOBAL + module + '\n' + name + '\n') self.memoize(obj) @@ -1081,6 +1108,34 @@ class Unpickler: self.append(klass) dispatch[GLOBAL] = load_global + def load_ext1(self): + code = ord(self.read(1)) + self.get_extension(code) + dispatch[EXT1] = load_ext1 + + def load_ext2(self): + code = mloads('i' + self.read(2) + '\000\000') + self.get_extension(code) + dispatch[EXT2] = load_ext2 + + def load_ext4(self): + code = mloads('i' + self.read(4)) + self.get_extension(code) + dispatch[EXT4] = load_ext4 + + def get_extension(self, code): + nil = [] + obj = extension_cache.get(code, nil) + if obj is not nil: + self.append(obj) + return + key = inverted_registry.get(code) + if not key: + raise ValueError("unregistered extension code %d" % code) + obj = self.find_class(*key) + extension_cache[code] = obj + self.append(obj) + def find_class(self, module, name): # Subclasses may override this __import__(module) |