diff --git a/docarray/document/document.py b/docarray/document/document.py index f0eb363a624..17497414429 100644 --- a/docarray/document/document.py +++ b/docarray/document/document.py @@ -1,14 +1,15 @@ import os from typing import Type +import orjson from pydantic import BaseModel, Field from docarray.document.abstract_document import AbstractDocument from docarray.document.base_node import BaseNode +from docarray.document.io.json import orjson_dumps +from docarray.document.mixins import ProtoMixin from docarray.typing import ID -from .mixins import ProtoMixin - class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): """ @@ -17,6 +18,10 @@ class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): id: ID = Field(default_factory=lambda: ID.validate(os.urandom(16).hex())) + class Config: + json_loads = orjson.loads + json_dumps = orjson_dumps + @classmethod def _get_nested_document_class(cls, field: str) -> Type['BaseDocument']: """ diff --git a/docarray/document/io/__init__.py b/docarray/document/io/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/document/io/json.py b/docarray/document/io/json.py new file mode 100644 index 00000000000..16e875fa359 --- /dev/null +++ b/docarray/document/io/json.py @@ -0,0 +1,22 @@ +import orjson + + +def _default_orjson(obj): + """ + default option for orjson dumps. It will call _to_json_compatible + from docarray typing object that expose such method. + :param obj: + :return: return a json compatible object + """ + + if getattr(obj, '_to_json_compatible'): + return obj._to_json_compatible() + else: + return obj + + +def orjson_dumps(v, *, default=None): + # orjson.dumps returns bytes, to match standard json.dumps we need to decode + return orjson.dumps( + v, default=_default_orjson, option=orjson.OPT_SERIALIZE_NUMPY + ).decode() diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index bb37c39e52a..356fdf23cdd 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, cast import numpy as np @@ -22,26 +22,71 @@ def __get_validators__(cls): @classmethod def validate( - cls: Type[T], value: Union[T, Any], field: 'ModelField', config: 'BaseConfig' + cls: Type[T], + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', ) -> T: if isinstance(value, np.ndarray): return cls.from_ndarray(value) elif isinstance(value, Tensor): return cast(T, value) + elif isinstance(value, list) or isinstance(value, tuple): + try: + arr_from_list: np.ndarray = np.asarray(value) + return cls.from_ndarray(arr_from_list) + except Exception: + pass # handled below else: try: arr: np.ndarray = np.ndarray(value) return cls.from_ndarray(arr) except Exception: pass # handled below - raise ValueError(f'Expected a numpy.ndarray, got {type(value)}') + raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}') @classmethod def from_ndarray(cls: Type[T], value: np.ndarray) -> T: return value.view(cls) + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + # this is needed to dump to json + field_schema.update(type='string', format='tensor') + + def _to_json_compatible(self) -> np.ndarray: + """ + Convert tensor into a json compatible object + :return: a list representation of the tensor + """ + return self.unwrap() + + def unwrap(self) -> np.ndarray: + """ + Return the original ndarray without any memory copy. + + The original view rest intact and is still a Document Tensor + but the return object is a pure np.ndarray but both object share + the same memory layout. + + EXAMPLE USAGE + .. code-block:: python + from docarray.typing import Tensor + import numpy as np + + t1 = Tensor.validate(np.zeros((3, 224, 224)), None, None) + # here t is a docarray Tensor + t2 = t.unwrap() + # here t2 is a pure np.ndarray but t1 is still a Docarray Tensor + # But both share the same underlying memory + + + :return: a numpy ndarray + """ + return self.view(np.ndarray) + def _to_node_protobuf(self: T, field: str = 'tensor') -> NodeProto: - """Convert Document into a NodeProto protobuf message. This function should + """Convert itself into a NodeProto protobuf message. This function should be called when the Document is nested into another Document that need to be converted into a protobuf :param field: field in which to store the content in the node proto diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 6bef83605cf..691b4e3f1c8 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast +from copy import copy +from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast import numpy as np import torch # type: ignore @@ -49,7 +50,47 @@ def validate( return cls.from_native_torch_tensor(arr) except Exception: pass # handled below - raise ValueError(f'Expected a torch.Tensor, got {type(value)}') + raise ValueError(f'Expected a torch.Tensor compatible type, got {type(value)}') + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + # this is needed to dump to json + field_schema.update(type='string', format='tensor') + + def _to_json_compatible(self) -> np.ndarray: + """ + Convert torch Tensor into a json compatible object + :return: a list representation of the torch tensor + """ + return self.numpy() ## might need to check device later + + def unwrap(self) -> torch.Tensor: + """ + Return the original torch.Tensor without any memory copy. + + The original view rest intact and is still a Document Tensor + but the return object is a pure torch Tensor but both object share + the same memory layout. + + EXAMPLE USAGE + .. code-block:: python + from docarray.typing import TorchTensor + import torch + + t = Tensor.validate(torch.zeros(3, 224, 224), None, None) + # here t is a docarray Tensor + t2 = t.unwrap() + # here t2 is a pure torch.Tensor but t1 is still a Docarray Tensor + # But both share the same underlying memory + + + :return: a torch Tensor + """ + value = copy(self) # as unintuitive as it sounds, this + # does not do any relevant memory copying, just shallow + # reference to the torch data + value.__class__ = torch.Tensor # type: ignore + return value @classmethod def from_native_torch_tensor(cls: Type[T], value: torch.Tensor) -> T: diff --git a/poetry.lock b/poetry.lock index c1f4138a6cc..f1dcad3686b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -270,6 +270,24 @@ python-versions = "*" [package.extras] tests = ["asttokens", "littleutils", "pytest", "rich"] +[[package]] +name = "fastapi" +version = "0.87.0" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0" +starlette = "0.21.0" + +[package.extras] +all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.114)", "uvicorn[standard] (>=0.12.0,<0.19.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer[all] (>=0.6.1,<0.7.0)"] +test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==22.8.0)", "coverage[toml] (>=6.5.0,<7.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.6)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.114)", "sqlalchemy (>=1.3.18,<=1.4.41)", "types-orjson (==3.6.2)", "types-ujson (==5.5.0)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] + [[package]] name = "fastjsonschema" version = "2.16.2" @@ -293,6 +311,52 @@ python-versions = ">=3.7" docs = ["furo (>=2022.6.21)", "sphinx (>=5.1.1)", "sphinx-autodoc-typehints (>=1.19.1)"] testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pytest-cov (>=3)", "pytest-timeout (>=2.1)"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "dev" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "httpcore" +version = "0.16.1" +description = "A minimal low-level HTTP client." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +anyio = ">=3.0,<5.0" +certifi = "*" +h11 = ">=0.13,<0.15" +sniffio = ">=1.0.0,<2.0.0" + +[package.extras] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + +[[package]] +name = "httpx" +version = "0.23.1" +description = "The next generation HTTP client." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +certifi = "*" +httpcore = ">=0.15.0,<0.17.0" +rfc3986 = {version = ">=1.3,<2", extras = ["idna2008"]} +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<13)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + [[package]] name = "identify" version = "2.5.8" @@ -343,6 +407,14 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)"] testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] +[[package]] +name = "iniconfig" +version = "1.1.1" +description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "ipykernel" version = "6.17.1" @@ -631,14 +703,6 @@ category = "dev" optional = false python-versions = "*" -[[package]] -name = "more-itertools" -version = "9.0.0" -description = "More routines for operating on iterables, beyond itertools" -category = "dev" -optional = false -python-versions = ">=3.7" - [[package]] name = "mypy" version = "0.990" @@ -888,6 +952,14 @@ python-versions = ">=3" setuptools = "*" wheel = "*" +[[package]] +name = "orjson" +version = "3.8.2" +description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +category = "main" +optional = false +python-versions = ">=3.7" + [[package]] name = "packaging" version = "21.3" @@ -1128,26 +1200,39 @@ python-versions = ">=3.7" [[package]] name = "pytest" -version = "5.4.3" +version = "6.2.5" description = "pytest: simple powerful testing with Python" category = "dev" optional = false -python-versions = ">=3.5" +python-versions = ">=3.6" [package.dependencies] atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} -attrs = ">=17.4.0" +attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} -more-itertools = ">=4.0.0" +iniconfig = "*" packaging = "*" -pluggy = ">=0.12,<1.0" -py = ">=1.5.0" -wcwidth = "*" +pluggy = ">=0.12,<2.0" +py = ">=1.8.2" +toml = "*" [package.extras] -checkqa-mypy = ["mypy (==v0.761)"] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.20.2" +description = "Pytest support for asyncio" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +pytest = ">=6.1.0" + +[package.extras] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1221,6 +1306,20 @@ urllib3 = ">=1.21.1,<1.27" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rfc3986" +version = "1.5.0" +description = "Validating URI References per RFC 3986" +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +idna = {version = "*", optional = true, markers = "extra == \"idna2008\""} + +[package.extras] +idna2008 = ["idna"] + [[package]] name = "ruff" version = "0.0.117" @@ -1295,6 +1394,21 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "starlette" +version = "0.21.0" +description = "The little ASGI library that shines." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] + [[package]] name = "terminado" version = "0.17.0" @@ -1418,6 +1532,21 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "uvicorn" +version = "0.19.0" +description = "The lightning-fast ASGI server." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.0)"] + [[package]] name = "virtualenv" version = "20.16.7" @@ -1495,7 +1624,7 @@ torch = ["torch"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "7fabdc150fb15e67a5eff53967ccf3846e464d78544b12784953635d2866a64a" +content-hash = "ee0a07a121a7e55e90055a58152acf241f2b3fb9d4769bc883a4484da30dcd92" [metadata.files] anyio = [ @@ -1709,6 +1838,10 @@ executing = [ {file = "executing-1.2.0-py2.py3-none-any.whl", hash = "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc"}, {file = "executing-1.2.0.tar.gz", hash = "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"}, ] +fastapi = [ + {file = "fastapi-0.87.0-py3-none-any.whl", hash = "sha256:254453a2e22f64e2a1b4e1d8baf67d239e55b6c8165c079d25746a5220c81bb4"}, + {file = "fastapi-0.87.0.tar.gz", hash = "sha256:07032e53df9a57165047b4f38731c38bdcc3be5493220471015e2b4b51b486a4"}, +] fastjsonschema = [ {file = "fastjsonschema-2.16.2-py3-none-any.whl", hash = "sha256:21f918e8d9a1a4ba9c22e09574ba72267a6762d47822db9add95f6454e51cc1c"}, {file = "fastjsonschema-2.16.2.tar.gz", hash = "sha256:01e366f25d9047816fe3d288cbfc3e10541daf0af2044763f3d0ade42476da18"}, @@ -1717,6 +1850,18 @@ filelock = [ {file = "filelock-3.8.0-py3-none-any.whl", hash = "sha256:617eb4e5eedc82fc5f47b6d61e4d11cb837c56cb4544e39081099fa17ad109d4"}, {file = "filelock-3.8.0.tar.gz", hash = "sha256:55447caa666f2198c5b6b13a26d2084d26fa5b115c00d065664b2124680c4edc"}, ] +h11 = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] +httpcore = [ + {file = "httpcore-0.16.1-py3-none-any.whl", hash = "sha256:8d393db683cc8e35cc6ecb02577c5e1abfedde52b38316d038932a84b4875ecb"}, + {file = "httpcore-0.16.1.tar.gz", hash = "sha256:3d3143ff5e1656a5740ea2f0c167e8e9d48c5a9bbd7f00ad1f8cff5711b08543"}, +] +httpx = [ + {file = "httpx-0.23.1-py3-none-any.whl", hash = "sha256:0b9b1f0ee18b9978d637b0776bfd7f54e2ca278e063e3586d8f01cda89e042a8"}, + {file = "httpx-0.23.1.tar.gz", hash = "sha256:202ae15319be24efe9a8bd4ed4360e68fde7b38bcc2ce87088d416f026667d19"}, +] identify = [ {file = "identify-2.5.8-py2.py3-none-any.whl", hash = "sha256:48b7925fe122720088aeb7a6c34f17b27e706b72c61070f27fe3789094233440"}, {file = "identify-2.5.8.tar.gz", hash = "sha256:7a214a10313b9489a0d61467db2856ae8d0b8306fc923e03a9effa53d8aedc58"}, @@ -1733,6 +1878,10 @@ importlib-resources = [ {file = "importlib_resources-5.10.0-py3-none-any.whl", hash = "sha256:ee17ec648f85480d523596ce49eae8ead87d5631ae1551f913c0100b5edd3437"}, {file = "importlib_resources-5.10.0.tar.gz", hash = "sha256:c01b1b94210d9849f286b86bb51bcea7cd56dde0600d8db721d7b81330711668"}, ] +iniconfig = [ + {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, + {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, +] ipykernel = [ {file = "ipykernel-6.17.1-py3-none-any.whl", hash = "sha256:3a9a1b2ad6dbbd5879855aabb4557f08e63fa2208bffed897f03070e2bb436f6"}, {file = "ipykernel-6.17.1.tar.gz", hash = "sha256:e178c1788399f93a459c241fe07c3b810771c607b1fb064a99d2c5d40c90c5d4"}, @@ -1839,10 +1988,6 @@ mistune = [ {file = "mistune-2.0.4-py2.py3-none-any.whl", hash = "sha256:182cc5ee6f8ed1b807de6b7bb50155df7b66495412836b9a74c8fbdfc75fe36d"}, {file = "mistune-2.0.4.tar.gz", hash = "sha256:9ee0a66053e2267aba772c71e06891fa8f1af6d4b01d5e84e267b4570d4d9808"}, ] -more-itertools = [ - {file = "more-itertools-9.0.0.tar.gz", hash = "sha256:5a6257e40878ef0520b1803990e3e22303a41b5714006c32a3fd8304b26ea1ab"}, - {file = "more_itertools-9.0.0-py3-none-any.whl", hash = "sha256:250e83d7e81d0c87ca6bd942e6aeab8cc9daa6096d12c5308f3f92fa5e5c1f41"}, -] mypy = [ {file = "mypy-0.990-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:aaf1be63e0207d7d17be942dcf9a6b641745581fe6c64df9a38deb562a7dbafa"}, {file = "mypy-0.990-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d555aa7f44cecb7ea3c0ac69d58b1a5afb92caa017285a8e9c4efbf0518b61b4"}, @@ -1958,6 +2103,57 @@ nvidia-cudnn-cu11 = [ {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"}, {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"}, ] +orjson = [ + {file = "orjson-3.8.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:43e69b360c2851b45c7dbab3b95f7fa8469df73fab325a683f7389c4db63aa71"}, + {file = "orjson-3.8.2-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:64c5da5c9679ef3d85e9bbcbb62f4ccdc1f1975780caa20f2ec1e37b4da6bd36"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c632a2157fa9ec098d655287e9e44809615af99837c49f53d96bfbca453c5bd"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f63da6309c282a2b58d4a846f0717f6440356b4872838b9871dc843ed1fe2b38"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c9be25c313ba2d5478829d949165445c3bd36c62e07092b4ba8dbe5426574d1"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:4bcce53e9e088f82633f784f79551fcd7637943ab56c51654aaf9d4c1d5cfa54"}, + {file = "orjson-3.8.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:33edb5379c6e6337f9383c85fe4080ce3aa1057cc2ce29345b7239461f50cbd6"}, + {file = "orjson-3.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:da35d347115758bbc8bfaf39bb213c42000f2a54e3f504c84374041d20835cd6"}, + {file = "orjson-3.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d755d94a90a941b91b4d39a6b02e289d8ba358af2d1a911edf266be7942609dc"}, + {file = "orjson-3.8.2-cp310-none-win_amd64.whl", hash = "sha256:7ea96923e26390b2142602ebb030e2a4db9351134696e0b219e5106bddf9b48e"}, + {file = "orjson-3.8.2-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:a0d89de876e6f1cef917a2338378a60a98584e1c2e1c67781e20b6ed1c512478"}, + {file = "orjson-3.8.2-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:8d47e7592fe938aec898eb22ea4946298c018133df084bc78442ff18e2c6347c"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3d9f1043f618d0c64228aab9711e5bd822253c50b6c56223951e32b51f81d62"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed10600e8b08f1e87b656ad38ab316191ce94f2c9adec57035680c0dc9e93c81"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99c49e49a04bf61fee7aaea6d92ac2b1fcf6507aea894bbdf3fbb25fe792168c"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:1463674f8efe6984902473d7b5ce3edf444c1fcd09dc8aa4779638a28fb9ca01"}, + {file = "orjson-3.8.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c1ef75f1d021d817e5c60a42da0b4b7e3123b1b37415260b8415666ddacc7cd7"}, + {file = "orjson-3.8.2-cp311-none-win_amd64.whl", hash = "sha256:b6007e1ac8564b13b2521720929e8bb3ccd3293d9fdf38f28728dcc06db6248f"}, + {file = "orjson-3.8.2-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:a02c13ae523221576b001071354380e277346722cc6b7fdaacb0fd6db5154b3e"}, + {file = "orjson-3.8.2-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fa2e565cf8ffdb37ce1887bd1592709ada7f701e61aa4b1e710be94b0aecbab4"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1d8864288f7c5fccc07b43394f83b721ddc999f25dccfb5d0651671a76023f5"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1874c05d0bb994601fa2d51605cb910d09343c6ebd36e84a573293523fab772a"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:349387ed6989e5db22e08c9af8d7ca14240803edc50de451d48d41a0e7be30f6"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:4e42b19619d6e97e201053b865ca4e62a48da71165f4081508ada8e1b91c6a30"}, + {file = "orjson-3.8.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:bc112c17e607c59d1501e72afb44226fa53d947d364aed053f0c82d153e29616"}, + {file = "orjson-3.8.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6fda669211f2ed1fc2c8130187ec90c96b4f77b6a250004e666d2ef8ed524e5f"}, + {file = "orjson-3.8.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:aebd4e80fea0f20578fd0452908b9206a6a0d5ae9f5c99b6e665bbcd989e56cd"}, + {file = "orjson-3.8.2-cp37-none-win_amd64.whl", hash = "sha256:9f3cd0394eb6d265beb2a1572b5663bc910883ddbb5cdfbcb660f5a0444e7fd8"}, + {file = "orjson-3.8.2-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:74e7d54d11b3da42558d69a23bf92c2c48fabf69b38432d5eee2c5b09cd4c433"}, + {file = "orjson-3.8.2-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:8cbadc9be748a823f9c743c7631b1ee95d3925a9c0b21de4e862a1d57daa10ec"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07d5a8c69a2947d9554a00302734fe3d8516415c8b280963c92bc1033477890"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b364ea01d1b71b9f97bf97af9eb79ebee892df302e127a9e2e4f8eaa74d6b98"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b98a8c825a59db94fbe8e0cce48618624c5a6fb1436467322d90667c08a0bf80"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:ab63103f60b516c0fce9b62cb4773f689a82ab56e19ef2387b5a3182f80c0d78"}, + {file = "orjson-3.8.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:73ab3f4288389381ae33ab99f914423b69570c88d626d686764634d5e0eeb909"}, + {file = "orjson-3.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2ab3fd8728e12c36e20c6d9d70c9e15033374682ce5acb6ed6a08a80dacd254d"}, + {file = "orjson-3.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cde11822cf71a7f0daaa84223249b2696a2b6cda7fa587e9fd762dff1a8848e4"}, + {file = "orjson-3.8.2-cp38-none-win_amd64.whl", hash = "sha256:b14765ea5aabfeab1a194abfaa0be62c9fee6480a75ac8c6974b4eeede3340b4"}, + {file = "orjson-3.8.2-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:6068a27d59d989d4f2864c2fc3440eb7126a0cfdfaf8a4ad136b0ffd932026ae"}, + {file = "orjson-3.8.2-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:6bf36fa759a1b941fc552ad76b2d7fb10c1d2a20c056be291ea45eb6ae1da09b"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f436132e62e647880ca6988974c8e3165a091cb75cbed6c6fd93e931630c22fa"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3ecd8936259a5920b52a99faf62d4efeb9f5e25a0aacf0cce1e9fa7c37af154f"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c13114b345cda33644f64e92fe5d8737828766cf02fbbc7d28271a95ea546832"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6e43cdc3ddf96bdb751b748b1984b701125abacca8fc2226b808d203916e8cba"}, + {file = "orjson-3.8.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ee39071da2026b11e4352d6fc3608a7b27ee14bc699fd240f4e604770bc7a255"}, + {file = "orjson-3.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1c3833976ebbeb3b5b6298cb22e23bf18453f6b80802103b7d08f7dd8a61611d"}, + {file = "orjson-3.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b9a34519d3d70935e1cd3797fbed8fbb6f61025182bea0140ca84d95b6f8fbe5"}, + {file = "orjson-3.8.2-cp39-none-win_amd64.whl", hash = "sha256:2734086d9a3dd9591c4be7d05aff9beccc086796d3f243685e56b7973ebac5bc"}, + {file = "orjson-3.8.2.tar.gz", hash = "sha256:a2fb95a45031ccf278e44341027b3035ab99caa32aa173279b1f0a06324f434b"}, +] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, @@ -2188,8 +2384,12 @@ pyrsistent = [ {file = "pyrsistent-0.19.2.tar.gz", hash = "sha256:bfa0351be89c9fcbcb8c9879b826f4353be10f58f8a677efab0c017bf7137ec2"}, ] pytest = [ - {file = "pytest-5.4.3-py3-none-any.whl", hash = "sha256:5c0db86b698e8f170ba4582a492248919255fcd4c79b1ee64ace34301fb589a1"}, - {file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"}, + {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, + {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, +] +pytest-asyncio = [ + {file = "pytest-asyncio-0.20.2.tar.gz", hash = "sha256:32a87a9836298a881c0ec637ebcc952cfe23a56436bdc0d09d1511941dd8a812"}, + {file = "pytest_asyncio-0.20.2-py3-none-any.whl", hash = "sha256:07e0abf9e6e6b95894a39f688a4a875d63c2128f76c02d03d16ccbc35bcc0f8a"}, ] python-dateutil = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, @@ -2345,6 +2545,10 @@ requests = [ {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, ] +rfc3986 = [ + {file = "rfc3986-1.5.0-py2.py3-none-any.whl", hash = "sha256:a86d6e1f5b1dc238b218b012df0aa79409667bb209e58da56d0b94704e712a97"}, + {file = "rfc3986-1.5.0.tar.gz", hash = "sha256:270aaf10d87d0d4e095063c65bf3ddbc6ee3d0b226328ce21e036f946e421835"}, +] ruff = [ {file = "ruff-0.0.117-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:cb274e0447e91a1a7844b85cd0ef243d32732a4772140129f33bbc8891ca8577"}, {file = "ruff-0.0.117-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:a670ccea678f1ddd9f9cac1a5681be8bf3b33010e56058d16b270dfb9b893b95"}, @@ -2387,6 +2591,10 @@ stack-data = [ {file = "stack_data-0.6.1-py3-none-any.whl", hash = "sha256:960cb054d6a1b2fdd9cbd529e365b3c163e8dabf1272e02cfe36b58403cff5c6"}, {file = "stack_data-0.6.1.tar.gz", hash = "sha256:6c9a10eb5f342415fe085db551d673955611afb821551f554d91772415464315"}, ] +starlette = [ + {file = "starlette-0.21.0-py3-none-any.whl", hash = "sha256:0efc058261bbcddeca93cad577efd36d0c8a317e44376bcfc0e097a2b3dc24a7"}, + {file = "starlette-0.21.0.tar.gz", hash = "sha256:b1b52305ee8f7cfc48cde383496f7c11ab897cd7112b33d998b1317dc8ef9027"}, +] terminado = [ {file = "terminado-0.17.0-py3-none-any.whl", hash = "sha256:bf6fe52accd06d0661d7611cc73202121ec6ee51e46d8185d489ac074ca457c2"}, {file = "terminado-0.17.0.tar.gz", hash = "sha256:520feaa3aeab8ad64a69ca779be54be9234edb2d0d6567e76c93c2c9a4e6e43f"}, @@ -2459,6 +2667,10 @@ urllib3 = [ {file = "urllib3-1.26.12-py2.py3-none-any.whl", hash = "sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997"}, {file = "urllib3-1.26.12.tar.gz", hash = "sha256:3fa96cf423e6987997fc326ae8df396db2a8b7c667747d47ddd8ecba91f4a74e"}, ] +uvicorn = [ + {file = "uvicorn-0.19.0-py3-none-any.whl", hash = "sha256:cc277f7e73435748e69e075a721841f7c4a95dba06d12a72fe9874acced16f6f"}, + {file = "uvicorn-0.19.0.tar.gz", hash = "sha256:cf538f3018536edb1f4a826311137ab4944ed741d52aeb98846f52215de57f25"}, +] virtualenv = [ {file = "virtualenv-20.16.7-py3-none-any.whl", hash = "sha256:efd66b00386fdb7dbe4822d172303f40cd05e50e01740b19ea42425cbe653e29"}, {file = "virtualenv-20.16.7.tar.gz", hash = "sha256:8691e3ff9387f743e00f6bb20f70121f5e4f596cae754531f2b3b3a1b1ac696e"}, diff --git a/pyproject.toml b/pyproject.toml index 22f2edcbeb4..15278aea9f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ pydantic = "^1.10.2" numpy = "^1.23.4" protobuf = { version = "^4.21.9", optional = true } torch = { version = "^1.0.0", optional = true } +orjson = "^3.8.2" pillow = {version = "^9.3.0", optional = true } types-pillow = {version = "^9.3.0.1", optional = true } @@ -20,7 +21,7 @@ torch = ["torch"] image = ["pillow", "types-pillow"] [tool.poetry.dev-dependencies] -pytest = "^5.2" +pytest = "^6.1" pre-commit = "^2.20.0" jupyterlab = "^3.5.0" mypy = "^0.990" @@ -29,6 +30,12 @@ black = "^22.10.0" isort = "^5.10.1" ruff = "^0.0.117" +[tool.poetry.group.dev.dependencies] +fastapi = "^0.87.0" +uvicorn = "^0.19.0" +httpx = "^0.23.0" +pytest-asyncio = "^0.20.2" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/tests/integrations/document/test_to_json.py b/tests/integrations/document/test_to_json.py new file mode 100644 index 00000000000..2aa642d0589 --- /dev/null +++ b/tests/integrations/document/test_to_json.py @@ -0,0 +1,43 @@ +import numpy as np +import torch + +from docarray.document import BaseDocument +from docarray.typing import AnyUrl, Tensor, TorchTensor + + +def test_to_json(): + class Mmdoc(BaseDocument): + img: Tensor + url: AnyUrl + txt: str + torch_tensor: TorchTensor + + doc = Mmdoc( + img=np.zeros((3, 224, 224)), + url='http://doccaray.io', + txt='hello', + torch_tensor=torch.zeros(3, 224, 224), + ) + doc.json() + + +def test_from_json(): + class Mmdoc(BaseDocument): + img: Tensor + url: AnyUrl + txt: str + torch_tensor: TorchTensor + + doc = Mmdoc( + img=np.zeros((2, 2)), + url='http://doccaray.io', + txt='hello', + torch_tensor=torch.zeros(3, 224, 224), + ) + new_doc = Mmdoc.parse_raw(doc.json()) + + for (field, field2) in zip(doc.dict().keys(), new_doc.dict().keys()): + if field in ['torch_tensor', 'img']: + assert (getattr(doc, field) == getattr(doc, field2)).all() + else: + assert getattr(doc, field) == getattr(doc, field2) diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py new file mode 100644 index 00000000000..0ebb6682518 --- /dev/null +++ b/tests/integrations/externals/test_fastapi.py @@ -0,0 +1,94 @@ +import numpy as np +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from docarray import Document, Image, Text +from docarray.typing import Tensor + + +@pytest.mark.asyncio +async def test_fast_api(): + class Mmdoc(Document): + img: Image + text: Text + title: str + + input_doc = Mmdoc( + img=Image(tensor=np.zeros((3, 224, 224))), text=Text(), title='hello' + ) + + app = FastAPI() + + @app.post("/doc/") + async def create_item(doc: Mmdoc): + return doc + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=input_doc.json()) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + +@pytest.mark.asyncio +async def test_image(): + class InputDoc(Document): + img: Image + + class OutputDoc(Document): + embedding_clip: Tensor + embedding_bert: Tensor + + input_doc = InputDoc(img=Image(tensor=np.zeros((3, 224, 224)))) + + app = FastAPI() + + @app.post("/doc/", response_model=OutputDoc) + async def create_item(doc: InputDoc) -> OutputDoc: + ## call my fancy model to generate the embeddings + return OutputDoc( + embedding_clip=np.zeros((100, 1)), embedding_bert=np.zeros((100, 1)) + ) + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=input_doc.json()) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + +@pytest.mark.asyncio +async def test_sentence_to_embeddings(): + class InputDoc(Document): + text: str + + class OutputDoc(Document): + embedding_clip: Tensor + embedding_bert: Tensor + + input_doc = InputDoc(text='hello') + + app = FastAPI() + + @app.post("/doc/", response_model=OutputDoc) + async def create_item(doc: InputDoc) -> OutputDoc: + ## call my fancy model to generate the embeddings + return OutputDoc( + embedding_clip=np.zeros((100, 1)), embedding_bert=np.zeros((100, 1)) + ) + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=input_doc.json()) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 diff --git a/tests/units/typing/tensor/test_embedding.py b/tests/units/typing/tensor/test_embedding.py index 61d31abe079..57334dd49eb 100644 --- a/tests/units/typing/tensor/test_embedding.py +++ b/tests/units/typing/tensor/test_embedding.py @@ -1,11 +1,21 @@ import numpy as np -from pydantic.tools import parse_obj_as +from pydantic.tools import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import Embedding def test_proto_embedding(): - uri = parse_obj_as(Embedding, np.zeros((3, 224, 224))) + embedding = parse_obj_as(Embedding, np.zeros((3, 224, 224))) - uri._to_node_protobuf() + embedding._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(Embedding) + + +def test_dump_json(): + tensor = parse_obj_as(Embedding, np.zeros((3, 224, 224))) + orjson_dumps(tensor) diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index 9bace9e8841..fbe5d72a50e 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -1,11 +1,49 @@ import numpy as np -from pydantic.tools import parse_obj_as +import orjson +from pydantic.tools import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import Tensor def test_proto_tensor(): - uri = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) - uri._to_node_protobuf() + tensor._to_node_protobuf() + + +def test_from_list(): + tensor = parse_obj_as(Tensor, [[0.0, 0.0], [0.0, 0.0]]) + + assert (tensor == np.zeros((2, 2))).all() + + +def test_json_schema(): + schema_json_of(Tensor) + + +def test_dump_json(): + tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + orjson_dumps(tensor) + + +def test_load_json(): + tensor = parse_obj_as(Tensor, np.zeros((2, 2))) + + json = orjson_dumps(tensor) + print(json) + print(type(json)) + new_tensor = orjson.loads(json) + + assert (new_tensor == tensor).all() + + +def test_unwrap(): + tensor = parse_obj_as(Tensor, np.zeros((3, 224, 224))) + ndarray = tensor.unwrap() + + assert not isinstance(ndarray, Tensor) + assert isinstance(ndarray, np.ndarray) + assert isinstance(tensor, Tensor) + assert (ndarray == np.zeros((3, 224, 224))).all() diff --git a/tests/units/typing/test_id.py b/tests/units/typing/test_id.py index 5c3476bc82a..5919b351209 100644 --- a/tests/units/typing/test_id.py +++ b/tests/units/typing/test_id.py @@ -1,8 +1,10 @@ from uuid import UUID import pytest +from pydantic import schema_json_of from pydantic.tools import parse_obj_as +from docarray.document.io.json import orjson_dumps from docarray.typing import ID @@ -14,3 +16,12 @@ def test_id_validation(id): parsed_id = parse_obj_as(ID, id) assert parsed_id == str(id) + + +def test_json_schema(): + schema_json_of(ID) + + +def test_dump_json(): + id = parse_obj_as(ID, 1234) + orjson_dumps(id) diff --git a/tests/units/typing/test_torch_tensor.py b/tests/units/typing/test_torch_tensor.py new file mode 100644 index 00000000000..7d3081f86e3 --- /dev/null +++ b/tests/units/typing/test_torch_tensor.py @@ -0,0 +1,34 @@ +import torch +from pydantic.tools import parse_obj_as, schema_json_of + +from docarray.document.io.json import orjson_dumps +from docarray.typing import TorchTensor + + +def test_proto_tensor(): + + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + + tensor._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(TorchTensor) + + +def test_dump_json(): + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + orjson_dumps(tensor) + + +def test_unwrap(): + tensor = parse_obj_as(TorchTensor, torch.zeros(3, 224, 224)) + ndarray = tensor.unwrap() + + assert not isinstance(ndarray, TorchTensor) + assert isinstance(tensor, TorchTensor) + assert isinstance(ndarray, torch.Tensor) + + assert tensor.data_ptr() == ndarray.data_ptr() + + assert (ndarray == torch.zeros(3, 224, 224)).all() diff --git a/tests/units/typing/url/test_any_url.py b/tests/units/typing/url/test_any_url.py index ad593b58519..1cd4988437d 100644 --- a/tests/units/typing/url/test_any_url.py +++ b/tests/units/typing/url/test_any_url.py @@ -1,10 +1,20 @@ -from pydantic.tools import parse_obj_as +from pydantic.tools import parse_obj_as, schema_json_of -from docarray.typing import ImageUrl +from docarray.document.io.json import orjson_dumps +from docarray.typing import AnyUrl def test_proto_any_url(): - uri = parse_obj_as(ImageUrl, 'http://jina.ai/img.png') + uri = parse_obj_as(AnyUrl, 'http://jina.ai/img.png') uri._to_node_protobuf() + + +def test_json_schema(): + schema_json_of(AnyUrl) + + +def test_dump_json(): + url = parse_obj_as(AnyUrl, 'http://jina.ai/img.png') + orjson_dumps(url) diff --git a/tests/units/typing/url/test_image_url.py b/tests/units/typing/url/test_image_url.py index 22280398358..0b4a49f2cc1 100644 --- a/tests/units/typing/url/test_image_url.py +++ b/tests/units/typing/url/test_image_url.py @@ -4,8 +4,9 @@ import numpy as np import PIL import pytest -from pydantic.tools import parse_obj_as +from pydantic.tools import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import ImageUrl CUR_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -36,6 +37,15 @@ def test_proto_image_url(): uri._to_node_protobuf() +def test_json_schema(): + schema_json_of(ImageUrl) + + +def test_dump_json(): + url = parse_obj_as(ImageUrl, 'http://jina.ai/img.png') + orjson_dumps(url) + + @pytest.mark.parametrize( 'image_format,path_to_img', [ diff --git a/tests/units/typing/url/test_text_url.py b/tests/units/typing/url/test_text_url.py index 2544358388f..21f1732df6a 100644 --- a/tests/units/typing/url/test_text_url.py +++ b/tests/units/typing/url/test_text_url.py @@ -2,8 +2,9 @@ import urllib import pytest -from pydantic import parse_obj_as +from pydantic import parse_obj_as, schema_json_of +from docarray.document.io.json import orjson_dumps from docarray.typing import TextUrl REMOTE_TXT = 'https://de.wikipedia.org/wiki/Brixen' @@ -43,3 +44,12 @@ def test_load_timeout(): _ = url.load(timeout=0.001) with pytest.raises(urllib.error.URLError): _ = url.load_to_bytes(timeout=0.001) + + +def test_json_schema(): + schema_json_of(TextUrl) + + +def test_dump_json(): + url = parse_obj_as(TextUrl, REMOTE_TXT) + orjson_dumps(url)