Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docarray/document/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

if TYPE_CHECKING:
from ..types import ArrayType, StructValueType, DocumentContentType
from .. import DocumentArray
from ..score import NamedScore


class Document(AllMixins, BaseDCType):
Expand Down
10 changes: 9 additions & 1 deletion docarray/document/mixins/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from collections import defaultdict
from typing import TYPE_CHECKING, Type

Expand Down Expand Up @@ -41,7 +42,6 @@ def from_pydantic_model(cls: Type['T'], model: 'BaseModel') -> 'T':
"""Build a Document object from a Pydantic model

:param model: the pydantic data model object that represents a Document
:param ndarray_as_list: if set to True, `embedding` and `tensor` are auto-casted to ndarray.
:return: a Document object
"""
from ... import Document
Expand All @@ -65,6 +65,14 @@ def from_pydantic_model(cls: Type['T'], model: 'BaseModel') -> 'T':
fields[f_name][k] = NamedScore(v)
elif f_name == 'embedding' or f_name == 'tensor':
fields[f_name] = np.array(value)
elif f_name == 'blob':
# here is a dirty fishy itchy trick
# the original bytes will be encoded two times:
# first time is real during `to_dict/to_json`, it converts into base64 string
# second time is at `from_dict/from_json`, it is unnecessary yet inevitable, the result string get
# converted into a binary string and encoded again.
# consequently, we need to decode two times here!
fields[f_name] = base64.b64decode(base64.b64decode(value))
else:
fields[f_name] = value

Expand Down
9 changes: 9 additions & 0 deletions docarray/document/pydantic_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from typing import Optional, List, Dict, Any, TYPE_CHECKING, Union

from pydantic import BaseModel, validator
Expand Down Expand Up @@ -43,6 +44,14 @@ class PydanticDocument(BaseModel):
_tensor2list = validator('tensor', allow_reuse=True)(_convert_ndarray_to_list)
_embedding2list = validator('embedding', allow_reuse=True)(_convert_ndarray_to_list)

@validator('blob')
def _blob2base64(cls, v):
if v is not None:
if isinstance(v, bytes):
return base64.b64encode(v).decode('utf8')
else:
raise ValueError('must be bytes')


PydanticDocument.update_forward_refs()

Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from collections import defaultdict
from typing import List, Optional

Expand Down Expand Up @@ -142,3 +143,19 @@ def test_tags_int_float_str_bool(tag_type, tag_value, protocol):
dd = d.to_dict(protocol=protocol)['tags']['hello'][-1]
assert dd == tag_value
assert isinstance(dd, tag_type)


@pytest.mark.parametrize(
'blob', [None, b'123', bytes(Document()), bytes(bytearray(os.urandom(512 * 4)))]
)
@pytest.mark.parametrize('protocol', ['jsonschema', 'protobuf'])
@pytest.mark.parametrize('to_fn', ['dict', 'json'])
def test_to_from_with_blob(protocol, to_fn, blob):
d = Document(blob=blob)
r_d = getattr(Document, f'from_{to_fn}')(
getattr(d, f'to_{to_fn}')(protocol=protocol), protocol=protocol
)

assert d.blob == r_d.blob
if d.blob:
assert isinstance(r_d.blob, bytes)