gh-114053: Fix bad interaction of PEP-695, PEP-563 and `get_type_hints` (#118009)

Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
This commit is contained in:
Alex Waygood 2024-04-19 14:03:44 +01:00 committed by GitHub
parent 15b3555e4a
commit 1e3e7ce11e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 81 additions and 10 deletions

View file

@ -399,7 +399,8 @@ def inner(*args, **kwds):
return decorator
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
def _eval_type(t, globalns, localns, type_params, *, recursive_guard=frozenset()):
"""Evaluate all forward references in the given type t.
For use of globalns and localns see the docstring for get_type_hints().
@ -407,7 +408,7 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
ForwardRef.
"""
if isinstance(t, ForwardRef):
return t._evaluate(globalns, localns, recursive_guard)
return t._evaluate(globalns, localns, type_params, recursive_guard=recursive_guard)
if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
if isinstance(t, GenericAlias):
args = tuple(
@ -421,7 +422,13 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
t = t.__origin__[args]
if is_unpacked:
t = Unpack[t]
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
ev_args = tuple(
_eval_type(
a, globalns, localns, type_params, recursive_guard=recursive_guard
)
for a in t.__args__
)
if ev_args == t.__args__:
return t
if isinstance(t, GenericAlias):
@ -974,7 +981,7 @@ def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
self.__forward_is_class__ = is_class
self.__forward_module__ = module
def _evaluate(self, globalns, localns, recursive_guard):
def _evaluate(self, globalns, localns, type_params, *, recursive_guard):
if self.__forward_arg__ in recursive_guard:
return self
if not self.__forward_evaluated__ or localns is not globalns:
@ -988,14 +995,25 @@ def _evaluate(self, globalns, localns, recursive_guard):
globalns = getattr(
sys.modules.get(self.__forward_module__, None), '__dict__', globalns
)
if type_params:
# "Inject" type parameters into the local namespace
# (unless they are shadowed by assignments *in* the local namespace),
# as a way of emulating annotation scopes when calling `eval()`
locals_to_pass = {param.__name__: param for param in type_params} | localns
else:
locals_to_pass = localns
type_ = _type_check(
eval(self.__forward_code__, globalns, localns),
eval(self.__forward_code__, globalns, locals_to_pass),
"Forward references must evaluate to types.",
is_argument=self.__forward_is_argument__,
allow_special_forms=self.__forward_is_class__,
)
self.__forward_value__ = _eval_type(
type_, globalns, localns, recursive_guard | {self.__forward_arg__}
type_,
globalns,
localns,
type_params,
recursive_guard=(recursive_guard | {self.__forward_arg__}),
)
self.__forward_evaluated__ = True
return self.__forward_value__
@ -2334,7 +2352,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
value = type(None)
if isinstance(value, str):
value = ForwardRef(value, is_argument=False, is_class=True)
value = _eval_type(value, base_globals, base_locals)
value = _eval_type(value, base_globals, base_locals, base.__type_params__)
hints[name] = value
return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
@ -2360,6 +2378,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
raise TypeError('{!r} is not a module, class, method, '
'or function.'.format(obj))
hints = dict(hints)
type_params = getattr(obj, "__type_params__", ())
for name, value in hints.items():
if value is None:
value = type(None)
@ -2371,7 +2390,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
hints[name] = _eval_type(value, globalns, localns)
hints[name] = _eval_type(value, globalns, localns, type_params)
return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}