diff --git a/docarray/utils/find.py b/docarray/utils/find.py index 49ffb922de5..f522a78f297 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -213,10 +213,11 @@ class MyDocument(BaseDoc): for _, (indices_per_query, scores_per_query) in enumerate( zip(top_indices, top_scores) ): - docs_per_query: DocList = DocList([]) + doc_type = cast(Type[BaseDoc], index.doc_type) + docs_per_query: DocList = DocList.__class_getitem__(doc_type)() for idx in indices_per_query: # workaround until #930 is fixed - docs_per_query.append(index[idx]) - batched_docs.append(DocList(docs_per_query)) + docs_per_query.append(index[int(idx)]) + batched_docs.append(docs_per_query) scores.append(scores_per_query) return FindResultBatched(documents=batched_docs, scores=scores) diff --git a/tests/units/util/test_find.py b/tests/units/util/test_find.py index deebea7835a..11cab69c312 100644 --- a/tests/units/util/test_find.py +++ b/tests/units/util/test_find.py @@ -58,6 +58,8 @@ def test_find_torch(random_torch_query, random_torch_index, metric): ) assert len(top_k) == 7 assert len(scores) == 7 + assert top_k.doc_type == random_torch_index.doc_type + if metric.endswith('_dist'): assert (torch.stack(sorted(scores)) == scores).all() else: @@ -151,6 +153,8 @@ def test_find_batched_torch(random_torch_batch_query, random_torch_index, metric for top_k, top_scores in zip(documents, scores): assert len(top_k) == 7 assert len(top_scores) == 7 + assert top_k.doc_type == random_torch_index.doc_type + for sc in scores: if metric.endswith('_dist'): assert (torch.stack(sorted(sc)) == sc).all()