From ac95df838fccebb45c53dc8ac131ab6e1872dbc2 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Mon, 3 Apr 2023 10:51:22 +0200 Subject: [PATCH 1/4] refactor: dummy change Signed-off-by: jupyterjazz --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8d4b45ae264..18f8b4113bb 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ DocArray is a library for **representing, sending and storing multi-modal data**, with a focus on applications in **ML** and **Neural Search**. -This means that DocArray lets you do the following things: +This means that `DocArray` lets you do the following things: ## Represent From bf595ba4117f9df4b37d344e610618a1778b5b55 Mon Sep 17 00:00:00 2001 From: Sriniketh J <81156510+srini047@users.noreply.github.com> Date: Mon, 3 Apr 2023 14:25:17 +0530 Subject: [PATCH 2/4] feat: add validate search field method (#1319) Co-authored-by: Saba Sturua <45267439+jupyterjazz@users.noreply.github.com> --- docarray/index/abstract.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 11c130086b4..fbf1a8485d3 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -815,6 +815,16 @@ def _validate_docs( ) return DocArray[BaseDoc].construct(out_docs) + + def _validate_search_field(self, search_field): + if search_field in self._column_infos.keys(): + return True + else: + raise ValueError( + {search_field} + + 'is not a valid search field. Valid search fields are: ' + + {', '.join(self._column_infos.keys())} + ) def _to_numpy(self, val: Any, allow_passthrough=False) -> Any: """ From 5a2e7d5de57666a0efe2fb93b7dda80266a13b9b Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Mon, 3 Apr 2023 11:16:03 +0200 Subject: [PATCH 3/4] feat: validate search fields Signed-off-by: jupyterjazz --- docarray/index/abstract.py | 27 ++++++++++++++----- .../index/base_classes/test_base_doc_store.py | 14 ++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index fbf1a8485d3..734fac06728 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -415,6 +415,7 @@ def find( :return: a named tuple containing `documents` and `scores` """ self._logger.debug(f'Executing `find` for search field {search_field}') + self._validate_search_field(search_field) if isinstance(query, BaseDoc): query_vec = self._get_values_by_column([query], search_field)[0] else: @@ -449,6 +450,7 @@ def find_batched( :return: a named tuple containing `documents` and `scores` """ self._logger.debug(f'Executing `find_batched` for search field {search_field}') + self._validate_search_field(search_field) if isinstance(queries, Sequence): query_vec_list = self._get_values_by_column(queries, search_field) query_vec_np = np.stack( @@ -523,6 +525,7 @@ def text_search( :return: a named tuple containing `documents` and `scores` """ self._logger.debug(f'Executing `text_search` for search field {search_field}') + self._validate_search_field(search_field) if isinstance(query, BaseDoc): query_text = self._get_values_by_column([query], search_field)[0] else: @@ -553,6 +556,7 @@ def text_search_batched( self._logger.debug( f'Executing `text_search_batched` for search field {search_field}' ) + self._validate_search_field(search_field) if isinstance(queries[0], BaseDoc): query_docs: Sequence[BaseDoc] = cast(Sequence[BaseDoc], queries) query_texts: Sequence[str] = self._get_values_by_column( @@ -815,15 +819,26 @@ def _validate_docs( ) return DocArray[BaseDoc].construct(out_docs) - - def _validate_search_field(self, search_field): - if search_field in self._column_infos.keys(): + + def _validate_search_field(self, search_field: Union[str, None]) -> bool: + """ + Validate if the given `search_field` corresponds to one of the + columns that was parsed from the schema. + + Some backends, like weaviate, don't use search fields, so the function + returns True if `search_field` is empty or None. + + :param search_field: search field to validate. + :return: True if the field exists, False otherwise. + """ + if not search_field or search_field in self._column_infos.keys(): + if not search_field: + self._logger.info('Empty search field was passed') return True else: + valid_search_fields = ', '.join(self._column_infos.keys()) raise ValueError( - {search_field} - + 'is not a valid search field. Valid search fields are: ' - + {', '.join(self._column_infos.keys())} + f'{search_field} is not a valid search field. Valid search fields are: {valid_search_fields}' ) def _to_numpy(self, val: Any, allow_passthrough=False) -> Any: diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index b5774020524..a6d8fdf4e79 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -558,3 +558,17 @@ class MyDoc2(BaseDoc): doc = store._convert_dict_to_doc(doc_dict, store._schema) assert doc.id == doc_dict_copy['id'] assert np.all(doc.tens == doc_dict_copy['tens']) + + +def test_validate_search_fields(): + store = DummyDocIndex[SimpleDoc]() + assert list(store._column_infos.keys()) == ['id', 'tens'] + + # 'tens' is a valid field + assert store._validate_search_field(search_field='tens') + # should not fail when an empty string or None is passed + assert store._validate_search_field(search_field='') + store._validate_search_field(search_field=None) + # 'ten' is not a valid field + with pytest.raises(ValueError): + store._validate_search_field('ten') From f7d6c8e6203acd17db50d0e87948444846884472 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Mon, 3 Apr 2023 11:26:10 +0200 Subject: [PATCH 4/4] refactor: validate fields in es Signed-off-by: jupyterjazz --- README.md | 2 +- docarray/index/backends/elastic.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 18f8b4113bb..8d4b45ae264 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ DocArray is a library for **representing, sending and storing multi-modal data**, with a focus on applications in **ML** and **Neural Search**. -This means that `DocArray` lets you do the following things: +This means that DocArray lets you do the following things: ## Represent diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index deefc3b2a86..08c29c150d2 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -136,6 +136,7 @@ def find( search_field: str = 'embedding', limit: int = 10, ): + self._outer_instance._validate_search_field(search_field) if isinstance(query, BaseDoc): query_vec = BaseDocIndex._get_values_by_column([query], search_field)[0] else: @@ -154,6 +155,7 @@ def filter(self, query: Dict[str, Any], limit: int = 10): return self def text_search(self, query: str, search_field: str = 'text', limit: int = 10): + self._outer_instance._validate_search_field(search_field) self._query['size'] = limit self._query['query']['bool']['must'].append( {'match': {search_field: query}} @@ -261,7 +263,6 @@ def _index( refresh: bool = True, chunk_size: Optional[int] = None, ): - data = self._transpose_col_value_dict(column_to_data) requests = [] @@ -420,7 +421,6 @@ def _text_search( limit: int, search_field: str = '', ) -> _FindResult: - body = self._form_text_search_body(query, limit, search_field) resp = self._client.search(