Issue 2917: Merge the pickle and cPickle module.

This commit is contained in:
Alexandre Vassalotti 2008-06-11 22:43:06 +00:00
parent 1e637b7373
commit cc313061a5
8 changed files with 4685 additions and 126 deletions

View file

@ -174,7 +174,7 @@ def __init__(self, value):
# Pickling machinery
class Pickler:
class _Pickler:
def __init__(self, file, protocol=None):
"""This takes a binary file for writing a pickle data stream.
@ -182,21 +182,19 @@ def __init__(self, file, protocol=None):
All protocols now read and write bytes.
The optional protocol argument tells the pickler to use the
given protocol; supported protocols are 0, 1, 2. The default
protocol is 2; it's been supported for many years now.
Protocol 1 is more efficient than protocol 0; protocol 2 is
more efficient than protocol 1.
given protocol; supported protocols are 0, 1, 2, 3. The default
protocol is 3; a backward-incompatible protocol designed for
Python 3.0.
Specifying a negative protocol version selects the highest
protocol version supported. The higher the protocol used, the
more recent the version of Python needed to read the pickle
produced.
The file parameter must have a write() method that accepts a single
string argument. It can thus be an open file object, a StringIO
object, or any other custom object that meets this interface.
The file argument must have a write() method that accepts a single
bytes argument. It can thus be a file object opened for binary
writing, a io.BytesIO instance, or any other custom object that
meets this interface.
"""
if protocol is None:
protocol = DEFAULT_PROTOCOL
@ -204,7 +202,10 @@ def __init__(self, file, protocol=None):
protocol = HIGHEST_PROTOCOL
elif not 0 <= protocol <= HIGHEST_PROTOCOL:
raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL)
try:
self.write = file.write
except AttributeError:
raise TypeError("file must have a 'write' attribute")
self.memo = {}
self.proto = int(protocol)
self.bin = protocol >= 1
@ -270,10 +271,10 @@ def get(self, i, pack=struct.pack):
return GET + repr(i).encode("ascii") + b'\n'
def save(self, obj):
def save(self, obj, save_persistent_id=True):
# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid:
if pid is not None and save_persistent_id:
self.save_pers(pid)
return
@ -341,7 +342,7 @@ def persistent_id(self, obj):
def save_pers(self, pid):
# Save a persistent id reference
if self.bin:
self.save(pid)
self.save(pid, save_persistent_id=False)
self.write(BINPERSID)
else:
self.write(PERSID + str(pid).encode("ascii") + b'\n')
@ -350,13 +351,13 @@ def save_reduce(self, func, args, state=None,
listitems=None, dictitems=None, obj=None):
# This API is called by some subclasses
# Assert that args is a tuple or None
# Assert that args is a tuple
if not isinstance(args, tuple):
raise PicklingError("args from reduce() should be a tuple")
raise PicklingError("args from save_reduce() should be a tuple")
# Assert that func is callable
if not hasattr(func, '__call__'):
raise PicklingError("func from reduce should be callable")
raise PicklingError("func from save_reduce() should be callable")
save = self.save
write = self.write
@ -438,31 +439,6 @@ def save_bool(self, obj):
self.write(obj and TRUE or FALSE)
dispatch[bool] = save_bool
def save_int(self, obj, pack=struct.pack):
if self.bin:
# If the int is small enough to fit in a signed 4-byte 2's-comp
# format, we can store it more efficiently than the general
# case.
# First one- and two-byte unsigned ints:
if obj >= 0:
if obj <= 0xff:
self.write(BININT1 + bytes([obj]))
return
if obj <= 0xffff:
self.write(BININT2 + bytes([obj&0xff, obj>>8]))
return
# Next check for 4-byte signed ints:
high_bits = obj >> 31 # note that Python shift sign-extends
if high_bits == 0 or high_bits == -1:
# All high bits are copies of bit 2**31, so the value
# fits in a 4-byte signed int.
self.write(BININT + pack("<i", obj))
return
# Text pickle, or int too big to fit in signed 4-byte format.
self.write(INT + repr(obj).encode("ascii") + b'\n')
# XXX save_int is merged into save_long
# dispatch[int] = save_int
def save_long(self, obj, pack=struct.pack):
if self.bin:
# If the int is small enough to fit in a signed 4-byte 2's-comp
@ -503,7 +479,7 @@ def save_float(self, obj, pack=struct.pack):
def save_bytes(self, obj, pack=struct.pack):
if self.proto < 3:
self.save_reduce(bytes, (list(obj),))
self.save_reduce(bytes, (list(obj),), obj=obj)
return
n = len(obj)
if n < 256:
@ -579,12 +555,6 @@ def save_tuple(self, obj):
dispatch[tuple] = save_tuple
# save_empty_tuple() isn't used by anything in Python 2.3. However, I
# found a Pickler subclass in Zope3 that calls it, so it's not harmless
# to remove it.
def save_empty_tuple(self, obj):
self.write(EMPTY_TUPLE)
def save_list(self, obj):
write = self.write
@ -696,7 +666,7 @@ def save_global(self, obj, name=None, pack=struct.pack):
module = whichmodule(obj, name)
try:
__import__(module)
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
except (ImportError, KeyError, AttributeError):
@ -720,9 +690,19 @@ def save_global(self, obj, name=None, pack=struct.pack):
else:
write(EXT4 + pack("<i", code))
return
# Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 3:
write(GLOBAL + bytes(module, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
else:
try:
write(GLOBAL + bytes(module, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
"can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module, name, self.proto))
self.memoize(obj)
dispatch[FunctionType] = save_global
@ -781,7 +761,7 @@ def whichmodule(func, funcname):
# Unpickling machinery
class Unpickler:
class _Unpickler:
def __init__(self, file, *, encoding="ASCII", errors="strict"):
"""This takes a binary file for reading a pickle data stream.
@ -841,6 +821,9 @@ def marker(self):
while stack[k] is not mark: k = k-1
return k
def persistent_load(self, pid):
raise UnpickingError("unsupported persistent id encountered")
dispatch = {}
def load_proto(self):
@ -850,7 +833,7 @@ def load_proto(self):
dispatch[PROTO[0]] = load_proto
def load_persid(self):
pid = self.readline()[:-1]
pid = self.readline()[:-1].decode("ascii")
self.append(self.persistent_load(pid))
dispatch[PERSID[0]] = load_persid
@ -879,9 +862,9 @@ def load_int(self):
val = True
else:
try:
val = int(data)
val = int(data, 0)
except ValueError:
val = int(data)
val = int(data, 0)
self.append(val)
dispatch[INT[0]] = load_int
@ -933,7 +916,8 @@ def load_string(self):
break
else:
raise ValueError("insecure string pickle: %r" % orig)
self.append(codecs.escape_decode(rep)[0])
self.append(codecs.escape_decode(rep)[0]
.decode(self.encoding, self.errors))
dispatch[STRING[0]] = load_string
def load_binstring(self):
@ -975,7 +959,7 @@ def load_tuple(self):
dispatch[TUPLE[0]] = load_tuple
def load_empty_tuple(self):
self.stack.append(())
self.append(())
dispatch[EMPTY_TUPLE[0]] = load_empty_tuple
def load_tuple1(self):
@ -991,11 +975,11 @@ def load_tuple3(self):
dispatch[TUPLE3[0]] = load_tuple3
def load_empty_list(self):
self.stack.append([])
self.append([])
dispatch[EMPTY_LIST[0]] = load_empty_list
def load_empty_dictionary(self):
self.stack.append({})
self.append({})
dispatch[EMPTY_DICT[0]] = load_empty_dictionary
def load_list(self):
@ -1022,13 +1006,13 @@ def load_dict(self):
def _instantiate(self, klass, k):
args = tuple(self.stack[k+1:])
del self.stack[k:]
instantiated = 0
instantiated = False
if (not args and
isinstance(klass, type) and
not hasattr(klass, "__getinitargs__")):
value = _EmptyClass()
value.__class__ = klass
instantiated = 1
instantiated = True
if not instantiated:
try:
value = klass(*args)
@ -1038,8 +1022,8 @@ def _instantiate(self, klass, k):
self.append(value)
def load_inst(self):
module = self.readline()[:-1]
name = self.readline()[:-1]
module = self.readline()[:-1].decode("ascii")
name = self.readline()[:-1].decode("ascii")
klass = self.find_class(module, name)
self._instantiate(klass, self.marker())
dispatch[INST[0]] = load_inst
@ -1059,8 +1043,8 @@ def load_newobj(self):
dispatch[NEWOBJ[0]] = load_newobj
def load_global(self):
module = self.readline()[:-1]
name = self.readline()[:-1]
module = self.readline()[:-1].decode("utf-8")
name = self.readline()[:-1].decode("utf-8")
klass = self.find_class(module, name)
self.append(klass)
dispatch[GLOBAL[0]] = load_global
@ -1095,11 +1079,7 @@ def get_extension(self, code):
def find_class(self, module, name):
# Subclasses may override this
if isinstance(module, bytes_types):
module = module.decode("utf-8")
if isinstance(name, bytes_types):
name = name.decode("utf-8")
__import__(module)
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
return klass
@ -1131,31 +1111,33 @@ def load_dup(self):
dispatch[DUP[0]] = load_dup
def load_get(self):
self.append(self.memo[self.readline()[:-1].decode("ascii")])
i = int(self.readline()[:-1])
self.append(self.memo[i])
dispatch[GET[0]] = load_get
def load_binget(self):
i = ord(self.read(1))
self.append(self.memo[repr(i)])
i = self.read(1)[0]
self.append(self.memo[i])
dispatch[BINGET[0]] = load_binget
def load_long_binget(self):
i = mloads(b'i' + self.read(4))
self.append(self.memo[repr(i)])
self.append(self.memo[i])
dispatch[LONG_BINGET[0]] = load_long_binget
def load_put(self):
self.memo[self.readline()[:-1].decode("ascii")] = self.stack[-1]
i = int(self.readline()[:-1])
self.memo[i] = self.stack[-1]
dispatch[PUT[0]] = load_put
def load_binput(self):
i = ord(self.read(1))
self.memo[repr(i)] = self.stack[-1]
i = self.read(1)[0]
self.memo[i] = self.stack[-1]
dispatch[BINPUT[0]] = load_binput
def load_long_binput(self):
i = mloads(b'i' + self.read(4))
self.memo[repr(i)] = self.stack[-1]
self.memo[i] = self.stack[-1]
dispatch[LONG_BINPUT[0]] = load_long_binput
def load_append(self):
@ -1321,6 +1303,15 @@ def decode_long(data):
n -= 1 << (nbytes * 8)
return n
# Use the faster _pickle if possible
try:
from _pickle import *
except ImportError:
Pickler, Unpickler = _Pickler, _Unpickler
PickleError = _PickleError
PicklingError = _PicklingError
UnpicklingError = _UnpicklingError
# Shorthands
def dump(obj, file, protocol=None):
@ -1333,14 +1324,14 @@ def dumps(obj, protocol=None):
assert isinstance(res, bytes_types)
return res
def load(file):
return Unpickler(file).load()
def load(file, *, encoding="ASCII", errors="strict"):
return Unpickler(file, encoding=encoding, errors=errors).load()
def loads(s):
def loads(s, *, encoding="ASCII", errors="strict"):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return Unpickler(file).load()
return Unpickler(file, encoding=encoding, errors=errors).load()
# Doctest

View file

@ -2079,11 +2079,12 @@ def __init__(self, value):
70: t TUPLE (MARK at 49)
71: p PUT 5
74: R REDUCE
75: V UNICODE 'def'
80: p PUT 6
83: s SETITEM
84: a APPEND
85: . STOP
75: p PUT 6
78: V UNICODE 'def'
83: p PUT 7
86: s SETITEM
87: a APPEND
88: . STOP
highest protocol among opcodes = 0
Try again with a "binary" pickle.
@ -2115,11 +2116,12 @@ def __init__(self, value):
49: t TUPLE (MARK at 37)
50: q BINPUT 5
52: R REDUCE
53: X BINUNICODE 'def'
61: q BINPUT 6
63: s SETITEM
64: e APPENDS (MARK at 3)
65: . STOP
53: q BINPUT 6
55: X BINUNICODE 'def'
63: q BINPUT 7
65: s SETITEM
66: e APPENDS (MARK at 3)
67: . STOP
highest protocol among opcodes = 1
Exercise the INST/OBJ/BUILD family.

View file

@ -362,7 +362,7 @@ def create_data():
return x
class AbstractPickleTests(unittest.TestCase):
# Subclass must define self.dumps, self.loads, self.error.
# Subclass must define self.dumps, self.loads.
_testdata = create_data()
@ -463,8 +463,9 @@ def test_recursive_multi(self):
self.assertEqual(list(x[0].attr.keys()), [1])
self.assert_(x[0].attr[1] is x)
def test_garyp(self):
self.assertRaises(self.error, self.loads, b'garyp')
def test_get(self):
self.assertRaises(KeyError, self.loads, b'g0\np0')
self.assertEquals(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)])
def test_insecure_strings(self):
# XXX Some of these tests are temporarily disabled
@ -955,7 +956,7 @@ def test_dump_closed_file(self):
f = open(TESTFN, "wb")
try:
f.close()
self.assertRaises(ValueError, self.module.dump, 123, f)
self.assertRaises(ValueError, pickle.dump, 123, f)
finally:
os.remove(TESTFN)
@ -964,24 +965,24 @@ def test_load_closed_file(self):
f = open(TESTFN, "wb")
try:
f.close()
self.assertRaises(ValueError, self.module.dump, 123, f)
self.assertRaises(ValueError, pickle.dump, 123, f)
finally:
os.remove(TESTFN)
def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes.
self.assertEqual(self.module.HIGHEST_PROTOCOL, 3)
self.assertEqual(pickle.HIGHEST_PROTOCOL, 3)
def test_callapi(self):
from io import BytesIO
f = BytesIO()
# With and without keyword arguments
self.module.dump(123, f, -1)
self.module.dump(123, file=f, protocol=-1)
self.module.dumps(123, -1)
self.module.dumps(123, protocol=-1)
self.module.Pickler(f, -1)
self.module.Pickler(f, protocol=-1)
pickle.dump(123, f, -1)
pickle.dump(123, file=f, protocol=-1)
pickle.dumps(123, -1)
pickle.dumps(123, protocol=-1)
pickle.Pickler(f, -1)
pickle.Pickler(f, protocol=-1)
class AbstractPersistentPicklerTests(unittest.TestCase):

View file

@ -7,37 +7,42 @@
from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests
class PickleTests(AbstractPickleTests, AbstractPickleModuleTests):
try:
import _pickle
has_c_implementation = True
except ImportError:
has_c_implementation = False
module = pickle
error = KeyError
def dumps(self, arg, proto=None):
return pickle.dumps(arg, proto)
class PickleTests(AbstractPickleModuleTests):
pass
def loads(self, buf):
return pickle.loads(buf)
class PicklerTests(AbstractPickleTests):
class PyPicklerTests(AbstractPickleTests):
error = KeyError
pickler = pickle._Pickler
unpickler = pickle._Unpickler
def dumps(self, arg, proto=None):
f = io.BytesIO()
p = pickle.Pickler(f, proto)
p = self.pickler(f, proto)
p.dump(arg)
f.seek(0)
return bytes(f.read())
def loads(self, buf):
f = io.BytesIO(buf)
u = pickle.Unpickler(f)
u = self.unpickler(f)
return u.load()
class PersPicklerTests(AbstractPersistentPicklerTests):
class PyPersPicklerTests(AbstractPersistentPicklerTests):
pickler = pickle._Pickler
unpickler = pickle._Unpickler
def dumps(self, arg, proto=None):
class PersPickler(pickle.Pickler):
class PersPickler(self.pickler):
def persistent_id(subself, obj):
return self.persistent_id(obj)
f = io.BytesIO()
@ -47,19 +52,29 @@ def persistent_id(subself, obj):
return f.read()
def loads(self, buf):
class PersUnpickler(pickle.Unpickler):
class PersUnpickler(self.unpickler):
def persistent_load(subself, obj):
return self.persistent_load(obj)
f = io.BytesIO(buf)
u = PersUnpickler(f)
return u.load()
if has_c_implementation:
class CPicklerTests(PyPicklerTests):
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler
class CPersPicklerTests(PyPersPicklerTests):
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler
def test_main():
support.run_unittest(
PickleTests,
PicklerTests,
PersPicklerTests
)
tests = [PickleTests, PyPicklerTests, PyPersPicklerTests]
if has_c_implementation:
tests.extend([CPicklerTests, CPersPicklerTests])
support.run_unittest(*tests)
support.run_doctest(pickle)
if __name__ == "__main__":

View file

@ -12,8 +12,6 @@ def dumps(self, arg, proto=None):
def loads(self, buf):
return pickle.loads(buf)
module = pickle
error = KeyError
def test_main():
support.run_unittest(OptimizedPickleTests)

View file

@ -78,6 +78,10 @@ Extension Modules
Library
-------
- The ``pickle`` module is now automatically use an optimized C
implementation of Pickler and Unpickler when available. The
``cPickle`` module is no longer needed.
- Removed the ``htmllib`` and ``sgmllib`` modules.
- The deprecated ``SmartCookie`` and ``SimpleCookie`` classes have

4546
Modules/_pickle.c Normal file

File diff suppressed because it is too large Load diff

View file

@ -422,6 +422,8 @@ def detect_modules(self):
exts.append( Extension("_functools", ["_functoolsmodule.c"]) )
# Memory-based IO accelerator modules
exts.append( Extension("_bytesio", ["_bytesio.c"]) )
# C-optimized pickle replacement
exts.append( Extension("_pickle", ["_pickle.c"]) )
# atexit
exts.append( Extension("atexit", ["atexitmodule.c"]) )
# _json speedups