Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions docarray/array/array/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
import csv
import io
import json
import os
import pathlib
import pickle
Expand All @@ -25,7 +24,10 @@
Union,
)

import orjson

from docarray.base_doc import AnyDoc, BaseDoc
from docarray.base_doc.io.json import orjson_dumps
from docarray.helper import (
_access_path_dict_to_nested_dict,
_all_access_paths_valid,
Expand All @@ -41,6 +43,7 @@
from docarray.proto import DocumentArrayProto

T = TypeVar('T', bound='IOMixinArray')
T_doc = TypeVar('T_doc', bound=BaseDoc)

ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'}
SINGLE_PROTOCOLS = {'pickle', 'protobuf'}
Expand Down Expand Up @@ -93,9 +96,9 @@ def __getitem__(self, item: slice):
return self.content[item]


class IOMixinArray(Iterable[BaseDoc]):

document_type: Type[BaseDoc]
class IOMixinArray(Iterable[T_doc]):
document_type: Type[T_doc]
_data: List[T_doc]

@abstractmethod
def __len__(self):
Expand Down Expand Up @@ -251,7 +254,7 @@ def to_bytes(
:return: the binary serialization in bytes or None if file_ctx is passed where to store
"""

with (file_ctx or io.BytesIO()) as bf:
with file_ctx or io.BytesIO() as bf:
self._write_bytes(
bf=bf,
protocol=protocol,
Expand Down Expand Up @@ -318,14 +321,21 @@ def from_json(
:param file: JSON object from where to deserialize a DocArray
:return: the deserialized DocArray
"""
json_docs = json.loads(file)
return cls([cls.document_type.parse_raw(v) for v in json_docs])
json_docs = orjson.loads(file)
return cls([cls.document_type(**v) for v in json_docs])

def to_json(self) -> str:
"""Convert the object into a JSON string. Can be loaded via :meth:`.from_json`.
def to_json(self) -> bytes:
"""Convert the object into JSON bytes. Can be loaded via :meth:`.from_json`.
:return: JSON serialization of DocArray
"""
return json.dumps([doc.json() for doc in self])
return orjson_dumps(self._data)

def _docarray_to_json_compatible(self) -> List[T_doc]:
"""
Convert itself into a json compatible object
:return: A list of documents
"""
return self._data

@classmethod
def from_csv(
Expand Down Expand Up @@ -615,7 +625,7 @@ def _load_binary_stream(
protocol: str = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Generator['BaseDoc', None, None]:
) -> Generator['T_doc', None, None]:
"""Yield `Document` objects from a binary file

:param protocol: protocol to use. It can be 'pickle' or 'protobuf'
Expand Down Expand Up @@ -672,7 +682,7 @@ def load_binary(
compress: Optional[str] = None,
show_progress: bool = False,
streaming: bool = False,
) -> Union[T, Generator['BaseDoc', None, None]]:
) -> Union[T, Generator['T_doc', None, None]]:
"""Load array elements from a compressed binary file.

:param file: File or filename or serialized bytes where the data is stored.
Expand Down
6 changes: 3 additions & 3 deletions docarray/base_doc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@


def __getattr__(name: str):
if name == 'DocResponse':
if name == 'DocArrayResponse':
import_library('fastapi', raise_error=True)
from docarray.base_doc.doc_response import DocResponse
from docarray.base_doc.docarray_response import DocArrayResponse

if name not in __all__:
__all__.append(name)

return DocResponse
return DocArrayResponse
else:
raise ImportError(
f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\''
Expand Down
17 changes: 16 additions & 1 deletion docarray/base_doc/base_node.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar, Optional, Type

if TYPE_CHECKING:
from docarray.proto import NodeProto

T = TypeVar('T')


class BaseNode(ABC):
"""
A DocumentNode is an object than can be nested inside a Document.
A Document itself is a DocumentNode as well as prebuilt type
"""

_proto_type_name: Optional[str] = None

@abstractmethod
def _to_node_protobuf(self) -> 'NodeProto':
"""Convert itself into a NodeProto message. This function should
Expand All @@ -20,3 +24,14 @@ def _to_node_protobuf(self) -> 'NodeProto':
:return: the nested item protobuf message
"""
...

@classmethod
@abstractmethod
def from_protobuf(cls: Type[T], pb_msg: T) -> T:
...

def _docarray_to_json_compatible(self):
"""
Convert itself into a json compatible object
"""
...
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not it still return self by default ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe not

Copy link
Copy Markdown
Contributor Author

@jupyterjazz jupyterjazz Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to leave this abstract and define for each class that uses, otherwise we say self by default but override it in both IOmixin and BaseDoc class

17 changes: 14 additions & 3 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Dict

import orjson
from pydantic import BaseModel, Field
from rich.console import Console

from docarray.base_doc.base_node import BaseNode
from docarray.base_doc.io.json import orjson_dumps, orjson_dumps_and_decode
from docarray.base_doc.io.json import orjson_dumps_and_decode
from docarray.base_doc.mixins import IOMixin, UpdateMixin
from docarray.typing import ID
from docarray.typing.tensor.abstract_tensor import AbstractTensor

if TYPE_CHECKING:
from docarray.array.stacked.column_storage import ColumnStorageView
Expand All @@ -28,7 +29,10 @@ class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode):
class Config:
json_loads = orjson.loads
json_dumps = orjson_dumps_and_decode
json_encoders = {dict: orjson_dumps}
# `DocArrayResponse` is able to handle tensors by itself.
# Therefore, we stop FastAPI from doing any transformations
# on tensors by setting an identity function as a custom encoder.
json_encoders = {AbstractTensor: lambda x: x}

validate_assignment = True

Expand Down Expand Up @@ -97,3 +101,10 @@ def __setattr__(self, field, value) -> None:
for key, val in self.__dict__.items():
dict_ref[key] = val
object.__setattr__(self, '__dict__', dict_ref)

def _docarray_to_json_compatible(self) -> Dict:
"""
Convert itself into a json compatible object
:return: A dictionary of the BaseDoc object
"""
return self.dict()
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any

from docarray.base_doc.io.json import orjson_dumps
from docarray.utils._internal.misc import import_library

if TYPE_CHECKING:
Expand All @@ -9,7 +10,7 @@
JSONResponse = fastapi.responses.JSONResponse


class DocResponse(JSONResponse):
class DocArrayResponse(JSONResponse):
"""
This is a custom Response class for FastAPI and starlette. This is needed
to handle serialization of the Document types when using FastAPI
Expand All @@ -26,7 +27,4 @@ async def create_item(doc: Text) -> Text:
"""

def render(self, content: Any) -> bytes:
if isinstance(content, bytes):
return content
else:
raise ValueError(f'{self.__class__} only work with json bytes content')
return orjson_dumps(content)
5 changes: 2 additions & 3 deletions docarray/base_doc/io/json.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import orjson
from pydantic.json import ENCODERS_BY_TYPE

from docarray.typing.abstract_type import AbstractType


def _default_orjson(obj):
"""
default option for orjson dumps.
:param obj:
:return: return a json compatible object
"""
from docarray.base_doc import BaseNode

if isinstance(obj, AbstractType):
if isinstance(obj, BaseNode):
return obj._docarray_to_json_compatible()
else:
for cls_, encoder in ENCODERS_BY_TYPE.items():
Expand Down
23 changes: 1 addition & 22 deletions docarray/typing/abstract_type.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
from typing import Any, Type, TypeVar

from pydantic import BaseConfig
from pydantic.fields import ModelField

from docarray.base_doc.base_node import BaseNode

if TYPE_CHECKING:
from docarray.proto import NodeProto

T = TypeVar('T')


class AbstractType(BaseNode):
_proto_type_name: Optional[str] = None

@classmethod
def __get_validators__(cls):
yield cls.validate
Expand All @@ -28,19 +23,3 @@ def validate(
config: 'BaseConfig',
) -> T:
...

@classmethod
@abstractmethod
def from_protobuf(cls: Type[T], pb_msg: T) -> T:
...

@abstractmethod
def _to_node_protobuf(self: T) -> 'NodeProto':
...

def _docarray_to_json_compatible(self):
"""
Convert itself into a json compatible object
:return: a representation of the tensor compatible with orjson
"""
return self
2 changes: 1 addition & 1 deletion docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,4 @@ def _docarray_to_json_compatible(self):
Convert tensor into a json compatible object
:return: a representation of the tensor compatible with orjson
"""
...
return self
40 changes: 34 additions & 6 deletions tests/integrations/externals/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List

import numpy as np
import pytest
from fastapi import FastAPI
from httpx import AsyncClient

from docarray import BaseDoc
from docarray.base_doc import DocResponse
from docarray import BaseDoc, DocArray
from docarray.base_doc import DocArrayResponse
from docarray.documents import ImageDoc, TextDoc
from docarray.typing import NdArray

Expand All @@ -22,8 +24,8 @@ class Mmdoc(BaseDoc):

app = FastAPI()

@app.post("/doc/", response_model=Mmdoc, response_class=DocResponse)
async def create_item(doc: Mmdoc):
@app.post("/doc/", response_model=Mmdoc, response_class=DocArrayResponse)
async def create_item(doc: Mmdoc) -> Mmdoc:
return doc

async with AsyncClient(app=app, base_url="http://test") as ac:
Expand All @@ -49,7 +51,7 @@ class OutputDoc(BaseDoc):

app = FastAPI()

@app.post("/doc/", response_model=OutputDoc, response_class=DocResponse)
@app.post("/doc/", response_model=OutputDoc, response_class=DocArrayResponse)
async def create_item(doc: InputDoc) -> OutputDoc:
## call my fancy model to generate the embeddings
doc = OutputDoc(
Expand Down Expand Up @@ -86,7 +88,7 @@ class OutputDoc(BaseDoc):

app = FastAPI()

@app.post("/doc/", response_model=OutputDoc, response_class=DocResponse)
@app.post("/doc/", response_model=OutputDoc, response_class=DocArrayResponse)
async def create_item(doc: InputDoc) -> OutputDoc:
## call my fancy model to generate the embeddings
return OutputDoc(
Expand All @@ -107,3 +109,29 @@ async def create_item(doc: InputDoc) -> OutputDoc:
assert isinstance(doc, OutputDoc)
assert doc.embedding_clip.shape == (100, 1)
assert doc.embedding_bert.shape == (100, 1)


@pytest.mark.asyncio
async def test_docarray():
doc = ImageDoc(tensor=np.zeros((3, 224, 224)))
docs = DocArray[ImageDoc]([doc, doc])

app = FastAPI()

@app.post("/doc/", response_class=DocArrayResponse)
async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]:
docarray_docs = DocArray[ImageDoc].construct(fastapi_docs)
return list(docarray_docs)

async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post("/doc/", data=docs.to_json())
resp_doc = await ac.get("/docs")
resp_redoc = await ac.get("/redoc")

assert response.status_code == 200
assert resp_doc.status_code == 200
assert resp_redoc.status_code == 200

docs = DocArray[ImageDoc].from_json(response.content.decode())
assert len(docs) == 2
assert docs[0].tensor.shape == (3, 224, 224)