| 1 | """Helper to provide extensibility for pickle/cPickle.
 | 
| 2 | 
 | 
| 3 | This is only useful to add pickle support for extension types defined in
 | 
| 4 | C, not for instances of user-defined classes.
 | 
| 5 | """
 | 
| 6 | 
 | 
| 7 | from types import ClassType as _ClassType
 | 
| 8 | 
 | 
| 9 | __all__ = ["pickle", "constructor",
 | 
| 10 |            "add_extension", "remove_extension", "clear_extension_cache"]
 | 
| 11 | 
 | 
| 12 | dispatch_table = {}
 | 
| 13 | 
 | 
| 14 | def pickle(ob_type, pickle_function, constructor_ob=None):
 | 
| 15 |     if type(ob_type) is _ClassType:
 | 
| 16 |         raise TypeError("copy_reg is not intended for use with classes")
 | 
| 17 | 
 | 
| 18 |     if not hasattr(pickle_function, '__call__'):
 | 
| 19 |         raise TypeError("reduction functions must be callable")
 | 
| 20 |     dispatch_table[ob_type] = pickle_function
 | 
| 21 | 
 | 
| 22 |     # The constructor_ob function is a vestige of safe for unpickling.
 | 
| 23 |     # There is no reason for the caller to pass it anymore.
 | 
| 24 |     if constructor_ob is not None:
 | 
| 25 |         constructor(constructor_ob)
 | 
| 26 | 
 | 
| 27 | def constructor(object):
 | 
| 28 |     if not hasattr(object, '__call__'):
 | 
| 29 |         raise TypeError("constructors must be callable")
 | 
| 30 | 
 | 
| 31 | # Example: provide pickling support for complex numbers.
 | 
| 32 | 
 | 
| 33 | try:
 | 
| 34 |     complex
 | 
| 35 | except NameError:
 | 
| 36 |     pass
 | 
| 37 | else:
 | 
| 38 | 
 | 
| 39 |     def pickle_complex(c):
 | 
| 40 |         return complex, (c.real, c.imag)
 | 
| 41 | 
 | 
| 42 |     pickle(complex, pickle_complex, complex)
 | 
| 43 | 
 | 
| 44 | # Support for pickling new-style objects
 | 
| 45 | 
 | 
| 46 | def _reconstructor(cls, base, state):
 | 
| 47 |     if base is object:
 | 
| 48 |         obj = object.__new__(cls)
 | 
| 49 |     else:
 | 
| 50 |         obj = base.__new__(cls, state)
 | 
| 51 |         if base.__init__ != object.__init__:
 | 
| 52 |             base.__init__(obj, state)
 | 
| 53 |     return obj
 | 
| 54 | 
 | 
| 55 | _HEAPTYPE = 1<<9
 | 
| 56 | 
 | 
| 57 | # Python code for object.__reduce_ex__ for protocols 0 and 1
 | 
| 58 | 
 | 
| 59 | def _reduce_ex(self, proto):
 | 
| 60 |     assert proto < 2
 | 
| 61 |     for base in self.__class__.__mro__:
 | 
| 62 |         if hasattr(base, '__flags__') and not base.__flags__ & _HEAPTYPE:
 | 
| 63 |             break
 | 
| 64 |     else:
 | 
| 65 |         base = object # not really reachable
 | 
| 66 |     if base is object:
 | 
| 67 |         state = None
 | 
| 68 |     else:
 | 
| 69 |         if base is self.__class__:
 | 
| 70 |             raise TypeError, "can't pickle %s objects" % base.__name__
 | 
| 71 |         state = base(self)
 | 
| 72 |     args = (self.__class__, base, state)
 | 
| 73 |     try:
 | 
| 74 |         getstate = self.__getstate__
 | 
| 75 |     except AttributeError:
 | 
| 76 |         if getattr(self, "__slots__", None):
 | 
| 77 |             raise TypeError("a class that defines __slots__ without "
 | 
| 78 |                             "defining __getstate__ cannot be pickled")
 | 
| 79 |         try:
 | 
| 80 |             dict = self.__dict__
 | 
| 81 |         except AttributeError:
 | 
| 82 |             dict = None
 | 
| 83 |     else:
 | 
| 84 |         dict = getstate()
 | 
| 85 |     if dict:
 | 
| 86 |         return _reconstructor, args, dict
 | 
| 87 |     else:
 | 
| 88 |         return _reconstructor, args
 | 
| 89 | 
 | 
| 90 | # Helper for __reduce_ex__ protocol 2
 | 
| 91 | 
 | 
| 92 | def __newobj__(cls, *args):
 | 
| 93 |     return cls.__new__(cls, *args)
 | 
| 94 | 
 | 
| 95 | def _slotnames(cls):
 | 
| 96 |     """Return a list of slot names for a given class.
 | 
| 97 | 
 | 
| 98 |     This needs to find slots defined by the class and its bases, so we
 | 
| 99 |     can't simply return the __slots__ attribute.  We must walk down
 | 
| 100 |     the Method Resolution Order and concatenate the __slots__ of each
 | 
| 101 |     class found there.  (This assumes classes don't modify their
 | 
| 102 |     __slots__ attribute to misrepresent their slots after the class is
 | 
| 103 |     defined.)
 | 
| 104 |     """
 | 
| 105 | 
 | 
| 106 |     # Get the value from a cache in the class if possible
 | 
| 107 |     names = cls.__dict__.get("__slotnames__")
 | 
| 108 |     if names is not None:
 | 
| 109 |         return names
 | 
| 110 | 
 | 
| 111 |     # Not cached -- calculate the value
 | 
| 112 |     names = []
 | 
| 113 |     if not hasattr(cls, "__slots__"):
 | 
| 114 |         # This class has no slots
 | 
| 115 |         pass
 | 
| 116 |     else:
 | 
| 117 |         # Slots found -- gather slot names from all base classes
 | 
| 118 |         for c in cls.__mro__:
 | 
| 119 |             if "__slots__" in c.__dict__:
 | 
| 120 |                 slots = c.__dict__['__slots__']
 | 
| 121 |                 # if class has a single slot, it can be given as a string
 | 
| 122 |                 if isinstance(slots, basestring):
 | 
| 123 |                     slots = (slots,)
 | 
| 124 |                 for name in slots:
 | 
| 125 |                     # special descriptors
 | 
| 126 |                     if name in ("__dict__", "__weakref__"):
 | 
| 127 |                         continue
 | 
| 128 |                     # mangled names
 | 
| 129 |                     elif name.startswith('__') and not name.endswith('__'):
 | 
| 130 |                         names.append('_%s%s' % (c.__name__, name))
 | 
| 131 |                     else:
 | 
| 132 |                         names.append(name)
 | 
| 133 | 
 | 
| 134 |     # Cache the outcome in the class if at all possible
 | 
| 135 |     try:
 | 
| 136 |         cls.__slotnames__ = names
 | 
| 137 |     except:
 | 
| 138 |         pass # But don't die if we can't
 | 
| 139 | 
 | 
| 140 |     return names
 | 
| 141 | 
 | 
| 142 | # A registry of extension codes.  This is an ad-hoc compression
 | 
| 143 | # mechanism.  Whenever a global reference to <module>, <name> is about
 | 
| 144 | # to be pickled, the (<module>, <name>) tuple is looked up here to see
 | 
| 145 | # if it is a registered extension code for it.  Extension codes are
 | 
| 146 | # universal, so that the meaning of a pickle does not depend on
 | 
| 147 | # context.  (There are also some codes reserved for local use that
 | 
| 148 | # don't have this restriction.)  Codes are positive ints; 0 is
 | 
| 149 | # reserved.
 | 
| 150 | 
 | 
| 151 | _extension_registry = {}                # key -> code
 | 
| 152 | _inverted_registry = {}                 # code -> key
 | 
| 153 | _extension_cache = {}                   # code -> object
 | 
| 154 | # Don't ever rebind those names:  cPickle grabs a reference to them when
 | 
| 155 | # it's initialized, and won't see a rebinding.
 | 
| 156 | 
 | 
| 157 | def add_extension(module, name, code):
 | 
| 158 |     """Register an extension code."""
 | 
| 159 |     code = int(code)
 | 
| 160 |     if not 1 <= code <= 0x7fffffff:
 | 
| 161 |         raise ValueError, "code out of range"
 | 
| 162 |     key = (module, name)
 | 
| 163 |     if (_extension_registry.get(key) == code and
 | 
| 164 |         _inverted_registry.get(code) == key):
 | 
| 165 |         return # Redundant registrations are benign
 | 
| 166 |     if key in _extension_registry:
 | 
| 167 |         raise ValueError("key %s is already registered with code %s" %
 | 
| 168 |                          (key, _extension_registry[key]))
 | 
| 169 |     if code in _inverted_registry:
 | 
| 170 |         raise ValueError("code %s is already in use for key %s" %
 | 
| 171 |                          (code, _inverted_registry[code]))
 | 
| 172 |     _extension_registry[key] = code
 | 
| 173 |     _inverted_registry[code] = key
 | 
| 174 | 
 | 
| 175 | def remove_extension(module, name, code):
 | 
| 176 |     """Unregister an extension code.  For testing only."""
 | 
| 177 |     key = (module, name)
 | 
| 178 |     if (_extension_registry.get(key) != code or
 | 
| 179 |         _inverted_registry.get(code) != key):
 | 
| 180 |         raise ValueError("key %s is not registered with code %s" %
 | 
| 181 |                          (key, code))
 | 
| 182 |     del _extension_registry[key]
 | 
| 183 |     del _inverted_registry[code]
 | 
| 184 |     if code in _extension_cache:
 | 
| 185 |         del _extension_cache[code]
 | 
| 186 | 
 | 
| 187 | def clear_extension_cache():
 | 
| 188 |     _extension_cache.clear()
 | 
| 189 | 
 | 
| 190 | # Standard extension code assignments
 | 
| 191 | 
 | 
| 192 | # Reserved ranges
 | 
| 193 | 
 | 
| 194 | # First  Last Count  Purpose
 | 
| 195 | #     1   127   127  Reserved for Python standard library
 | 
| 196 | #   128   191    64  Reserved for Zope
 | 
| 197 | #   192   239    48  Reserved for 3rd parties
 | 
| 198 | #   240   255    16  Reserved for private use (will never be assigned)
 | 
| 199 | #   256   Inf   Inf  Reserved for future assignment
 | 
| 200 | 
 | 
| 201 | # Extension codes are assigned by the Python Software Foundation.
 |