gh-99631: Add custom loads and dumps support for the shelve module (#118065)

Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Petr Viktorin <encukou@gmail.com>
Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
This commit is contained in:
Furkan Onder 2025-07-12 15:27:32 +03:00 committed by GitHub
parent c564847e98
commit dda70fa771
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 336 additions and 36 deletions

View file

@ -17,7 +17,8 @@ This includes most class instances, recursive data types, and objects containing
lots of shared sub-objects. The keys are ordinary strings. lots of shared sub-objects. The keys are ordinary strings.
.. function:: open(filename, flag='c', protocol=None, writeback=False) .. function:: open(filename, flag='c', protocol=None, writeback=False, *, \
serializer=None, deserializer=None)
Open a persistent dictionary. The filename specified is the base filename for Open a persistent dictionary. The filename specified is the base filename for
the underlying database. As a side-effect, an extension may be added to the the underlying database. As a side-effect, an extension may be added to the
@ -41,6 +42,21 @@ lots of shared sub-objects. The keys are ordinary strings.
determine which accessed entries are mutable, nor which ones were actually determine which accessed entries are mutable, nor which ones were actually
mutated). mutated).
By default, :mod:`shelve` uses :func:`pickle.dumps` and :func:`pickle.loads`
for serializing and deserializing. This can be changed by supplying
*serializer* and *deserializer*, respectively.
The *serializer* argument must be a callable which takes an object ``obj``
and the *protocol* as inputs and returns the representation ``obj`` as a
:term:`bytes-like object`; the *protocol* value may be ignored by the
serializer.
The *deserializer* argument must be callable which takes a serialized object
given as a :class:`bytes` object and returns the corresponding object.
A :exc:`ShelveError` is raised if *serializer* is given but *deserializer*
is not, or vice-versa.
.. versionchanged:: 3.10 .. versionchanged:: 3.10
:const:`pickle.DEFAULT_PROTOCOL` is now used as the default pickle :const:`pickle.DEFAULT_PROTOCOL` is now used as the default pickle
protocol. protocol.
@ -48,6 +64,10 @@ lots of shared sub-objects. The keys are ordinary strings.
.. versionchanged:: 3.11 .. versionchanged:: 3.11
Accepts :term:`path-like object` for filename. Accepts :term:`path-like object` for filename.
.. versionchanged:: next
Accepts custom *serializer* and *deserializer* functions in place of
:func:`pickle.dumps` and :func:`pickle.loads`.
.. note:: .. note::
Do not rely on the shelf being closed automatically; always call Do not rely on the shelf being closed automatically; always call
@ -129,7 +149,8 @@ Restrictions
explicitly. explicitly.
.. class:: Shelf(dict, protocol=None, writeback=False, keyencoding='utf-8') .. class:: Shelf(dict, protocol=None, writeback=False, \
keyencoding='utf-8', *, serializer=None, deserializer=None)
A subclass of :class:`collections.abc.MutableMapping` which stores pickled A subclass of :class:`collections.abc.MutableMapping` which stores pickled
values in the *dict* object. values in the *dict* object.
@ -147,6 +168,9 @@ Restrictions
The *keyencoding* parameter is the encoding used to encode keys before they The *keyencoding* parameter is the encoding used to encode keys before they
are used with the underlying dict. are used with the underlying dict.
The *serializer* and *deserializer* parameters have the same interpretation
as in :func:`~shelve.open`.
A :class:`Shelf` object can also be used as a context manager, in which A :class:`Shelf` object can also be used as a context manager, in which
case it will be automatically closed when the :keyword:`with` block ends. case it will be automatically closed when the :keyword:`with` block ends.
@ -161,8 +185,13 @@ Restrictions
:const:`pickle.DEFAULT_PROTOCOL` is now used as the default pickle :const:`pickle.DEFAULT_PROTOCOL` is now used as the default pickle
protocol. protocol.
.. versionchanged:: next
Added the *serializer* and *deserializer* parameters.
.. class:: BsdDbShelf(dict, protocol=None, writeback=False, keyencoding='utf-8')
.. class:: BsdDbShelf(dict, protocol=None, writeback=False, \
keyencoding='utf-8', *, \
serializer=None, deserializer=None)
A subclass of :class:`Shelf` which exposes :meth:`!first`, :meth:`!next`, A subclass of :class:`Shelf` which exposes :meth:`!first`, :meth:`!next`,
:meth:`!previous`, :meth:`!last` and :meth:`!set_location` methods. :meth:`!previous`, :meth:`!last` and :meth:`!set_location` methods.
@ -172,18 +201,27 @@ Restrictions
modules. The *dict* object passed to the constructor must support those modules. The *dict* object passed to the constructor must support those
methods. This is generally accomplished by calling one of methods. This is generally accomplished by calling one of
:func:`!bsddb.hashopen`, :func:`!bsddb.btopen` or :func:`!bsddb.rnopen`. The :func:`!bsddb.hashopen`, :func:`!bsddb.btopen` or :func:`!bsddb.rnopen`. The
optional *protocol*, *writeback*, and *keyencoding* parameters have the same optional *protocol*, *writeback*, *keyencoding*, *serializer* and *deserializer*
interpretation as for the :class:`Shelf` class. parameters have the same interpretation as in :func:`~shelve.open`.
.. versionchanged:: next
Added the *serializer* and *deserializer* parameters.
.. class:: DbfilenameShelf(filename, flag='c', protocol=None, writeback=False) .. class:: DbfilenameShelf(filename, flag='c', protocol=None, \
writeback=False, *, serializer=None, \
deserializer=None)
A subclass of :class:`Shelf` which accepts a *filename* instead of a dict-like A subclass of :class:`Shelf` which accepts a *filename* instead of a dict-like
object. The underlying file will be opened using :func:`dbm.open`. By object. The underlying file will be opened using :func:`dbm.open`. By
default, the file will be created and opened for both read and write. The default, the file will be created and opened for both read and write. The
optional *flag* parameter has the same interpretation as for the :func:`.open` optional *flag* parameter has the same interpretation as for the
function. The optional *protocol* and *writeback* parameters have the same :func:`.open` function. The optional *protocol*, *writeback*, *serializer*
interpretation as for the :class:`Shelf` class. and *deserializer* parameters have the same interpretation as in
:func:`~shelve.open`.
.. versionchanged:: next
Added the *serializer* and *deserializer* parameters.
.. _shelve-example: .. _shelve-example:
@ -225,6 +263,20 @@ object)::
d.close() # close it d.close() # close it
Exceptions
----------
.. exception:: ShelveError
Exception raised when one of the arguments *deserializer* and *serializer*
is missing in the :func:`~shelve.open`, :class:`Shelf`, :class:`BsdDbShelf`
and :class:`DbfilenameShelf`.
The *deserializer* and *serializer* arguments must be given together.
.. versionadded:: next
.. seealso:: .. seealso::
Module :mod:`dbm` Module :mod:`dbm`

View file

@ -56,12 +56,17 @@
the persistent dictionary on disk, if feasible). the persistent dictionary on disk, if feasible).
""" """
from pickle import DEFAULT_PROTOCOL, Pickler, Unpickler from pickle import DEFAULT_PROTOCOL, dumps, loads
from io import BytesIO from io import BytesIO
import collections.abc import collections.abc
__all__ = ["Shelf", "BsdDbShelf", "DbfilenameShelf", "open"] __all__ = ["ShelveError", "Shelf", "BsdDbShelf", "DbfilenameShelf", "open"]
class ShelveError(Exception):
pass
class _ClosedDict(collections.abc.MutableMapping): class _ClosedDict(collections.abc.MutableMapping):
'Marker for a closed dict. Access attempts raise a ValueError.' 'Marker for a closed dict. Access attempts raise a ValueError.'
@ -82,7 +87,7 @@ class Shelf(collections.abc.MutableMapping):
""" """
def __init__(self, dict, protocol=None, writeback=False, def __init__(self, dict, protocol=None, writeback=False,
keyencoding="utf-8"): keyencoding="utf-8", *, serializer=None, deserializer=None):
self.dict = dict self.dict = dict
if protocol is None: if protocol is None:
protocol = DEFAULT_PROTOCOL protocol = DEFAULT_PROTOCOL
@ -91,6 +96,16 @@ def __init__(self, dict, protocol=None, writeback=False,
self.cache = {} self.cache = {}
self.keyencoding = keyencoding self.keyencoding = keyencoding
if serializer is None and deserializer is None:
self.serializer = dumps
self.deserializer = loads
elif (serializer is None) ^ (deserializer is None):
raise ShelveError("serializer and deserializer must be "
"defined together")
else:
self.serializer = serializer
self.deserializer = deserializer
def __iter__(self): def __iter__(self):
for k in self.dict.keys(): for k in self.dict.keys():
yield k.decode(self.keyencoding) yield k.decode(self.keyencoding)
@ -110,8 +125,8 @@ def __getitem__(self, key):
try: try:
value = self.cache[key] value = self.cache[key]
except KeyError: except KeyError:
f = BytesIO(self.dict[key.encode(self.keyencoding)]) f = self.dict[key.encode(self.keyencoding)]
value = Unpickler(f).load() value = self.deserializer(f)
if self.writeback: if self.writeback:
self.cache[key] = value self.cache[key] = value
return value return value
@ -119,10 +134,8 @@ def __getitem__(self, key):
def __setitem__(self, key, value): def __setitem__(self, key, value):
if self.writeback: if self.writeback:
self.cache[key] = value self.cache[key] = value
f = BytesIO() serialized_value = self.serializer(value, self._protocol)
p = Pickler(f, self._protocol) self.dict[key.encode(self.keyencoding)] = serialized_value
p.dump(value)
self.dict[key.encode(self.keyencoding)] = f.getvalue()
def __delitem__(self, key): def __delitem__(self, key):
del self.dict[key.encode(self.keyencoding)] del self.dict[key.encode(self.keyencoding)]
@ -191,33 +204,29 @@ class BsdDbShelf(Shelf):
""" """
def __init__(self, dict, protocol=None, writeback=False, def __init__(self, dict, protocol=None, writeback=False,
keyencoding="utf-8"): keyencoding="utf-8", *, serializer=None, deserializer=None):
Shelf.__init__(self, dict, protocol, writeback, keyencoding) Shelf.__init__(self, dict, protocol, writeback, keyencoding,
serializer=serializer, deserializer=deserializer)
def set_location(self, key): def set_location(self, key):
(key, value) = self.dict.set_location(key) (key, value) = self.dict.set_location(key)
f = BytesIO(value) return (key.decode(self.keyencoding), self.deserializer(value))
return (key.decode(self.keyencoding), Unpickler(f).load())
def next(self): def next(self):
(key, value) = next(self.dict) (key, value) = next(self.dict)
f = BytesIO(value) return (key.decode(self.keyencoding), self.deserializer(value))
return (key.decode(self.keyencoding), Unpickler(f).load())
def previous(self): def previous(self):
(key, value) = self.dict.previous() (key, value) = self.dict.previous()
f = BytesIO(value) return (key.decode(self.keyencoding), self.deserializer(value))
return (key.decode(self.keyencoding), Unpickler(f).load())
def first(self): def first(self):
(key, value) = self.dict.first() (key, value) = self.dict.first()
f = BytesIO(value) return (key.decode(self.keyencoding), self.deserializer(value))
return (key.decode(self.keyencoding), Unpickler(f).load())
def last(self): def last(self):
(key, value) = self.dict.last() (key, value) = self.dict.last()
f = BytesIO(value) return (key.decode(self.keyencoding), self.deserializer(value))
return (key.decode(self.keyencoding), Unpickler(f).load())
class DbfilenameShelf(Shelf): class DbfilenameShelf(Shelf):
@ -227,9 +236,11 @@ class DbfilenameShelf(Shelf):
See the module's __doc__ string for an overview of the interface. See the module's __doc__ string for an overview of the interface.
""" """
def __init__(self, filename, flag='c', protocol=None, writeback=False): def __init__(self, filename, flag='c', protocol=None, writeback=False, *,
serializer=None, deserializer=None):
import dbm import dbm
Shelf.__init__(self, dbm.open(filename, flag), protocol, writeback) Shelf.__init__(self, dbm.open(filename, flag), protocol, writeback,
serializer=serializer, deserializer=deserializer)
def clear(self): def clear(self):
"""Remove all items from the shelf.""" """Remove all items from the shelf."""
@ -238,8 +249,8 @@ def clear(self):
self.cache.clear() self.cache.clear()
self.dict.clear() self.dict.clear()
def open(filename, flag='c', protocol=None, writeback=False, *,
def open(filename, flag='c', protocol=None, writeback=False): serializer=None, deserializer=None):
"""Open a persistent dictionary for reading and writing. """Open a persistent dictionary for reading and writing.
The filename parameter is the base filename for the underlying The filename parameter is the base filename for the underlying
@ -252,4 +263,5 @@ def open(filename, flag='c', protocol=None, writeback=False):
See the module's __doc__ string for an overview of the interface. See the module's __doc__ string for an overview of the interface.
""" """
return DbfilenameShelf(filename, flag, protocol, writeback) return DbfilenameShelf(filename, flag, protocol, writeback,
serializer=serializer, deserializer=deserializer)

View file

@ -1,10 +1,11 @@
import array
import unittest import unittest
import dbm import dbm
import shelve import shelve
import pickle import pickle
import os import os
from test.support import os_helper from test.support import import_helper, os_helper
from collections.abc import MutableMapping from collections.abc import MutableMapping
from test.test_dbm import dbm_iterator from test.test_dbm import dbm_iterator
@ -165,6 +166,239 @@ def test_default_protocol(self):
with shelve.Shelf({}) as s: with shelve.Shelf({}) as s:
self.assertEqual(s._protocol, pickle.DEFAULT_PROTOCOL) self.assertEqual(s._protocol, pickle.DEFAULT_PROTOCOL)
def test_custom_serializer_and_deserializer(self):
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)
def serializer(obj, protocol):
if isinstance(obj, (bytes, bytearray, str)):
if protocol == 5:
return obj
return type(obj).__name__
elif isinstance(obj, array.array):
return obj.tobytes()
raise TypeError(f"Unsupported type for serialization: {type(obj)}")
def deserializer(data):
if isinstance(data, (bytes, bytearray, str)):
return data.decode("utf-8")
elif isinstance(data, array.array):
return array.array("b", data)
raise TypeError(
f"Unsupported type for deserialization: {type(data)}"
)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto), shelve.open(
self.fn,
protocol=proto,
serializer=serializer,
deserializer=deserializer
) as s:
bar = "bar"
bytes_data = b"Hello, world!"
bytearray_data = bytearray(b"\x00\x01\x02\x03\x04")
array_data = array.array("i", [1, 2, 3, 4, 5])
s["foo"] = bar
s["bytes_data"] = bytes_data
s["bytearray_data"] = bytearray_data
s["array_data"] = array_data
if proto == 5:
self.assertEqual(s["foo"], str(bar))
self.assertEqual(s["bytes_data"], "Hello, world!")
self.assertEqual(
s["bytearray_data"], bytearray_data.decode()
)
self.assertEqual(
s["array_data"], array_data.tobytes().decode()
)
else:
self.assertEqual(s["foo"], "str")
self.assertEqual(s["bytes_data"], "bytes")
self.assertEqual(s["bytearray_data"], "bytearray")
self.assertEqual(
s["array_data"], array_data.tobytes().decode()
)
def test_custom_incomplete_serializer_and_deserializer(self):
dbm_sqlite3 = import_helper.import_module("dbm.sqlite3")
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)
with self.assertRaises(dbm_sqlite3.error):
def serializer(obj, protocol=None):
pass
def deserializer(data):
return data.decode("utf-8")
with shelve.open(self.fn, serializer=serializer,
deserializer=deserializer) as s:
s["foo"] = "bar"
def serializer(obj, protocol=None):
return type(obj).__name__.encode("utf-8")
def deserializer(data):
pass
with shelve.open(self.fn, serializer=serializer,
deserializer=deserializer) as s:
s["foo"] = "bar"
self.assertNotEqual(s["foo"], "bar")
self.assertIsNone(s["foo"])
def test_custom_serializer_and_deserializer_bsd_db_shelf(self):
berkeleydb = import_helper.import_module("berkeleydb")
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)
def serializer(obj, protocol=None):
data = obj.__class__.__name__
if protocol == 5:
data = str(len(data))
return data.encode("utf-8")
def deserializer(data):
return data.decode("utf-8")
def type_name_len(obj):
return f"{(len(type(obj).__name__))}"
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto), shelve.BsdDbShelf(
berkeleydb.btopen(self.fn),
protocol=proto,
serializer=serializer,
deserializer=deserializer,
) as s:
bar = "bar"
bytes_obj = b"Hello, world!"
bytearray_obj = bytearray(b"\x00\x01\x02\x03\x04")
arr_obj = array.array("i", [1, 2, 3, 4, 5])
s["foo"] = bar
s["bytes_data"] = bytes_obj
s["bytearray_data"] = bytearray_obj
s["array_data"] = arr_obj
if proto == 5:
self.assertEqual(s["foo"], type_name_len(bar))
self.assertEqual(s["bytes_data"], type_name_len(bytes_obj))
self.assertEqual(s["bytearray_data"],
type_name_len(bytearray_obj))
self.assertEqual(s["array_data"], type_name_len(arr_obj))
k, v = s.set_location(b"foo")
self.assertEqual(k, "foo")
self.assertEqual(v, type_name_len(bar))
k, v = s.previous()
self.assertEqual(k, "bytes_data")
self.assertEqual(v, type_name_len(bytes_obj))
k, v = s.previous()
self.assertEqual(k, "bytearray_data")
self.assertEqual(v, type_name_len(bytearray_obj))
k, v = s.previous()
self.assertEqual(k, "array_data")
self.assertEqual(v, type_name_len(arr_obj))
k, v = s.next()
self.assertEqual(k, "bytearray_data")
self.assertEqual(v, type_name_len(bytearray_obj))
k, v = s.next()
self.assertEqual(k, "bytes_data")
self.assertEqual(v, type_name_len(bytes_obj))
k, v = s.first()
self.assertEqual(k, "array_data")
self.assertEqual(v, type_name_len(arr_obj))
else:
k, v = s.set_location(b"foo")
self.assertEqual(k, "foo")
self.assertEqual(v, "str")
k, v = s.previous()
self.assertEqual(k, "bytes_data")
self.assertEqual(v, "bytes")
k, v = s.previous()
self.assertEqual(k, "bytearray_data")
self.assertEqual(v, "bytearray")
k, v = s.previous()
self.assertEqual(k, "array_data")
self.assertEqual(v, "array")
k, v = s.next()
self.assertEqual(k, "bytearray_data")
self.assertEqual(v, "bytearray")
k, v = s.next()
self.assertEqual(k, "bytes_data")
self.assertEqual(v, "bytes")
k, v = s.first()
self.assertEqual(k, "array_data")
self.assertEqual(v, "array")
self.assertEqual(s["foo"], "str")
self.assertEqual(s["bytes_data"], "bytes")
self.assertEqual(s["bytearray_data"], "bytearray")
self.assertEqual(s["array_data"], "array")
def test_custom_incomplete_serializer_and_deserializer_bsd_db_shelf(self):
berkeleydb = import_helper.import_module("berkeleydb")
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)
def serializer(obj, protocol=None):
return type(obj).__name__.encode("utf-8")
def deserializer(data):
pass
with shelve.BsdDbShelf(berkeleydb.btopen(self.fn),
serializer=serializer,
deserializer=deserializer) as s:
s["foo"] = "bar"
self.assertIsNone(s["foo"])
self.assertNotEqual(s["foo"], "bar")
def serializer(obj, protocol=None):
pass
def deserializer(data):
return data.decode("utf-8")
with shelve.BsdDbShelf(berkeleydb.btopen(self.fn),
serializer=serializer,
deserializer=deserializer) as s:
s["foo"] = "bar"
self.assertEqual(s["foo"], "")
self.assertNotEqual(s["foo"], "bar")
def test_missing_custom_deserializer(self):
def serializer(obj, protocol=None):
pass
kwargs = dict(protocol=2, writeback=False, serializer=serializer)
self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs)
self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs)
def test_missing_custom_serializer(self):
def deserializer(data):
pass
kwargs = dict(protocol=2, writeback=False, deserializer=deserializer)
self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs)
self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs)
class TestShelveBase: class TestShelveBase:
type2test = shelve.Shelf type2test = shelve.Shelf

View file

@ -0,0 +1,2 @@
The :mod:`shelve` module now accepts custom serialization
and deserialization functions.