diff --git a/Lib/functools.py b/Lib/functools.py index 43ead512e1ea4e..1201cefbdfac59 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -781,21 +781,36 @@ def _find_impl(cls, registry): *object* type, this function may return None. """ - mro = _compose_mro(cls, registry.keys()) + from typing import get_args, get_origin + # Distinguish between funcs for type[A] and A + if get_origin(cls) is type: + to_type_like_given = lambda t: type[t] + class_ = get_args(cls)[0] + registry_classes = { + get_args(key)[0] + for key in registry.keys() if get_origin(key) is type + } + else: + to_type_like_given = lambda t: t + class_ = cls + registry_classes = { + key for key in registry.keys() if get_origin(key) is not type + } + mro = _compose_mro(class_, registry_classes) match = None for t in mro: if match is not None: # If *match* is an implicit ABC but there is another unrelated, # equally matching implicit ABC, refuse the temptation to guess. - if (t in registry and t not in cls.__mro__ - and match not in cls.__mro__ + if (to_type_like_given(t) in registry and t not in class_.__mro__ + and match not in class_.__mro__ and not issubclass(match, t)): raise RuntimeError("Ambiguous dispatch: {} or {}".format( match, t)) break - if t in registry: + if to_type_like_given(t) in registry: match = t - return registry.get(match) + return registry.get(to_type_like_given(match)) def singledispatch(func): """Single-dispatch generic function decorator. @@ -842,12 +857,24 @@ def _is_union_type(cls): from typing import get_origin, Union return get_origin(cls) in {Union, types.UnionType} + def _is_type_type(cls): + # checks if cls is something like type[A] + from typing import get_origin, Type + if get_origin(cls) in (type, Type): + return True + def _is_valid_dispatch_type(cls): + if _is_type_type(cls): + from typing import get_args + cls = get_args(cls)[0] if isinstance(cls, type): return True from typing import get_args return (_is_union_type(cls) and - all(isinstance(arg, type) for arg in get_args(cls))) + all(isinstance(arg, type) if not _is_type_type(arg) + else isinstance(get_args(arg)[0], + type) + for arg in get_args(cls))) def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -893,7 +920,20 @@ def register(cls, func=None): from typing import get_args for arg in get_args(cls): - registry[arg] = func + if _is_type_type(arg): + registry[type[get_args(arg)[0]]] = func + else: + registry[arg] = func + elif _is_type_type(cls): + from typing import get_args + + inner = get_args(cls)[0] + + if _is_union_type(inner): + for arg in get_args(inner): + registry[type[arg]] = func + else: + registry[type[inner]] = func else: registry[cls] = func if cache_token is None and hasattr(cls, '__abstractmethods__'): @@ -906,10 +946,13 @@ def wrapper(*args, **kw): raise TypeError(f'{funcname} requires at least ' '1 positional argument') - return dispatch(args[0].__class__)(*args, **kw) + type_arg = (type[arg1] if isinstance(arg1 := args[0], type) + else arg1.__class__) + return dispatch(type_arg)(*args, **kw) funcname = getattr(func, '__name__', 'singledispatch function') registry[object] = func + registry[type[object]] = func wrapper.register = register wrapper.dispatch = dispatch wrapper.registry = types.MappingProxyType(registry) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 730ab1f595f22c..c194d4980da89b 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2905,6 +2905,88 @@ def _(arg: typing.List[float] | bytes): self.assertEqual(f(""), "default") self.assertEqual(f(b""), "default") + def test_type_argument(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: type[int]): + return "type[int]" + + @f.register + def _(arg: typing.Type[float]): + return "type[float]" + + @f.register(type[str]) + def _(arg): + return "type[str]" + + @f.register(typing.Type[bytes]) + def _(arg): + return "type[bytes]" + + self.assertEqual(f(int), "type[int]") + self.assertEqual(f(float), "type[float]") + self.assertEqual(f(str), "type[str]") + self.assertEqual(f(bytes), "type[bytes]") + self.assertEqual(f(2), "default") + + def test_type_argument_mro(self): + class A: + pass + + class B(A): + pass + + class C: + pass + + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: type[A]): + return "type[A]" + + @f.register + def _(arg: B): + return "B" + + self.assertEqual(f(B), "type[A]") + self.assertEqual(f(C), "default") + + def test_type_argument_unions(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: type[list|dict]): + return "type[list|dict]" + + @f.register + def _(arg: type[set]|typing.Type[type(None)]): + return "type[set]|type[NoneType]" + + self.assertEqual(f(list), "type[list|dict]") + self.assertEqual(f(type(None)), "type[set]|type[NoneType]") + + def test_type_argument_invalid_types(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: type[2]): + pass + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: typing.Type[int]|type[3]): + pass class CachedCostItem: _cost = 1 diff --git a/Misc/NEWS.d/next/Library/2022-12-30-15-21-50.gh-issue-100623.3hdA1o.rst b/Misc/NEWS.d/next/Library/2022-12-30-15-21-50.gh-issue-100623.3hdA1o.rst new file mode 100644 index 00000000000000..ca9c9c67054d89 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-12-30-15-21-50.gh-issue-100623.3hdA1o.rst @@ -0,0 +1 @@ +Add support for dispatching on ``type[...]`` arguments to :func:`functools.singledispatch`.