Skip to content

bpo-45283: Run _type_check on get_type_hints() #28563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Lib/test/ann_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional
from functools import wraps

__annotations__[1] = 2
__annotations__[1] = int

class C:

Expand All @@ -19,7 +19,7 @@ class C:

class M(type):

__annotations__['123'] = 123
__annotations__['123'] = int
o: type = object

(pars): bool = True
Expand Down
6 changes: 6 additions & 0 deletions Lib/test/ann_module7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Ensures that top-level ``ClassVar`` is not allowed.
# We test this explicitly without ``from __future__ import annotations``

from typing import ClassVar, Final

wrong: ClassVar[int] = 1
4 changes: 2 additions & 2 deletions Lib/test/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,9 @@ class CC(metaclass=CMeta):
def test_var_annot_module_semantics(self):
self.assertEqual(test.__annotations__, {})
self.assertEqual(ann_module.__annotations__,
{1: 2, 'x': int, 'y': str, 'f': typing.Tuple[int, int], 'u': int | float})
{1: int, 'x': int, 'y': str, 'f': typing.Tuple[int, int], 'u': int | float})
self.assertEqual(ann_module.M.__annotations__,
{'123': 123, 'o': type})
{'123': int, 'o': type})
self.assertEqual(ann_module2.__annotations__, {})

def test_var_annot_in_module(self):
Expand Down
84 changes: 76 additions & 8 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2975,7 +2975,10 @@ async def __aexit__(self, etype, eval, tb):

# Definitions needed for features introduced in Python 3.6

from test import ann_module, ann_module2, ann_module3, ann_module5, ann_module6
from test import (
ann_module, ann_module2, ann_module3, ann_module5, ann_module6,
ann_module7,
)
from typing import AsyncContextManager

class A:
Expand Down Expand Up @@ -3070,7 +3073,8 @@ def test_get_type_hints_from_various_objects(self):
gth(None)

def test_get_type_hints_modules(self):
ann_module_type_hints = {1: 2, 'f': Tuple[int, int], 'x': int, 'y': str, 'u': int | float}
ann_module_type_hints = {1: int, 'f': Tuple[int, int],
'x': int, 'y': str, 'u': int | float}
self.assertEqual(gth(ann_module), ann_module_type_hints)
self.assertEqual(gth(ann_module2), {})
self.assertEqual(gth(ann_module3), {})
Expand All @@ -3088,7 +3092,7 @@ def test_get_type_hints_classes(self):
self.assertEqual(gth(ann_module.C), # gth will find the right globalns
{'y': Optional[ann_module.C]})
self.assertIsInstance(gth(ann_module.j_class), dict)
self.assertEqual(gth(ann_module.M), {'123': 123, 'o': type})
self.assertEqual(gth(ann_module.M), {'123': int, 'o': type})
self.assertEqual(gth(ann_module.D),
{'j': str, 'k': str, 'y': Optional[ann_module.C]})
self.assertEqual(gth(ann_module.Y), {'z': int})
Expand Down Expand Up @@ -3276,6 +3280,67 @@ class BadType(BadBase):
self.assertNotIn('bad', sys.modules)
self.assertEqual(get_type_hints(BadType), {'foo': tuple, 'bar': list})

def test_type_check_error_message_during_get_type_hints(self):
class InvalidTupleAnnotation:
x: (1, 2)
with self.assertRaisesRegex(
TypeError,
re.escape(
'get_type_hints() got invalid type annotation. Got (1, 2).',
),
):
get_type_hints(InvalidTupleAnnotation)

def test_invalid_class_level_annotations(self):
class InvalidIntAnnotation:
x: 1 = 1
class InvalidStrAnnotation:
x: '1'
class InvalidListAnnotation1:
x: []
class InvalidListAnnotation2:
x: '[1, 2]'

for fixture in [
InvalidIntAnnotation,
InvalidStrAnnotation,
InvalidListAnnotation1,
InvalidListAnnotation2]:
with self.subTest(fixture=fixture):
with self.assertRaises(TypeError):
get_type_hints(fixture)

def test_invalid_function_arg_annotations(self):
def invalid_arg_type(arg: 1): pass
def invalid_arg_type2(arg: '1'): pass
def invalid_return_type() -> 1: pass
def invalid_return_type2() -> '1': pass
def class_var_arg(arg: ClassVar): pass
def class_var_arg2(arg: ClassVar[int]): pass
def class_var_return_type() -> ClassVar: pass
def class_var_return_type2() -> ClassVar[int]: pass
def final_var_arg(arg: Final): pass
def final_var_arg2(arg: Final[int]): pass
def final_var_return_type() -> Final: pass
def final_var_return_type2() -> Final[int]: pass

for func in [
invalid_arg_type,
invalid_arg_type2,
invalid_return_type,
invalid_return_type2,
class_var_arg,
class_var_arg2,
class_var_return_type,
class_var_return_type2,
final_var_arg,
final_var_arg2,
final_var_return_type,
final_var_return_type2]:
with self.subTest(func=func):
with self.assertRaises(TypeError):
get_type_hints(func)


class GetUtilitiesTestCase(TestCase):
def test_get_origin(self):
Expand Down Expand Up @@ -3349,11 +3414,14 @@ def test_forward_ref_and_final(self):

def test_top_level_class_var(self):
# https://bugs.python.org/issue45166
with self.assertRaisesRegex(
TypeError,
r'typing.ClassVar\[int\] is not valid as type argument',
):
get_type_hints(ann_module6)
# https://bugs.python.org/issue45283
for obj in [ann_module6, ann_module7]:
with self.subTest(obj=obj):
with self.assertRaisesRegex(
TypeError,
r'typing.ClassVar\[int\] is not valid as type argument',
):
get_type_hints(obj)


class CollectionsAbcTests(BaseTestCase):
Expand Down
19 changes: 17 additions & 2 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,12 @@ def _type_check(arg, msg, is_argument=True, module=None, *, is_class=False):
invalid_generic_forms += (Final,)

arg = _type_convert(arg, module=module)
if (isinstance(arg, _GenericAlias) and
arg.__origin__ in invalid_generic_forms):
is_invalid_generic = (
isinstance(arg, _GenericAlias)
and arg.__origin__ in invalid_generic_forms
)
is_invalid_bare_final = is_argument and arg is Final
if is_invalid_generic or is_invalid_bare_final:
raise TypeError(f"{arg} is not valid as type argument")
if arg in (Any, NoReturn, Final):
return arg
Expand Down Expand Up @@ -1780,6 +1784,9 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):

if getattr(obj, '__no_type_check__', None):
return {}

error_msg = "get_type_hints() got invalid type annotation."

# Classes require a special treatment.
if isinstance(obj, type):
hints = {}
Expand All @@ -1805,6 +1812,9 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
value = type(None)
if isinstance(value, str):
value = ForwardRef(value, is_argument=False, is_class=True)
else:
value = _type_check(value, error_msg,
is_argument=False, is_class=True)
value = _eval_type(value, base_globals, base_locals)
hints[name] = value
return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
Expand Down Expand Up @@ -1843,6 +1853,11 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
else:
value = _type_check(value,
error_msg,
is_argument=not isinstance(obj, types.ModuleType),
is_class=False)
value = _eval_type(value, globalns, localns)
if name in defaults and defaults[name] is None:
value = Optional[value]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Run ``_type_check()`` on all ``get_type_hints()`` calls. This catches cases
when any invalid annotations are used.