diff --git a/docarray/base_document/__init__.py b/docarray/base_document/__init__.py index d362827bb6f..f7ce8d3d7ab 100644 --- a/docarray/base_document/__init__.py +++ b/docarray/base_document/__init__.py @@ -1,5 +1,6 @@ from docarray.base_document.any_document import AnyDocument from docarray.base_document.base_node import BaseNode from docarray.base_document.document import BaseDocument +from docarray.base_document.document_response import DocumentResponse -__all__ = ['AnyDocument', 'BaseDocument', 'BaseNode'] +__all__ = ['AnyDocument', 'BaseDocument', 'BaseNode', 'DocumentResponse'] diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 89d3eb47111..a985cd24e32 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -6,7 +6,7 @@ from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode -from docarray.base_document.io.json import orjson_dumps +from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode from docarray.base_document.mixins import ProtoMixin from docarray.typing import ID @@ -20,7 +20,9 @@ class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): class Config: json_loads = orjson.loads - json_dumps = orjson_dumps + json_dumps = orjson_dumps_and_decode + json_encoders = {dict: orjson_dumps} + validate_assignment = True @classmethod diff --git a/docarray/base_document/document_response.py b/docarray/base_document/document_response.py new file mode 100644 index 00000000000..cb756c32412 --- /dev/null +++ b/docarray/base_document/document_response.py @@ -0,0 +1,34 @@ +from typing import Any + +try: + from fastapi.responses import JSONResponse, Response +except ImportError: + + class NoImportResponse: + def __init__(self, *args, **kwargs): + ImportError('fastapi is not installed') + + Response = JSONResponse = NoImportResponse # type: ignore + + +class DocumentResponse(JSONResponse): + """ + This is a custom Response class for FastAPI and starlette. This is needed + to handle serialization of the Document types when using FastAPI + + EXAMPLE USAGE + .. code-block:: python + from docarray.documets import Text + from docarray.base_document import DocumentResponse + + + @app.post("/doc/", response_model=Text, response_class=DocumentResponse) + async def create_item(doc: Text) -> Text: + return doc + """ + + def render(self, content: Any) -> bytes: + if isinstance(content, bytes): + return content + else: + raise ValueError(f'{self.__class__} only work with json bytes content') diff --git a/docarray/base_document/io/json.py b/docarray/base_document/io/json.py index 16e875fa359..3d7809ae11a 100644 --- a/docarray/base_document/io/json.py +++ b/docarray/base_document/io/json.py @@ -1,22 +1,26 @@ import orjson +from docarray.typing.tensor.abstract_tensor import AbstractTensor + def _default_orjson(obj): """ - default option for orjson dumps. It will call _to_json_compatible - from docarray typing object that expose such method. + default option for orjson dumps. :param obj: :return: return a json compatible object """ - if getattr(obj, '_to_json_compatible'): - return obj._to_json_compatible() + if isinstance(obj, AbstractTensor): + return obj._docarray_to_json_compatible() else: return obj -def orjson_dumps(v, *, default=None): - # orjson.dumps returns bytes, to match standard json.dumps we need to decode - return orjson.dumps( - v, default=_default_orjson, option=orjson.OPT_SERIALIZE_NUMPY - ).decode() +def orjson_dumps(v, *, default=None) -> bytes: + # dumps to bytes using orjson + return orjson.dumps(v, default=_default_orjson, option=orjson.OPT_SERIALIZE_NUMPY) + + +def orjson_dumps_and_decode(v, *, default=None) -> str: + # dumps to bytes using orjson + return orjson_dumps(v, default=default).decode() diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 29bd773a9fb..a877861aa4e 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -120,7 +120,6 @@ def get_comp_backend() -> Type[AbstractComputationalBackend]: """The computational backend compatible with this tensor type.""" ... - @abc.abstractmethod def __getitem__(self, item): """Get a slice of this tensor.""" ... @@ -136,4 +135,11 @@ def to_protobuf(self) -> 'NdArrayProto': def unwrap(self): """Return the native tensor object that this DocArray tensor wraps.""" + + @abc.abstractmethod + def _docarray_to_json_compatible(self): + """ + Convert tensor into a json compatible object + :return: a representation of the tensor compatible with orjson + """ ... diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index d4dce8c707c..00f7109f966 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -130,10 +130,10 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # this is needed to dump to json field_schema.update(type='string', format='tensor') - def _to_json_compatible(self) -> np.ndarray: + def _docarray_to_json_compatible(self) -> np.ndarray: """ Convert tensor into a json compatible object - :return: a list representation of the tensor + :return: a representation of the tensor compatible with orjson """ return self.unwrap() diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 236ea1a09cb..07cefc25947 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -126,10 +126,10 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # this is needed to dump to json field_schema.update(type='string', format='tensor') - def _to_json_compatible(self) -> np.ndarray: + def _docarray_to_json_compatible(self) -> np.ndarray: """ Convert torchTensor into a json compatible object - :return: a list representation of the torch tensor + :return: a representation of the tensor compatible with orjson """ return self.numpy() ## might need to check device later diff --git a/poetry.lock b/poetry.lock index 5996317fe67..bddaefa8750 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,7 +2,7 @@ name = "anyio" version = "3.6.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" -category = "dev" +category = "main" optional = false python-versions = ">=3.6.2" @@ -274,8 +274,8 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] name = "fastapi" version = "0.87.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" -category = "dev" -optional = false +category = "main" +optional = true python-versions = ">=3.7" [package.dependencies] @@ -372,7 +372,7 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -category = "dev" +category = "main" optional = false python-versions = ">=3.5" @@ -1366,7 +1366,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -1398,8 +1398,8 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "starlette" version = "0.21.0" description = "The little ASGI library that shines." -category = "dev" -optional = false +category = "main" +optional = true python-versions = ">=3.7" [package.dependencies] @@ -1668,11 +1668,12 @@ common = ["protobuf"] image = ["pillow", "types-pillow"] mesh = ["trimesh"] torch = ["torch"] +web = ["fastapi"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "b1aa40aea6ec7f56a8c3b511fd2ce96ed217c6fc81d6f8dd931e519cc0774154" +content-hash = "e9505149fb25b56e7cbccfa923e71070f783ec35fc6b43f00564c6974eab3eae" [metadata.files] anyio = [ diff --git a/pyproject.toml b/pyproject.toml index 1d29532b696..66e16f420a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,12 +17,14 @@ types-pillow = {version = "^9.3.0.1", optional = true } trimesh = {version = "^3.17.1", optional = true} typing-inspect = "^0.8.0" types-requests = "^2.28.11.6" +fastapi = {version = "^0.87.0", optional = true } [tool.poetry.extras] common = ["protobuf"] torch = ["torch"] image = ["pillow", "types-pillow"] mesh = ["trimesh"] +web = ["fastapi"] [tool.poetry.dev-dependencies] pytest = "^6.1" @@ -35,7 +37,6 @@ isort = "^5.10.1" ruff = "^0.0.165" [tool.poetry.group.dev.dependencies] -fastapi = "^0.87.0" uvicorn = "^0.19.0" httpx = "^0.23.0" pytest-asyncio = "^0.20.2" diff --git a/tests/integrations/document/test_to_json.py b/tests/integrations/document/test_to_json.py index 4008e4642b8..72935bfcfda 100644 --- a/tests/integrations/document/test_to_json.py +++ b/tests/integrations/document/test_to_json.py @@ -1,11 +1,14 @@ import numpy as np +import pytest import torch from docarray.base_document import BaseDocument +from docarray.base_document.io.json import orjson_dumps from docarray.typing import AnyUrl, NdArray, TorchTensor -def test_to_json(): +@pytest.fixture() +def doc_and_class(): class Mmdoc(BaseDocument): img: NdArray url: AnyUrl @@ -13,27 +16,21 @@ class Mmdoc(BaseDocument): torch_tensor: TorchTensor doc = Mmdoc( - img=np.zeros((3, 224, 224)), + img=np.zeros((10)), url='http://doccaray.io', txt='hello', - torch_tensor=torch.zeros(3, 224, 224), + torch_tensor=torch.zeros(10), ) - doc.json() + return doc, Mmdoc -def test_from_json(): - class Mmdoc(BaseDocument): - img: NdArray - url: AnyUrl - txt: str - torch_tensor: TorchTensor +def test_to_json(doc_and_class): + doc, _ = doc_and_class + doc.json() - doc = Mmdoc( - img=np.zeros((2, 2)), - url='http://doccaray.io', - txt='hello', - torch_tensor=torch.zeros(3, 224, 224), - ) + +def test_from_json(doc_and_class): + doc, Mmdoc = doc_and_class new_doc = Mmdoc.parse_raw(doc.json()) for (field, field2) in zip(doc.dict().keys(), new_doc.dict().keys()): @@ -41,3 +38,14 @@ class Mmdoc(BaseDocument): assert (getattr(doc, field) == getattr(doc, field2)).all() else: assert getattr(doc, field) == getattr(doc, field2) + + +def test_to_dict_to_json(doc_and_class): + doc, Mmdoc = doc_and_class + new_doc = Mmdoc.parse_raw(orjson_dumps(doc.dict())) + + for (field, field2) in zip(doc.dict().keys(), new_doc.dict().keys()): + if field in ['torch_tensor', 'img']: + assert (getattr(doc, field) == getattr(doc, field2)).all() + else: + assert getattr(doc, field) == getattr(doc, field2) diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 3a99d56dd69..f74f8519a9c 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -4,6 +4,7 @@ from httpx import AsyncClient from docarray import BaseDocument +from docarray.base_document import DocumentResponse from docarray.documents import Image, Text from docarray.typing import NdArray @@ -21,7 +22,7 @@ class Mmdoc(BaseDocument): app = FastAPI() - @app.post("/doc/") + @app.post("/doc/", response_model=Mmdoc, response_class=DocumentResponse) async def create_item(doc: Mmdoc): return doc @@ -48,12 +49,13 @@ class OutputDoc(BaseDocument): app = FastAPI() - @app.post("/doc/", response_model=OutputDoc) + @app.post("/doc/", response_model=OutputDoc, response_class=DocumentResponse) async def create_item(doc: InputDoc) -> OutputDoc: ## call my fancy model to generate the embeddings - return OutputDoc( + doc = OutputDoc( embedding_clip=np.zeros((100, 1)), embedding_bert=np.zeros((100, 1)) ) + return doc async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.post("/doc/", data=input_doc.json()) @@ -64,6 +66,12 @@ async def create_item(doc: InputDoc) -> OutputDoc: assert resp_doc.status_code == 200 assert resp_redoc.status_code == 200 + doc = OutputDoc.parse_raw(response.content.decode()) + + assert isinstance(doc, OutputDoc) + assert doc.embedding_clip.shape == (100, 1) + assert doc.embedding_bert.shape == (100, 1) + @pytest.mark.asyncio async def test_sentence_to_embeddings(): @@ -78,7 +86,7 @@ class OutputDoc(BaseDocument): app = FastAPI() - @app.post("/doc/", response_model=OutputDoc) + @app.post("/doc/", response_model=OutputDoc, response_class=DocumentResponse) async def create_item(doc: InputDoc) -> OutputDoc: ## call my fancy model to generate the embeddings return OutputDoc( @@ -93,3 +101,9 @@ async def create_item(doc: InputDoc) -> OutputDoc: assert response.status_code == 200 assert resp_doc.status_code == 200 assert resp_redoc.status_code == 200 + + doc = OutputDoc.parse_raw(response.content.decode()) + + assert isinstance(doc, OutputDoc) + assert doc.embedding_clip.shape == (100, 1) + assert doc.embedding_bert.shape == (100, 1)