Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 131 additions & 19 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
overload,
)

import numpy as np
from pydantic import BaseConfig, parse_obj_as
from typing_inspect import typingGenericAlias

Expand All @@ -34,7 +35,12 @@
if TYPE_CHECKING:
from pydantic.fields import ModelField

from docarray.proto import DocVecProto
from docarray.proto import (
DocVecProto,
ListOfDocArrayProto,
ListOfDocVecProto,
NdArrayProto,
)

torch_available = is_torch_available()
if torch_available:
Expand All @@ -54,6 +60,56 @@
T = TypeVar('T', bound='DocVec')
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]

NONE_NDARRAY_PROTO_SHAPE = (0,)
NONE_NDARRAY_PROTO_DTYPE = 'None'


def _none_ndarray_proto() -> 'NdArrayProto':
from docarray.proto import NdArrayProto

zeros_arr = parse_obj_as(NdArray, np.zeros(NONE_NDARRAY_PROTO_SHAPE))
nd_proto = NdArrayProto()
nd_proto.dense.buffer = zeros_arr.tobytes()
nd_proto.dense.ClearField('shape')
nd_proto.dense.shape.extend(list(zeros_arr.shape))
nd_proto.dense.dtype = NONE_NDARRAY_PROTO_DTYPE

return nd_proto


def _none_docvec_proto() -> 'DocVecProto':
from docarray.proto import DocVecProto

return DocVecProto()


def _none_list_of_docvec_proto() -> 'ListOfDocArrayProto':
from docarray.proto import ListOfDocVecProto

return ListOfDocVecProto()


def _is_none_ndarray_proto(proto: 'NdArrayProto') -> bool:
return (
proto.dense.shape == list(NONE_NDARRAY_PROTO_SHAPE)
and proto.dense.dtype == NONE_NDARRAY_PROTO_DTYPE
)


def _is_none_docvec_proto(proto: 'DocVecProto') -> bool:
return (
proto.tensor_columns == {}
and proto.doc_columns == {}
and proto.docs_vec_columns == {}
and proto.any_columns == {}
)


def _is_none_list_of_docvec_proto(proto: 'ListOfDocVecProto') -> bool:
from docarray.proto import ListOfDocVecProto

return isinstance(proto, ListOfDocVecProto) and len(proto.data) == 0


class DocVec(AnyDocArray[T_doc]):
"""
Expand Down Expand Up @@ -243,7 +299,7 @@ def _check_doc_field_not_none(field_name, doc):
elif issubclass(field_type, AnyDocArray):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
doc_columns[field_name] = None
docs_vec_columns[field_name] = None
else:
docs_list = list()
for doc in docs:
Expand Down Expand Up @@ -534,48 +590,104 @@ def __len__(self):

@classmethod
def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T:
"""create a Document from a protobuf message"""
"""create a DocVec from a protobuf message"""

tensor_columns: Dict[str, Optional[AbstractTensor]] = {}
doc_columns: Dict[str, Optional['DocVec']] = {}
docs_vec_columns: Dict[str, Optional[ListAdvancedIndexing['DocVec']]] = {}
any_columns: Dict[str, ListAdvancedIndexing] = {}

for tens_col_name, tens_col_proto in pb_msg.tensor_columns.items():
if _is_none_ndarray_proto(tens_col_proto):
# handle values that were None before serialization
tensor_columns[tens_col_name] = None
else:
# TODO(johannes): handle torch, tf, numpy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will do this in a separate PR, it might require a proto change (but not sure)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that it require a proto change. I think here you should look at the tensor type of this field in the doc type and use it to load the proto

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same way you do it for doc_columns

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's discuss that in a separate PR

tensor_columns[tens_col_name] = NdArray.from_protobuf(tens_col_proto)

for doc_col_name, doc_col_proto in pb_msg.doc_columns.items():
if _is_none_docvec_proto(doc_col_proto):
# handle values that were None before serialization
doc_columns[doc_col_name] = None
else:
col_doc_type: Type = cls.doc_type._get_field_type(doc_col_name)
doc_columns[doc_col_name] = DocVec.__class_getitem__(
col_doc_type
).from_protobuf(doc_col_proto)

for docs_vec_col_name, docs_vec_col_proto in pb_msg.docs_vec_columns.items():
vec_list: Optional[ListAdvancedIndexing]
if _is_none_list_of_docvec_proto(docs_vec_col_proto):
# handle values that were None before serialization
vec_list = None
else:
vec_list = ListAdvancedIndexing()
for doc_list_proto in docs_vec_col_proto.data:
col_doc_type = cls.doc_type._get_field_type(
docs_vec_col_name
).doc_type
vec_list.append(
DocVec.__class_getitem__(col_doc_type).from_protobuf(
doc_list_proto
)
)
docs_vec_columns[docs_vec_col_name] = vec_list

for any_col_name, any_col_proto in pb_msg.any_columns.items():
any_column: ListAdvancedIndexing = ListAdvancedIndexing()
for node_proto in any_col_proto.data:
content = cls.doc_type._get_content_from_node_proto(
node_proto, any_col_name
)
any_column.append(content)
any_columns[any_col_name] = any_column

storage = ColumnStorage(
pb_msg.tensor_columns,
pb_msg.doc_columns,
pb_msg.docs_vec_columns,
pb_msg.any_columns,
tensor_columns=tensor_columns,
doc_columns=doc_columns,
docs_vec_columns=docs_vec_columns,
any_columns=any_columns,
)

return cls.from_columns_storage(storage)

def to_protobuf(self) -> 'DocVecProto':
"""Convert DocVec into a Protobuf message"""
from docarray.proto import (
DocListProto,
DocVecProto,
ListOfAnyProto,
ListOfDocArrayProto,
ListOfDocVecProto,
NdArrayProto,
)

da_proto = DocListProto()
for doc in self:
da_proto.docs.append(doc.to_protobuf())

doc_columns_proto: Dict[str, DocVecProto] = dict()
tensor_columns_proto: Dict[str, NdArrayProto] = dict()
da_columns_proto: Dict[str, ListOfDocArrayProto] = dict()
any_columns_proto: Dict[str, ListOfAnyProto] = dict()

for field, col_doc in self._storage.doc_columns.items():
doc_columns_proto[field] = (
col_doc.to_protobuf() if col_doc is not None else None
)
if col_doc is None:
# put dummy empty DocVecProto for serialization
doc_columns_proto[field] = _none_docvec_proto()
else:
doc_columns_proto[field] = col_doc.to_protobuf()
for field, col_tens in self._storage.tensor_columns.items():
tensor_columns_proto[field] = (
col_tens.to_protobuf() if col_tens is not None else None
)
if col_tens is None:
# put dummy empty NdArrayProto for serialization
tensor_columns_proto[field] = _none_ndarray_proto()
else:
tensor_columns_proto[field] = (
col_tens.to_protobuf() if col_tens is not None else None
)
for field, col_da in self._storage.docs_vec_columns.items():
list_proto = ListOfDocArrayProto()
list_proto = ListOfDocVecProto()
if col_da:
for docs in col_da:
list_proto.data.append(docs.to_protobuf())
else:
# put dummy empty ListOfDocVecProto for serialization
list_proto = _none_list_of_docvec_proto()
da_columns_proto[field] = list_proto
for field, col_any in self._storage.any_columns.items():
list_proto = ListOfAnyProto()
Expand Down
3 changes: 3 additions & 0 deletions docarray/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DocVecProto,
ListOfAnyProto,
ListOfDocArrayProto,
ListOfDocVecProto,
NdArrayProto,
NodeProto,
)
Expand All @@ -28,6 +29,7 @@
DocVecProto,
ListOfAnyProto,
ListOfDocArrayProto,
ListOfDocVecProto,
NdArrayProto,
NodeProto,
)
Expand All @@ -40,6 +42,7 @@
'DocVecProto',
'DocListProto',
'ListOfDocArrayProto',
'ListOfDocVecProto',
'ListOfAnyProto',
'DictOfAnyProto',
]
6 changes: 5 additions & 1 deletion docarray/proto/docarray.proto
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,13 @@ message ListOfDocArrayProto {
repeated DocListProto data = 1;
}

message ListOfDocVecProto {
repeated DocVecProto data = 1;
}

message DocVecProto{
map<string, NdArrayProto> tensor_columns = 1; // a dict of document columns
map<string, DocVecProto> doc_columns = 2; // a dict of tensor columns
map<string, ListOfDocArrayProto> docs_vec_columns = 3; // a dict of document array columns
map<string, ListOfDocVecProto> docs_vec_columns = 3; // a dict of document array columns
map<string, ListOfAnyProto> any_columns = 4; // a dict of any columns. Used for the rest of the data
}
24 changes: 13 additions & 11 deletions docarray/proto/pb/docarray_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading