From 8dcd33eea8b506cfb0f47b6bf64df79a3814ee8b Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Tue, 11 Jan 2022 16:37:54 +0100 Subject: [PATCH] fix: fix match score --- docarray/array/mixins/match.py | 2 +- tests/unit/array/mixins/test_match.py | 31 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/docarray/array/mixins/match.py b/docarray/array/mixins/match.py index 930454e18e5..6a8ae335df6 100644 --- a/docarray/array/mixins/match.py +++ b/docarray/array/mixins/match.py @@ -121,7 +121,7 @@ def match( if only_id: d = Document(id=rhv[_id].id) else: - d = rhv[int(_id)] # type: Document + d = Document(rhv[int(_id)], copy=True) # type: Document if d.id in lhv: d = Document( diff --git a/tests/unit/array/mixins/test_match.py b/tests/unit/array/mixins/test_match.py index 43592259593..86a2b0ceae6 100644 --- a/tests/unit/array/mixins/test_match.py +++ b/tests/unit/array/mixins/test_match.py @@ -515,3 +515,34 @@ def test_diff_framework_match(ndarray_val): da = DocumentArray.empty(10) da.embeddings = ndarray_val da.match(da) + + +def test_match_ensure_scores_unique(): + import numpy as np + from docarray import DocumentArray + + da1 = DocumentArray.empty(4) + da1.embeddings = np.array( + [[0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [1, 1, 1, 1, 0], [1, 2, 2, 1, 0]] + ) + + da2 = DocumentArray.empty(5) + da2.embeddings = np.array( + [ + [0.0, 0.1, 0.0, 0.0, 0.0], + [1.0, 0.1, 0.0, 0.0, 0.0], + [1.0, 1.2, 1.0, 1.0, 0.0], + [1.0, 2.2, 2.0, 1.0, 0.0], + [4.0, 5.2, 2.0, 1.0, 0.0], + ] + ) + + da1.match(da2, metric='euclidean', only_id=False, limit=5) + + assert len(da1) == 4 + for query in da1: + previous_score = -10000 + assert len(query.matches) == 5 + for m in query.matches: + assert m.scores['euclidean'].value >= previous_score + previous_score = m.scores['euclidean'].value