diff --git a/docarray/utils/find.py b/docarray/utils/find.py index c7eeb787159..a626134d1b6 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -130,39 +130,39 @@ def find_batched( --- ```python - # from docarray import DocArray, BaseDoc - # from docarray.typing import TorchTensor - # from docarray.utils.find import find - # import torch - # - # - # class MyDocument(BaseDoc): - # embedding: TorchTensor - # - # - # index = DocArray[MyDocument]( - # [MyDocument(embedding=torch.rand(128)) for _ in range(100)] - # ) - # - # # use DocArray as query - # query = DocArray[MyDocument]([MyDocument(embedding=torch.rand(128)) for _ in range(3)]) - # results = find( - # index=index, - # query=query, - # embedding_field='embedding', - # metric='cosine_sim', - # ) - # top_matches, scores = results[0] - # - # # use tensor as query - # query = torch.rand(3, 128) - # results, scores = find( - # index=index, - # query=query, - # embedding_field='embedding', - # metric='cosine_sim', - # ) - # top_matches, scores = results[0] + from docarray import DocArray, BaseDoc + from docarray.typing import TorchTensor + from docarray.utils.find import find_batched + import torch + + + class MyDocument(BaseDoc): + embedding: TorchTensor + + + index = DocArray[MyDocument]( + [MyDocument(embedding=torch.rand(128)) for _ in range(100)] + ) + + # use DocArray as query + query = DocArray[MyDocument]([MyDocument(embedding=torch.rand(128)) for _ in range(3)]) + results = find_batched( + index=index, + query=query, + embedding_field='embedding', + metric='cosine_sim', + ) + top_matches, scores = results[0] + + # use tensor as query + query = torch.rand(3, 128) + results = find_batched( + index=index, + query=query, + embedding_field='embedding', + metric='cosine_sim', + ) + top_matches, scores = results[0] ``` ---