Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docarray/base_document/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from docarray.base_document.any_document import AnyDocument
from docarray.base_document.base_node import BaseNode
from docarray.base_document.document import BaseDocument
from docarray.base_document.document_response import DocumentResponse

__all__ = ['AnyDocument', 'BaseDocument', 'BaseNode']
__all__ = ['AnyDocument', 'BaseDocument', 'BaseNode', 'DocumentResponse']
6 changes: 4 additions & 2 deletions docarray/base_document/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from docarray.base_document.abstract_document import AbstractDocument
from docarray.base_document.base_node import BaseNode
from docarray.base_document.io.json import orjson_dumps
from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode
from docarray.base_document.mixins import ProtoMixin
from docarray.typing import ID

Expand All @@ -20,7 +20,9 @@ class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode):

class Config:
json_loads = orjson.loads
json_dumps = orjson_dumps
json_dumps = orjson_dumps_and_decode
json_encoders = {dict: orjson_dumps}

validate_assignment = True

@classmethod
Expand Down
34 changes: 34 additions & 0 deletions docarray/base_document/document_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any

try:
from fastapi.responses import JSONResponse, Response
except ImportError:

class NoImportResponse:
def __init__(self, *args, **kwargs):
ImportError('fastapi is not installed')

Response = JSONResponse = NoImportResponse # type: ignore


class DocumentResponse(JSONResponse):
"""
This is a custom Response class for FastAPI and starlette. This is needed
to handle serialization of the Document types when using FastAPI

EXAMPLE USAGE
.. code-block:: python
from docarray.documets import Text
from docarray.base_document import DocumentResponse


@app.post("/doc/", response_model=Text, response_class=DocumentResponse)
async def create_item(doc: Text) -> Text:
return doc
"""

def render(self, content: Any) -> bytes:
if isinstance(content, bytes):
return content
else:
raise ValueError(f'{self.__class__} only work with json bytes content')
22 changes: 13 additions & 9 deletions docarray/base_document/io/json.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import orjson

from docarray.typing.tensor.abstract_tensor import AbstractTensor


def _default_orjson(obj):
"""
default option for orjson dumps. It will call _to_json_compatible
from docarray typing object that expose such method.
default option for orjson dumps.
:param obj:
:return: return a json compatible object
"""

if getattr(obj, '_to_json_compatible'):
return obj._to_json_compatible()
if isinstance(obj, AbstractTensor):
return obj._docarray_to_json_compatible()
else:
return obj


def orjson_dumps(v, *, default=None):
# orjson.dumps returns bytes, to match standard json.dumps we need to decode
return orjson.dumps(
v, default=_default_orjson, option=orjson.OPT_SERIALIZE_NUMPY
).decode()
def orjson_dumps(v, *, default=None) -> bytes:
# dumps to bytes using orjson
return orjson.dumps(v, default=_default_orjson, option=orjson.OPT_SERIALIZE_NUMPY)


def orjson_dumps_and_decode(v, *, default=None) -> str:
# dumps to bytes using orjson
return orjson_dumps(v, default=default).decode()
8 changes: 7 additions & 1 deletion docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def get_comp_backend() -> Type[AbstractComputationalBackend]:
"""The computational backend compatible with this tensor type."""
...

@abc.abstractmethod
def __getitem__(self, item):
"""Get a slice of this tensor."""
...
Expand All @@ -136,4 +135,11 @@ def to_protobuf(self) -> 'NdArrayProto':

def unwrap(self):
"""Return the native tensor object that this DocArray tensor wraps."""

@abc.abstractmethod
def _docarray_to_json_compatible(self):
"""
Convert tensor into a json compatible object
:return: a representation of the tensor compatible with orjson
"""
...
4 changes: 2 additions & 2 deletions docarray/typing/tensor/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
# this is needed to dump to json
field_schema.update(type='string', format='tensor')

def _to_json_compatible(self) -> np.ndarray:
def _docarray_to_json_compatible(self) -> np.ndarray:
"""
Convert tensor into a json compatible object
:return: a list representation of the tensor
:return: a representation of the tensor compatible with orjson
"""
return self.unwrap()

Expand Down
4 changes: 2 additions & 2 deletions docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
# this is needed to dump to json
field_schema.update(type='string', format='tensor')

def _to_json_compatible(self) -> np.ndarray:
def _docarray_to_json_compatible(self) -> np.ndarray:
"""
Convert torchTensor into a json compatible object
:return: a list representation of the torch tensor
:return: a representation of the tensor compatible with orjson
"""
return self.numpy() ## might need to check device later

Expand Down
17 changes: 9 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ types-pillow = {version = "^9.3.0.1", optional = true }
trimesh = {version = "^3.17.1", optional = true}
typing-inspect = "^0.8.0"
types-requests = "^2.28.11.6"
fastapi = {version = "^0.87.0", optional = true }

[tool.poetry.extras]
common = ["protobuf"]
torch = ["torch"]
image = ["pillow", "types-pillow"]
mesh = ["trimesh"]
web = ["fastapi"]

[tool.poetry.dev-dependencies]
pytest = "^6.1"
Expand All @@ -35,7 +37,6 @@ isort = "^5.10.1"
ruff = "^0.0.165"

[tool.poetry.group.dev.dependencies]
fastapi = "^0.87.0"
uvicorn = "^0.19.0"
httpx = "^0.23.0"
pytest-asyncio = "^0.20.2"
Expand Down
40 changes: 24 additions & 16 deletions tests/integrations/document/test_to_json.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,51 @@
import numpy as np
import pytest
import torch

from docarray.base_document import BaseDocument
from docarray.base_document.io.json import orjson_dumps
from docarray.typing import AnyUrl, NdArray, TorchTensor


def test_to_json():
@pytest.fixture()
def doc_and_class():
class Mmdoc(BaseDocument):
img: NdArray
url: AnyUrl
txt: str
torch_tensor: TorchTensor

doc = Mmdoc(
img=np.zeros((3, 224, 224)),
img=np.zeros((10)),
url='http://doccaray.io',
txt='hello',
torch_tensor=torch.zeros(3, 224, 224),
torch_tensor=torch.zeros(10),
)
doc.json()
return doc, Mmdoc


def test_from_json():
class Mmdoc(BaseDocument):
img: NdArray
url: AnyUrl
txt: str
torch_tensor: TorchTensor
def test_to_json(doc_and_class):
doc, _ = doc_and_class
doc.json()

doc = Mmdoc(
img=np.zeros((2, 2)),
url='http://doccaray.io',
txt='hello',
torch_tensor=torch.zeros(3, 224, 224),
)

def test_from_json(doc_and_class):
doc, Mmdoc = doc_and_class
new_doc = Mmdoc.parse_raw(doc.json())

for (field, field2) in zip(doc.dict().keys(), new_doc.dict().keys()):
if field in ['torch_tensor', 'img']:
assert (getattr(doc, field) == getattr(doc, field2)).all()
else:
assert getattr(doc, field) == getattr(doc, field2)


def test_to_dict_to_json(doc_and_class):
doc, Mmdoc = doc_and_class
new_doc = Mmdoc.parse_raw(orjson_dumps(doc.dict()))

for (field, field2) in zip(doc.dict().keys(), new_doc.dict().keys()):
if field in ['torch_tensor', 'img']:
assert (getattr(doc, field) == getattr(doc, field2)).all()
else:
assert getattr(doc, field) == getattr(doc, field2)
22 changes: 18 additions & 4 deletions tests/integrations/externals/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from httpx import AsyncClient

from docarray import BaseDocument
from docarray.base_document import DocumentResponse
from docarray.documents import Image, Text
from docarray.typing import NdArray

Expand All @@ -21,7 +22,7 @@ class Mmdoc(BaseDocument):

app = FastAPI()

@app.post("/doc/")
@app.post("/doc/", response_model=Mmdoc, response_class=DocumentResponse)
async def create_item(doc: Mmdoc):
return doc

Expand All @@ -48,12 +49,13 @@ class OutputDoc(BaseDocument):

app = FastAPI()

@app.post("/doc/", response_model=OutputDoc)
@app.post("/doc/", response_model=OutputDoc, response_class=DocumentResponse)
async def create_item(doc: InputDoc) -> OutputDoc:
## call my fancy model to generate the embeddings
return OutputDoc(
doc = OutputDoc(
embedding_clip=np.zeros((100, 1)), embedding_bert=np.zeros((100, 1))
)
return doc

async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post("/doc/", data=input_doc.json())
Expand All @@ -64,6 +66,12 @@ async def create_item(doc: InputDoc) -> OutputDoc:
assert resp_doc.status_code == 200
assert resp_redoc.status_code == 200

doc = OutputDoc.parse_raw(response.content.decode())

assert isinstance(doc, OutputDoc)
assert doc.embedding_clip.shape == (100, 1)
assert doc.embedding_bert.shape == (100, 1)


@pytest.mark.asyncio
async def test_sentence_to_embeddings():
Expand All @@ -78,7 +86,7 @@ class OutputDoc(BaseDocument):

app = FastAPI()

@app.post("/doc/", response_model=OutputDoc)
@app.post("/doc/", response_model=OutputDoc, response_class=DocumentResponse)
async def create_item(doc: InputDoc) -> OutputDoc:
## call my fancy model to generate the embeddings
return OutputDoc(
Expand All @@ -93,3 +101,9 @@ async def create_item(doc: InputDoc) -> OutputDoc:
assert response.status_code == 200
assert resp_doc.status_code == 200
assert resp_redoc.status_code == 200

doc = OutputDoc.parse_raw(response.content.decode())

assert isinstance(doc, OutputDoc)
assert doc.embedding_clip.shape == (100, 1)
assert doc.embedding_bert.shape == (100, 1)