From 523875cea7f40a38835be153f3252775d362ace7 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Wed, 29 Mar 2023 14:34:46 +0200 Subject: [PATCH] docs: fix docstring example of find_batched Signed-off-by: Johannes Messner --- docarray/utils/find.py | 66 +++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/docarray/utils/find.py b/docarray/utils/find.py index c7eeb787159..a626134d1b6 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -130,39 +130,39 @@ def find_batched( --- ```python - # from docarray import DocArray, BaseDoc - # from docarray.typing import TorchTensor - # from docarray.utils.find import find - # import torch - # - # - # class MyDocument(BaseDoc): - # embedding: TorchTensor - # - # - # index = DocArray[MyDocument]( - # [MyDocument(embedding=torch.rand(128)) for _ in range(100)] - # ) - # - # # use DocArray as query - # query = DocArray[MyDocument]([MyDocument(embedding=torch.rand(128)) for _ in range(3)]) - # results = find( - # index=index, - # query=query, - # embedding_field='embedding', - # metric='cosine_sim', - # ) - # top_matches, scores = results[0] - # - # # use tensor as query - # query = torch.rand(3, 128) - # results, scores = find( - # index=index, - # query=query, - # embedding_field='embedding', - # metric='cosine_sim', - # ) - # top_matches, scores = results[0] + from docarray import DocArray, BaseDoc + from docarray.typing import TorchTensor + from docarray.utils.find import find_batched + import torch + + + class MyDocument(BaseDoc): + embedding: TorchTensor + + + index = DocArray[MyDocument]( + [MyDocument(embedding=torch.rand(128)) for _ in range(100)] + ) + + # use DocArray as query + query = DocArray[MyDocument]([MyDocument(embedding=torch.rand(128)) for _ in range(3)]) + results = find_batched( + index=index, + query=query, + embedding_field='embedding', + metric='cosine_sim', + ) + top_matches, scores = results[0] + + # use tensor as query + query = torch.rand(3, 128) + results = find_batched( + index=index, + query=query, + embedding_field='embedding', + metric='cosine_sim', + ) + top_matches, scores = results[0] ``` ---