gh-99631: Add custom loads and dumps support for the shelve module (#118065)

Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Petr Viktorin <encukou@gmail.com>
Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
This commit is contained in:
Furkan Onder 2025-07-12 15:27:32 +03:00 committed by GitHub
parent c564847e98
commit dda70fa771
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 336 additions and 36 deletions

View file

@ -56,12 +56,17 @@
the persistent dictionary on disk, if feasible).
"""
from pickle import DEFAULT_PROTOCOL, Pickler, Unpickler
from pickle import DEFAULT_PROTOCOL, dumps, loads
from io import BytesIO
import collections.abc
__all__ = ["Shelf", "BsdDbShelf", "DbfilenameShelf", "open"]
__all__ = ["ShelveError", "Shelf", "BsdDbShelf", "DbfilenameShelf", "open"]
class ShelveError(Exception):
pass
class _ClosedDict(collections.abc.MutableMapping):
'Marker for a closed dict. Access attempts raise a ValueError.'
@ -82,7 +87,7 @@ class Shelf(collections.abc.MutableMapping):
"""
def __init__(self, dict, protocol=None, writeback=False,
keyencoding="utf-8"):
keyencoding="utf-8", *, serializer=None, deserializer=None):
self.dict = dict
if protocol is None:
protocol = DEFAULT_PROTOCOL
@ -91,6 +96,16 @@ def __init__(self, dict, protocol=None, writeback=False,
self.cache = {}
self.keyencoding = keyencoding
if serializer is None and deserializer is None:
self.serializer = dumps
self.deserializer = loads
elif (serializer is None) ^ (deserializer is None):
raise ShelveError("serializer and deserializer must be "
"defined together")
else:
self.serializer = serializer
self.deserializer = deserializer
def __iter__(self):
for k in self.dict.keys():
yield k.decode(self.keyencoding)
@ -110,8 +125,8 @@ def __getitem__(self, key):
try:
value = self.cache[key]
except KeyError:
f = BytesIO(self.dict[key.encode(self.keyencoding)])
value = Unpickler(f).load()
f = self.dict[key.encode(self.keyencoding)]
value = self.deserializer(f)
if self.writeback:
self.cache[key] = value
return value
@ -119,10 +134,8 @@ def __getitem__(self, key):
def __setitem__(self, key, value):
if self.writeback:
self.cache[key] = value
f = BytesIO()
p = Pickler(f, self._protocol)
p.dump(value)
self.dict[key.encode(self.keyencoding)] = f.getvalue()
serialized_value = self.serializer(value, self._protocol)
self.dict[key.encode(self.keyencoding)] = serialized_value
def __delitem__(self, key):
del self.dict[key.encode(self.keyencoding)]
@ -191,33 +204,29 @@ class BsdDbShelf(Shelf):
"""
def __init__(self, dict, protocol=None, writeback=False,
keyencoding="utf-8"):
Shelf.__init__(self, dict, protocol, writeback, keyencoding)
keyencoding="utf-8", *, serializer=None, deserializer=None):
Shelf.__init__(self, dict, protocol, writeback, keyencoding,
serializer=serializer, deserializer=deserializer)
def set_location(self, key):
(key, value) = self.dict.set_location(key)
f = BytesIO(value)
return (key.decode(self.keyencoding), Unpickler(f).load())
return (key.decode(self.keyencoding), self.deserializer(value))
def next(self):
(key, value) = next(self.dict)
f = BytesIO(value)
return (key.decode(self.keyencoding), Unpickler(f).load())
return (key.decode(self.keyencoding), self.deserializer(value))
def previous(self):
(key, value) = self.dict.previous()
f = BytesIO(value)
return (key.decode(self.keyencoding), Unpickler(f).load())
return (key.decode(self.keyencoding), self.deserializer(value))
def first(self):
(key, value) = self.dict.first()
f = BytesIO(value)
return (key.decode(self.keyencoding), Unpickler(f).load())
return (key.decode(self.keyencoding), self.deserializer(value))
def last(self):
(key, value) = self.dict.last()
f = BytesIO(value)
return (key.decode(self.keyencoding), Unpickler(f).load())
return (key.decode(self.keyencoding), self.deserializer(value))
class DbfilenameShelf(Shelf):
@ -227,9 +236,11 @@ class DbfilenameShelf(Shelf):
See the module's __doc__ string for an overview of the interface.
"""
def __init__(self, filename, flag='c', protocol=None, writeback=False):
def __init__(self, filename, flag='c', protocol=None, writeback=False, *,
serializer=None, deserializer=None):
import dbm
Shelf.__init__(self, dbm.open(filename, flag), protocol, writeback)
Shelf.__init__(self, dbm.open(filename, flag), protocol, writeback,
serializer=serializer, deserializer=deserializer)
def clear(self):
"""Remove all items from the shelf."""
@ -238,8 +249,8 @@ def clear(self):
self.cache.clear()
self.dict.clear()
def open(filename, flag='c', protocol=None, writeback=False):
def open(filename, flag='c', protocol=None, writeback=False, *,
serializer=None, deserializer=None):
"""Open a persistent dictionary for reading and writing.
The filename parameter is the base filename for the underlying
@ -252,4 +263,5 @@ def open(filename, flag='c', protocol=None, writeback=False):
See the module's __doc__ string for an overview of the interface.
"""
return DbfilenameShelf(filename, flag, protocol, writeback)
return DbfilenameShelf(filename, flag, protocol, writeback,
serializer=serializer, deserializer=deserializer)