diff --git a/docarray/document/pydantic_model.py b/docarray/document/pydantic_model.py index 596fa568b78..b332c578a7c 100644 --- a/docarray/document/pydantic_model.py +++ b/docarray/document/pydantic_model.py @@ -20,6 +20,13 @@ def _convert_ndarray_to_list(v: 'ArrayType'): return to_list(v) +class _NamedScore(BaseModel): + value: Optional[float] = None + op_name: Optional[str] = None + description: Optional[str] = None + ref_id: Optional[str] = None + + class PydanticDocument(BaseModel): id: str parent_id: Optional[str] @@ -36,8 +43,8 @@ class PydanticDocument(BaseModel): location: Optional[List[float]] embedding: Optional[Any] modality: Optional[str] - evaluations: Optional[Dict[str, Dict[str, '_StructValueType']]] - scores: Optional[Dict[str, Dict[str, '_StructValueType']]] + evaluations: Optional[Dict[str, '_NamedScore']] + scores: Optional[Dict[str, '_NamedScore']] chunks: Optional[List['PydanticDocument']] matches: Optional[List['PydanticDocument']] diff --git a/tests/unit/test_pydantic.py b/tests/unit/test_pydantic.py index cdd8a87da4d..074f162852f 100644 --- a/tests/unit/test_pydantic.py +++ b/tests/unit/test_pydantic.py @@ -111,6 +111,7 @@ def test_match_to_from_pydantic(): dap = da.to_pydantic_model() da_r = DocumentArray.from_pydantic_model(dap) assert da_r[0].matches[0].scores['cosine'] + assert isinstance(da_r[0].matches[0].scores['cosine'], NamedScore) assert isinstance(da_r[0].matches[0].scores, defaultdict) assert isinstance(da_r[0].matches[0].scores['random_score'], NamedScore)