Skip to content

Commit 21e107b

Browse files
author
Joan Fontanals
authored
fix: fix issue serializing deserializing complex schemas (docarray#1836)
Signed-off-by: Joan Martinez <[email protected]>
1 parent 3cfa0b8 commit 21e107b

File tree

4 files changed

+112
-10
lines changed

4 files changed

+112
-10
lines changed

docarray/base_doc/mixins/io.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def _get_content_from_node_proto(
285285
)
286286

287287
return_field: Any
288-
289288
if docarray_type in content_type_dict:
290289
return_field = content_type_dict[docarray_type].from_protobuf(
291290
getattr(value, content_key)
@@ -308,13 +307,18 @@ def _get_content_from_node_proto(
308307
f'{field_type} is not supported for proto deserialization'
309308
)
310309
elif content_key == 'doc_array':
311-
if field_name is None:
310+
if field_type is not None and field_name is None:
311+
return_field = field_type.from_protobuf(getattr(value, content_key))
312+
elif field_name is not None:
313+
return_field = cls._get_field_annotation_array(
314+
field_name
315+
).from_protobuf(
316+
getattr(value, content_key)
317+
) # we get to the parent class
318+
else:
312319
raise ValueError(
313-
'field_name cannot be None when trying to deserialize a BaseDoc'
320+
'field_name and field_type cannot be None when trying to deserialize a DocArray'
314321
)
315-
return_field = cls._get_field_annotation_array(field_name).from_protobuf(
316-
getattr(value, content_key)
317-
) # we get to the parent class
318322
elif content_key is None:
319323
return_field = None
320324
elif docarray_type is None:
@@ -330,8 +334,6 @@ def _get_content_from_node_proto(
330334
elif content_key in arg_to_container.keys():
331335
if field_name and field_name in cls._docarray_fields():
332336
field_type = cls._get_field_inner_type(field_name)
333-
else:
334-
field_type = None
335337

336338
if isinstance(field_type, GenericAlias):
337339
field_type = get_args(field_type)[0]

tests/units/array/test_array_from_to_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Dict, List
22

33
import numpy as np
44
import pytest

tests/units/array/test_array_proto.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from typing import Dict, List
34

45
from docarray import BaseDoc, DocList
56
from docarray.base_doc import AnyDoc
@@ -111,3 +112,41 @@ class BasisUnion(BaseDoc):
111112
docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
112113
docs_copy = DocList[BasisUnion].from_protobuf(docs_basic.to_protobuf())
113114
assert docs_copy == docs_basic
115+
116+
117+
class MySimpleDoc(BaseDoc):
118+
title: str
119+
120+
121+
class MyComplexDoc(BaseDoc):
122+
content_dict_doclist: Dict[str, DocList[MySimpleDoc]]
123+
content_dict_list: Dict[str, List[MySimpleDoc]]
124+
aux_dict: Dict[str, int]
125+
126+
127+
def test_to_from_proto_complex():
128+
da = DocList[MyComplexDoc](
129+
[
130+
MyComplexDoc(
131+
content_dict_doclist={
132+
'test1': DocList[MySimpleDoc](
133+
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
134+
)
135+
},
136+
content_dict_list={
137+
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
138+
},
139+
aux_dict={'a': 0},
140+
)
141+
]
142+
)
143+
da2 = DocList[MyComplexDoc].from_protobuf(da.to_protobuf())
144+
assert len(da2) == 1
145+
d2 = da2[0]
146+
assert d2.aux_dict == {'a': 0}
147+
assert len(d2.content_dict_doclist['test1']) == 2
148+
assert d2.content_dict_doclist['test1'][0].title == '123'
149+
assert d2.content_dict_doclist['test1'][1].title == '456'
150+
assert len(d2.content_dict_list['test1']) == 2
151+
assert d2.content_dict_list['test1'][0].title == '123'
152+
assert d2.content_dict_list['test1'][1].title == '456'

tests/units/document/test_from_to_bytes.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
2+
from typing import Dict, List
23

3-
from docarray import BaseDoc
4+
from docarray import BaseDoc, DocList
45
from docarray.documents import ImageDoc
56
from docarray.typing import NdArray
67

@@ -11,6 +12,16 @@ class MyDoc(BaseDoc):
1112
image: ImageDoc
1213

1314

15+
class MySimpleDoc(BaseDoc):
16+
title: str
17+
18+
19+
class MyComplexDoc(BaseDoc):
20+
content_dict_doclist: Dict[str, DocList[MySimpleDoc]]
21+
content_dict_list: Dict[str, List[MySimpleDoc]]
22+
aux_dict: Dict[str, int]
23+
24+
1425
@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
1526
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
1627
def test_to_from_bytes(protocol, compress):
@@ -39,3 +50,53 @@ def test_to_from_base64(protocol, compress):
3950
assert d2.text == 'hello'
4051
assert d2.embedding.tolist() == [1, 2, 3, 4, 5]
4152
assert d2.image.url == 'aux.png'
53+
54+
55+
@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
56+
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
57+
def test_to_from_bytes_complex(protocol, compress):
58+
d = MyComplexDoc(
59+
content_dict_doclist={
60+
'test1': DocList[MySimpleDoc](
61+
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
62+
)
63+
},
64+
content_dict_list={
65+
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
66+
},
67+
aux_dict={'a': 0},
68+
)
69+
bstr = d.to_bytes(protocol=protocol, compress=compress)
70+
d2 = MyComplexDoc.from_bytes(bstr, protocol=protocol, compress=compress)
71+
assert d2.aux_dict == {'a': 0}
72+
assert len(d2.content_dict_doclist['test1']) == 2
73+
assert d2.content_dict_doclist['test1'][0].title == '123'
74+
assert d2.content_dict_doclist['test1'][1].title == '456'
75+
assert len(d2.content_dict_list['test1']) == 2
76+
assert d2.content_dict_list['test1'][0].title == '123'
77+
assert d2.content_dict_list['test1'][1].title == '456'
78+
79+
80+
@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
81+
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
82+
def test_to_from_base64_complex(protocol, compress):
83+
d = MyComplexDoc(
84+
content_dict_doclist={
85+
'test1': DocList[MySimpleDoc](
86+
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
87+
)
88+
},
89+
content_dict_list={
90+
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
91+
},
92+
aux_dict={'a': 0},
93+
)
94+
bstr = d.to_base64(protocol=protocol, compress=compress)
95+
d2 = MyComplexDoc.from_base64(bstr, protocol=protocol, compress=compress)
96+
assert d2.aux_dict == {'a': 0}
97+
assert len(d2.content_dict_doclist['test1']) == 2
98+
assert d2.content_dict_doclist['test1'][0].title == '123'
99+
assert d2.content_dict_doclist['test1'][1].title == '456'
100+
assert len(d2.content_dict_list['test1']) == 2
101+
assert d2.content_dict_list['test1'][0].title == '123'
102+
assert d2.content_dict_list['test1'][1].title == '456'

0 commit comments

Comments
 (0)