From c75d6e0d640235d68691ab463deb96e48a71163a Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 27 Apr 2023 16:21:15 +0200 Subject: [PATCH 1/4] fix: return doclist of same type as input index in find and findbatched Signed-off-by: anna-charlotte --- docarray/utils/find.py | 5 +++-- tests/units/util/test_find.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docarray/utils/find.py b/docarray/utils/find.py index 49ffb922de5..a75a4b45802 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -213,10 +213,11 @@ class MyDocument(BaseDoc): for _, (indices_per_query, scores_per_query) in enumerate( zip(top_indices, top_scores) ): - docs_per_query: DocList = DocList([]) + doc_type = cast(Type[BaseDoc], index.doc_type) + docs_per_query: DocList = DocList.__class_getitem__(doc_type)() for idx in indices_per_query: # workaround until #930 is fixed docs_per_query.append(index[idx]) - batched_docs.append(DocList(docs_per_query)) + batched_docs.append(docs_per_query) scores.append(scores_per_query) return FindResultBatched(documents=batched_docs, scores=scores) diff --git a/tests/units/util/test_find.py b/tests/units/util/test_find.py index deebea7835a..3a63dae9c75 100644 --- a/tests/units/util/test_find.py +++ b/tests/units/util/test_find.py @@ -58,6 +58,8 @@ def test_find_torch(random_torch_query, random_torch_index, metric): ) assert len(top_k) == 7 assert len(scores) == 7 + assert top_k.doc_type == random_torch_index.doc_type + if metric.endswith('_dist'): assert (torch.stack(sorted(scores)) == scores).all() else: From 3740d8f0b9fad8181a5d46d76afbd9c7edbbe883 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 27 Apr 2023 16:26:34 +0200 Subject: [PATCH 2/4] test: add assertion to find_batched test Signed-off-by: anna-charlotte --- tests/units/util/test_find.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/units/util/test_find.py b/tests/units/util/test_find.py index 3a63dae9c75..11cab69c312 100644 --- a/tests/units/util/test_find.py +++ b/tests/units/util/test_find.py @@ -153,6 +153,8 @@ def test_find_batched_torch(random_torch_batch_query, random_torch_index, metric for top_k, top_scores in zip(documents, scores): assert len(top_k) == 7 assert len(top_scores) == 7 + assert top_k.doc_type == random_torch_index.doc_type + for sc in scores: if metric.endswith('_dist'): assert (torch.stack(sorted(sc)) == sc).all() From a1dbfdac165335e9f3f61fccb1e41edb04176112 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 27 Apr 2023 17:05:09 +0200 Subject: [PATCH 3/4] fix: index to int Signed-off-by: anna-charlotte --- docarray/utils/find.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/utils/find.py b/docarray/utils/find.py index a75a4b45802..c1fc3fc3077 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -210,13 +210,13 @@ class MyDocument(BaseDoc): batched_docs: List[DocList] = [] scores = [] - for _, (indices_per_query, scores_per_query) in enumerate( + for i, (indices_per_query, scores_per_query) in enumerate( zip(top_indices, top_scores) ): doc_type = cast(Type[BaseDoc], index.doc_type) docs_per_query: DocList = DocList.__class_getitem__(doc_type)() for idx in indices_per_query: # workaround until #930 is fixed - docs_per_query.append(index[idx]) + docs_per_query.append(index[int(idx)]) batched_docs.append(docs_per_query) scores.append(scores_per_query) return FindResultBatched(documents=batched_docs, scores=scores) From af9827fe6591697a49c03553a6ba7909c71edfad Mon Sep 17 00:00:00 2001 From: Charlotte Gerhaher Date: Thu, 27 Apr 2023 17:30:12 +0200 Subject: [PATCH 4/4] fix: apply suggestions from code review Co-authored-by: Joan Fontanals Signed-off-by: Charlotte Gerhaher --- docarray/utils/find.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/utils/find.py b/docarray/utils/find.py index c1fc3fc3077..f522a78f297 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -210,7 +210,7 @@ class MyDocument(BaseDoc): batched_docs: List[DocList] = [] scores = [] - for i, (indices_per_query, scores_per_query) in enumerate( + for _, (indices_per_query, scores_per_query) in enumerate( zip(top_indices, top_scores) ): doc_type = cast(Type[BaseDoc], index.doc_type)