diff --git a/docarray/array/storage/milvus/backend.py b/docarray/array/storage/milvus/backend.py index e09a24c093b..3ec8da1bc6a 100644 --- a/docarray/array/storage/milvus/backend.py +++ b/docarray/array/storage/milvus/backend.py @@ -18,6 +18,7 @@ from docarray import Document, DocumentArray from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap from docarray.helper import dataclass_from_dict, _safe_cast_int +from docarray.score import NamedScore if TYPE_CHECKING: from docarray.typing import ( @@ -266,16 +267,15 @@ def _docs_from_query_response(response): return DocumentArray([Document.from_base64(d['serialized']) for d in response]) @staticmethod - def _docs_from_search_response( - responses, - ) -> 'List[DocumentArray]': + def _docs_from_search_response(responses, distance: str) -> 'List[DocumentArray]': das = [] for r in responses: - das.append( - DocumentArray( - [Document.from_base64(hit.entity.get('serialized')) for hit in r] - ) - ) + da = [] + for hit in r: + doc = Document.from_base64(hit.entity.get('serialized')) + doc.scores[distance] = NamedScore(value=hit.score) + da.append(doc) + das.append(DocumentArray(da)) return das def _update_kwargs_from_config(self, field_to_update, **kwargs): diff --git a/docarray/array/storage/milvus/find.py b/docarray/array/storage/milvus/find.py index fa9f42709ad..3bff6f598e9 100644 --- a/docarray/array/storage/milvus/find.py +++ b/docarray/array/storage/milvus/find.py @@ -42,7 +42,7 @@ def _find( output_fields=['serialized'], **kwargs, ) - return self._docs_from_search_response(results) + return self._docs_from_search_response(results, distance=self._config.distance) def _filter(self, filter, limit=10, **kwargs): kwargs = self._update_kwargs_from_config('consistency_level', **kwargs) diff --git a/docs/advanced/document-store/extend.md b/docs/advanced/document-store/extend.md index 912efa41910..a65d5ac32bb 100644 --- a/docs/advanced/document-store/extend.md +++ b/docs/advanced/document-store/extend.md @@ -289,6 +289,7 @@ class FindMixin: ... ``` +Make sure to store the distance scores in the `.scores` dictionary of the Documents that are being returned with the `distance` value as key. ## Step 6: summarize everything in `__init__.py`. diff --git a/docs/advanced/document-store/milvus.md b/docs/advanced/document-store/milvus.md index 6fdff41eb0f..e37f60b6001 100644 --- a/docs/advanced/document-store/milvus.md +++ b/docs/advanced/document-store/milvus.md @@ -253,6 +253,21 @@ Embeddings Nearest Neighbours with "price" at most 7: embedding=[5. 5. 5.], price=5 embedding=[4. 4. 4.], price=4 ``` + +You can access the scores as follows: + +````python +for doc in results: + print(f"score = {doc.scores[distance].value}") +```` + +``` +score = 3.0 +score = 12.0 +score = 27.0 +score = 48.0 +``` + ### Example of `.find` with only a filter The following example shows how to use DocArray with Milvus Document Store in order to filter text documents. diff --git a/tests/unit/array/storage/milvus/test_milvus.py b/tests/unit/array/storage/milvus/test_milvus.py index 42ac17173ea..48c05fb8958 100644 --- a/tests/unit/array/storage/milvus/test_milvus.py +++ b/tests/unit/array/storage/milvus/test_milvus.py @@ -73,6 +73,8 @@ def test_memory_cntxt_mngr(start_storage): @pytest.fixture() def mock_response(): class MockHit: + score = 1.0 + @property def entity(self): return {'serialized': Document().to_base64()}