Skip to content
Closed
59 changes: 51 additions & 8 deletions Lib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,21 +781,36 @@ def _find_impl(cls, registry):
*object* type, this function may return None.

"""
mro = _compose_mro(cls, registry.keys())
from typing import get_args, get_origin
# Distinguish between funcs for type[A] and A
if get_origin(cls) is type:
to_type_like_given = lambda t: type[t]
class_ = get_args(cls)[0]
registry_classes = {
get_args(key)[0]
for key in registry.keys() if get_origin(key) is type
}
else:
to_type_like_given = lambda t: t
class_ = cls
registry_classes = {
key for key in registry.keys() if get_origin(key) is not type
}
mro = _compose_mro(class_, registry_classes)
match = None
for t in mro:
if match is not None:
# If *match* is an implicit ABC but there is another unrelated,
# equally matching implicit ABC, refuse the temptation to guess.
if (t in registry and t not in cls.__mro__
and match not in cls.__mro__
if (to_type_like_given(t) in registry and t not in class_.__mro__
and match not in class_.__mro__
and not issubclass(match, t)):
raise RuntimeError("Ambiguous dispatch: {} or {}".format(
match, t))
break
if t in registry:
if to_type_like_given(t) in registry:
match = t
return registry.get(match)
return registry.get(to_type_like_given(match))

def singledispatch(func):
"""Single-dispatch generic function decorator.
Expand Down Expand Up @@ -842,12 +857,24 @@ def _is_union_type(cls):
from typing import get_origin, Union
return get_origin(cls) in {Union, types.UnionType}

def _is_type_type(cls):
# checks if cls is something like type[A]
from typing import get_origin, Type
if get_origin(cls) in (type, Type):
return True

def _is_valid_dispatch_type(cls):
if _is_type_type(cls):
from typing import get_args
cls = get_args(cls)[0]
if isinstance(cls, type):
return True
from typing import get_args
return (_is_union_type(cls) and
all(isinstance(arg, type) for arg in get_args(cls)))
all(isinstance(arg, type) if not _is_type_type(arg)
else isinstance(get_args(arg)[0],
type)
for arg in get_args(cls)))

def register(cls, func=None):
"""generic_func.register(cls, func) -> func
Expand Down Expand Up @@ -893,7 +920,20 @@ def register(cls, func=None):
from typing import get_args

for arg in get_args(cls):
registry[arg] = func
if _is_type_type(arg):
registry[type[get_args(arg)[0]]] = func
else:
registry[arg] = func
elif _is_type_type(cls):
from typing import get_args

inner = get_args(cls)[0]

if _is_union_type(inner):
for arg in get_args(inner):
registry[type[arg]] = func
else:
registry[type[inner]] = func
else:
registry[cls] = func
if cache_token is None and hasattr(cls, '__abstractmethods__'):
Expand All @@ -906,10 +946,13 @@ def wrapper(*args, **kw):
raise TypeError(f'{funcname} requires at least '
'1 positional argument')

return dispatch(args[0].__class__)(*args, **kw)
type_arg = (type[arg1] if isinstance(arg1 := args[0], type)
else arg1.__class__)
return dispatch(type_arg)(*args, **kw)

funcname = getattr(func, '__name__', 'singledispatch function')
registry[object] = func
registry[type[object]] = func
wrapper.register = register
wrapper.dispatch = dispatch
wrapper.registry = types.MappingProxyType(registry)
Expand Down
82 changes: 82 additions & 0 deletions Lib/test/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2905,6 +2905,88 @@ def _(arg: typing.List[float] | bytes):
self.assertEqual(f(""), "default")
self.assertEqual(f(b""), "default")

def test_type_argument(self):
@functools.singledispatch
def f(arg):
return "default"

@f.register
def _(arg: type[int]):
return "type[int]"

@f.register
def _(arg: typing.Type[float]):
return "type[float]"

@f.register(type[str])
def _(arg):
return "type[str]"

@f.register(typing.Type[bytes])
def _(arg):
return "type[bytes]"

self.assertEqual(f(int), "type[int]")
self.assertEqual(f(float), "type[float]")
self.assertEqual(f(str), "type[str]")
self.assertEqual(f(bytes), "type[bytes]")
self.assertEqual(f(2), "default")

def test_type_argument_mro(self):
class A:
pass

class B(A):
pass

class C:
pass

@functools.singledispatch
def f(arg):
return "default"

@f.register
def _(arg: type[A]):
return "type[A]"

@f.register
def _(arg: B):
return "B"

self.assertEqual(f(B), "type[A]")
self.assertEqual(f(C), "default")

def test_type_argument_unions(self):
@functools.singledispatch
def f(arg):
return "default"

@f.register
def _(arg: type[list|dict]):
return "type[list|dict]"

@f.register
def _(arg: type[set]|typing.Type[type(None)]):
return "type[set]|type[NoneType]"

self.assertEqual(f(list), "type[list|dict]")
self.assertEqual(f(type(None)), "type[set]|type[NoneType]")
Comment on lines +2965 to +2974
Copy link
Copy Markdown
Contributor Author

@smheidrich smheidrich Dec 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if treating type[a|b] and type[a]|type[b] the same actually makes sense and even if so, whether union types should be supported for type[...] arguments at all:

E.g. if we have a single-dispatch function with an implementation for type[a|b|c], wouldn't people expect to be able to pass e.g. a|b to it and have it dispatch to that implementation? That would be much harder to implement than my current naive implementation of just splitting up unions into their individual constituent types, especially considering issubclass(a|b, a|b|c) isn't even possible.

So maybe unions of type[...]s and type[...]s of unions should just not be allowed for now, deferring them to when (if ever) issubclass supports these kinds of checks (or Python gets another issubtype function).

OTOH, it's unlikely that introducing support for dispatching to type[a|b|c] given a|b later on would be much of a breaking change... If only dispatching on single types is supported for now then that is all people will use, nobody will rely on the fact that a|b dispatches to the default implementation, they'll just never pass a|b to a single-dispatch function in the first place I should think. So it might be fine to implement it like this for now and leave only the "proper" handling for later.


def test_type_argument_invalid_types(self):
@functools.singledispatch
def f(arg):
return "default"

with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: type[2]):
pass

with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: typing.Type[int]|type[3]):
pass

class CachedCostItem:
_cost = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for dispatching on ``type[...]`` arguments to :func:`functools.singledispatch`.