diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 11c130086b4..734fac06728 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -415,6 +415,7 @@ def find( :return: a named tuple containing `documents` and `scores` """ self._logger.debug(f'Executing `find` for search field {search_field}') + self._validate_search_field(search_field) if isinstance(query, BaseDoc): query_vec = self._get_values_by_column([query], search_field)[0] else: @@ -449,6 +450,7 @@ def find_batched( :return: a named tuple containing `documents` and `scores` """ self._logger.debug(f'Executing `find_batched` for search field {search_field}') + self._validate_search_field(search_field) if isinstance(queries, Sequence): query_vec_list = self._get_values_by_column(queries, search_field) query_vec_np = np.stack( @@ -523,6 +525,7 @@ def text_search( :return: a named tuple containing `documents` and `scores` """ self._logger.debug(f'Executing `text_search` for search field {search_field}') + self._validate_search_field(search_field) if isinstance(query, BaseDoc): query_text = self._get_values_by_column([query], search_field)[0] else: @@ -553,6 +556,7 @@ def text_search_batched( self._logger.debug( f'Executing `text_search_batched` for search field {search_field}' ) + self._validate_search_field(search_field) if isinstance(queries[0], BaseDoc): query_docs: Sequence[BaseDoc] = cast(Sequence[BaseDoc], queries) query_texts: Sequence[str] = self._get_values_by_column( @@ -816,6 +820,27 @@ def _validate_docs( return DocArray[BaseDoc].construct(out_docs) + def _validate_search_field(self, search_field: Union[str, None]) -> bool: + """ + Validate if the given `search_field` corresponds to one of the + columns that was parsed from the schema. + + Some backends, like weaviate, don't use search fields, so the function + returns True if `search_field` is empty or None. + + :param search_field: search field to validate. + :return: True if the field exists, False otherwise. + """ + if not search_field or search_field in self._column_infos.keys(): + if not search_field: + self._logger.info('Empty search field was passed') + return True + else: + valid_search_fields = ', '.join(self._column_infos.keys()) + raise ValueError( + f'{search_field} is not a valid search field. Valid search fields are: {valid_search_fields}' + ) + def _to_numpy(self, val: Any, allow_passthrough=False) -> Any: """ Converts a value to a numpy array, if possible. diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index deefc3b2a86..08c29c150d2 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -136,6 +136,7 @@ def find( search_field: str = 'embedding', limit: int = 10, ): + self._outer_instance._validate_search_field(search_field) if isinstance(query, BaseDoc): query_vec = BaseDocIndex._get_values_by_column([query], search_field)[0] else: @@ -154,6 +155,7 @@ def filter(self, query: Dict[str, Any], limit: int = 10): return self def text_search(self, query: str, search_field: str = 'text', limit: int = 10): + self._outer_instance._validate_search_field(search_field) self._query['size'] = limit self._query['query']['bool']['must'].append( {'match': {search_field: query}} @@ -261,7 +263,6 @@ def _index( refresh: bool = True, chunk_size: Optional[int] = None, ): - data = self._transpose_col_value_dict(column_to_data) requests = [] @@ -420,7 +421,6 @@ def _text_search( limit: int, search_field: str = '', ) -> _FindResult: - body = self._form_text_search_body(query, limit, search_field) resp = self._client.search( diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index b5774020524..a6d8fdf4e79 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -558,3 +558,17 @@ class MyDoc2(BaseDoc): doc = store._convert_dict_to_doc(doc_dict, store._schema) assert doc.id == doc_dict_copy['id'] assert np.all(doc.tens == doc_dict_copy['tens']) + + +def test_validate_search_fields(): + store = DummyDocIndex[SimpleDoc]() + assert list(store._column_infos.keys()) == ['id', 'tens'] + + # 'tens' is a valid field + assert store._validate_search_field(search_field='tens') + # should not fail when an empty string or None is passed + assert store._validate_search_field(search_field='') + store._validate_search_field(search_field=None) + # 'ten' is not a valid field + with pytest.raises(ValueError): + store._validate_search_field('ten')