gh-102103: add module argument to dataclasses.make_dataclass (#102104)

This commit is contained in:
Nikita Sobolev 2023-03-11 03:26:46 +03:00 committed by GitHub
parent ee6f8413a9
commit b48be8fa18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 2 deletions

View file

@ -389,7 +389,7 @@ Module contents
:func:`astuple` raises :exc:`TypeError` if ``obj`` is not a dataclass
instance.
.. function:: make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)
.. function:: make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, module=None)
Creates a new dataclass with name ``cls_name``, fields as defined
in ``fields``, base classes as given in ``bases``, and initialized
@ -401,6 +401,10 @@ Module contents
``match_args``, ``kw_only``, ``slots``, and ``weakref_slot`` have
the same meaning as they do in :func:`dataclass`.
If ``module`` is defined, the ``__module__`` attribute
of the dataclass is set to that value.
By default, it is set to the module name of the caller.
This function is not strictly required, because any Python
mechanism for creating a new class with ``__annotations__`` can
then apply the :func:`dataclass` function to convert that class to

View file

@ -1391,7 +1391,7 @@ def _astuple_inner(obj, tuple_factory):
def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
repr=True, eq=True, order=False, unsafe_hash=False,
frozen=False, match_args=True, kw_only=False, slots=False,
weakref_slot=False):
weakref_slot=False, module=None):
"""Return a new dynamically created dataclass.
The dataclass name will be 'cls_name'. 'fields' is an iterable
@ -1455,6 +1455,19 @@ def exec_body_callback(ns):
# of generic dataclasses.
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
# For pickling to work, the __module__ variable needs to be set to the frame
# where the dataclass is created.
if module is None:
try:
module = sys._getframemodulename(1) or '__main__'
except AttributeError:
try:
module = sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
if module is not None:
cls.__module__ = module
# Apply the normal decorator.
return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash, frozen=frozen,

View file

@ -3606,6 +3606,15 @@ def test_text_annotations(self):
'return': type(None)})
ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)])
ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass',
[('x', int)],
module='test.test_dataclasses')
WrongNameMakeDataclass = make_dataclass('Wrong', [('x', int)])
WrongModuleMakeDataclass = make_dataclass('WrongModuleMakeDataclass',
[('x', int)],
module='custom')
class TestMakeDataclass(unittest.TestCase):
def test_simple(self):
C = make_dataclass('C',
@ -3715,6 +3724,36 @@ def test_no_types(self):
'y': int,
'z': 'typing.Any'})
def test_module_attr(self):
self.assertEqual(ByMakeDataClass.__module__, __name__)
self.assertEqual(ByMakeDataClass(1).__module__, __name__)
self.assertEqual(WrongModuleMakeDataclass.__module__, "custom")
Nested = make_dataclass('Nested', [])
self.assertEqual(Nested.__module__, __name__)
self.assertEqual(Nested().__module__, __name__)
def test_pickle_support(self):
for klass in [ByMakeDataClass, ManualModuleMakeDataClass]:
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
self.assertEqual(
pickle.loads(pickle.dumps(klass, proto)),
klass,
)
self.assertEqual(
pickle.loads(pickle.dumps(klass(1), proto)),
klass(1),
)
def test_cannot_be_pickled(self):
for klass in [WrongNameMakeDataclass, WrongModuleMakeDataclass]:
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
with self.assertRaises(pickle.PickleError):
pickle.dumps(klass, proto)
with self.assertRaises(pickle.PickleError):
pickle.dumps(klass(1), proto)
def test_invalid_type_specification(self):
for bad_field in [(),
(1, 2, 3, 4),

View file

@ -0,0 +1,2 @@
Add ``module`` argument to :func:`dataclasses.make_dataclass` and make
classes produced by it pickleable.