Skip to content

Commit 25a12aa

Browse files
[3.9] bpo-46032: Check types in singledispatch's register() at declaration time (GH-30050) (GH-30254) (GH-30255)
The registry() method of functools.singledispatch() functions checks now the first argument or the first parameter annotation and raises a TypeError if it is not supported. Previously unsupported "types" were ignored (e.g. typing.List[int]) or caused an error at calling time (e.g. list[int]). (cherry picked from commit 078abb6) (cherry picked from commit 03c7449) Co-authored-by: Serhiy Storchaka <[email protected]>
1 parent 0722905 commit 25a12aa

File tree

3 files changed

+88
-4
lines changed

3 files changed

+88
-4
lines changed

Lib/functools.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,7 @@ def _compose_mro(cls, types):
739739
# Remove entries which are already present in the __mro__ or unrelated.
740740
def is_related(typ):
741741
return (typ not in bases and hasattr(typ, '__mro__')
742+
and not isinstance(typ, GenericAlias)
742743
and issubclass(cls, typ))
743744
types = [n for n in types if is_related(n)]
744745
# Remove entries which are strict bases of other entries (they will end up
@@ -836,16 +837,25 @@ def dispatch(cls):
836837
dispatch_cache[cls] = impl
837838
return impl
838839

840+
def _is_valid_dispatch_type(cls):
841+
return isinstance(cls, type) and not isinstance(cls, GenericAlias)
842+
839843
def register(cls, func=None):
840844
"""generic_func.register(cls, func) -> func
841845
842846
Registers a new implementation for the given *cls* on a *generic_func*.
843847
844848
"""
845849
nonlocal cache_token
846-
if func is None:
847-
if isinstance(cls, type):
850+
if _is_valid_dispatch_type(cls):
851+
if func is None:
848852
return lambda f: register(cls, f)
853+
else:
854+
if func is not None:
855+
raise TypeError(
856+
f"Invalid first argument to `register()`. "
857+
f"{cls!r} is not a class."
858+
)
849859
ann = getattr(cls, '__annotations__', {})
850860
if not ann:
851861
raise TypeError(
@@ -858,11 +868,12 @@ def register(cls, func=None):
858868
# only import typing if annotation parsing is necessary
859869
from typing import get_type_hints
860870
argname, cls = next(iter(get_type_hints(func).items()))
861-
if not isinstance(cls, type):
871+
if not _is_valid_dispatch_type(cls):
862872
raise TypeError(
863873
f"Invalid annotation for {argname!r}. "
864874
f"{cls!r} is not a class."
865875
)
876+
866877
registry[cls] = func
867878
if cache_token is None and hasattr(cls, '__abstractmethods__'):
868879
cache_token = get_cache_token()

Lib/test/test_functools.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2560,7 +2560,7 @@ def decorated_classmethod(cls, arg: int) -> str:
25602560

25612561
self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
25622562
self.assertEqual(single_dispatch_foo, '5')
2563-
2563+
25642564
self.assertEqual(
25652565
WithoutSingleDispatch.decorated_classmethod(5),
25662566
WithSingleDispatch.decorated_classmethod(5)
@@ -2655,6 +2655,74 @@ def f(*args):
26552655
with self.assertRaisesRegex(TypeError, msg):
26562656
f()
26572657

2658+
def test_register_genericalias(self):
2659+
@functools.singledispatch
2660+
def f(arg):
2661+
return "default"
2662+
2663+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2664+
f.register(list[int], lambda arg: "types.GenericAlias")
2665+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2666+
f.register(typing.List[int], lambda arg: "typing.GenericAlias")
2667+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2668+
f.register(typing.Union[list[int], str], lambda arg: "typing.Union[types.GenericAlias]")
2669+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2670+
f.register(typing.Union[typing.List[float], bytes], lambda arg: "typing.Union[typing.GenericAlias]")
2671+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2672+
f.register(typing.Any, lambda arg: "typing.Any")
2673+
2674+
self.assertEqual(f([1]), "default")
2675+
self.assertEqual(f([1.0]), "default")
2676+
self.assertEqual(f(""), "default")
2677+
self.assertEqual(f(b""), "default")
2678+
2679+
def test_register_genericalias_decorator(self):
2680+
@functools.singledispatch
2681+
def f(arg):
2682+
return "default"
2683+
2684+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2685+
f.register(list[int])
2686+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2687+
f.register(typing.List[int])
2688+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2689+
f.register(typing.Union[list[int], str])
2690+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2691+
f.register(typing.Union[typing.List[int], str])
2692+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2693+
f.register(typing.Any)
2694+
2695+
def test_register_genericalias_annotation(self):
2696+
@functools.singledispatch
2697+
def f(arg):
2698+
return "default"
2699+
2700+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2701+
@f.register
2702+
def _(arg: list[int]):
2703+
return "types.GenericAlias"
2704+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2705+
@f.register
2706+
def _(arg: typing.List[float]):
2707+
return "typing.GenericAlias"
2708+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2709+
@f.register
2710+
def _(arg: typing.Union[list[int], str]):
2711+
return "types.UnionType(types.GenericAlias)"
2712+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2713+
@f.register
2714+
def _(arg: typing.Union[typing.List[float], bytes]):
2715+
return "typing.Union[typing.GenericAlias]"
2716+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2717+
@f.register
2718+
def _(arg: typing.Any):
2719+
return "typing.Any"
2720+
2721+
self.assertEqual(f([1]), "default")
2722+
self.assertEqual(f([1.0]), "default")
2723+
self.assertEqual(f(""), "default")
2724+
self.assertEqual(f(b""), "default")
2725+
26582726

26592727
class CachedCostItem:
26602728
_cost = 1
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
The ``registry()`` method of :func:`functools.singledispatch` functions
2+
checks now the first argument or the first parameter annotation and raises a
3+
TypeError if it is not supported. Previously unsupported "types" were
4+
ignored (e.g. ``typing.List[int]``) or caused an error at calling time (e.g.
5+
``list[int]``).

0 commit comments

Comments
 (0)