From d2d09c1dead23f3ff682abca3358e9b5f1e45d23 Mon Sep 17 00:00:00 2001 From: smheidrich Date: Wed, 28 Dec 2022 22:09:31 +0100 Subject: [PATCH 01/11] Implement singledispatch on type arguments POC --- Lib/functools.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 43ead512e1ea4e..d36f87c2a9ba58 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -781,7 +781,18 @@ def _find_impl(cls, registry): *object* type, this function may return None. """ - mro = _compose_mro(cls, registry.keys()) + from typing import get_args, get_origin + # Distinguish between funcs for type[A] and A, only use appropriate ones + if get_origin(cls) is type: + classes = ( + get_args(key) for key in registry.keys() if get_origin(key) is type + ) + else: + classes = ( + key for key in registry.keys() if get_origin(key) is None + ) + # Everything from here on out works the same regardless of type[A] or A + mro = _compose_mro(cls, classes) match = None for t in mro: if match is not None: @@ -845,7 +856,9 @@ def _is_union_type(cls): def _is_valid_dispatch_type(cls): if isinstance(cls, type): return True - from typing import get_args + from typing import get_args, get_origin + if get_origin(cls) is type and isinstance(get_args(cls)[0], type): + return True return (_is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))) @@ -906,7 +919,12 @@ def wrapper(*args, **kw): raise TypeError(f'{funcname} requires at least ' '1 positional argument') - return dispatch(args[0].__class__)(*args, **kw) + from inspect import isclass + if isclass(args[0]): + cls_arg = type[args[0]] + else: + cls_arg = args[0].__class__ + return dispatch(cls_arg)(*args, **kw) funcname = getattr(func, '__name__', 'singledispatch function') registry[object] = func From 31e1ae65b4db1033f6025c9b3cd0895ce1d14876 Mon Sep 17 00:00:00 2001 From: smheidrich Date: Thu, 29 Dec 2022 14:50:09 +0100 Subject: [PATCH 02/11] Fix wrapping MRO classes in type[...] if necessary --- Lib/functools.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index d36f87c2a9ba58..feebdf5812c0af 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -782,31 +782,35 @@ def _find_impl(cls, registry): """ from typing import get_args, get_origin - # Distinguish between funcs for type[A] and A, only use appropriate ones + # Distinguish between funcs for type[A] and A if get_origin(cls) is type: - classes = ( - get_args(key) for key in registry.keys() if get_origin(key) is type + to_type_like_given = lambda t: type[t] + class_ = get_args(cls)[0] + registry_classes = ( + get_args(key)[0] + for key in registry.keys() if get_origin(key) is type ) else: - classes = ( - key for key in registry.keys() if get_origin(key) is None + to_type_like_given = lambda t: t + class_ = cls + registry_classes = ( + key for key in registry.keys() if get_origin(key) is not type ) - # Everything from here on out works the same regardless of type[A] or A - mro = _compose_mro(cls, classes) + mro = _compose_mro(class_, registry_classes) match = None for t in mro: if match is not None: # If *match* is an implicit ABC but there is another unrelated, # equally matching implicit ABC, refuse the temptation to guess. - if (t in registry and t not in cls.__mro__ - and match not in cls.__mro__ + if (to_type_like_given(t) in registry and t not in class_.__mro__ + and match not in class_.__mro__ and not issubclass(match, t)): raise RuntimeError("Ambiguous dispatch: {} or {}".format( match, t)) break - if t in registry: + if to_type_like_given(t) in registry: match = t - return registry.get(match) + return registry.get(to_type_like_given(match)) def singledispatch(func): """Single-dispatch generic function decorator. From b874773efe7917a22122be5928b2ce43c9b1078c Mon Sep 17 00:00:00 2001 From: smheidrich Date: Fri, 30 Dec 2022 01:46:42 +0100 Subject: [PATCH 03/11] Add support for Type[X] instead of just type[X] --- Lib/functools.py | 21 ++++++++++++------- Lib/test/test_functools.py | 41 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index feebdf5812c0af..ce856d57d81f9f 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -857,12 +857,18 @@ def _is_union_type(cls): from typing import get_origin, Union return get_origin(cls) in {Union, types.UnionType} + def _is_type_type(cls): + from typing import get_args, get_origin, Type + if (get_origin(cls) in (type, Type) and + isinstance(get_args(cls)[0], type)): + return True + def _is_valid_dispatch_type(cls): if isinstance(cls, type): return True - from typing import get_args, get_origin - if get_origin(cls) is type and isinstance(get_args(cls)[0], type): + if _is_type_type(cls): return True + from typing import get_args return (_is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))) @@ -911,6 +917,10 @@ def register(cls, func=None): for arg in get_args(cls): registry[arg] = func + elif _is_type_type(cls): + from typing import get_args + + registry[type[get_args(cls)[0]]] = func # normalize Type -> type else: registry[cls] = func if cache_token is None and hasattr(cls, '__abstractmethods__'): @@ -924,11 +934,8 @@ def wrapper(*args, **kw): '1 positional argument') from inspect import isclass - if isclass(args[0]): - cls_arg = type[args[0]] - else: - cls_arg = args[0].__class__ - return dispatch(cls_arg)(*args, **kw) + type_arg = type[arg1] if isclass(arg1 := args[0]) else arg1.__class__ + return dispatch(type_arg)(*args, **kw) funcname = getattr(func, '__name__', 'singledispatch function') registry[object] = func diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 730ab1f595f22c..2c08c624b08dce 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2905,6 +2905,47 @@ def _(arg: typing.List[float] | bytes): self.assertEqual(f(""), "default") self.assertEqual(f(b""), "default") + def test_type_argument(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: type[int]): + return "type[int]" + + @f.register + def _(arg: typing.Type[float]): + return "type[float]" + + @f.register(type[str]) + def _(arg): + return "type[str]" + + @f.register(typing.Type[bytes]) + def _(arg): + return "type[bytes]" + + class A: + pass + + class B(A): + pass + + @f.register + def _(arg: type[A]): + return "type[A]" + + @f.register + def _(arg: B): + return "B" + + self.assertEqual(f(int), "type[int]") + self.assertEqual(f(float), "type[float]") + self.assertEqual(f(str), "type[str]") + self.assertEqual(f(bytes), "type[bytes]") + self.assertEqual(f(B), "type[A]") + class CachedCostItem: _cost = 1 From 78d25ce3682893764450e5978b02405484ebb9cb Mon Sep 17 00:00:00 2001 From: smheidrich Date: Fri, 30 Dec 2022 02:05:25 +0100 Subject: [PATCH 04/11] Add test for and fix default func for type arg --- Lib/functools.py | 1 + Lib/test/test_functools.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/Lib/functools.py b/Lib/functools.py index ce856d57d81f9f..cecd78a8d1311d 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -939,6 +939,7 @@ def wrapper(*args, **kw): funcname = getattr(func, '__name__', 'singledispatch function') registry[object] = func + registry[type[object]] = func wrapper.register = register wrapper.dispatch = dispatch wrapper.registry = types.MappingProxyType(registry) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 2c08c624b08dce..c3bc7b6bf87eea 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2932,6 +2932,9 @@ class A: class B(A): pass + class C: + pass + @f.register def _(arg: type[A]): return "type[A]" @@ -2945,6 +2948,7 @@ def _(arg: B): self.assertEqual(f(str), "type[str]") self.assertEqual(f(bytes), "type[bytes]") self.assertEqual(f(B), "type[A]") + self.assertEqual(f(C), "default") class CachedCostItem: From 1ec0ad47d7cc19d19909a1f398e81f20392aae41 Mon Sep 17 00:00:00 2001 From: smheidrich Date: Fri, 30 Dec 2022 02:53:44 +0100 Subject: [PATCH 05/11] Use set instead of generator for registry_classes --- Lib/functools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index cecd78a8d1311d..790e2dc90cf948 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -786,16 +786,16 @@ def _find_impl(cls, registry): if get_origin(cls) is type: to_type_like_given = lambda t: type[t] class_ = get_args(cls)[0] - registry_classes = ( + registry_classes = { get_args(key)[0] for key in registry.keys() if get_origin(key) is type - ) + } else: to_type_like_given = lambda t: t class_ = cls - registry_classes = ( + registry_classes = { key for key in registry.keys() if get_origin(key) is not type - ) + } mro = _compose_mro(class_, registry_classes) match = None for t in mro: From 7da29465826b6c85f45c6c4e2900be5a182ab02e Mon Sep 17 00:00:00 2001 From: smheidrich Date: Fri, 30 Dec 2022 03:13:40 +0100 Subject: [PATCH 06/11] Implement support for type[A|B] unions Still missing: type[A]|type[B] unions (same thing) --- Lib/functools.py | 19 +++++++++++++------ Lib/test/test_functools.py | 5 +++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 790e2dc90cf948..78d64f8d53887b 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -858,15 +858,16 @@ def _is_union_type(cls): return get_origin(cls) in {Union, types.UnionType} def _is_type_type(cls): - from typing import get_args, get_origin, Type - if (get_origin(cls) in (type, Type) and - isinstance(get_args(cls)[0], type)): + # checks if cls is something like type[A] + from typing import get_origin, Type + if get_origin(cls) in (type, Type): return True def _is_valid_dispatch_type(cls): - if isinstance(cls, type): - return True if _is_type_type(cls): + from typing import get_args + cls = get_args(cls)[0] + if isinstance(cls, type): return True from typing import get_args return (_is_union_type(cls) and @@ -920,7 +921,13 @@ def register(cls, func=None): elif _is_type_type(cls): from typing import get_args - registry[type[get_args(cls)[0]]] = func # normalize Type -> type + inner = get_args(cls)[0] + + if _is_union_type(inner): + for arg in get_args(inner): + registry[type[arg]] = func + else: + registry[type[inner]] = func else: registry[cls] = func if cache_token is None and hasattr(cls, '__abstractmethods__'): diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index c3bc7b6bf87eea..d700ce35907078 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2943,12 +2943,17 @@ def _(arg: type[A]): def _(arg: B): return "B" + @f.register + def _(arg: type[list|dict]): + return "list|dict" + self.assertEqual(f(int), "type[int]") self.assertEqual(f(float), "type[float]") self.assertEqual(f(str), "type[str]") self.assertEqual(f(bytes), "type[bytes]") self.assertEqual(f(B), "type[A]") self.assertEqual(f(C), "default") + self.assertEqual(f(list), "list|dict") class CachedCostItem: From b5171c533c7957eaf18d47224c9927e4636cbaa2 Mon Sep 17 00:00:00 2001 From: smheidrich Date: Fri, 30 Dec 2022 03:25:55 +0100 Subject: [PATCH 07/11] Implement type[A]|type[B] union types --- Lib/functools.py | 10 ++++++++-- Lib/test/test_functools.py | 9 +++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 78d64f8d53887b..79b43ff1d921a8 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -871,7 +871,10 @@ def _is_valid_dispatch_type(cls): return True from typing import get_args return (_is_union_type(cls) and - all(isinstance(arg, type) for arg in get_args(cls))) + all(isinstance(arg, type) if not _is_type_type(arg) + else isinstance(get_args(arg)[0], + type) + for arg in get_args(cls))) def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -917,7 +920,10 @@ def register(cls, func=None): from typing import get_args for arg in get_args(cls): - registry[arg] = func + if _is_type_type(arg): + registry[type[get_args(arg)[0]]] = func + else: + registry[arg] = func elif _is_type_type(cls): from typing import get_args diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index d700ce35907078..eb7538c74de634 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2945,7 +2945,11 @@ def _(arg: B): @f.register def _(arg: type[list|dict]): - return "list|dict" + return "type[list|dict]" + + @f.register + def _(arg: type[set]|typing.Type[type(None)]): + return "type[set]|type[NoneType]" self.assertEqual(f(int), "type[int]") self.assertEqual(f(float), "type[float]") @@ -2953,7 +2957,8 @@ def _(arg: type[list|dict]): self.assertEqual(f(bytes), "type[bytes]") self.assertEqual(f(B), "type[A]") self.assertEqual(f(C), "default") - self.assertEqual(f(list), "list|dict") + self.assertEqual(f(list), "type[list|dict]") + self.assertEqual(f(type(None)), "type[set]|type[NoneType]") class CachedCostItem: From 9761b56f87407dba6afcca07845468179dacd5db Mon Sep 17 00:00:00 2001 From: smheidrich Date: Fri, 30 Dec 2022 03:44:38 +0100 Subject: [PATCH 08/11] Add tests for invalid type[X] types --- Lib/test/test_functools.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index eb7538c74de634..1c3611b0b7846b 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2960,6 +2960,15 @@ def _(arg: type[set]|typing.Type[type(None)]): self.assertEqual(f(list), "type[list|dict]") self.assertEqual(f(type(None)), "type[set]|type[NoneType]") + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: type[2]): + pass + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: typing.Type[int]|type[3]): + pass class CachedCostItem: _cost = 1 From eb23d0561ca169d5454d1f0fe28af2ea31077bae Mon Sep 17 00:00:00 2001 From: smheidrich Date: Fri, 30 Dec 2022 03:53:48 +0100 Subject: [PATCH 09/11] Split up tests --- Lib/test/test_functools.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 1c3611b0b7846b..c194d4980da89b 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2926,6 +2926,13 @@ def _(arg): def _(arg): return "type[bytes]" + self.assertEqual(f(int), "type[int]") + self.assertEqual(f(float), "type[float]") + self.assertEqual(f(str), "type[str]") + self.assertEqual(f(bytes), "type[bytes]") + self.assertEqual(f(2), "default") + + def test_type_argument_mro(self): class A: pass @@ -2935,6 +2942,10 @@ class B(A): class C: pass + @functools.singledispatch + def f(arg): + return "default" + @f.register def _(arg: type[A]): return "type[A]" @@ -2943,6 +2954,14 @@ def _(arg: type[A]): def _(arg: B): return "B" + self.assertEqual(f(B), "type[A]") + self.assertEqual(f(C), "default") + + def test_type_argument_unions(self): + @functools.singledispatch + def f(arg): + return "default" + @f.register def _(arg: type[list|dict]): return "type[list|dict]" @@ -2951,15 +2970,14 @@ def _(arg: type[list|dict]): def _(arg: type[set]|typing.Type[type(None)]): return "type[set]|type[NoneType]" - self.assertEqual(f(int), "type[int]") - self.assertEqual(f(float), "type[float]") - self.assertEqual(f(str), "type[str]") - self.assertEqual(f(bytes), "type[bytes]") - self.assertEqual(f(B), "type[A]") - self.assertEqual(f(C), "default") self.assertEqual(f(list), "type[list|dict]") self.assertEqual(f(type(None)), "type[set]|type[NoneType]") + def test_type_argument_invalid_types(self): + @functools.singledispatch + def f(arg): + return "default" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): @f.register def _(arg: type[2]): From 8a648bc0aa38cad32e4fc646ec0f4f26268d2e1d Mon Sep 17 00:00:00 2001 From: "blurb-it[bot]" <43283697+blurb-it[bot]@users.noreply.github.com> Date: Fri, 30 Dec 2022 15:21:50 +0000 Subject: [PATCH 10/11] =?UTF-8?q?=F0=9F=93=9C=F0=9F=A4=96=20Added=20by=20b?= =?UTF-8?q?lurb=5Fit.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../next/Library/2022-12-30-15-21-50.gh-issue-100623.3hdA1o.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 Misc/NEWS.d/next/Library/2022-12-30-15-21-50.gh-issue-100623.3hdA1o.rst diff --git a/Misc/NEWS.d/next/Library/2022-12-30-15-21-50.gh-issue-100623.3hdA1o.rst b/Misc/NEWS.d/next/Library/2022-12-30-15-21-50.gh-issue-100623.3hdA1o.rst new file mode 100644 index 00000000000000..ca9c9c67054d89 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-12-30-15-21-50.gh-issue-100623.3hdA1o.rst @@ -0,0 +1 @@ +Add support for dispatching on ``type[...]`` arguments to :func:`functools.singledispatch`. From 2183609dd395a9420bdf00ee791901c8a31eb372 Mon Sep 17 00:00:00 2001 From: smheidrich Date: Sat, 31 Dec 2022 03:08:31 +0100 Subject: [PATCH 11/11] Replace unnecessary isclass w/ isinstance( , type) --- Lib/functools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index 79b43ff1d921a8..1201cefbdfac59 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -946,8 +946,8 @@ def wrapper(*args, **kw): raise TypeError(f'{funcname} requires at least ' '1 positional argument') - from inspect import isclass - type_arg = type[arg1] if isclass(arg1 := args[0]) else arg1.__class__ + type_arg = (type[arg1] if isinstance(arg1 := args[0], type) + else arg1.__class__) return dispatch(type_arg)(*args, **kw) funcname = getattr(func, '__name__', 'singledispatch function')