diff --git a/docarray/document/pydantic_model.py b/docarray/document/pydantic_model.py index b332c578a7c..6661de63f2c 100644 --- a/docarray/document/pydantic_model.py +++ b/docarray/document/pydantic_model.py @@ -59,6 +59,9 @@ def _blob2base64(cls, v): else: raise ValueError('must be bytes') + class Config: + smart_union = True + PydanticDocument.update_forward_refs() diff --git a/tests/unit/test_pydantic.py b/tests/unit/test_pydantic.py index 074f162852f..174caafedad 100644 --- a/tests/unit/test_pydantic.py +++ b/tests/unit/test_pydantic.py @@ -123,7 +123,13 @@ def test_with_embedding_no_tensor(): @pytest.mark.parametrize( 'tag_value, tag_type', - [(3, float), (3.4, float), ('hello', str), (True, bool), (False, bool)], + [ + (3.0, float), + ('hello', str), + ('1', str), + (True, bool), + (False, bool), + ], ) @pytest.mark.parametrize('protocol', ['protobuf', 'jsonschema']) def test_tags_int_float_str_bool(tag_type, tag_value, protocol):