diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 251afc65a4a..02b250fad4e 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -1,7 +1,6 @@ import base64 import csv import io -import json import os import pathlib import pickle @@ -25,7 +24,10 @@ Union, ) +import orjson + from docarray.base_doc import AnyDoc, BaseDoc +from docarray.base_doc.io.json import orjson_dumps from docarray.helper import ( _access_path_dict_to_nested_dict, _all_access_paths_valid, @@ -41,6 +43,7 @@ from docarray.proto import DocumentArrayProto T = TypeVar('T', bound='IOMixinArray') +T_doc = TypeVar('T_doc', bound=BaseDoc) ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} SINGLE_PROTOCOLS = {'pickle', 'protobuf'} @@ -93,9 +96,9 @@ def __getitem__(self, item: slice): return self.content[item] -class IOMixinArray(Iterable[BaseDoc]): - - document_type: Type[BaseDoc] +class IOMixinArray(Iterable[T_doc]): + document_type: Type[T_doc] + _data: List[T_doc] @abstractmethod def __len__(self): @@ -251,7 +254,7 @@ def to_bytes( :return: the binary serialization in bytes or None if file_ctx is passed where to store """ - with (file_ctx or io.BytesIO()) as bf: + with file_ctx or io.BytesIO() as bf: self._write_bytes( bf=bf, protocol=protocol, @@ -318,14 +321,21 @@ def from_json( :param file: JSON object from where to deserialize a DocArray :return: the deserialized DocArray """ - json_docs = json.loads(file) - return cls([cls.document_type.parse_raw(v) for v in json_docs]) + json_docs = orjson.loads(file) + return cls([cls.document_type(**v) for v in json_docs]) - def to_json(self) -> str: - """Convert the object into a JSON string. Can be loaded via :meth:`.from_json`. + def to_json(self) -> bytes: + """Convert the object into JSON bytes. Can be loaded via :meth:`.from_json`. :return: JSON serialization of DocArray """ - return json.dumps([doc.json() for doc in self]) + return orjson_dumps(self._data) + + def _docarray_to_json_compatible(self) -> List[T_doc]: + """ + Convert itself into a json compatible object + :return: A list of documents + """ + return self._data @classmethod def from_csv( @@ -615,7 +625,7 @@ def _load_binary_stream( protocol: str = 'protobuf', compress: Optional[str] = None, show_progress: bool = False, - ) -> Generator['BaseDoc', None, None]: + ) -> Generator['T_doc', None, None]: """Yield `Document` objects from a binary file :param protocol: protocol to use. It can be 'pickle' or 'protobuf' @@ -672,7 +682,7 @@ def load_binary( compress: Optional[str] = None, show_progress: bool = False, streaming: bool = False, - ) -> Union[T, Generator['BaseDoc', None, None]]: + ) -> Union[T, Generator['T_doc', None, None]]: """Load array elements from a compressed binary file. :param file: File or filename or serialized bytes where the data is stored. diff --git a/docarray/base_doc/__init__.py b/docarray/base_doc/__init__.py index 941e76c542b..47e01c1c662 100644 --- a/docarray/base_doc/__init__.py +++ b/docarray/base_doc/__init__.py @@ -10,14 +10,14 @@ def __getattr__(name: str): - if name == 'DocResponse': + if name == 'DocArrayResponse': import_library('fastapi', raise_error=True) - from docarray.base_doc.doc_response import DocResponse + from docarray.base_doc.docarray_response import DocArrayResponse if name not in __all__: __all__.append(name) - return DocResponse + return DocArrayResponse else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/base_doc/base_node.py b/docarray/base_doc/base_node.py index 726403313aa..7cbb76c9e98 100644 --- a/docarray/base_doc/base_node.py +++ b/docarray/base_doc/base_node.py @@ -1,9 +1,11 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, Optional, Type if TYPE_CHECKING: from docarray.proto import NodeProto +T = TypeVar('T') + class BaseNode(ABC): """ @@ -11,6 +13,8 @@ class BaseNode(ABC): A Document itself is a DocumentNode as well as prebuilt type """ + _proto_type_name: Optional[str] = None + @abstractmethod def _to_node_protobuf(self) -> 'NodeProto': """Convert itself into a NodeProto message. This function should @@ -20,3 +24,14 @@ def _to_node_protobuf(self) -> 'NodeProto': :return: the nested item protobuf message """ ... + + @classmethod + @abstractmethod + def from_protobuf(cls: Type[T], pb_msg: T) -> T: + ... + + def _docarray_to_json_compatible(self): + """ + Convert itself into a json compatible object + """ + ... diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index c828a98f8b2..50abc6722a7 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -1,14 +1,15 @@ import os -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Dict import orjson from pydantic import BaseModel, Field from rich.console import Console from docarray.base_doc.base_node import BaseNode -from docarray.base_doc.io.json import orjson_dumps, orjson_dumps_and_decode +from docarray.base_doc.io.json import orjson_dumps_and_decode from docarray.base_doc.mixins import IOMixin, UpdateMixin from docarray.typing import ID +from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: from docarray.array.stacked.column_storage import ColumnStorageView @@ -28,7 +29,10 @@ class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode): class Config: json_loads = orjson.loads json_dumps = orjson_dumps_and_decode - json_encoders = {dict: orjson_dumps} + # `DocArrayResponse` is able to handle tensors by itself. + # Therefore, we stop FastAPI from doing any transformations + # on tensors by setting an identity function as a custom encoder. + json_encoders = {AbstractTensor: lambda x: x} validate_assignment = True @@ -97,3 +101,10 @@ def __setattr__(self, field, value) -> None: for key, val in self.__dict__.items(): dict_ref[key] = val object.__setattr__(self, '__dict__', dict_ref) + + def _docarray_to_json_compatible(self) -> Dict: + """ + Convert itself into a json compatible object + :return: A dictionary of the BaseDoc object + """ + return self.dict() diff --git a/docarray/base_doc/doc_response.py b/docarray/base_doc/docarray_response.py similarity index 80% rename from docarray/base_doc/doc_response.py rename to docarray/base_doc/docarray_response.py index 7d2b7b76ef6..3e62cf64f9b 100644 --- a/docarray/base_doc/doc_response.py +++ b/docarray/base_doc/docarray_response.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any +from docarray.base_doc.io.json import orjson_dumps from docarray.utils._internal.misc import import_library if TYPE_CHECKING: @@ -9,7 +10,7 @@ JSONResponse = fastapi.responses.JSONResponse -class DocResponse(JSONResponse): +class DocArrayResponse(JSONResponse): """ This is a custom Response class for FastAPI and starlette. This is needed to handle serialization of the Document types when using FastAPI @@ -26,7 +27,4 @@ async def create_item(doc: Text) -> Text: """ def render(self, content: Any) -> bytes: - if isinstance(content, bytes): - return content - else: - raise ValueError(f'{self.__class__} only work with json bytes content') + return orjson_dumps(content) diff --git a/docarray/base_doc/io/json.py b/docarray/base_doc/io/json.py index 1558659f061..27468b2b61c 100644 --- a/docarray/base_doc/io/json.py +++ b/docarray/base_doc/io/json.py @@ -1,8 +1,6 @@ import orjson from pydantic.json import ENCODERS_BY_TYPE -from docarray.typing.abstract_type import AbstractType - def _default_orjson(obj): """ @@ -10,8 +8,9 @@ def _default_orjson(obj): :param obj: :return: return a json compatible object """ + from docarray.base_doc import BaseNode - if isinstance(obj, AbstractType): + if isinstance(obj, BaseNode): return obj._docarray_to_json_compatible() else: for cls_, encoder in ENCODERS_BY_TYPE.items(): diff --git a/docarray/typing/abstract_type.py b/docarray/typing/abstract_type.py index fd73c93452e..3193116db08 100644 --- a/docarray/typing/abstract_type.py +++ b/docarray/typing/abstract_type.py @@ -1,20 +1,15 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar +from typing import Any, Type, TypeVar from pydantic import BaseConfig from pydantic.fields import ModelField from docarray.base_doc.base_node import BaseNode -if TYPE_CHECKING: - from docarray.proto import NodeProto - T = TypeVar('T') class AbstractType(BaseNode): - _proto_type_name: Optional[str] = None - @classmethod def __get_validators__(cls): yield cls.validate @@ -28,19 +23,3 @@ def validate( config: 'BaseConfig', ) -> T: ... - - @classmethod - @abstractmethod - def from_protobuf(cls: Type[T], pb_msg: T) -> T: - ... - - @abstractmethod - def _to_node_protobuf(self: T) -> 'NodeProto': - ... - - def _docarray_to_json_compatible(self): - """ - Convert itself into a json compatible object - :return: a representation of the tensor compatible with orjson - """ - return self diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 08aa0d014ae..f9814b429e4 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -305,4 +305,4 @@ def _docarray_to_json_compatible(self): Convert tensor into a json compatible object :return: a representation of the tensor compatible with orjson """ - ... + return self diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 5c5ed0bba60..438d2a86402 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -1,10 +1,12 @@ +from typing import List + import numpy as np import pytest from fastapi import FastAPI from httpx import AsyncClient -from docarray import BaseDoc -from docarray.base_doc import DocResponse +from docarray import BaseDoc, DocArray +from docarray.base_doc import DocArrayResponse from docarray.documents import ImageDoc, TextDoc from docarray.typing import NdArray @@ -22,8 +24,8 @@ class Mmdoc(BaseDoc): app = FastAPI() - @app.post("/doc/", response_model=Mmdoc, response_class=DocResponse) - async def create_item(doc: Mmdoc): + @app.post("/doc/", response_model=Mmdoc, response_class=DocArrayResponse) + async def create_item(doc: Mmdoc) -> Mmdoc: return doc async with AsyncClient(app=app, base_url="http://test") as ac: @@ -49,7 +51,7 @@ class OutputDoc(BaseDoc): app = FastAPI() - @app.post("/doc/", response_model=OutputDoc, response_class=DocResponse) + @app.post("/doc/", response_model=OutputDoc, response_class=DocArrayResponse) async def create_item(doc: InputDoc) -> OutputDoc: ## call my fancy model to generate the embeddings doc = OutputDoc( @@ -86,7 +88,7 @@ class OutputDoc(BaseDoc): app = FastAPI() - @app.post("/doc/", response_model=OutputDoc, response_class=DocResponse) + @app.post("/doc/", response_model=OutputDoc, response_class=DocArrayResponse) async def create_item(doc: InputDoc) -> OutputDoc: ## call my fancy model to generate the embeddings return OutputDoc( @@ -107,3 +109,29 @@ async def create_item(doc: InputDoc) -> OutputDoc: assert isinstance(doc, OutputDoc) assert doc.embedding_clip.shape == (100, 1) assert doc.embedding_bert.shape == (100, 1) + + +@pytest.mark.asyncio +async def test_docarray(): + doc = ImageDoc(tensor=np.zeros((3, 224, 224))) + docs = DocArray[ImageDoc]([doc, doc]) + + app = FastAPI() + + @app.post("/doc/", response_class=DocArrayResponse) + async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]: + docarray_docs = DocArray[ImageDoc].construct(fastapi_docs) + return list(docarray_docs) + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=docs.to_json()) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + docs = DocArray[ImageDoc].from_json(response.content.decode()) + assert len(docs) == 2 + assert docs[0].tensor.shape == (3, 224, 224)