From 77a2055a3aac675dbe603a2d5ba09540fb117b4a Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 16 Nov 2022 14:21:11 +0100 Subject: [PATCH 01/26] feat: allow da bulk access to return da for document Signed-off-by: Sami Jaghouar --- docarray/array/abstract_array.py | 4 +--- docarray/array/array.py | 2 +- docarray/array/mixins/attribute.py | 7 +------ docarray/document/mixins/proto.py | 2 ++ 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 4812d69d4ca..764da3b0fd7 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -13,7 +13,5 @@ def __init__(self, docs: Iterable[BaseDocument]): ... @abstractmethod - def __class_getitem__( - cls, item: Type[BaseDocument] - ) -> Type['AbstractDocumentArray']: + def __class_getitem__(cls, item: Type[BaseDocument]) -> Type['AbstractDocument']: ... diff --git a/docarray/array/array.py b/docarray/array/array.py index 681bbe19e02..344edb547c8 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -6,7 +6,7 @@ class DocumentArray( - list, + list[AbstractDocument], ProtoArrayMixin, GetAttributeArrayMixin, AbstractDocumentArray, diff --git a/docarray/array/mixins/attribute.py b/docarray/array/mixins/attribute.py index 712b5b6524e..a22e172e6c9 100644 --- a/docarray/array/mixins/attribute.py +++ b/docarray/array/mixins/attribute.py @@ -20,11 +20,6 @@ def _get_documents_attribute( field_type = self.__class__.document_type._get_nested_document_class(field) if issubclass(field_type, BaseDocument): - # calling __class_getitem__ ourselves is a hack otherwise mypy complain - # most likely a bug in mypy though - # bug reported here https://github.com/python/mypy/issues/14111 - return self.__class__.__class_getitem__(field_type)( - (getattr(doc, field) for doc in self) - ) + return self.__class__[field_type]((getattr(doc, field) for doc in self)) else: return [getattr(doc, field) for doc in self] diff --git a/docarray/document/mixins/proto.py b/docarray/document/mixins/proto.py index aaa3e6157dc..73c0733f363 100644 --- a/docarray/document/mixins/proto.py +++ b/docarray/document/mixins/proto.py @@ -9,6 +9,8 @@ T = TypeVar('T', bound='ProtoMixin') +T = TypeVar('T', bound='ProtoMixin') + class ProtoMixin(AbstractDocument, BaseNode): @classmethod From bd6d128081b6ded4e2e391959b7a7b21d3797a65 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 16 Nov 2022 15:27:27 +0100 Subject: [PATCH 02/26] fix: fix mypy type pb Signed-off-by: Sami Jaghouar --- docarray/array/abstract_array.py | 4 +++- docarray/array/mixins/attribute.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 764da3b0fd7..4812d69d4ca 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -13,5 +13,7 @@ def __init__(self, docs: Iterable[BaseDocument]): ... @abstractmethod - def __class_getitem__(cls, item: Type[BaseDocument]) -> Type['AbstractDocument']: + def __class_getitem__( + cls, item: Type[BaseDocument] + ) -> Type['AbstractDocumentArray']: ... diff --git a/docarray/array/mixins/attribute.py b/docarray/array/mixins/attribute.py index a22e172e6c9..438f08916e7 100644 --- a/docarray/array/mixins/attribute.py +++ b/docarray/array/mixins/attribute.py @@ -20,6 +20,10 @@ def _get_documents_attribute( field_type = self.__class__.document_type._get_nested_document_class(field) if issubclass(field_type, BaseDocument): - return self.__class__[field_type]((getattr(doc, field) for doc in self)) + # calling __class_getitem__ ourselves is a hack otherwise mypy complain + # most likely a bug in mypy though + return self.__class__.__class_getitem__(field_type)( + (getattr(doc, field) for doc in self) + ) else: return [getattr(doc, field) for doc in self] From d778cecb2ef7a9764f5cde9feeb3a319f6ada140 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 16 Nov 2022 16:26:31 +0100 Subject: [PATCH 03/26] fix: add link to the mypy issue Signed-off-by: Sami Jaghouar --- docarray/array/mixins/attribute.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docarray/array/mixins/attribute.py b/docarray/array/mixins/attribute.py index 438f08916e7..712b5b6524e 100644 --- a/docarray/array/mixins/attribute.py +++ b/docarray/array/mixins/attribute.py @@ -22,6 +22,7 @@ def _get_documents_attribute( if issubclass(field_type, BaseDocument): # calling __class_getitem__ ourselves is a hack otherwise mypy complain # most likely a bug in mypy though + # bug reported here https://github.com/python/mypy/issues/14111 return self.__class__.__class_getitem__(field_type)( (getattr(doc, field) for doc in self) ) From 717e7ee4f5d97ed6517a78f8463d3784e02a7fc1 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 16 Nov 2022 16:30:04 +0100 Subject: [PATCH 04/26] fix: remove useless list type hint Signed-off-by: Sami Jaghouar --- docarray/array/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index 344edb547c8..681bbe19e02 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -6,7 +6,7 @@ class DocumentArray( - list[AbstractDocument], + list, ProtoArrayMixin, GetAttributeArrayMixin, AbstractDocumentArray, From 2c3d5f103f824ea58794ba5a7cb5ae7e4555c178 Mon Sep 17 00:00:00 2001 From: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Date: Thu, 17 Nov 2022 08:51:48 +0100 Subject: [PATCH 05/26] feat: torch tensor type (#800) * feat: add tensor type for ndarray * fix: fix mypy typing * feat: torch tensor type Signed-off-by: Johannes Messner * fix: protobuf for pytorch type Signed-off-by: Johannes Messner * ci: install all extras in the ci Signed-off-by: Johannes Messner * refactor: make nice looking * docs: update docarray/typing/tensor/torch_tensor.py Co-authored-by: samsja <55492238+samsja@users.noreply.github.com> Signed-off-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> * refactor: code style Signed-off-by: Johannes Messner * fix: black and mypy Signed-off-by: Johannes Messner * fix: suppress mypy import error * ci: fix ci install Signed-off-by: Johannes Messner Signed-off-by: Johannes Messner Signed-off-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Co-authored-by: samsja <55492238+samsja@users.noreply.github.com> --- docarray/document/mixins/proto.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/docarray/document/mixins/proto.py b/docarray/document/mixins/proto.py index 73c0733f363..aaa3e6157dc 100644 --- a/docarray/document/mixins/proto.py +++ b/docarray/document/mixins/proto.py @@ -9,8 +9,6 @@ T = TypeVar('T', bound='ProtoMixin') -T = TypeVar('T', bound='ProtoMixin') - class ProtoMixin(AbstractDocument, BaseNode): @classmethod From e5d678c6d90a081bcf4eceffc3cd44120e3a8fda Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 17 Nov 2022 15:15:19 +0100 Subject: [PATCH 06/26] feat: add fastapi to dependency Signed-off-by: Sami Jaghouar --- poetry.lock | 74 +++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 4 +++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 88db66d8d36..890b85f40a2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -270,6 +270,24 @@ python-versions = "*" [package.extras] tests = ["asttokens", "littleutils", "pytest", "rich"] +[[package]] +name = "fastapi" +version = "0.87.0" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0" +starlette = "0.21.0" + +[package.extras] +all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.114)", "uvicorn[standard] (>=0.12.0,<0.19.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer[all] (>=0.6.1,<0.7.0)"] +test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==22.8.0)", "coverage[toml] (>=6.5.0,<7.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.6)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.114)", "sqlalchemy (>=1.3.18,<=1.4.41)", "types-orjson (==3.6.2)", "types-ujson (==5.5.0)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] + [[package]] name = "fastjsonschema" version = "2.16.2" @@ -293,6 +311,14 @@ python-versions = ">=3.7" docs = ["furo (>=2022.6.21)", "sphinx (>=5.1.1)", "sphinx-autodoc-typehints (>=1.19.1)"] testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pytest-cov (>=3)", "pytest-timeout (>=2.1)"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "dev" +optional = false +python-versions = ">=3.7" + [[package]] name = "identify" version = "2.5.8" @@ -1283,6 +1309,21 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "starlette" +version = "0.21.0" +description = "The little ASGI library that shines." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] + [[package]] name = "terminado" version = "0.17.0" @@ -1398,6 +1439,21 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "uvicorn" +version = "0.19.0" +description = "The lightning-fast ASGI server." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.0)"] + [[package]] name = "virtualenv" version = "20.16.7" @@ -1474,7 +1530,7 @@ torch = ["torch"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "9f20e49f31a6f56c379c1dca4b3a327dabff31cf217980ca9731deea7a4b821c" +content-hash = "4fa68f08685cb4a52ff58ec6b0dd0de63d6fea01a8159bbbd41b9c057bd38512" [metadata.files] anyio = [ @@ -1688,6 +1744,10 @@ executing = [ {file = "executing-1.2.0-py2.py3-none-any.whl", hash = "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc"}, {file = "executing-1.2.0.tar.gz", hash = "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"}, ] +fastapi = [ + {file = "fastapi-0.87.0-py3-none-any.whl", hash = "sha256:254453a2e22f64e2a1b4e1d8baf67d239e55b6c8165c079d25746a5220c81bb4"}, + {file = "fastapi-0.87.0.tar.gz", hash = "sha256:07032e53df9a57165047b4f38731c38bdcc3be5493220471015e2b4b51b486a4"}, +] fastjsonschema = [ {file = "fastjsonschema-2.16.2-py3-none-any.whl", hash = "sha256:21f918e8d9a1a4ba9c22e09574ba72267a6762d47822db9add95f6454e51cc1c"}, {file = "fastjsonschema-2.16.2.tar.gz", hash = "sha256:01e366f25d9047816fe3d288cbfc3e10541daf0af2044763f3d0ade42476da18"}, @@ -1696,6 +1756,10 @@ filelock = [ {file = "filelock-3.8.0-py3-none-any.whl", hash = "sha256:617eb4e5eedc82fc5f47b6d61e4d11cb837c56cb4544e39081099fa17ad109d4"}, {file = "filelock-3.8.0.tar.gz", hash = "sha256:55447caa666f2198c5b6b13a26d2084d26fa5b115c00d065664b2124680c4edc"}, ] +h11 = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] identify = [ {file = "identify-2.5.8-py2.py3-none-any.whl", hash = "sha256:48b7925fe122720088aeb7a6c34f17b27e706b72c61070f27fe3789094233440"}, {file = "identify-2.5.8.tar.gz", hash = "sha256:7a214a10313b9489a0d61467db2856ae8d0b8306fc923e03a9effa53d8aedc58"}, @@ -2303,6 +2367,10 @@ stack-data = [ {file = "stack_data-0.6.1-py3-none-any.whl", hash = "sha256:960cb054d6a1b2fdd9cbd529e365b3c163e8dabf1272e02cfe36b58403cff5c6"}, {file = "stack_data-0.6.1.tar.gz", hash = "sha256:6c9a10eb5f342415fe085db551d673955611afb821551f554d91772415464315"}, ] +starlette = [ + {file = "starlette-0.21.0-py3-none-any.whl", hash = "sha256:0efc058261bbcddeca93cad577efd36d0c8a317e44376bcfc0e097a2b3dc24a7"}, + {file = "starlette-0.21.0.tar.gz", hash = "sha256:b1b52305ee8f7cfc48cde383496f7c11ab897cd7112b33d998b1317dc8ef9027"}, +] terminado = [ {file = "terminado-0.17.0-py3-none-any.whl", hash = "sha256:bf6fe52accd06d0661d7611cc73202121ec6ee51e46d8185d489ac074ca457c2"}, {file = "terminado-0.17.0.tar.gz", hash = "sha256:520feaa3aeab8ad64a69ca779be54be9234edb2d0d6567e76c93c2c9a4e6e43f"}, @@ -2371,6 +2439,10 @@ urllib3 = [ {file = "urllib3-1.26.12-py2.py3-none-any.whl", hash = "sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997"}, {file = "urllib3-1.26.12.tar.gz", hash = "sha256:3fa96cf423e6987997fc326ae8df396db2a8b7c667747d47ddd8ecba91f4a74e"}, ] +uvicorn = [ + {file = "uvicorn-0.19.0-py3-none-any.whl", hash = "sha256:cc277f7e73435748e69e075a721841f7c4a95dba06d12a72fe9874acced16f6f"}, + {file = "uvicorn-0.19.0.tar.gz", hash = "sha256:cf538f3018536edb1f4a826311137ab4944ed741d52aeb98846f52215de57f25"}, +] virtualenv = [ {file = "virtualenv-20.16.7-py3-none-any.whl", hash = "sha256:efd66b00386fdb7dbe4822d172303f40cd05e50e01740b19ea42425cbe653e29"}, {file = "virtualenv-20.16.7.tar.gz", hash = "sha256:8691e3ff9387f743e00f6bb20f70121f5e4f596cae754531f2b3b3a1b1ac696e"}, diff --git a/pyproject.toml b/pyproject.toml index 1ed1d039fea..94093f0a312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,10 @@ black = "^22.10.0" isort = "^5.10.1" ruff = "^0.0.117" +[tool.poetry.group.dev.dependencies] +fastapi = "^0.87.0" +uvicorn = "^0.19.0" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" From 414a04c9df5ea49b913e4435e50c43d50f257386 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 17 Nov 2022 17:08:24 +0100 Subject: [PATCH 07/26] feat(wip): add fake method to dump tensor to json Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/tensor.py | 7 ++++++- docarray/typing/tensor/torch_tensor.py | 7 ++++++- tests/units/typing/test_to_json_schema.py | 12 ++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) create mode 100644 tests/units/typing/test_to_json_schema.py diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index c031207cdaa..df72651e8e6 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast import numpy as np @@ -40,6 +40,11 @@ def validate( def from_ndarray(cls: Type[T], value: np.ndarray) -> T: return value.view(cls) + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + # this is needed to dump to json + field_schema.update(type='string', format='uuidhello') + def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto: """Convert Document into a NodeProto protobuf message. This function should be called when the Document is nested into another Document that need to be diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 4812d52ee38..f9ea4ae0cbd 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast import numpy as np import torch # type: ignore @@ -51,6 +51,11 @@ def validate( pass # handled below raise ValueError(f'Expected a torch.Tensor, got {type(value)}') + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + # this is needed to dump to json + field_schema.update(type='string', format='uuidhello') + @classmethod def from_native_torch_tensor(cls: Type[T], value: torch.Tensor) -> T: """Create a TorchTensor from a native torch.Tensor diff --git a/tests/units/typing/test_to_json_schema.py b/tests/units/typing/test_to_json_schema.py new file mode 100644 index 00000000000..d9315ef0c11 --- /dev/null +++ b/tests/units/typing/test_to_json_schema.py @@ -0,0 +1,12 @@ +import pytest +from pydantic import schema_json_of + +from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor, TorchTensor + + +@pytest.mark.parametrize( + 'type_', [Tensor, Embedding, ImageUrl, AnyUrl, ID, TorchTensor] +) +def test_json(type_): + # this test verify that all of our type can be dumped to json + schema_json_of(type_) From 42c6fe48068a41fabc03cd3a885b37ed7debbca9 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 17 Nov 2022 18:01:36 +0100 Subject: [PATCH 08/26] feat(wip): add fastapi test Signed-off-by: Sami Jaghouar --- poetry.lock | 133 +++++++++++++++---- pyproject.toml | 4 +- tests/integrations/externals/test_fastapi.py | 30 +++++ 3 files changed, 140 insertions(+), 27 deletions(-) create mode 100644 tests/integrations/externals/test_fastapi.py diff --git a/poetry.lock b/poetry.lock index 890b85f40a2..3cf3d658f3e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -313,12 +313,50 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt [[package]] name = "h11" -version = "0.14.0" +version = "0.12.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" category = "dev" optional = false +python-versions = ">=3.6" + +[[package]] +name = "httpcore" +version = "0.15.0" +description = "A minimal low-level HTTP client." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +anyio = ">=3.0.0,<4.0.0" +certifi = "*" +h11 = ">=0.11,<0.13" +sniffio = ">=1.0.0,<2.0.0" + +[package.extras] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + +[[package]] +name = "httpx" +version = "0.23.0" +description = "The next generation HTTP client." +category = "dev" +optional = false python-versions = ">=3.7" +[package.dependencies] +certifi = "*" +httpcore = ">=0.15.0,<0.16.0" +rfc3986 = {version = ">=1.3,<2", extras = ["idna2008"]} +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<13)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + [[package]] name = "identify" version = "2.5.8" @@ -369,6 +407,14 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)"] testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] +[[package]] +name = "iniconfig" +version = "1.1.1" +description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "ipykernel" version = "6.17.1" @@ -657,14 +703,6 @@ category = "dev" optional = false python-versions = "*" -[[package]] -name = "more-itertools" -version = "9.0.0" -description = "More routines for operating on iterables, beyond itertools" -category = "dev" -optional = false -python-versions = ">=3.7" - [[package]] name = "mypy" version = "0.990" @@ -1142,26 +1180,39 @@ python-versions = ">=3.7" [[package]] name = "pytest" -version = "5.4.3" +version = "6.2.5" description = "pytest: simple powerful testing with Python" category = "dev" optional = false -python-versions = ">=3.5" +python-versions = ">=3.6" [package.dependencies] atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} -attrs = ">=17.4.0" +attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} -more-itertools = ">=4.0.0" +iniconfig = "*" packaging = "*" -pluggy = ">=0.12,<1.0" -py = ">=1.5.0" -wcwidth = "*" +pluggy = ">=0.12,<2.0" +py = ">=1.8.2" +toml = "*" [package.extras] -checkqa-mypy = ["mypy (==v0.761)"] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.20.2" +description = "Pytest support for asyncio" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +pytest = ">=6.1.0" + +[package.extras] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1235,6 +1286,20 @@ urllib3 = ">=1.21.1,<1.27" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rfc3986" +version = "1.5.0" +description = "Validating URI References per RFC 3986" +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +idna = {version = "*", optional = true, markers = "extra == \"idna2008\""} + +[package.extras] +idna2008 = ["idna"] + [[package]] name = "ruff" version = "0.0.117" @@ -1530,7 +1595,7 @@ torch = ["torch"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "4fa68f08685cb4a52ff58ec6b0dd0de63d6fea01a8159bbbd41b9c057bd38512" +content-hash = "135e9a70bbe29f64fddb9a7e2efae8405e34ba75188cc957b27268d47194dcea" [metadata.files] anyio = [ @@ -1757,8 +1822,16 @@ filelock = [ {file = "filelock-3.8.0.tar.gz", hash = "sha256:55447caa666f2198c5b6b13a26d2084d26fa5b115c00d065664b2124680c4edc"}, ] h11 = [ - {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, - {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, + {file = "h11-0.12.0-py3-none-any.whl", hash = "sha256:36a3cb8c0a032f56e2da7084577878a035d3b61d104230d4bd49c0c6b555a9c6"}, + {file = "h11-0.12.0.tar.gz", hash = "sha256:47222cb6067e4a307d535814917cd98fd0a57b6788ce715755fa2b6c28b56042"}, +] +httpcore = [ + {file = "httpcore-0.15.0-py3-none-any.whl", hash = "sha256:1105b8b73c025f23ff7c36468e4432226cbb959176eab66864b8e31c4ee27fa6"}, + {file = "httpcore-0.15.0.tar.gz", hash = "sha256:18b68ab86a3ccf3e7dc0f43598eaddcf472b602aba29f9aa6ab85fe2ada3980b"}, +] +httpx = [ + {file = "httpx-0.23.0-py3-none-any.whl", hash = "sha256:42974f577483e1e932c3cdc3cd2303e883cbfba17fe228b0f63589764d7b9c4b"}, + {file = "httpx-0.23.0.tar.gz", hash = "sha256:f28eac771ec9eb4866d3fb4ab65abd42d38c424739e80c08d8d20570de60b0ef"}, ] identify = [ {file = "identify-2.5.8-py2.py3-none-any.whl", hash = "sha256:48b7925fe122720088aeb7a6c34f17b27e706b72c61070f27fe3789094233440"}, @@ -1776,6 +1849,10 @@ importlib-resources = [ {file = "importlib_resources-5.10.0-py3-none-any.whl", hash = "sha256:ee17ec648f85480d523596ce49eae8ead87d5631ae1551f913c0100b5edd3437"}, {file = "importlib_resources-5.10.0.tar.gz", hash = "sha256:c01b1b94210d9849f286b86bb51bcea7cd56dde0600d8db721d7b81330711668"}, ] +iniconfig = [ + {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, + {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, +] ipykernel = [ {file = "ipykernel-6.17.1-py3-none-any.whl", hash = "sha256:3a9a1b2ad6dbbd5879855aabb4557f08e63fa2208bffed897f03070e2bb436f6"}, {file = "ipykernel-6.17.1.tar.gz", hash = "sha256:e178c1788399f93a459c241fe07c3b810771c607b1fb064a99d2c5d40c90c5d4"}, @@ -1882,10 +1959,6 @@ mistune = [ {file = "mistune-2.0.4-py2.py3-none-any.whl", hash = "sha256:182cc5ee6f8ed1b807de6b7bb50155df7b66495412836b9a74c8fbdfc75fe36d"}, {file = "mistune-2.0.4.tar.gz", hash = "sha256:9ee0a66053e2267aba772c71e06891fa8f1af6d4b01d5e84e267b4570d4d9808"}, ] -more-itertools = [ - {file = "more-itertools-9.0.0.tar.gz", hash = "sha256:5a6257e40878ef0520b1803990e3e22303a41b5714006c32a3fd8304b26ea1ab"}, - {file = "more_itertools-9.0.0-py3-none-any.whl", hash = "sha256:250e83d7e81d0c87ca6bd942e6aeab8cc9daa6096d12c5308f3f92fa5e5c1f41"}, -] mypy = [ {file = "mypy-0.990-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:aaf1be63e0207d7d17be942dcf9a6b641745581fe6c64df9a38deb562a7dbafa"}, {file = "mypy-0.990-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d555aa7f44cecb7ea3c0ac69d58b1a5afb92caa017285a8e9c4efbf0518b61b4"}, @@ -2168,8 +2241,12 @@ pyrsistent = [ {file = "pyrsistent-0.19.2.tar.gz", hash = "sha256:bfa0351be89c9fcbcb8c9879b826f4353be10f58f8a677efab0c017bf7137ec2"}, ] pytest = [ - {file = "pytest-5.4.3-py3-none-any.whl", hash = "sha256:5c0db86b698e8f170ba4582a492248919255fcd4c79b1ee64ace34301fb589a1"}, - {file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"}, + {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, + {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, +] +pytest-asyncio = [ + {file = "pytest-asyncio-0.20.2.tar.gz", hash = "sha256:32a87a9836298a881c0ec637ebcc952cfe23a56436bdc0d09d1511941dd8a812"}, + {file = "pytest_asyncio-0.20.2-py3-none-any.whl", hash = "sha256:07e0abf9e6e6b95894a39f688a4a875d63c2128f76c02d03d16ccbc35bcc0f8a"}, ] python-dateutil = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, @@ -2325,6 +2402,10 @@ requests = [ {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, ] +rfc3986 = [ + {file = "rfc3986-1.5.0-py2.py3-none-any.whl", hash = "sha256:a86d6e1f5b1dc238b218b012df0aa79409667bb209e58da56d0b94704e712a97"}, + {file = "rfc3986-1.5.0.tar.gz", hash = "sha256:270aaf10d87d0d4e095063c65bf3ddbc6ee3d0b226328ce21e036f946e421835"}, +] ruff = [ {file = "ruff-0.0.117-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:cb274e0447e91a1a7844b85cd0ef243d32732a4772140129f33bbc8891ca8577"}, {file = "ruff-0.0.117-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:a670ccea678f1ddd9f9cac1a5681be8bf3b33010e56058d16b270dfb9b893b95"}, diff --git a/pyproject.toml b/pyproject.toml index 94093f0a312..c61f6b460e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ common = ["protobuf"] torch = ["torch"] [tool.poetry.dev-dependencies] -pytest = "^5.2" +pytest = "^6.1" pre-commit = "^2.20.0" jupyterlab = "^3.5.0" mypy = "^0.990" @@ -29,6 +29,8 @@ ruff = "^0.0.117" [tool.poetry.group.dev.dependencies] fastapi = "^0.87.0" uvicorn = "^0.19.0" +httpx = "^0.23.0" +pytest-asyncio = "^0.20.2" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py new file mode 100644 index 00000000000..a1b16164452 --- /dev/null +++ b/tests/integrations/externals/test_fastapi.py @@ -0,0 +1,30 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from docarray import Document, Image, Text + + +class Mmdoc(Document): + img: Image + text: Text + title: str + + +@pytest.mark.asyncio +async def test_fast_api(): + + app = FastAPI() + + @app.post("/doc/") + async def create_item(doc: Mmdoc): + return doc + + async with AsyncClient(app=app, base_url="http://test") as ac: + # response = await ac.get("/doc") + response2 = await ac.get("/docs") + + # assert response.status_code == 200 + assert response2.status_code == 200 + + # assert response.json() == {"message": "Tomato"} From 56497b9d537a52e190a5cefde76c7459951fe6e9 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 18 Nov 2022 15:21:12 +0100 Subject: [PATCH 09/26] feat: add json dump for type and document Signed-off-by: Sami Jaghouar --- docarray/document/document.py | 17 +++++++++++++++++ docarray/typing/tensor/tensor.py | 9 ++++++++- docarray/typing/tensor/torch_tensor.py | 7 +++++++ tests/units/document/test_to_json.py | 21 +++++++++++++++++++++ tests/units/typing/conftest.py | 20 ++++++++++++++++++++ tests/units/typing/test_embedding.py | 17 ++++++++++++++--- tests/units/typing/test_id.py | 11 +++++++++++ tests/units/typing/test_tensor.py | 17 ++++++++++++++--- tests/units/typing/test_to_json_schema.py | 12 ------------ tests/units/typing/url/test_any_url.py | 17 ++++++++++++++--- tests/units/typing/url/test_image_url.py | 13 ++++++++++++- 11 files changed, 138 insertions(+), 23 deletions(-) create mode 100644 tests/units/document/test_to_json.py create mode 100644 tests/units/typing/conftest.py delete mode 100644 tests/units/typing/test_to_json_schema.py diff --git a/docarray/document/document.py b/docarray/document/document.py index f0eb363a624..57348103d19 100644 --- a/docarray/document/document.py +++ b/docarray/document/document.py @@ -1,4 +1,5 @@ import os +from json import JSONEncoder from typing import Type from pydantic import BaseModel, Field @@ -10,6 +11,19 @@ from .mixins import ProtoMixin +class _DocumentJsonEncoder(JSONEncoder): + """ + This is a custom JSONEncoder that will call the + _to_json_compatible method of type. This Encoder will be + used when calling doc.json() + """ + + def default(self, obj): + if hasattr(obj, '_to_json_compatible'): + return obj._to_json_compatible() + return JSONEncoder.default(self, obj) + + class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): """ The base class for Document @@ -17,6 +31,9 @@ class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): id: ID = Field(default_factory=lambda: ID.validate(os.urandom(16).hex())) + class Config: + json_loads = _DocumentJsonEncoder + @classmethod def _get_nested_document_class(cls, field: str) -> Type['BaseDocument']: """ diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index df72651e8e6..1f0464b46c7 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -45,8 +45,15 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # this is needed to dump to json field_schema.update(type='string', format='uuidhello') + def _to_json_compatible(self): + """ + Convert tensor into a json compatible object + :return: a list representation of the tensor + """ + return self.tolist() + def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto: - """Convert Document into a NodeProto protobuf message. This function should + """Convert itself into a NodeProto protobuf message. This function should be called when the Document is nested into another Document that need to be converted into a protobuf :param field: field in which to store the content in the node proto diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index f9ea4ae0cbd..e5c5795a0e3 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -56,6 +56,13 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # this is needed to dump to json field_schema.update(type='string', format='uuidhello') + def _to_json_compatible(self): + """ + Convert tensor into a json compatible object + :return: a list representation of the tensor + """ + return self.tolist() + @classmethod def from_native_torch_tensor(cls: Type[T], value: torch.Tensor) -> T: """Create a TorchTensor from a native torch.Tensor diff --git a/tests/units/document/test_to_json.py b/tests/units/document/test_to_json.py new file mode 100644 index 00000000000..d0afd742bb9 --- /dev/null +++ b/tests/units/document/test_to_json.py @@ -0,0 +1,21 @@ +import numpy as np +import torch + +from docarray.document import BaseDocument +from docarray.typing import AnyUrl, Tensor, TorchTensor + + +def test_to_json(): + class Mmdoc(BaseDocument): + img: Tensor + url: AnyUrl + txt: str + torch_tensor: TorchTensor + + doc = Mmdoc( + img=np.zeros((3, 224, 224)), + url='http://doccaray.io', + txt='hello', + torch_tensor=torch.zeros(3, 224, 224), + ) + doc.json() diff --git a/tests/units/typing/conftest.py b/tests/units/typing/conftest.py new file mode 100644 index 00000000000..01dfb51a998 --- /dev/null +++ b/tests/units/typing/conftest.py @@ -0,0 +1,20 @@ +from json import JSONEncoder + +import pytest + + +@pytest.fixture(scope='module') +def json_encoder(): + class TestJsonEncoder(JSONEncoder): + """ + This is a custom JSONEncoder that will call the + _to_json_compatible method of type. This Encoder will be + used when calling doc.json() + """ + + def default(self, obj): + if hasattr(obj, '_to_json_compatible'): + return obj._to_json_compatible() + return JSONEncoder.default(self, obj) + + return TestJsonEncoder diff --git a/tests/units/typing/test_embedding.py b/tests/units/typing/test_embedding.py index 61d31abe079..3ea188d4801 100644 --- a/tests/units/typing/test_embedding.py +++ b/tests/units/typing/test_embedding.py @@ -1,11 +1,22 @@ +import json + import numpy as np -from pydantic.tools import parse_obj_as +from pydantic.tools import parse_obj_as, schema_json_of from docarray.typing import Embedding def test_proto_embedding(): - uri = parse_obj_as(Embedding, np.zeros((3, 224, 224))) + embedding = parse_obj_as(Embedding, np.zeros((3, 224, 224))) + + embedding._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(Embedding) + - uri._to_node_protobuf() +def test_dump_json(json_encoder): + tensor = parse_obj_as(Embedding, np.zeros((3, 224, 224))) + json.dumps(tensor, cls=json_encoder) diff --git a/tests/units/typing/test_id.py b/tests/units/typing/test_id.py index 5c3476bc82a..1f13e66e09a 100644 --- a/tests/units/typing/test_id.py +++ b/tests/units/typing/test_id.py @@ -1,6 +1,8 @@ +import json from uuid import UUID import pytest +from pydantic import schema_json_of from pydantic.tools import parse_obj_as from docarray.typing import ID @@ -14,3 +16,12 @@ def test_id_validation(id): parsed_id = parse_obj_as(ID, id) assert parsed_id == str(id) + + +def test_json_schema(): + schema_json_of(ID) + + +def test_dump_json(json_encoder): + id = parse_obj_as(ID, 1234) + json.dumps(id, cls=json_encoder) diff --git a/tests/units/typing/test_tensor.py b/tests/units/typing/test_tensor.py index 9bace9e8841..5d5b59e9306 100644 --- a/tests/units/typing/test_tensor.py +++ b/tests/units/typing/test_tensor.py @@ -1,11 +1,22 @@ +import json + import numpy as np -from pydantic.tools import parse_obj_as +from pydantic.tools import parse_obj_as, schema_json_of from docarray.typing import Tensor def test_proto_tensor(): - uri = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + + tensor._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(Tensor) + - uri._to_node_protobuf() +def test_dump_json(json_encoder): + tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + json.dumps(tensor, cls=json_encoder) diff --git a/tests/units/typing/test_to_json_schema.py b/tests/units/typing/test_to_json_schema.py deleted file mode 100644 index d9315ef0c11..00000000000 --- a/tests/units/typing/test_to_json_schema.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest -from pydantic import schema_json_of - -from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor, TorchTensor - - -@pytest.mark.parametrize( - 'type_', [Tensor, Embedding, ImageUrl, AnyUrl, ID, TorchTensor] -) -def test_json(type_): - # this test verify that all of our type can be dumped to json - schema_json_of(type_) diff --git a/tests/units/typing/url/test_any_url.py b/tests/units/typing/url/test_any_url.py index ad593b58519..38f85f6c2f7 100644 --- a/tests/units/typing/url/test_any_url.py +++ b/tests/units/typing/url/test_any_url.py @@ -1,10 +1,21 @@ -from pydantic.tools import parse_obj_as +import json -from docarray.typing import ImageUrl +from pydantic.tools import parse_obj_as, schema_json_of + +from docarray.typing import AnyUrl def test_proto_any_url(): - uri = parse_obj_as(ImageUrl, 'http://jina.ai/img.png') + uri = parse_obj_as(AnyUrl, 'http://jina.ai/img.png') uri._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(AnyUrl) + + +def test_dump_json(json_encoder): + url = parse_obj_as(AnyUrl, 'http://jina.ai/img.png') + json.dumps(url, cls=json_encoder) diff --git a/tests/units/typing/url/test_image_url.py b/tests/units/typing/url/test_image_url.py index 37fcf525d23..e7f3609c4c1 100644 --- a/tests/units/typing/url/test_image_url.py +++ b/tests/units/typing/url/test_image_url.py @@ -1,5 +1,7 @@ +import json + import numpy as np -from pydantic.tools import parse_obj_as +from pydantic.tools import parse_obj_as, schema_json_of from docarray.typing import ImageUrl @@ -17,3 +19,12 @@ def test_proto_image_url(): uri = parse_obj_as(ImageUrl, 'http://jina.ai/img.png') uri._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(ImageUrl) + + +def test_dump_json(json_encoder): + url = parse_obj_as(ImageUrl, 'http://jina.ai/img.png') + json.dumps(url, cls=json_encoder) From 945a72ae27c6bca3f7f0e49de09675134b9ca4f8 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 21 Nov 2022 20:09:57 +0100 Subject: [PATCH 10/26] feat: add json compatible with orjson Signed-off-by: Sami Jaghouar --- docarray/document/document.py | 27 +++++++---- docarray/typing/tensor/tensor.py | 28 ++++++++++-- docarray/typing/tensor/torch_tensor.py | 29 ++++++++++++ poetry.lock | 61 ++++++++++++++++++++++++- pyproject.toml | 1 + tests/units/typing/test_tensor.py | 10 ++++ tests/units/typing/test_torch_tensor.py | 31 +++++++++++++ 7 files changed, 173 insertions(+), 14 deletions(-) create mode 100644 tests/units/typing/test_torch_tensor.py diff --git a/docarray/document/document.py b/docarray/document/document.py index 57348103d19..adf2f865b0a 100644 --- a/docarray/document/document.py +++ b/docarray/document/document.py @@ -1,7 +1,7 @@ import os -from json import JSONEncoder from typing import Type +import orjson from pydantic import BaseModel, Field from docarray.document.abstract_document import AbstractDocument @@ -11,17 +11,23 @@ from .mixins import ProtoMixin -class _DocumentJsonEncoder(JSONEncoder): +def _default_orjson(obj): """ - This is a custom JSONEncoder that will call the - _to_json_compatible method of type. This Encoder will be - used when calling doc.json() + default option for orjson dumps. It will call _to_json_compatible + from docarray typing object that expose such method. + :param obj: + :return: return a json compatible object """ - def default(self, obj): - if hasattr(obj, '_to_json_compatible'): - return obj._to_json_compatible() - return JSONEncoder.default(self, obj) + if getattr(obj, '_to_json_compatible'): + return obj._to_json_compatible() + + +def _orjson_dumps(v, *, default): + # orjson.dumps returns bytes, to match standard json.dumps we need to decode + return orjson.dumps( + v, default=_default_orjson, option=orjson.OPT_SERIALIZE_NUMPY + ).decode() class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): @@ -32,7 +38,8 @@ class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): id: ID = Field(default_factory=lambda: ID.validate(os.urandom(16).hex())) class Config: - json_loads = _DocumentJsonEncoder + json_loads = orjson.loads + json_dumps = _orjson_dumps @classmethod def _get_nested_document_class(cls, field: str) -> Type['BaseDocument']: diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index 1f0464b46c7..8267a132b58 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -22,7 +22,10 @@ def __get_validators__(cls): @classmethod def validate( - cls: Type[T], value: Union[T, Any], field: 'ModelField', config: 'BaseConfig' + cls: Type[T], + value: Union[T, np.ndarray, Any], + field: 'ModelField', + config: 'BaseConfig', ) -> T: if isinstance(value, np.ndarray): return cls.from_ndarray(value) @@ -45,12 +48,31 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # this is needed to dump to json field_schema.update(type='string', format='uuidhello') - def _to_json_compatible(self): + def _to_json_compatible(self) -> np.ndarray: """ Convert tensor into a json compatible object :return: a list representation of the tensor """ - return self.tolist() + return self.unwrap() + + def unwrap(self) -> np.ndarray: + """ + Return the original ndarray without any memory copy + + EXAMPLE USAGE + .. code-block:: python + from docarray.typing import Tensor + import numpy as np + + t = Tensor.validate(np.zeros((3, 224, 224)), None, None) + # here t is a docarray Tensor + t = t.unwrap() + # here t is a pure np.ndarray + + + :return: a numpy ndarray + """ + return self.view(np.ndarray) def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto: """Convert itself into a NodeProto protobuf message. This function should diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index e5c5795a0e3..4a6f55e4862 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -63,6 +63,35 @@ def _to_json_compatible(self): """ return self.tolist() + def unwrap(self) -> torch.Tensor: + """ + Return the original ndarray without any memory copy + + EXAMPLE USAGE + .. code-block:: python + from docarray.typing import TorchTensor + import torch + + t = Tensor.validate(torch.zeros(3, 224, 224), None, None) + # here t is a docarray Tensor + t = t.unwrap() + # here t is a pure torch Tensor + + + :return: a torch Tensor + """ + ## might need to check device later + value = torch.tensor(self) + value.__class__ = torch.Tensor + return value + + def _to_json_compatible(self) -> np.ndarray: + """ + Convert torch Tensor into a json compatible object + :return: a list representation of the torch tensor + """ + return self.numpy() ## might need to check device later + @classmethod def from_native_torch_tensor(cls: Type[T], value: torch.Tensor) -> T: """Create a TorchTensor from a native torch.Tensor diff --git a/poetry.lock b/poetry.lock index 3cf3d658f3e..238498fbd46 100644 --- a/poetry.lock +++ b/poetry.lock @@ -952,6 +952,14 @@ python-versions = ">=3" setuptools = "*" wheel = "*" +[[package]] +name = "orjson" +version = "3.8.2" +description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +category = "main" +optional = false +python-versions = ">=3.7" + [[package]] name = "packaging" version = "21.3" @@ -1595,7 +1603,7 @@ torch = ["torch"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "135e9a70bbe29f64fddb9a7e2efae8405e34ba75188cc957b27268d47194dcea" +content-hash = "70d89905c1a4f89de69be2bd25f7ef20ad9fe391920593a1b79df733e3b6c941" [metadata.files] anyio = [ @@ -2074,6 +2082,57 @@ nvidia-cudnn-cu11 = [ {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"}, {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"}, ] +orjson = [ + {file = "orjson-3.8.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:43e69b360c2851b45c7dbab3b95f7fa8469df73fab325a683f7389c4db63aa71"}, + {file = "orjson-3.8.2-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:64c5da5c9679ef3d85e9bbcbb62f4ccdc1f1975780caa20f2ec1e37b4da6bd36"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c632a2157fa9ec098d655287e9e44809615af99837c49f53d96bfbca453c5bd"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f63da6309c282a2b58d4a846f0717f6440356b4872838b9871dc843ed1fe2b38"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c9be25c313ba2d5478829d949165445c3bd36c62e07092b4ba8dbe5426574d1"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:4bcce53e9e088f82633f784f79551fcd7637943ab56c51654aaf9d4c1d5cfa54"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:33edb5379c6e6337f9383c85fe4080ce3aa1057cc2ce29345b7239461f50cbd6"}, + {file = "orjson-3.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:da35d347115758bbc8bfaf39bb213c42000f2a54e3f504c84374041d20835cd6"}, + {file = "orjson-3.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d755d94a90a941b91b4d39a6b02e289d8ba358af2d1a911edf266be7942609dc"}, + {file = "orjson-3.8.2-cp310-none-win_amd64.whl", hash = "sha256:7ea96923e26390b2142602ebb030e2a4db9351134696e0b219e5106bddf9b48e"}, + {file = "orjson-3.8.2-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:a0d89de876e6f1cef917a2338378a60a98584e1c2e1c67781e20b6ed1c512478"}, + {file = "orjson-3.8.2-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:8d47e7592fe938aec898eb22ea4946298c018133df084bc78442ff18e2c6347c"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3d9f1043f618d0c64228aab9711e5bd822253c50b6c56223951e32b51f81d62"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed10600e8b08f1e87b656ad38ab316191ce94f2c9adec57035680c0dc9e93c81"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99c49e49a04bf61fee7aaea6d92ac2b1fcf6507aea894bbdf3fbb25fe792168c"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:1463674f8efe6984902473d7b5ce3edf444c1fcd09dc8aa4779638a28fb9ca01"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c1ef75f1d021d817e5c60a42da0b4b7e3123b1b37415260b8415666ddacc7cd7"}, + {file = "orjson-3.8.2-cp311-none-win_amd64.whl", hash = "sha256:b6007e1ac8564b13b2521720929e8bb3ccd3293d9fdf38f28728dcc06db6248f"}, + {file = "orjson-3.8.2-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:a02c13ae523221576b001071354380e277346722cc6b7fdaacb0fd6db5154b3e"}, + {file = "orjson-3.8.2-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fa2e565cf8ffdb37ce1887bd1592709ada7f701e61aa4b1e710be94b0aecbab4"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1d8864288f7c5fccc07b43394f83b721ddc999f25dccfb5d0651671a76023f5"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1874c05d0bb994601fa2d51605cb910d09343c6ebd36e84a573293523fab772a"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:349387ed6989e5db22e08c9af8d7ca14240803edc50de451d48d41a0e7be30f6"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:4e42b19619d6e97e201053b865ca4e62a48da71165f4081508ada8e1b91c6a30"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:bc112c17e607c59d1501e72afb44226fa53d947d364aed053f0c82d153e29616"}, + {file = "orjson-3.8.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6fda669211f2ed1fc2c8130187ec90c96b4f77b6a250004e666d2ef8ed524e5f"}, + {file = "orjson-3.8.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:aebd4e80fea0f20578fd0452908b9206a6a0d5ae9f5c99b6e665bbcd989e56cd"}, + {file = "orjson-3.8.2-cp37-none-win_amd64.whl", hash = "sha256:9f3cd0394eb6d265beb2a1572b5663bc910883ddbb5cdfbcb660f5a0444e7fd8"}, + {file = "orjson-3.8.2-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:74e7d54d11b3da42558d69a23bf92c2c48fabf69b38432d5eee2c5b09cd4c433"}, + {file = "orjson-3.8.2-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:8cbadc9be748a823f9c743c7631b1ee95d3925a9c0b21de4e862a1d57daa10ec"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07d5a8c69a2947d9554a00302734fe3d8516415c8b280963c92bc1033477890"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b364ea01d1b71b9f97bf97af9eb79ebee892df302e127a9e2e4f8eaa74d6b98"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b98a8c825a59db94fbe8e0cce48618624c5a6fb1436467322d90667c08a0bf80"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:ab63103f60b516c0fce9b62cb4773f689a82ab56e19ef2387b5a3182f80c0d78"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:73ab3f4288389381ae33ab99f914423b69570c88d626d686764634d5e0eeb909"}, + {file = "orjson-3.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2ab3fd8728e12c36e20c6d9d70c9e15033374682ce5acb6ed6a08a80dacd254d"}, + {file = "orjson-3.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cde11822cf71a7f0daaa84223249b2696a2b6cda7fa587e9fd762dff1a8848e4"}, + {file = "orjson-3.8.2-cp38-none-win_amd64.whl", hash = "sha256:b14765ea5aabfeab1a194abfaa0be62c9fee6480a75ac8c6974b4eeede3340b4"}, + {file = "orjson-3.8.2-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:6068a27d59d989d4f2864c2fc3440eb7126a0cfdfaf8a4ad136b0ffd932026ae"}, + {file = "orjson-3.8.2-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:6bf36fa759a1b941fc552ad76b2d7fb10c1d2a20c056be291ea45eb6ae1da09b"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f436132e62e647880ca6988974c8e3165a091cb75cbed6c6fd93e931630c22fa"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3ecd8936259a5920b52a99faf62d4efeb9f5e25a0aacf0cce1e9fa7c37af154f"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c13114b345cda33644f64e92fe5d8737828766cf02fbbc7d28271a95ea546832"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6e43cdc3ddf96bdb751b748b1984b701125abacca8fc2226b808d203916e8cba"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ee39071da2026b11e4352d6fc3608a7b27ee14bc699fd240f4e604770bc7a255"}, + {file = "orjson-3.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1c3833976ebbeb3b5b6298cb22e23bf18453f6b80802103b7d08f7dd8a61611d"}, + {file = "orjson-3.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b9a34519d3d70935e1cd3797fbed8fbb6f61025182bea0140ca84d95b6f8fbe5"}, + {file = "orjson-3.8.2-cp39-none-win_amd64.whl", hash = "sha256:2734086d9a3dd9591c4be7d05aff9beccc086796d3f243685e56b7973ebac5bc"}, + {file = "orjson-3.8.2.tar.gz", hash = "sha256:a2fb95a45031ccf278e44341027b3035ab99caa32aa173279b1f0a06324f434b"}, +] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, diff --git a/pyproject.toml b/pyproject.toml index c61f6b460e8..a943bf82c79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ pydantic = "^1.10.2" numpy = "^1.23.4" protobuf = { version = "^4.21.9", optional = true } torch = { version = "^1.0.0", optional = true } +orjson = "^3.8.2" [tool.poetry.extras] common = ["protobuf"] diff --git a/tests/units/typing/test_tensor.py b/tests/units/typing/test_tensor.py index 5d5b59e9306..899b6f97bef 100644 --- a/tests/units/typing/test_tensor.py +++ b/tests/units/typing/test_tensor.py @@ -20,3 +20,13 @@ def test_json_schema(): def test_dump_json(json_encoder): tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) json.dumps(tensor, cls=json_encoder) + + +def test_unwrap(): + tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + ndarray = tensor.unwrap() + + assert not isinstance(ndarray, Tensor) + assert isinstance(ndarray, np.ndarray) + assert isinstance(tensor, Tensor) + assert (ndarray == np.zeros((3, 224, 224))).all() diff --git a/tests/units/typing/test_torch_tensor.py b/tests/units/typing/test_torch_tensor.py new file mode 100644 index 00000000000..07c87038416 --- /dev/null +++ b/tests/units/typing/test_torch_tensor.py @@ -0,0 +1,31 @@ +import torch +from pydantic.tools import parse_obj_as, schema_json_of + +from docarray.typing import TorchTensor + + +def test_proto_tensor(): + + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + + tensor._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(TorchTensor) + + +# def test_dump_json(json_encoder): +# tensor = parse_obj_as(Tensor, torch.zeros(3, 224, 224)) +# json.dumps(tensor, cls=json_encoder) + + +def test_unwrap(): + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + ndarray = tensor.unwrap() + + assert not isinstance(ndarray, TorchTensor) + assert isinstance(tensor, TorchTensor) + assert isinstance(ndarray, torch.Tensor) + + assert (ndarray == torch.zeros(3, 224, 224)).all() From 83236cd9bc7f553dde4c5de010405519c4bf4777 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 21 Nov 2022 20:29:53 +0100 Subject: [PATCH 11/26] refactor: clean tests Signed-off-by: Sami Jaghouar --- docarray/document/document.py | 25 +++--------------------- docarray/document/io/__init__.py | 0 docarray/document/io/json.py | 20 +++++++++++++++++++ tests/units/typing/conftest.py | 20 ------------------- tests/units/typing/test_embedding.py | 7 +++---- tests/units/typing/test_id.py | 6 +++--- tests/units/typing/test_tensor.py | 7 +++---- tests/units/typing/test_torch_tensor.py | 7 ++++--- tests/units/typing/url/test_any_url.py | 7 +++---- tests/units/typing/url/test_image_url.py | 7 +++---- 10 files changed, 42 insertions(+), 64 deletions(-) create mode 100644 docarray/document/io/__init__.py create mode 100644 docarray/document/io/json.py delete mode 100644 tests/units/typing/conftest.py diff --git a/docarray/document/document.py b/docarray/document/document.py index adf2f865b0a..17497414429 100644 --- a/docarray/document/document.py +++ b/docarray/document/document.py @@ -6,29 +6,10 @@ from docarray.document.abstract_document import AbstractDocument from docarray.document.base_node import BaseNode +from docarray.document.io.json import orjson_dumps +from docarray.document.mixins import ProtoMixin from docarray.typing import ID -from .mixins import ProtoMixin - - -def _default_orjson(obj): - """ - default option for orjson dumps. It will call _to_json_compatible - from docarray typing object that expose such method. - :param obj: - :return: return a json compatible object - """ - - if getattr(obj, '_to_json_compatible'): - return obj._to_json_compatible() - - -def _orjson_dumps(v, *, default): - # orjson.dumps returns bytes, to match standard json.dumps we need to decode - return orjson.dumps( - v, default=_default_orjson, option=orjson.OPT_SERIALIZE_NUMPY - ).decode() - class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): """ @@ -39,7 +20,7 @@ class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): class Config: json_loads = orjson.loads - json_dumps = _orjson_dumps + json_dumps = orjson_dumps @classmethod def _get_nested_document_class(cls, field: str) -> Type['BaseDocument']: diff --git a/docarray/document/io/__init__.py b/docarray/document/io/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/document/io/json.py b/docarray/document/io/json.py new file mode 100644 index 00000000000..95961189759 --- /dev/null +++ b/docarray/document/io/json.py @@ -0,0 +1,20 @@ +import orjson + + +def _default_orjson(obj): + """ + default option for orjson dumps. It will call _to_json_compatible + from docarray typing object that expose such method. + :param obj: + :return: return a json compatible object + """ + + if getattr(obj, '_to_json_compatible'): + return obj._to_json_compatible() + + +def orjson_dumps(v, *, default=None): + # orjson.dumps returns bytes, to match standard json.dumps we need to decode + return orjson.dumps( + v, default=_default_orjson, option=orjson.OPT_SERIALIZE_NUMPY + ).decode() diff --git a/tests/units/typing/conftest.py b/tests/units/typing/conftest.py deleted file mode 100644 index 01dfb51a998..00000000000 --- a/tests/units/typing/conftest.py +++ /dev/null @@ -1,20 +0,0 @@ -from json import JSONEncoder - -import pytest - - -@pytest.fixture(scope='module') -def json_encoder(): - class TestJsonEncoder(JSONEncoder): - """ - This is a custom JSONEncoder that will call the - _to_json_compatible method of type. This Encoder will be - used when calling doc.json() - """ - - def default(self, obj): - if hasattr(obj, '_to_json_compatible'): - return obj._to_json_compatible() - return JSONEncoder.default(self, obj) - - return TestJsonEncoder diff --git a/tests/units/typing/test_embedding.py b/tests/units/typing/test_embedding.py index 3ea188d4801..57334dd49eb 100644 --- a/tests/units/typing/test_embedding.py +++ b/tests/units/typing/test_embedding.py @@ -1,8 +1,7 @@ -import json - import numpy as np from pydantic.tools import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import Embedding @@ -17,6 +16,6 @@ def test_json_schema(): schema_json_of(Embedding) -def test_dump_json(json_encoder): +def test_dump_json(): tensor = parse_obj_as(Embedding, np.zeros((3, 224, 224))) - json.dumps(tensor, cls=json_encoder) + orjson_dumps(tensor) diff --git a/tests/units/typing/test_id.py b/tests/units/typing/test_id.py index 1f13e66e09a..5919b351209 100644 --- a/tests/units/typing/test_id.py +++ b/tests/units/typing/test_id.py @@ -1,10 +1,10 @@ -import json from uuid import UUID import pytest from pydantic import schema_json_of from pydantic.tools import parse_obj_as +from docarray.document.io.json import orjson_dumps from docarray.typing import ID @@ -22,6 +22,6 @@ def test_json_schema(): schema_json_of(ID) -def test_dump_json(json_encoder): +def test_dump_json(): id = parse_obj_as(ID, 1234) - json.dumps(id, cls=json_encoder) + orjson_dumps(id) diff --git a/tests/units/typing/test_tensor.py b/tests/units/typing/test_tensor.py index 899b6f97bef..32086f0e402 100644 --- a/tests/units/typing/test_tensor.py +++ b/tests/units/typing/test_tensor.py @@ -1,8 +1,7 @@ -import json - import numpy as np from pydantic.tools import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import Tensor @@ -17,9 +16,9 @@ def test_json_schema(): schema_json_of(Tensor) -def test_dump_json(json_encoder): +def test_dump_json(): tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) - json.dumps(tensor, cls=json_encoder) + orjson_dumps(tensor) def test_unwrap(): diff --git a/tests/units/typing/test_torch_tensor.py b/tests/units/typing/test_torch_tensor.py index 07c87038416..addd7297203 100644 --- a/tests/units/typing/test_torch_tensor.py +++ b/tests/units/typing/test_torch_tensor.py @@ -1,6 +1,7 @@ import torch from pydantic.tools import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import TorchTensor @@ -15,9 +16,9 @@ def test_json_schema(): schema_json_of(TorchTensor) -# def test_dump_json(json_encoder): -# tensor = parse_obj_as(Tensor, torch.zeros(3, 224, 224)) -# json.dumps(tensor, cls=json_encoder) +def test_dump_json(): + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + orjson_dumps(tensor) def test_unwrap(): diff --git a/tests/units/typing/url/test_any_url.py b/tests/units/typing/url/test_any_url.py index 38f85f6c2f7..1cd4988437d 100644 --- a/tests/units/typing/url/test_any_url.py +++ b/tests/units/typing/url/test_any_url.py @@ -1,7 +1,6 @@ -import json - from pydantic.tools import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import AnyUrl @@ -16,6 +15,6 @@ def test_json_schema(): schema_json_of(AnyUrl) -def test_dump_json(json_encoder): +def test_dump_json(): url = parse_obj_as(AnyUrl, 'http://jina.ai/img.png') - json.dumps(url, cls=json_encoder) + orjson_dumps(url) diff --git a/tests/units/typing/url/test_image_url.py b/tests/units/typing/url/test_image_url.py index e7f3609c4c1..07207bb31f4 100644 --- a/tests/units/typing/url/test_image_url.py +++ b/tests/units/typing/url/test_image_url.py @@ -1,8 +1,7 @@ -import json - import numpy as np from pydantic.tools import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import ImageUrl @@ -25,6 +24,6 @@ def test_json_schema(): schema_json_of(ImageUrl) -def test_dump_json(json_encoder): +def test_dump_json(): url = parse_obj_as(ImageUrl, 'http://jina.ai/img.png') - json.dumps(url, cls=json_encoder) + orjson_dumps(url) From f1a4d7bf005e39a95431de356871cbe83615a487 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 21 Nov 2022 20:31:18 +0100 Subject: [PATCH 12/26] fix: remove duplicate Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/torch_tensor.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 4a6f55e4862..389928e6514 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -56,12 +56,12 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # this is needed to dump to json field_schema.update(type='string', format='uuidhello') - def _to_json_compatible(self): + def _to_json_compatible(self) -> np.ndarray: """ - Convert tensor into a json compatible object - :return: a list representation of the tensor + Convert torch Tensor into a json compatible object + :return: a list representation of the torch tensor """ - return self.tolist() + return self.numpy() ## might need to check device later def unwrap(self) -> torch.Tensor: """ @@ -85,13 +85,6 @@ def unwrap(self) -> torch.Tensor: value.__class__ = torch.Tensor return value - def _to_json_compatible(self) -> np.ndarray: - """ - Convert torch Tensor into a json compatible object - :return: a list representation of the torch tensor - """ - return self.numpy() ## might need to check device later - @classmethod def from_native_torch_tensor(cls: Type[T], value: torch.Tensor) -> T: """Create a TorchTensor from a native torch.Tensor From c983246ce4002ff36542ae14b0b80b4317ca8d6e Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 21 Nov 2022 21:27:01 +0100 Subject: [PATCH 13/26] fix: better json schema for tensor Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/tensor.py | 2 +- docarray/typing/tensor/torch_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index 8267a132b58..3e997fe6577 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -46,7 +46,7 @@ def from_ndarray(cls: Type[T], value: np.ndarray) -> T: @classmethod def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # this is needed to dump to json - field_schema.update(type='string', format='uuidhello') + field_schema.update(type='string', format='tensor') def _to_json_compatible(self) -> np.ndarray: """ diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 389928e6514..56ba375e88a 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -54,7 +54,7 @@ def validate( @classmethod def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # this is needed to dump to json - field_schema.update(type='string', format='uuidhello') + field_schema.update(type='string', format='tensor') def _to_json_compatible(self) -> np.ndarray: """ From 3873a9ac1b858a442a2aca242f01c7d27f82a574 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 09:00:42 +0100 Subject: [PATCH 14/26] fix: fix fast api test Signed-off-by: Sami Jaghouar --- tests/integrations/externals/test_fastapi.py | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index a1b16164452..f31641608d7 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -5,14 +5,14 @@ from docarray import Document, Image, Text -class Mmdoc(Document): - img: Image - text: Text - title: str - - @pytest.mark.asyncio async def test_fast_api(): + class Mmdoc(Document): + img: Image + text: Text + title: str + + input_doc = Mmdoc(img=Image(), text=Text(), title='hello') app = FastAPI() @@ -21,10 +21,10 @@ async def create_item(doc: Mmdoc): return doc async with AsyncClient(app=app, base_url="http://test") as ac: - # response = await ac.get("/doc") - response2 = await ac.get("/docs") - - # assert response.status_code == 200 - assert response2.status_code == 200 + response = await ac.post("/doc/", data=input_doc.json()) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") - # assert response.json() == {"message": "Tomato"} + assert response.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 From 51a402d1cfbcfc63e812a50b038db2e7c9fa6b2c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 09:04:37 +0100 Subject: [PATCH 15/26] refactor: move to json test to integration Signed-off-by: Sami Jaghouar --- tests/{units => integrations}/document/test_to_json.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{units => integrations}/document/test_to_json.py (100%) diff --git a/tests/units/document/test_to_json.py b/tests/integrations/document/test_to_json.py similarity index 100% rename from tests/units/document/test_to_json.py rename to tests/integrations/document/test_to_json.py From 2e92866402a15bb1d5c4355ee94a3fa931f3a13a Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 14:14:03 +0100 Subject: [PATCH 16/26] fix: json laod from tensor type now working Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/tensor.py | 12 ++++++++--- docarray/typing/tensor/torch_tensor.py | 2 +- tests/integrations/document/test_to_json.py | 22 +++++++++++++++++++++ tests/units/typing/test_tensor.py | 18 +++++++++++++++++ 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index 3e997fe6577..969af4f9a7c 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, cast import numpy as np @@ -23,7 +23,7 @@ def __get_validators__(cls): @classmethod def validate( cls: Type[T], - value: Union[T, np.ndarray, Any], + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], field: 'ModelField', config: 'BaseConfig', ) -> T: @@ -31,13 +31,19 @@ def validate( return cls.from_ndarray(value) elif isinstance(value, Tensor): return cast(T, value) + elif isinstance(value, list) or isinstance(value, tuple): + try: + arr: np.ndarray = np.asarray(value) + return cls.from_ndarray(arr) + except Exception: + pass # handled below else: try: arr: np.ndarray = np.ndarray(value) return cls.from_ndarray(arr) except Exception: pass # handled below - raise ValueError(f'Expected a numpy.ndarray, got {type(value)}') + raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}') @classmethod def from_ndarray(cls: Type[T], value: np.ndarray) -> T: diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 56ba375e88a..e4a86a160f6 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -49,7 +49,7 @@ def validate( return cls.from_native_torch_tensor(arr) except Exception: pass # handled below - raise ValueError(f'Expected a torch.Tensor, got {type(value)}') + raise ValueError(f'Expected a torch.Tensor compatible type, got {type(value)}') @classmethod def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: diff --git a/tests/integrations/document/test_to_json.py b/tests/integrations/document/test_to_json.py index d0afd742bb9..2aa642d0589 100644 --- a/tests/integrations/document/test_to_json.py +++ b/tests/integrations/document/test_to_json.py @@ -19,3 +19,25 @@ class Mmdoc(BaseDocument): torch_tensor=torch.zeros(3, 224, 224), ) doc.json() + + +def test_from_json(): + class Mmdoc(BaseDocument): + img: Tensor + url: AnyUrl + txt: str + torch_tensor: TorchTensor + + doc = Mmdoc( + img=np.zeros((2, 2)), + url='http://doccaray.io', + txt='hello', + torch_tensor=torch.zeros(3, 224, 224), + ) + new_doc = Mmdoc.parse_raw(doc.json()) + + for (field, field2) in zip(doc.dict().keys(), new_doc.dict().keys()): + if field in ['torch_tensor', 'img']: + assert (getattr(doc, field) == getattr(doc, field2)).all() + else: + assert getattr(doc, field) == getattr(doc, field2) diff --git a/tests/units/typing/test_tensor.py b/tests/units/typing/test_tensor.py index 32086f0e402..fbe5d72a50e 100644 --- a/tests/units/typing/test_tensor.py +++ b/tests/units/typing/test_tensor.py @@ -1,4 +1,5 @@ import numpy as np +import orjson from pydantic.tools import parse_obj_as, schema_json_of from docarray.document.io.json import orjson_dumps @@ -12,6 +13,12 @@ def test_proto_tensor(): tensor._to_node_protobuf() +def test_from_list(): + tensor = parse_obj_as(Tensor, [[0.0, 0.0], [0.0, 0.0]]) + + assert (tensor == np.zeros((2, 2))).all() + + def test_json_schema(): schema_json_of(Tensor) @@ -21,6 +28,17 @@ def test_dump_json(): orjson_dumps(tensor) +def test_load_json(): + tensor = parse_obj_as(Tensor, np.zeros((2, 2))) + + json = orjson_dumps(tensor) + print(json) + print(type(json)) + new_tensor = orjson.loads(json) + + assert (new_tensor == tensor).all() + + def test_unwrap(): tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) ndarray = tensor.unwrap() From 606134b94565d08cd430266826d136cd5ce79357 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 14:15:47 +0100 Subject: [PATCH 17/26] fix: add tensor to fastapi test Signed-off-by: Sami Jaghouar --- tests/integrations/externals/test_fastapi.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index f31641608d7..54e3c44bc2b 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from fastapi import FastAPI from httpx import AsyncClient @@ -12,7 +13,9 @@ class Mmdoc(Document): text: Text title: str - input_doc = Mmdoc(img=Image(), text=Text(), title='hello') + input_doc = Mmdoc( + img=Image(tensor=np.zeros((3, 224, 224))), text=Text(), title='hello' + ) app = FastAPI() From 731feff20d0347494ee0e514bc09b3b87306bd2c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 14:42:40 +0100 Subject: [PATCH 18/26] fix: add new fastapi test Signed-off-by: Sami Jaghouar --- tests/integrations/externals/test_fastapi.py | 61 ++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 54e3c44bc2b..0ebb6682518 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -4,6 +4,7 @@ from httpx import AsyncClient from docarray import Document, Image, Text +from docarray.typing import Tensor @pytest.mark.asyncio @@ -31,3 +32,63 @@ async def create_item(doc: Mmdoc): assert response.status_code == 200 assert resp_doc.status_code == 200 assert resp_redoc.status_code == 200 + + +@pytest.mark.asyncio +async def test_image(): + class InputDoc(Document): + img: Image + + class OutputDoc(Document): + embedding_clip: Tensor + embedding_bert: Tensor + + input_doc = InputDoc(img=Image(tensor=np.zeros((3, 224, 224)))) + + app = FastAPI() + + @app.post("/doc/", response_model=OutputDoc) + async def create_item(doc: InputDoc) -> OutputDoc: + ## call my fancy model to generate the embeddings + return OutputDoc( + embedding_clip=np.zeros((100, 1)), embedding_bert=np.zeros((100, 1)) + ) + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=input_doc.json()) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + +@pytest.mark.asyncio +async def test_sentence_to_embeddings(): + class InputDoc(Document): + text: str + + class OutputDoc(Document): + embedding_clip: Tensor + embedding_bert: Tensor + + input_doc = InputDoc(text='hello') + + app = FastAPI() + + @app.post("/doc/", response_model=OutputDoc) + async def create_item(doc: InputDoc) -> OutputDoc: + ## call my fancy model to generate the embeddings + return OutputDoc( + embedding_clip=np.zeros((100, 1)), embedding_bert=np.zeros((100, 1)) + ) + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=input_doc.json()) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 From 061b05e42648cc420a9b7e6784ba3de32776508b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 15:31:46 +0100 Subject: [PATCH 19/26] fix: fix mypy Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index 53b457316fe..146a86bf86b 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -33,8 +33,8 @@ def validate( return cast(T, value) elif isinstance(value, list) or isinstance(value, tuple): try: - arr: np.ndarray = np.asarray(value) - return cls.from_ndarray(arr) + arr_from_list: np.ndarray = np.asarray(value) + return cls.from_ndarray(arr_from_list) except Exception: pass # handled below else: From 94c0069d8a7060e0570922ec9ce25bc23d7e6af2 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 15:46:48 +0100 Subject: [PATCH 20/26] feat: add more testing for text uri about json Signed-off-by: Sami Jaghouar --- tests/units/typing/url/test_text_url.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/units/typing/url/test_text_url.py b/tests/units/typing/url/test_text_url.py index 2544358388f..21f1732df6a 100644 --- a/tests/units/typing/url/test_text_url.py +++ b/tests/units/typing/url/test_text_url.py @@ -2,8 +2,9 @@ import urllib import pytest -from pydantic import parse_obj_as +from pydantic import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import TextUrl REMOTE_TXT = 'https://de.wikipedia.org/wiki/Brixen' @@ -43,3 +44,12 @@ def test_load_timeout(): _ = url.load(timeout=0.001) with pytest.raises(urllib.error.URLError): _ = url.load_to_bytes(timeout=0.001) + + +def test_json_schema(): + schema_json_of(TextUrl) + + +def test_dump_json(): + url = parse_obj_as(TextUrl, REMOTE_TXT) + orjson_dumps(url) From e04285a47032ecd9a64289ed3b01b38636914387 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 16:56:23 +0100 Subject: [PATCH 21/26] fix: fix default orson not returning Signed-off-by: Sami Jaghouar --- docarray/document/io/json.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docarray/document/io/json.py b/docarray/document/io/json.py index 95961189759..16e875fa359 100644 --- a/docarray/document/io/json.py +++ b/docarray/document/io/json.py @@ -11,6 +11,8 @@ def _default_orjson(obj): if getattr(obj, '_to_json_compatible'): return obj._to_json_compatible() + else: + return obj def orjson_dumps(v, *, default=None): From b8de8bc4414ee61a69e75a4f41c1488f6a1a67cf Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 17:15:34 +0100 Subject: [PATCH 22/26] fix: apply johannes suggestion on docstring Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/tensor.py | 13 +++++++++---- docarray/typing/tensor/torch_tensor.py | 11 ++++++++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index 146a86bf86b..356fdf23cdd 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -63,17 +63,22 @@ def _to_json_compatible(self) -> np.ndarray: def unwrap(self) -> np.ndarray: """ - Return the original ndarray without any memory copy + Return the original ndarray without any memory copy. + + The original view rest intact and is still a Document Tensor + but the return object is a pure np.ndarray but both object share + the same memory layout. EXAMPLE USAGE .. code-block:: python from docarray.typing import Tensor import numpy as np - t = Tensor.validate(np.zeros((3, 224, 224)), None, None) + t1 = Tensor.validate(np.zeros((3, 224, 224)), None, None) # here t is a docarray Tensor - t = t.unwrap() - # here t is a pure np.ndarray + t2 = t.unwrap() + # here t2 is a pure np.ndarray but t1 is still a Docarray Tensor + # But both share the same underlying memory :return: a numpy ndarray diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index d84963b3378..376dc175fec 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -65,7 +65,11 @@ def _to_json_compatible(self) -> np.ndarray: def unwrap(self) -> torch.Tensor: """ - Return the original ndarray without any memory copy + Return the original torch.Tensor without any memory copy. + + The original view rest intact and is still a Document Tensor + but the return object is a pure torch Tensor but both object share + the same memory layout. EXAMPLE USAGE .. code-block:: python @@ -74,8 +78,9 @@ def unwrap(self) -> torch.Tensor: t = Tensor.validate(torch.zeros(3, 224, 224), None, None) # here t is a docarray Tensor - t = t.unwrap() - # here t is a pure torch Tensor + t2 = t.unwrap() + # here t2 is a pure torch.Tensor but t1 is still a Docarray Tensor + # But both share the same underlying memory :return: a torch Tensor From 19ffb0b5221ba849ec52ac317cefc1d5652c22e1 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 22 Nov 2022 17:44:46 +0100 Subject: [PATCH 23/26] fix: does not perform copy anymore on torch tensor unwrap Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/torch_tensor.py | 5 +++-- tests/units/typing/test_torch_tensor.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 376dc175fec..304a7f8ec9f 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -1,3 +1,4 @@ +from copy import copy from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast import numpy as np @@ -85,8 +86,8 @@ def unwrap(self) -> torch.Tensor: :return: a torch Tensor """ - ## might need to check device later - value = torch.tensor(self) + value = copy(self) ## as intuitivly as it sounds this + # does not do any memory copy just shallow reference copy value.__class__ = torch.Tensor return value diff --git a/tests/units/typing/test_torch_tensor.py b/tests/units/typing/test_torch_tensor.py index addd7297203..7d3081f86e3 100644 --- a/tests/units/typing/test_torch_tensor.py +++ b/tests/units/typing/test_torch_tensor.py @@ -29,4 +29,6 @@ def test_unwrap(): assert isinstance(tensor, TorchTensor) assert isinstance(ndarray, torch.Tensor) + assert tensor.data_ptr() == ndarray.data_ptr() + assert (ndarray == torch.zeros(3, 224, 224)).all() From 0b7dac7da467fbb70c658fbc4c6deefde38c1685 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Tue, 22 Nov 2022 17:52:51 +0100 Subject: [PATCH 24/26] fix: add johannes suggestion Co-authored-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Signed-off-by: samsja <55492238+samsja@users.noreply.github.com> --- docarray/typing/tensor/torch_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 304a7f8ec9f..39895fab1b0 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -86,8 +86,8 @@ def unwrap(self) -> torch.Tensor: :return: a torch Tensor """ - value = copy(self) ## as intuitivly as it sounds this - # does not do any memory copy just shallow reference copy + value = copy(self) # as unintuitive as it sounds, this + # does not do any relevant memory copying, just shallow reference to the torch data value.__class__ = torch.Tensor return value From a334c65ef33b3cda19e0ab44ca7e81c881ddb00d Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 23 Nov 2022 08:47:50 +0100 Subject: [PATCH 25/26] fix: fix ruff line lenght Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/torch_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 39895fab1b0..bdc5400790c 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -87,7 +87,8 @@ def unwrap(self) -> torch.Tensor: :return: a torch Tensor """ value = copy(self) # as unintuitive as it sounds, this - # does not do any relevant memory copying, just shallow reference to the torch data + # does not do any relevant memory copying, just shallow + # reference to the torch data value.__class__ = torch.Tensor return value From 909ee746135627007490984785ea548ce505e3f0 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 23 Nov 2022 08:58:53 +0100 Subject: [PATCH 26/26] fix: fix mypy pb Signed-off-by: Sami Jaghouar --- docarray/typing/tensor/torch_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index bdc5400790c..691b4e3f1c8 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -89,7 +89,7 @@ def unwrap(self) -> torch.Tensor: value = copy(self) # as unintuitive as it sounds, this # does not do any relevant memory copying, just shallow # reference to the torch data - value.__class__ = torch.Tensor + value.__class__ = torch.Tensor # type: ignore return value @classmethod