mirror of
https://github.com/python/cpython.git
synced 2025-12-08 06:10:17 +00:00
[3.14] gh-137530: generate an __annotate__ function for dataclasses __init__ (GH-137711) (#141352)
(cherry picked from commit 12837c6363)
Co-authored-by: David Ellis <ducksual@gmail.com>
This commit is contained in:
parent
9221030909
commit
727cdcba8e
3 changed files with 219 additions and 15 deletions
|
|
@ -441,9 +441,11 @@ def __init__(self, globals):
|
|||
self.locals = {}
|
||||
self.overwrite_errors = {}
|
||||
self.unconditional_adds = {}
|
||||
self.method_annotations = {}
|
||||
|
||||
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
|
||||
overwrite_error=False, unconditional_add=False, decorator=None):
|
||||
overwrite_error=False, unconditional_add=False, decorator=None,
|
||||
annotation_fields=None):
|
||||
if locals is not None:
|
||||
self.locals.update(locals)
|
||||
|
||||
|
|
@ -464,16 +466,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
|
|||
|
||||
self.names.append(name)
|
||||
|
||||
if return_type is not MISSING:
|
||||
self.locals[f'__dataclass_{name}_return_type__'] = return_type
|
||||
return_annotation = f'->__dataclass_{name}_return_type__'
|
||||
else:
|
||||
return_annotation = ''
|
||||
if annotation_fields is not None:
|
||||
self.method_annotations[name] = (annotation_fields, return_type)
|
||||
|
||||
args = ','.join(args)
|
||||
body = '\n'.join(body)
|
||||
|
||||
# Compute the text of the entire function, add it to the text we're generating.
|
||||
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
|
||||
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')
|
||||
|
||||
def add_fns_to_class(self, cls):
|
||||
# The source to all of the functions we're generating.
|
||||
|
|
@ -509,6 +509,15 @@ def add_fns_to_class(self, cls):
|
|||
# Now that we've generated the functions, assign them into cls.
|
||||
for name, fn in zip(self.names, fns):
|
||||
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
|
||||
|
||||
try:
|
||||
annotation_fields, return_type = self.method_annotations[name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
|
||||
fn.__annotate__ = annotate_fn
|
||||
|
||||
if self.unconditional_adds.get(name, False):
|
||||
setattr(cls, name, fn)
|
||||
else:
|
||||
|
|
@ -524,6 +533,44 @@ def add_fns_to_class(self, cls):
|
|||
raise TypeError(error_msg)
|
||||
|
||||
|
||||
def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
|
||||
# Create an __annotate__ function for a dataclass
|
||||
# Try to return annotations in the same format as they would be
|
||||
# from a regular __init__ function
|
||||
|
||||
def __annotate__(format, /):
|
||||
Format = annotationlib.Format
|
||||
match format:
|
||||
case Format.VALUE | Format.FORWARDREF | Format.STRING:
|
||||
cls_annotations = {}
|
||||
for base in reversed(__class__.__mro__):
|
||||
cls_annotations.update(
|
||||
annotationlib.get_annotations(base, format=format)
|
||||
)
|
||||
|
||||
new_annotations = {}
|
||||
for k in annotation_fields:
|
||||
new_annotations[k] = cls_annotations[k]
|
||||
|
||||
if return_type is not MISSING:
|
||||
if format == Format.STRING:
|
||||
new_annotations["return"] = annotationlib.type_repr(return_type)
|
||||
else:
|
||||
new_annotations["return"] = return_type
|
||||
|
||||
return new_annotations
|
||||
|
||||
case _:
|
||||
raise NotImplementedError(format)
|
||||
|
||||
# This is a flag for _add_slots to know it needs to regenerate this method
|
||||
# In order to remove references to the original class when it is replaced
|
||||
__annotate__.__generated_by_dataclasses__ = True
|
||||
__annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"
|
||||
|
||||
return __annotate__
|
||||
|
||||
|
||||
def _field_assign(frozen, name, value, self_name):
|
||||
# If we're a frozen class, then assign to our fields in __init__
|
||||
# via object.__setattr__. Otherwise, just use a simple
|
||||
|
|
@ -612,7 +659,7 @@ def _init_param(f):
|
|||
elif f.default_factory is not MISSING:
|
||||
# There's a factory function. Set a marker.
|
||||
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
|
||||
return f'{f.name}:__dataclass_type_{f.name}__{default}'
|
||||
return f'{f.name}{default}'
|
||||
|
||||
|
||||
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
||||
|
|
@ -635,11 +682,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
|||
raise TypeError(f'non-default argument {f.name!r} '
|
||||
f'follows default argument {seen_default.name!r}')
|
||||
|
||||
locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
|
||||
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
|
||||
'__dataclass_builtins_object__': object,
|
||||
}
|
||||
}
|
||||
annotation_fields = [f.name for f in fields if f.init]
|
||||
|
||||
locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
|
||||
'__dataclass_builtins_object__': object}
|
||||
|
||||
body_lines = []
|
||||
for f in fields:
|
||||
|
|
@ -670,7 +716,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
|||
[self_name] + _init_params,
|
||||
body_lines,
|
||||
locals=locals,
|
||||
return_type=None)
|
||||
return_type=None,
|
||||
annotation_fields=annotation_fields)
|
||||
|
||||
|
||||
def _frozen_get_del_attr(cls, fields, func_builder):
|
||||
|
|
@ -1336,6 +1383,25 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
|
|||
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
|
||||
break
|
||||
|
||||
# Get new annotations to remove references to the original class
|
||||
# in forward references
|
||||
newcls_ann = annotationlib.get_annotations(
|
||||
newcls, format=annotationlib.Format.FORWARDREF)
|
||||
|
||||
# Fix references in dataclass Fields
|
||||
for f in getattr(newcls, _FIELDS).values():
|
||||
try:
|
||||
ann = newcls_ann[f.name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
f.type = ann
|
||||
|
||||
# Fix the class reference in the __annotate__ method
|
||||
init_annotate = newcls.__init__.__annotate__
|
||||
if getattr(init_annotate, "__generated_by_dataclasses__", False):
|
||||
_update_func_cell_for__class__(init_annotate, cls, newcls)
|
||||
|
||||
return newcls
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue