Skip to content

Commit 3da3603

Browse files
authored
fix: allow config extension in pydantic v2 (docarray#1814)
Signed-off-by: samsja <[email protected]>
1 parent 9a6b1e6 commit 3da3603

File tree

5 files changed

+80
-36
lines changed

5 files changed

+80
-36
lines changed

docarray/base_doc/doc.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
5252
)
5353

54+
from pydantic import ConfigDict
55+
5456

5557
_console: Console = Console()
5658

@@ -71,10 +73,14 @@ class BaseDocWithoutId(BaseModel, IOMixin, UpdateMixin, BaseNode):
7173

7274
if is_pydantic_v2:
7375

74-
class Config:
75-
validate_assignment = True
76-
_load_extra_fields_from_protobuf = False
77-
json_encoders = {AbstractTensor: lambda x: x}
76+
class ConfigDocArray(ConfigDict):
77+
_load_extra_fields_from_protobuf: bool
78+
79+
model_config = ConfigDocArray(
80+
validate_assignment=True,
81+
_load_extra_fields_from_protobuf=False,
82+
json_encoders={AbstractTensor: lambda x: x},
83+
)
7884

7985
else:
8086

docarray/base_doc/mixins/io.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,14 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T:
238238
"""
239239

240240
fields: Dict[str, Any] = {}
241-
241+
load_extra_field = (
242+
cls.model_config['_load_extra_fields_from_protobuf']
243+
if is_pydantic_v2
244+
else cls.Config._load_extra_fields_from_protobuf
245+
)
242246
for field_name in pb_msg.data:
243247
if (
244-
not (cls.Config._load_extra_fields_from_protobuf)
248+
not (load_extra_field)
245249
and field_name not in cls._docarray_fields().keys()
246250
):
247251
continue # optimization we don't even load the data if the key does not

docs/user_guide/representing/first_step.md

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,22 +119,37 @@ This representation can be used to [send](../sending/first_step.md) or [store](.
119119

120120
## Setting a Pydantic `Config` class
121121

122-
Documents support setting a `Config` [like any other Pydantic `BaseModel`](https://docs.pydantic.dev/latest/usage/model_config/).
122+
Documents support setting a custom `configuration` [like any other Pydantic `BaseModel`](https://docs.pydantic.dev/latest/api/config/).
123123

124-
However, if you set a config, you should inherit from the `BaseDoc` config class:
124+
Here is an example to extend the Config of a Document dependong on which version of Pydantic you are using.
125125

126-
```python
127-
from docarray import BaseDoc
128126

129127

130-
class MyDoc(BaseDoc):
131-
class Config(BaseDoc.Config):
132-
arbitrary_types_allowed = True # just an example setting
133-
```
128+
=== "Pydantic v1"
129+
```python
130+
from docarray import BaseDoc
131+
132+
133+
class MyDoc(BaseDoc):
134+
class Config(BaseDoc.Config):
135+
arbitrary_types_allowed = True # just an example setting
136+
```
137+
138+
=== "Pydantic v2"
139+
```python
140+
from docarray import BaseDoc
141+
142+
143+
class MyDoc(BaseDoc):
144+
model_config = BaseDoc.ConfigDocArray.ConfigDict(
145+
arbitrary_types_allowed=True
146+
) # just an example setting
147+
```
134148

135149
See also:
136150

137151
* The [next part](./array.md) of the representing section
138152
* API reference for the [BaseDoc][docarray.base_doc.doc.BaseDoc] class
139153
* The [Storing](../storing/first_step.md) section on how to store your data
140154
* The [Sending](../sending/first_step.md) section on how to send your data
155+

tests/units/document/test_any_document.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@
22

33
import numpy as np
44
import pytest
5-
from orjson import orjson
65

76
from docarray import DocList
87
from docarray.base_doc import AnyDoc, BaseDoc
9-
from docarray.base_doc.io.json import orjson_dumps_and_decode
108
from docarray.typing import NdArray
11-
from docarray.typing.tensor.abstract_tensor import AbstractTensor
12-
from docarray.utils._internal.pydantic import is_pydantic_v2
139

1410

1511
def test_any_doc():
@@ -94,21 +90,3 @@ class DocTest(BaseDoc):
9490
assert isinstance(d.ld[0], dict)
9591
assert d.ld[0]['text'] == 'I am inner'
9692
assert d.ld[0]['t'] == {'a': 'b'}
97-
98-
99-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
100-
def test_subclass_config():
101-
class MyDoc(BaseDoc):
102-
x: str
103-
104-
class Config(BaseDoc.Config):
105-
arbitrary_types_allowed = True # just an example setting
106-
107-
assert MyDoc.Config.json_loads == orjson.loads
108-
assert MyDoc.Config.json_dumps == orjson_dumps_and_decode
109-
assert (
110-
MyDoc.Config.json_encoders[AbstractTensor](3) == 3
111-
) # dirty check that it is identity
112-
assert MyDoc.Config.validate_assignment
113-
assert not MyDoc.Config._load_extra_fields_from_protobuf
114-
assert MyDoc.Config.arbitrary_types_allowed

tests/units/document/test_base_document.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from typing import Any, List, Optional, Tuple
22

33
import numpy as np
4+
import orjson
45
import pytest
56

67
from docarray import DocList, DocVec
78
from docarray.base_doc.doc import BaseDoc
9+
from docarray.base_doc.io.json import orjson_dumps_and_decode
810
from docarray.typing import NdArray
11+
from docarray.typing.tensor.abstract_tensor import AbstractTensor
12+
from docarray.utils._internal.pydantic import is_pydantic_v2
913

1014

1115
def test_base_document_init():
@@ -146,3 +150,40 @@ class MyDoc(BaseDoc):
146150
field_type = MyDoc._get_field_inner_type("tuple_")
147151

148152
assert field_type == Any
153+
154+
155+
@pytest.mark.skipif(
156+
is_pydantic_v2, reason="syntax only working with pydantic v1 for now"
157+
)
158+
def test_subclass_config():
159+
class MyDoc(BaseDoc):
160+
x: str
161+
162+
class Config(BaseDoc.Config):
163+
arbitrary_types_allowed = True # just an example setting
164+
165+
assert MyDoc.Config.json_loads == orjson.loads
166+
assert MyDoc.Config.json_dumps == orjson_dumps_and_decode
167+
assert (
168+
MyDoc.Config.json_encoders[AbstractTensor](3) == 3
169+
) # dirty check that it is identity
170+
assert MyDoc.Config.validate_assignment
171+
assert not MyDoc.Config._load_extra_fields_from_protobuf
172+
assert MyDoc.Config.arbitrary_types_allowed
173+
174+
175+
@pytest.mark.skipif(not (is_pydantic_v2), reason="syntax only working with pydantic v2")
176+
def test_subclass_config_v2():
177+
class MyDoc(BaseDoc):
178+
x: str
179+
180+
model_config = BaseDoc.ConfigDocArray(
181+
arbitrary_types_allowed=True
182+
) # just an example setting
183+
184+
assert (
185+
MyDoc.model_config['json_encoders'][AbstractTensor](3) == 3
186+
) # dirty check that it is identity
187+
assert MyDoc.model_config['validate_assignment']
188+
assert not MyDoc.model_config['_load_extra_fields_from_protobuf']
189+
assert MyDoc.model_config['arbitrary_types_allowed']

0 commit comments

Comments
 (0)