diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bccbe2f206d..9df8e8a06d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,4 +32,4 @@ repos: args: - -S additional_dependencies: - - black==22.3.0 \ No newline at end of file + - black==22.3.0 diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 4feb576ed76..3c423137259 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -49,12 +49,12 @@ class FindResultBatched(NamedTuple): documents: List[DocList] - scores: np.ndarray + scores: List[np.ndarray] class _FindResultBatched(NamedTuple): documents: Union[List[DocList], List[List[Dict[str, Any]]]] - scores: np.ndarray + scores: List[np.ndarray] def _raise_not_composable(name): @@ -571,7 +571,9 @@ def text_search_batched( if len(da_list) > 0 and isinstance(da_list[0], List): docs = [self._dict_list_to_docarray(docs) for docs in da_list] - return FindResultBatched(documents=docs, scores=scores) + return FindResultBatched(documents=docs, scores=scores) + + return FindResultBatched(documents=da_list, scores=scores) ########################################################## # Helper methods # diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py new file mode 100644 index 00000000000..c54d3e76f47 --- /dev/null +++ b/docarray/index/backends/weaviate.py @@ -0,0 +1,833 @@ +import base64 +import copy +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import ( + Any, + Dict, + Generator, + Generic, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +import numpy as np +import weaviate +from pydantic import parse_obj_as +from typing_extensions import Literal + +import docarray +from docarray import BaseDoc, DocList +from docarray.index.abstract import BaseDocIndex, FindResultBatched, _FindResultBatched +from docarray.typing import AnyTensor +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.typing.tensor.ndarray import NdArray +from docarray.utils.find import FindResult, _FindResult + +TSchema = TypeVar('TSchema', bound=BaseDoc) +T = TypeVar('T', bound='WeaviateDocumentIndex') + + +DEFAULT_BATCH_CONFIG = { + "batch_size": 20, + "dynamic": False, + "timeout_retries": 3, + "num_workers": 1, +} + +DEFAULT_BINARY_PATH = str(Path.home() / ".cache/weaviate-embedded/") +DEFAULT_PERSISTENCE_DATA_PATH = str(Path.home() / ".local/share/weaviate") + + +@dataclass +class EmbeddedOptions: + persistence_data_path: str = os.environ.get( + "XDG_DATA_HOME", DEFAULT_PERSISTENCE_DATA_PATH + ) + binary_path: str = os.environ.get("XDG_CACHE_HOME", DEFAULT_BINARY_PATH) + version: str = "latest" + port: int = 6666 + hostname: str = "127.0.0.1" + additional_env_vars: Optional[Dict[str, str]] = None + + +# TODO: add more types and figure out how to handle text vs string type +# see https://weaviate.io/developers/weaviate/configuration/datatypes +WEAVIATE_PY_VEC_TYPES = [list, np.ndarray, AbstractTensor] +WEAVIATE_PY_TYPES = [bool, int, float, str, docarray.typing.ID] + +# "id" and "_id" are reserved names in weaviate so we need to use a different +# name for the id column in a BaseDocument +DOCUMENTID = "docarrayid" + + +class WeaviateDocumentIndex(BaseDocIndex, Generic[TSchema]): + def __init__(self, db_config=None, **kwargs) -> None: + self.embedding_column: Optional[str] = None + self.properties: Optional[List[str]] = None + # keep track of the column name that contains the bytes + # type because we will store them as a base64 encoded string + # in weaviate + self.bytes_columns: List[str] = [] + # keep track of the array columns that are not embeddings because we will + # convert them to python lists before uploading to weaviate + self.nonembedding_array_columns: List[str] = [] + super().__init__(db_config=db_config, **kwargs) + self._db_config: WeaviateDocumentIndex.DBConfig = cast( + WeaviateDocumentIndex.DBConfig, self._db_config + ) + self._runtime_config: WeaviateDocumentIndex.RuntimeConfig = cast( + WeaviateDocumentIndex.RuntimeConfig, self._runtime_config + ) + + if self._db_config.embedded_options: + self._client = weaviate.Client( + embedded_options=self._db_config.embedded_options + ) + else: + self._client = weaviate.Client( + self._db_config.host, auth_client_secret=self._build_auth_credentials() + ) + + self._configure_client() + self._validate_columns() + self._set_embedding_column() + self._set_properties() + self._create_schema() + + def _set_properties(self) -> None: + field_overwrites = {"id": DOCUMENTID} + + self.properties = [ + field_overwrites.get(k, k) + for k, v in self._column_infos.items() + if v.config.get('is_embedding', False) is False + ] + + def _validate_columns(self) -> None: + # must have at most one column with property is_embedding=True + # and that column must be of type WEAVIATE_PY_VEC_TYPES + # TODO: update when https://github.com/weaviate/weaviate/issues/2424 + # is implemented and discuss best interface to signal which column(s) + # should be used for embeddings + num_embedding_columns = 0 + + for column_name, column_info in self._column_infos.items(): + if column_info.config.get('is_embedding', False): + num_embedding_columns += 1 + # if db_type is not 'number[]', then that means the type of the column in + # the given schema is not one of WEAVIATE_PY_VEC_TYPES + # note: the mapping between a column's type in the schema to a weaviate type + # is handled by the python_type_to_db_type method + if column_info.db_type != 'number[]': + raise ValueError( + f'Column {column_name} is marked as embedding but is not of type {WEAVIATE_PY_VEC_TYPES}' + ) + + if num_embedding_columns > 1: + raise ValueError( + f'Only one column can be marked as embedding but found {num_embedding_columns} columns marked as embedding' + ) + + def _set_embedding_column(self) -> None: + for column_name, column_info in self._column_infos.items(): + if column_info.config.get('is_embedding', False): + self.embedding_column = column_name + break + + def _configure_client(self) -> None: + self._client.batch.configure(**self._runtime_config.batch_config) + + def _build_auth_credentials(self): + dbconfig = self._db_config + + if dbconfig.auth_api_key: + return weaviate.auth.AuthApiKey(api_key=dbconfig.auth_api_key) + elif dbconfig.username and dbconfig.password: + return weaviate.auth.AuthClientPassword( + dbconfig.username, dbconfig.password, dbconfig.scopes + ) + else: + return None + + def configure(self, runtime_config=None, **kwargs) -> None: + super().configure(runtime_config, **kwargs) + self._configure_client() + + def _create_schema(self) -> None: + schema: Dict[str, Any] = {} + + properties = [] + column_infos = self._column_infos + + for column_name, column_info in column_infos.items(): + # in weaviate, we do not create a property for the doc's embeddings + if column_name == self.embedding_column: + continue + if column_info.db_type == 'blob': + self.bytes_columns.append(column_name) + if column_info.db_type == 'number[]': + self.nonembedding_array_columns.append(column_name) + prop = { + "name": column_name + if column_name != 'id' + else DOCUMENTID, # in weaviate, id and _id is a reserved keyword + "dataType": [column_info.db_type], + } + properties.append(prop) + + # TODO: What is the best way to specify other config that is part of schema? + # e.g. invertedIndexConfig, shardingConfig, moduleConfig, vectorIndexConfig + # and configure replication + # we will update base on user feedback + schema["properties"] = properties + schema["class"] = self._db_config.index_name + + # TODO: Use exists() instead of contains() when available + # see https://github.com/weaviate/weaviate-python-client/issues/232 + if self._client.schema.contains(schema): + logging.warning( + f"Found index {self._db_config.index_name} with schema {schema}. Will reuse existing schema." + ) + else: + self._client.schema.create_class(schema) + + @dataclass + class DBConfig(BaseDocIndex.DBConfig): + host: str = 'http://localhost:8080' + index_name: str = 'Document' + username: Optional[str] = None + password: Optional[str] = None + scopes: List[str] = field(default_factory=lambda: ["offline_access"]) + auth_api_key: Optional[str] = None + embedded_options: Optional[EmbeddedOptions] = None + + @dataclass + class RuntimeConfig(BaseDocIndex.RuntimeConfig): + default_column_config: Dict[Any, Dict[str, Any]] = field( + default_factory=lambda: { + np.ndarray: {}, + docarray.typing.ID: {}, + 'string': {}, + 'text': {}, + 'int': {}, + 'number': {}, + 'boolean': {}, + 'number[]': {}, + 'blob': {}, + } + ) + + batch_config: Dict[str, Any] = field( + default_factory=lambda: DEFAULT_BATCH_CONFIG + ) + + def _del_items(self, doc_ids: Sequence[str]): + has_matches = True + + operands = [ + {"path": [DOCUMENTID], "operator": "Equal", "valueString": doc_id} + for doc_id in doc_ids + ] + where_filter = { + "operator": "Or", + "operands": operands, + } + + # do a loop because there is a limit to how many objects can be deleted at + # in a single query + # see: https://weaviate.io/developers/weaviate/api/rest/batch#maximum-number-of-deletes-per-query + while has_matches: + results = self._client.batch.delete_objects( + class_name=self._db_config.index_name, + where=where_filter, + ) + + has_matches = results["results"]["matches"] + + def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: + self._overwrite_id(filter_query) + + results = ( + self._client.query.get(self._db_config.index_name, self.properties) + .with_additional("vector") + .with_where(filter_query) + .with_limit(limit) + .do() + ) + + docs = results["data"]["Get"][self._db_config.index_name] + + return [self._parse_weaviate_result(doc) for doc in docs] + + def _filter_batched( + self, filter_queries: Any, limit: int + ) -> Union[List[DocList], List[List[Dict]]]: + for filter_query in filter_queries: + self._overwrite_id(filter_query) + + qs = [ + self._client.query.get(self._db_config.index_name, self.properties) + .with_additional("vector") + .with_where(filter_query) + .with_limit(limit) + .with_alias(f'query_{i}') + for i, filter_query in enumerate(filter_queries) + ] + + batched_results = self._client.query.multi_get(qs).do() + + return [ + [self._parse_weaviate_result(doc) for doc in batched_result] + for batched_result in batched_results["data"]["Get"].values() + ] + + def find( + self, + query: Union[AnyTensor, BaseDoc], + search_field: str = '', + limit: int = 10, + **kwargs, + ): + self._logger.debug('Executing `find`') + if search_field != '': + raise ValueError( + 'Argument search_field is not supported for WeaviateDocumentIndex.\nSet search_field to an empty string to proceed.' + ) + embedding_field = self._get_embedding_field() + if isinstance(query, BaseDoc): + query_vec = self._get_values_by_column([query], embedding_field)[0] + else: + query_vec = query + query_vec_np = self._to_numpy(query_vec) + docs, scores = self._find( + query_vec_np, search_field=search_field, limit=limit, **kwargs + ) + + if isinstance(docs, List): + docs = self._dict_list_to_docarray(docs) + + return FindResult(documents=docs, scores=scores) + + def _overwrite_id(self, where_filter): + """ + Overwrite the id field in the where filter to DOCUMENTID + if the "id" field is present in the path + """ + for key, value in where_filter.items(): + if key == "path" and value == ["id"]: + where_filter[key] = [DOCUMENTID] + elif isinstance(value, dict): + self._overwrite_id(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + self._overwrite_id(item) + + def _find( + self, + query: np.ndarray, + limit: int, + search_field: str = '', + score_name: Literal["certainty", "distance"] = "certainty", + score_threshold: Optional[float] = None, + ) -> _FindResult: + index_name = self._db_config.index_name + if search_field: + logging.warning( + 'Argument search_field is not supported for WeaviateDocumentIndex. Ignoring.' + ) + near_vector: Dict[str, Any] = { + "vector": query, + } + if score_threshold: + near_vector[score_name] = score_threshold + + results = ( + self._client.query.get(index_name, self.properties) + .with_near_vector( + near_vector, + ) + .with_limit(limit) + .with_additional([score_name, "vector"]) + .do() + ) + + docs, scores = self._format_response( + results["data"]["Get"][index_name], score_name + ) + return _FindResult(docs, parse_obj_as(NdArray, scores)) + + def _format_response( + self, results, score_name + ) -> Tuple[List[Dict[Any, Any]], List[Any]]: + """ + Format the response from Weaviate into a Tuple of DocList and scores + """ + + documents = [] + scores = [] + + for result in results: + score = result["_additional"][score_name] + scores.append(score) + + document = self._parse_weaviate_result(result) + documents.append(document) + + return documents, scores + + def find_batched( + self, + queries: Union[AnyTensor, DocList], + search_field: str = '', + limit: int = 10, + **kwargs, + ) -> FindResultBatched: + self._logger.debug('Executing `find_batched`') + if search_field != '': + raise ValueError( + 'Argument search_field is not supported for WeaviateDocumentIndex.\nSet search_field to an empty string to proceed.' + ) + embedding_field = self._get_embedding_field() + + if isinstance(queries, Sequence): + query_vec_list = self._get_values_by_column(queries, embedding_field) + query_vec_np = np.stack( + tuple(self._to_numpy(query_vec) for query_vec in query_vec_list) + ) + else: + query_vec_np = self._to_numpy(queries) + + da_list, scores = self._find_batched( + query_vec_np, search_field=search_field, limit=limit, **kwargs + ) + + if len(da_list) > 0 and isinstance(da_list[0], List): + da_list = [self._dict_list_to_docarray(docs) for docs in da_list] + + return FindResultBatched(documents=da_list, scores=scores) # type: ignore + + def _find_batched( + self, + queries: np.ndarray, + limit: int, + search_field: str = '', + score_name: Literal["certainty", "distance"] = "certainty", + score_threshold: Optional[float] = None, + ) -> _FindResultBatched: + qs = [] + for i, query in enumerate(queries): + near_vector: Dict[str, Any] = {"vector": query} + + if score_threshold: + near_vector[score_name] = score_threshold + + q = ( + self._client.query.get(self._db_config.index_name, self.properties) + .with_near_vector(near_vector) + .with_limit(limit) + .with_additional([score_name, "vector"]) + .with_alias(f'query_{i}') + ) + + qs.append(q) + + results = self._client.query.multi_get(qs).do() + + docs_and_scores = [ + self._format_response(result, score_name) + for result in results["data"]["Get"].values() + ] + + docs, scores = zip(*docs_and_scores) + return _FindResultBatched(list(docs), list(scores)) + + def _get_items(self, doc_ids: Sequence[str]) -> List[Dict]: + # TODO: warn when doc_ids > QUERY_MAXIMUM_RESULTS after + # https://github.com/weaviate/weaviate/issues/2792 + # is implemented + operands = [ + {"path": [DOCUMENTID], "operator": "Equal", "valueString": doc_id} + for doc_id in doc_ids + ] + where_filter = { + "operator": "Or", + "operands": operands, + } + + results = ( + self._client.query.get(self._db_config.index_name, self.properties) + .with_where(where_filter) + .with_additional("vector") + .do() + ) + + docs = [ + self._parse_weaviate_result(doc) + for doc in results["data"]["Get"][self._db_config.index_name] + ] + + return docs + + def _rewrite_documentid(self, document: Dict): + doc = document.copy() + + # rewrite the id to DOCUMENTID + document_id = doc.pop('id') + doc[DOCUMENTID] = document_id + + return doc + + def _parse_weaviate_result(self, result: Dict) -> Dict: + """ + Parse the result from weaviate to a format that is compatible with the schema + that was used to initialize weaviate with. + """ + + result = result.copy() + + # rewrite the DOCUMENTID to id + if DOCUMENTID in result: + result['id'] = result.pop(DOCUMENTID) + + # take the vector from the _additional field + if '_additional' in result and self.embedding_column: + additional_fields = result.pop('_additional') + if 'vector' in additional_fields: + result[self.embedding_column] = additional_fields['vector'] + + # convert any base64 encoded bytes column to bytes + self._decode_base64_properties_to_bytes(result) + + return result + + def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): + docs = self._transpose_col_value_dict(column_to_data) + index_name = self._db_config.index_name + + with self._client.batch as batch: + for doc in docs: + parsed_doc = self._rewrite_documentid(doc) + self._encode_bytes_columns_to_base64(parsed_doc) + self._convert_nonembedding_array_to_list(parsed_doc) + vector = ( + parsed_doc.pop(self.embedding_column) + if self.embedding_column + else None + ) + + batch.add_data_object( + uuid=weaviate.util.generate_uuid5(parsed_doc, index_name), + data_object=parsed_doc, + class_name=index_name, + vector=vector, + ) + + def _text_search( + self, query: str, limit: int, search_field: str = '' + ) -> _FindResult: + index_name = self._db_config.index_name + bm25 = {"query": query, "properties": [search_field]} + + results = ( + self._client.query.get(index_name, self.properties) + .with_bm25(bm25) + .with_limit(limit) + .with_additional(["score", "vector"]) + .do() + ) + + docs, scores = self._format_response( + results["data"]["Get"][index_name], "score" + ) + + return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores)) + + def _text_search_batched( + self, queries: Sequence[str], limit: int, search_field: str = '' + ) -> _FindResultBatched: + qs = [] + for i, query in enumerate(queries): + bm25 = {"query": query, "properties": [search_field]} + + q = ( + self._client.query.get(self._db_config.index_name, self.properties) + .with_bm25(bm25) + .with_limit(limit) + .with_additional(["score", "vector"]) + .with_alias(f'query_{i}') + ) + + qs.append(q) + + results = self._client.query.multi_get(qs).do() + + docs_and_scores = [ + self._format_response(result, "score") + for result in results["data"]["Get"].values() + ] + + docs, scores = zip(*docs_and_scores) + return _FindResultBatched(list(docs), list(scores)) + + def execute_query(self, query: Any, *args, **kwargs) -> Any: + da_class = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema)) + + if isinstance(query, self.QueryBuilder): + batched_results = self._client.query.multi_get(query._queries).do() + batched_docs = batched_results["data"]["Get"].values() + + def f(doc): + # TODO: use + # return self._schema(**self._parse_weaviate_result(doc)) + # when https://github.com/weaviate/weaviate/issues/2858 + # is fixed + return self._schema.from_view(self._parse_weaviate_result(doc)) # type: ignore + + results = [ + da_class([f(doc) for doc in batched_doc]) + for batched_doc in batched_docs + ] + return results if len(results) > 1 else results[0] + + # TODO: validate graphql query string before sending it to weaviate + if isinstance(query, str): + return self._client.query.raw(query) + + def num_docs(self) -> int: + index_name = self._db_config.index_name + result = self._client.query.aggregate(index_name).with_meta_count().do() + # TODO: decorator to check for errors + total_docs = result["data"]["Aggregate"][index_name][0]["meta"]["count"] + + return total_docs + + def python_type_to_db_type(self, python_type: Type) -> Any: + """Map python type to database type.""" + for allowed_type in WEAVIATE_PY_VEC_TYPES: + if issubclass(python_type, allowed_type): + return 'number[]' + + py_weaviate_type_map = { + docarray.typing.ID: 'string', + str: 'text', + int: 'int', + float: 'number', + bool: 'boolean', + np.ndarray: 'number[]', + bytes: 'blob', + } + + for py_type, weaviate_type in py_weaviate_type_map.items(): + if issubclass(python_type, py_type): + return weaviate_type + + raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') + + def build_query(self) -> BaseDocIndex.QueryBuilder: + return self.QueryBuilder(self) + + def _get_embedding_field(self): + for colname, colinfo in self._column_infos.items(): + # no need to check for missing is_embedding attribute because this check + # is done when the index is created + if colinfo.config.get('is_embedding', None): + return colname + + # just to pass mypy + return "" + + def _encode_bytes_columns_to_base64(self, doc): + for column in self.bytes_columns: + if doc[column] is not None: + doc[column] = base64.b64encode(doc[column]).decode("utf-8") + + def _decode_base64_properties_to_bytes(self, doc): + for column in self.bytes_columns: + if doc[column] is not None: + doc[column] = base64.b64decode(doc[column]) + + def _convert_nonembedding_array_to_list(self, doc): + for column in self.nonembedding_array_columns: + if doc[column] is not None: + doc[column] = doc[column].tolist() + + class QueryBuilder(BaseDocIndex.QueryBuilder): + def __init__(self, document_index): + self._queries = [ + document_index._client.query.get( + document_index._db_config.index_name, document_index.properties + ) + ] + + def build(self) -> Any: + num_queries = len(self._queries) + + for i in range(num_queries): + q = self._queries[i] + if self._is_hybrid_query(q): + self._make_proper_hybrid_query(q) + q.with_additional(["vector"]).with_alias(f'query_{i}') + + return self + + def _is_hybrid_query(self, query: weaviate.gql.get.GetBuilder) -> bool: + """ + Checks if a query has been composed with both a with_bm25 and a with_near_vector verb + """ + if not query._near_ask: + return False + else: + return query._bm25 and query._near_ask._content.get("vector", None) + + def _make_proper_hybrid_query( + self, query: weaviate.gql.get.GetBuilder + ) -> weaviate.gql.get.GetBuilder: + """ + Modifies a query to be a proper hybrid query. + + In weaviate, a query with with_bm25 and with_near_vector verb is not a hybrid query. + We need to use the with_hybrid verb to make it a hybrid query. + """ + + text_query = query._bm25.query + vector_query = query._near_ask._content["vector"] + hybrid_query = weaviate.gql.get.Hybrid( + query=text_query, vector=vector_query, alpha=0.5 + ) + + query._bm25 = None + query._near_ask = None + query._hybrid = hybrid_query + + def _overwrite_id(self, where_filter): + """ + Overwrite the id field in the where filter to DOCUMENTID + if the "id" field is present in the path + """ + for key, value in where_filter.items(): + if key == "path" and value == ["id"]: + where_filter[key] = [DOCUMENTID] + elif isinstance(value, dict): + self._overwrite_id(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + self._overwrite_id(item) + + def find( + self, + query, + score_name: Literal["certainty", "distance"] = "certainty", + score_threshold: Optional[float] = None, + ) -> Any: + near_vector = { + "vector": query, + } + if score_threshold: + near_vector[score_name] = score_threshold + + self._queries[0] = self._queries[0].with_near_vector(near_vector) + return self + + def find_batched( + self, + queries, + score_name: Literal["certainty", "distance"] = "certainty", + score_threshold: Optional[float] = None, + ) -> Any: + adj_queries, adj_clauses = self._resize_queries_and_clauses( + self._queries, queries + ) + new_queries = [] + + for query, clause in zip(adj_queries, adj_clauses): + near_vector = { + "vector": clause, + } + if score_threshold: + near_vector[score_name] = score_threshold + + new_queries.append(query.with_near_vector(near_vector)) + + self._queries = new_queries + + return self + + def filter(self, where_filter) -> Any: + where_filter = where_filter.copy() + self._overwrite_id(where_filter) + self._queries[0] = self._queries[0].with_where(where_filter) + return self + + def filter_batched(self, filters) -> Any: + adj_queries, adj_clauses = self._resize_queries_and_clauses( + self._queries, filters + ) + new_queries = [] + + for query, clause in zip(adj_queries, adj_clauses): + clause = clause.copy() + self._overwrite_id(clause) + new_queries.append(query.with_where(clause)) + + self._queries = new_queries + + return self + + def text_search(self, query, search_field) -> Any: + bm25 = {"query": query, "properties": [search_field]} + self._queries[0] = self._queries[0].with_bm25(**bm25) + return self + + def text_search_batched(self, queries, search_field) -> Any: + adj_queries, adj_clauses = self._resize_queries_and_clauses( + self._queries, queries + ) + new_queries = [] + + for query, clause in zip(adj_queries, adj_clauses): + bm25 = {"query": clause, "properties": [search_field]} + new_queries.append(query.with_bm25(**bm25)) + + self._queries = new_queries + + return self + + def limit(self, limit: int) -> Any: + self._queries = [query.with_limit(limit) for query in self._queries] + return self + + def _resize_queries_and_clauses(self, queries, clauses): + """ + Adjust the length and content of queries and clauses so that we can compose + them element-wise + """ + num_clauses = len(clauses) + num_queries = len(queries) + + # if there's only one clause, then we assume that it should be applied + # to every query + if num_clauses == 1: + return queries, clauses * num_queries + # if there's only one query, then we can lengthen it to match the number + # of clauses + elif num_queries == 1: + return [copy.deepcopy(queries[0]) for _ in range(num_clauses)], clauses + # if the number of queries and clauses is the same, then we can just + # return them as-is + elif num_clauses == num_queries: + return queries, clauses + else: + raise ValueError( + f"Can't compose {num_clauses} clauses with {num_queries} queries" + ) diff --git a/poetry.lock b/poetry.lock index cd46e05c897..398a9ec992d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -264,6 +264,21 @@ docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +[[package]] +name = "authlib" +version = "1.2.0" +description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "Authlib-1.2.0-py2.py3-none-any.whl", hash = "sha256:4ddf4fd6cfa75c9a460b361d4bd9dac71ffda0be879dbe4292a02e92349ad55a"}, + {file = "Authlib-1.2.0.tar.gz", hash = "sha256:4fa3e80883a5915ef9f5bc28630564bc4ed5b5af39812a3ff130ec76bd631e9d"}, +] + +[package.dependencies] +cryptography = ">=3.2" + [[package]] name = "av" version = "10.0.0" @@ -525,7 +540,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." -category = "dev" +category = "main" optional = false python-versions = "*" files = [ @@ -698,6 +713,48 @@ files = [ [package.extras] test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] +[[package]] +name = "cryptography" +version = "40.0.1" +description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "cryptography-40.0.1-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:918cb89086c7d98b1b86b9fdb70c712e5a9325ba6f7d7cfb509e784e0cfc6917"}, + {file = "cryptography-40.0.1-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:9618a87212cb5200500e304e43691111570e1f10ec3f35569fdfcd17e28fd797"}, + {file = "cryptography-40.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4805a4ca729d65570a1b7cac84eac1e431085d40387b7d3bbaa47e39890b88"}, + {file = "cryptography-40.0.1-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63dac2d25c47f12a7b8aa60e528bfb3c51c5a6c5a9f7c86987909c6c79765554"}, + {file = "cryptography-40.0.1-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:0a4e3406cfed6b1f6d6e87ed243363652b2586b2d917b0609ca4f97072994405"}, + {file = "cryptography-40.0.1-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1e0af458515d5e4028aad75f3bb3fe7a31e46ad920648cd59b64d3da842e4356"}, + {file = "cryptography-40.0.1-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:d8aa3609d337ad85e4eb9bb0f8bcf6e4409bfb86e706efa9a027912169e89122"}, + {file = "cryptography-40.0.1-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cf91e428c51ef692b82ce786583e214f58392399cf65c341bc7301d096fa3ba2"}, + {file = "cryptography-40.0.1-cp36-abi3-win32.whl", hash = "sha256:650883cc064297ef3676b1db1b7b1df6081794c4ada96fa457253c4cc40f97db"}, + {file = "cryptography-40.0.1-cp36-abi3-win_amd64.whl", hash = "sha256:a805a7bce4a77d51696410005b3e85ae2839bad9aa38894afc0aa99d8e0c3160"}, + {file = "cryptography-40.0.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cd033d74067d8928ef00a6b1327c8ea0452523967ca4463666eeba65ca350d4c"}, + {file = "cryptography-40.0.1-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d36bbeb99704aabefdca5aee4eba04455d7a27ceabd16f3b3ba9bdcc31da86c4"}, + {file = "cryptography-40.0.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:32057d3d0ab7d4453778367ca43e99ddb711770477c4f072a51b3ca69602780a"}, + {file = "cryptography-40.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:f5d7b79fa56bc29580faafc2ff736ce05ba31feaa9d4735048b0de7d9ceb2b94"}, + {file = "cryptography-40.0.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7c872413353c70e0263a9368c4993710070e70ab3e5318d85510cc91cce77e7c"}, + {file = "cryptography-40.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:28d63d75bf7ae4045b10de5413fb1d6338616e79015999ad9cf6fc538f772d41"}, + {file = "cryptography-40.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6f2bbd72f717ce33100e6467572abaedc61f1acb87b8d546001328d7f466b778"}, + {file = "cryptography-40.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cc3a621076d824d75ab1e1e530e66e7e8564e357dd723f2533225d40fe35c60c"}, + {file = "cryptography-40.0.1.tar.gz", hash = "sha256:2803f2f8b1e95f614419926c7e6f55d828afc614ca5ed61543877ae668cc3472"}, +] + +[package.dependencies] +cffi = ">=1.12" + +[package.extras] +docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"] +docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"] +pep8test = ["black", "check-manifest", "mypy", "ruff"] +sdist = ["setuptools-rust (>=0.11.4)"] +ssh = ["bcrypt (>=3.1.5)"] +test = ["iso8601", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-shard (>=0.1.2)", "pytest-subtests", "pytest-xdist"] +test-randomorder = ["pytest-randomly"] +tox = ["tox"] + [[package]] name = "debugpy" version = "1.6.3" @@ -730,7 +787,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" -category = "dev" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -3004,7 +3061,7 @@ validation = ["lxml"] name = "pycparser" version = "2.21" description = "C parser in Python" -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3541,25 +3598,25 @@ files = [ [[package]] name = "requests" -version = "2.27.1" +version = "2.28.2" description = "Python HTTP for Humans." category = "main" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +python-versions = ">=3.7, <4" files = [ - {file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"}, - {file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"}, + {file = "requests-2.28.2-py3-none-any.whl", hash = "sha256:64299f4909223da747622c030b781c0d7811e359c37124b4bd368fb8c6518baa"}, + {file = "requests-2.28.2.tar.gz", hash = "sha256:98b1b2782e3c6c4904938b84c0eb932721069dfdb9134313beff7c83c2df24bf"}, ] [package.dependencies] certifi = ">=2017.4.17" -charset-normalizer = {version = ">=2.0.0,<2.1.0", markers = "python_version >= \"3\""} -idna = {version = ">=2.5,<4", markers = "python_version >= \"3\""} +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" urllib3 = ">=1.21.1,<1.27" [package.extras] -socks = ["PySocks (>=1.5.6,!=1.5.7)", "win-inet-pton"] -use-chardet-on-py3 = ["chardet (>=3.0.2,<5)"] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rfc3986" @@ -4073,6 +4130,27 @@ files = [ {file = "tornado-6.2.tar.gz", hash = "sha256:9b630419bde84ec666bfd7ea0a4cb2a8a651c2d5cccdbdd1972a0c859dfc3c13"}, ] +[[package]] +name = "tqdm" +version = "4.65.0" +description = "Fast, Extensible Progress Meter" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"}, + {file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["py-make (>=0.1.0)", "twine", "wheel"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "traitlets" version = "5.5.0" @@ -4275,6 +4353,23 @@ typing-extensions = {version = "*", markers = "python_version < \"3.8\""} [package.extras] standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.0)"] +[[package]] +name = "validators" +version = "0.20.0" +description = "Python Data Validation for Humans™." +category = "main" +optional = false +python-versions = ">=3.4" +files = [ + {file = "validators-0.20.0.tar.gz", hash = "sha256:24148ce4e64100a2d5e267233e23e7afeb55316b47d30faae7eb6e7292bc226a"}, +] + +[package.dependencies] +decorator = ">=3.4.0" + +[package.extras] +test = ["flake8 (>=2.4.0)", "isort (>=4.2.2)", "pytest (>=2.2.3)"] + [[package]] name = "virtualenv" version = "20.16.7" @@ -4365,6 +4460,24 @@ files = [ {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, ] +[[package]] +name = "weaviate-client" +version = "3.15.5" +description = "A python native weaviate client" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "weaviate-client-3.15.5.tar.gz", hash = "sha256:6da7e5d08dc9bb8b7879661d1a457c50af7d73e621a5305efe131160e83da69e"}, + {file = "weaviate_client-3.15.5-py3-none-any.whl", hash = "sha256:24d0be614e5494534e758cc67a45e7e15f3929a89bf512afd642de53d08723c7"}, +] + +[package.dependencies] +authlib = ">=1.1.0" +requests = ">=2.28.0,<2.29.0" +tqdm = ">=4.59.0,<5.0.0" +validators = ">=0.18.2,<=0.21.0" + [[package]] name = "webencodings" version = "0.5.1" @@ -4625,14 +4738,14 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [extras] audio = ["pydub"] aws = ["smart-open"] -elasticsearch = ["elasticsearch", "elastic-transport"] -full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh"] +elasticsearch = ["elastic-transport", "elasticsearch"] +full = ["av", "lz4", "pandas", "pillow", "protobuf", "pydub", "trimesh", "types-pillow"] hnswlib = ["hnswlib"] image = ["pillow", "types-pillow"] jac = ["jina-hubble-sdk"] mesh = ["trimesh"] pandas = ["pandas"] -proto = ["protobuf", "lz4"] +proto = ["lz4", "protobuf"] torch = ["torch"] video = ["av"] web = ["fastapi"] @@ -4640,4 +4753,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.7,<4.0" -content-hash = "a5bae8ca8239347d066e7566dfea56f08d42950f7037e50870cee226809f4b01" +content-hash = "5a07acb92ae45bc42e49e68af897444874d6facd4ed81af4bd9e8d37d7737037" diff --git a/pyproject.toml b/pyproject.toml index ecc72c74719..2b5bc301296 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ hnswlib = {version = ">=0.6.2", optional = true } lz4 = {version= ">=1.0.0", optional = true} pydub = {version = "^0.25.1", optional = true } pandas = {version = ">=1.1.0", optional = true } +weaviate-client = {version = ">=3.15", extras = ["weaviate"]} elasticsearch = {version = ">=7.10.1", optional = true } smart-open = {version = ">=6.3.0", extras = ["s3"], optional = true} jina-hubble-sdk = {version = ">=0.34.0", optional = true} @@ -92,6 +93,7 @@ module = [ "trimesh", "pandas", "av", + "weaviate" ] ignore_missing_imports = true diff --git a/tests/integrations/doc_index/weaviate/docker-compose.yml b/tests/integrations/doc_index/weaviate/docker-compose.yml new file mode 100644 index 00000000000..5cca1e722eb --- /dev/null +++ b/tests/integrations/doc_index/weaviate/docker-compose.yml @@ -0,0 +1,27 @@ +version: '3.8' + +services: + + weaviate: + command: + - --host + - 0.0.0.0 + - --port + - '8080' + - --scheme + - http + image: semitechnologies/weaviate:1.18.3 + ports: + - "8080:8080" + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + ENABLE_MODULES: '' + CLUSTER_HOSTNAME: 'node1' + LOG_LEVEL: debug # verbose + LOG_FORMAT: text + # LOG_LEVEL: trace # very verbose + GODEBUG: gctrace=1 # make go garbage collector verbose \ No newline at end of file diff --git a/tests/integrations/doc_index/weaviate/fixture_weaviate.py b/tests/integrations/doc_index/weaviate/fixture_weaviate.py new file mode 100644 index 00000000000..786a92b2a00 --- /dev/null +++ b/tests/integrations/doc_index/weaviate/fixture_weaviate.py @@ -0,0 +1,41 @@ +import os +import time + +import pytest +import requests +import weaviate + +HOST = "http://localhost:8080" + + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +weaviate_yml = os.path.abspath(os.path.join(cur_dir, 'docker-compose.yml')) + + +@pytest.fixture(scope='session', autouse=True) +def start_storage(): + os.system(f"docker-compose -f {weaviate_yml} up -d --remove-orphans") + _wait_for_weaviate() + + yield + os.system(f"docker-compose -f {weaviate_yml} down --remove-orphans") + + +def _wait_for_weaviate(): + while True: + try: + response = requests.get(f"{HOST}/v1/.well-known/ready") + if response.status_code == 200: + return + else: + time.sleep(0.5) + except requests.exceptions.ConnectionError: + time.sleep(1) + + +@pytest.fixture +def weaviate_client(start_storage): + client = weaviate.Client(HOST) + client.schema.delete_all() + yield client + client.schema.delete_all() diff --git a/tests/integrations/doc_index/weaviate/test_column_config_weaviate.py b/tests/integrations/doc_index/weaviate/test_column_config_weaviate.py new file mode 100644 index 00000000000..a3050e9b9ba --- /dev/null +++ b/tests/integrations/doc_index/weaviate/test_column_config_weaviate.py @@ -0,0 +1,33 @@ +# TODO: enable ruff qa on this file when we figure out why it thinks weaviate_client is +# redefined at each test that fixture +# ruff: noqa +from pydantic import Field + +from docarray import BaseDoc +from docarray.index.backends.weaviate import WeaviateDocumentIndex +from tests.integrations.doc_index.weaviate.fixture_weaviate import ( # noqa: F401 + start_storage, + weaviate_client, +) + + +def test_column_config(weaviate_client): + def get_text_field_data_type(store, index_name): + props = store._client.schema.get(index_name)["properties"] + text_field = [p for p in props if p["name"] == "text"][0] + + return text_field["dataType"][0] + + class TextDoc(BaseDoc): + text: str = Field() + + class StringDoc(BaseDoc): + text: str = Field(col_type="string") + + dbconfig = WeaviateDocumentIndex.DBConfig(index_name="TextDoc") + store = WeaviateDocumentIndex[TextDoc](db_config=dbconfig) + assert get_text_field_data_type(store, "TextDoc") == "text" + + dbconfig = WeaviateDocumentIndex.DBConfig(index_name="StringDoc") + store = WeaviateDocumentIndex[StringDoc](db_config=dbconfig) + assert get_text_field_data_type(store, "StringDoc") == "string" diff --git a/tests/integrations/doc_index/weaviate/test_find_weaviate.py b/tests/integrations/doc_index/weaviate/test_find_weaviate.py new file mode 100644 index 00000000000..c54d167c634 --- /dev/null +++ b/tests/integrations/doc_index/weaviate/test_find_weaviate.py @@ -0,0 +1,66 @@ +# TODO: enable ruff qa on this file when we figure out why it thinks weaviate_client is +# redefined at each test that fixture +# ruff: noqa +import numpy as np +import pytest +import torch +from pydantic import Field + +from docarray import BaseDoc +from docarray.index.backends.weaviate import WeaviateDocumentIndex +from docarray.typing import TorchTensor +from tests.integrations.doc_index.weaviate.fixture_weaviate import ( # noqa: F401 + start_storage, + weaviate_client, +) + + +def test_find_torch(weaviate_client): + class TorchDoc(BaseDoc): + tens: TorchTensor[10] = Field(dims=10, is_embedding=True) + + store = WeaviateDocumentIndex[TorchDoc]() + + index_docs = [ + TorchDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) + ] + store.index(index_docs) + + query = index_docs[-1] + docs, scores = store.find(query, limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + for doc in docs: + assert isinstance(doc.tens, TorchTensor) + + assert docs[0].id == index_docs[-1].id + assert torch.allclose(docs[0].tens, index_docs[-1].tens) + + +@pytest.mark.tensorflow +def test_find_tensorflow(): + from docarray.typing import TensorFlowTensor + + class TfDoc(BaseDoc): + tens: TensorFlowTensor[10] = Field(dims=10, is_embedding=True) + + store = WeaviateDocumentIndex[TfDoc]() + + index_docs = [ + TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) + ] + store.index(index_docs) + + query = index_docs[-1] + docs, scores = store.find(query, limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + for doc in docs: + assert isinstance(doc.tens, TensorFlowTensor) + + assert docs[0].id == index_docs[-1].id + assert np.allclose( + docs[0].tens.unwrap().numpy(), index_docs[-1].tens.unwrap().numpy() + ) diff --git a/tests/integrations/doc_index/weaviate/test_index_get_del_weaviate.py b/tests/integrations/doc_index/weaviate/test_index_get_del_weaviate.py new file mode 100644 index 00000000000..e9c218d45a4 --- /dev/null +++ b/tests/integrations/doc_index/weaviate/test_index_get_del_weaviate.py @@ -0,0 +1,452 @@ +# TODO: enable ruff qa on this file when we figure out why it thinks weaviate_client is +# redefined at each test that fixture +# ruff: noqa +import logging + +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.documents import ImageDoc, TextDoc +from docarray.index.backends.weaviate import ( + DOCUMENTID, + EmbeddedOptions, + WeaviateDocumentIndex, +) +from docarray.typing import NdArray +from tests.integrations.doc_index.weaviate.fixture_weaviate import ( # noqa: F401 + HOST, + start_storage, + weaviate_client, +) + + +class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(dim=1000, is_embedding=True) + + +class Document(BaseDoc): + embedding: NdArray[2] = Field(dim=2, is_embedding=True) + text: str = Field() + + +class NestedDocument(BaseDoc): + text: str = Field() + child: Document + + +@pytest.fixture +def ten_simple_docs(): + return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)] + + +@pytest.fixture +def documents(): + texts = ["lorem ipsum", "dolor sit amet", "consectetur adipiscing elit"] + embeddings = [[10, 10], [10.5, 10.5], [-100, -100]] + + # create the docs by enumerating from 1 and use that as the id + docs = [ + Document(id=str(i), embedding=embedding, text=text) + for i, (embedding, text) in enumerate(zip(embeddings, texts)) + ] + + yield docs + + +@pytest.fixture +def test_store(weaviate_client, documents): + store = WeaviateDocumentIndex[Document]() + store.index(documents) + yield store + + +def test_index_simple_schema(weaviate_client, ten_simple_docs): + store = WeaviateDocumentIndex[SimpleDoc]() + store.index(ten_simple_docs) + assert store.num_docs() == 10 + + for doc in ten_simple_docs: + doc_id = doc.id + doc_embedding = doc.tens + + result = ( + weaviate_client.query.get("Document", DOCUMENTID) + .with_additional("vector") + .with_where( + {"path": [DOCUMENTID], "operator": "Equal", "valueString": doc_id} + ) + .do() + ) + + result = result["data"]["Get"]["Document"][0] + assert result[DOCUMENTID] == doc_id + assert np.allclose(result["_additional"]["vector"], doc_embedding) + + +def test_validate_columns(weaviate_client): + dbconfig = WeaviateDocumentIndex.DBConfig(host=HOST) + + class InvalidDoc1(BaseDoc): + tens: NdArray[10] = Field(dim=1000, is_embedding=True) + tens2: NdArray[10] = Field(dim=1000, is_embedding=True) + + class InvalidDoc2(BaseDoc): + tens: int = Field(dim=1000, is_embedding=True) + + with pytest.raises(ValueError, match=r"Only one column can be marked as embedding"): + WeaviateDocumentIndex[InvalidDoc1](db_config=dbconfig) + + with pytest.raises(ValueError, match=r"marked as embedding but is not of type"): + WeaviateDocumentIndex[InvalidDoc2](db_config=dbconfig) + + +def test_find(weaviate_client, caplog): + class Document(BaseDoc): + embedding: NdArray[2] = Field(dim=2, is_embedding=True) + + vectors = [[10, 10], [10.5, 10.5], [-100, -100]] + docs = [Document(embedding=vector) for vector in vectors] + + store = WeaviateDocumentIndex[Document]() + store.index(docs) + + query = [10.1, 10.1] + + results = store.find( + query, search_field='', limit=3, score_name="distance", score_threshold=1e-2 + ) + assert len(results) == 2 + + results = store.find(query, search_field='', limit=3, score_threshold=0.99) + assert len(results) == 2 + + with pytest.raises( + ValueError, + match=r"Argument search_field is not supported for WeaviateDocumentIndex", + ): + store.find(query, search_field="foo", limit=10) + + +def test_find_batched(weaviate_client, caplog): + class Document(BaseDoc): + embedding: NdArray[2] = Field(dim=2, is_embedding=True) + + vectors = [[10, 10], [10.5, 10.5], [-100, -100]] + docs = [Document(embedding=vector) for vector in vectors] + + store = WeaviateDocumentIndex[Document]() + store.index(docs) + + queries = np.array([[10.1, 10.1], [-100, -100]]) + + results = store.find_batched( + queries, search_field='', limit=3, score_name="distance", score_threshold=1e-2 + ) + assert len(results) == 2 + assert len(results.documents[0]) == 2 + assert len(results.documents[1]) == 1 + + results = store.find_batched( + queries, search_field='', limit=3, score_name="certainty" + ) + assert len(results) == 2 + assert len(results.documents[0]) == 3 + assert len(results.documents[1]) == 3 + + with pytest.raises( + ValueError, + match=r"Argument search_field is not supported for WeaviateDocumentIndex", + ): + store.find_batched(queries, search_field="foo", limit=10) + + +@pytest.mark.parametrize( + "filter_query, expected_num_docs", + [ + ({"path": ["text"], "operator": "Equal", "valueText": "lorem ipsum"}, 1), + ({"path": ["text"], "operator": "Equal", "valueText": "foo"}, 0), + ({"path": ["id"], "operator": "Equal", "valueString": "1"}, 1), + ], +) +def test_filter(test_store, filter_query, expected_num_docs): + docs = test_store.filter(filter_query, limit=3) + actual_num_docs = len(docs) + + assert actual_num_docs == expected_num_docs + + +@pytest.mark.parametrize( + "filter_queries, expected_num_docs", + [ + ( + [ + {"path": ["text"], "operator": "Equal", "valueText": "lorem ipsum"}, + {"path": ["text"], "operator": "Equal", "valueText": "foo"}, + ], + [1, 0], + ), + ( + [ + {"path": ["id"], "operator": "Equal", "valueString": "1"}, + {"path": ["id"], "operator": "Equal", "valueString": "2"}, + ], + [1, 0], + ), + ], +) +def test_filter_batched(test_store, filter_queries, expected_num_docs): + filter_queries = [ + {"path": ["text"], "operator": "Equal", "valueText": "lorem ipsum"}, + {"path": ["text"], "operator": "Equal", "valueText": "foo"}, + ] + + results = test_store.filter_batched(filter_queries, limit=3) + actual_num_docs = [len(docs) for docs in results] + assert actual_num_docs == expected_num_docs + + +def test_text_search(test_store): + results = test_store.text_search(query="lorem", search_field="text", limit=3) + assert len(results.documents) == 1 + + +def test_text_search_batched(test_store): + text_queries = ["lorem", "foo"] + + results = test_store.text_search_batched( + queries=text_queries, search_field="text", limit=3 + ) + assert len(results.documents[0]) == 1 + assert len(results.documents[1]) == 0 + + +def test_del_items(test_store): + del test_store[["1", "2"]] + assert test_store.num_docs() == 1 + + +def test_get_items(test_store): + docs = test_store[["1", "2"]] + assert len(docs) == 2 + assert set(doc.id for doc in docs) == {'1', '2'} + + +def test_index_nested_documents(weaviate_client): + store = WeaviateDocumentIndex[NestedDocument]() + document = NestedDocument( + text="lorem ipsum", child=Document(embedding=[10, 10], text="dolor sit amet") + ) + store.index([document]) + assert store.num_docs() == 1 + + +@pytest.mark.parametrize( + "search_field, query, expected_num_docs", + [ + ("text", "lorem", 1), + ("child__text", "dolor", 1), + ("text", "foo", 0), + ("child__text", "bar", 0), + ], +) +def test_text_search_nested_documents( + weaviate_client, search_field, query, expected_num_docs +): + store = WeaviateDocumentIndex[NestedDocument]() + document = NestedDocument( + text="lorem ipsum", child=Document(embedding=[10, 10], text="dolor sit amet") + ) + store.index([document]) + + results = store.text_search(query=query, search_field=search_field, limit=3) + + assert len(results.documents) == expected_num_docs + + +def test_reuse_existing_schema(weaviate_client, caplog): + WeaviateDocumentIndex[SimpleDoc]() + + with caplog.at_level(logging.DEBUG): + WeaviateDocumentIndex[SimpleDoc]() + assert "Will reuse existing schema" in caplog.text + + +def test_query_builder(test_store): + query_embedding = [10.25, 10.25] + query_text = "ipsum" + where_filter = {"path": ["id"], "operator": "Equal", "valueString": "1"} + q = ( + test_store.build_query() + .find(query=query_embedding) + .filter(where_filter) + .build() + ) + + docs = test_store.execute_query(q) + assert len(docs) == 1 + + q = ( + test_store.build_query() + .text_search(query=query_text, search_field="text") + .build() + ) + + docs = test_store.execute_query(q) + assert len(docs) == 1 + + +def test_batched_query_builder(test_store): + query_embeddings = [[10.25, 10.25], [-100, -100]] + query_texts = ["ipsum", "foo"] + where_filters = [{"path": ["id"], "operator": "Equal", "valueString": "1"}] + + q = ( + test_store.build_query() + .find_batched( + queries=query_embeddings, score_name="certainty", score_threshold=0.99 + ) + .filter_batched(filters=where_filters) + .build() + ) + + docs = test_store.execute_query(q) + assert len(docs[0]) == 1 + assert len(docs[1]) == 0 + + q = ( + test_store.build_query() + .text_search_batched(queries=query_texts, search_field="text") + .build() + ) + + docs = test_store.execute_query(q) + assert len(docs[0]) == 1 + assert len(docs[1]) == 0 + + +def test_raw_graphql(test_store): + graphql_query = """ + { + Aggregate { + Document { + meta { + count + } + } + } + } + """ + + results = test_store.execute_query(graphql_query) + num_docs = results["data"]["Aggregate"]["Document"][0]["meta"]["count"] + + assert num_docs == 3 + + +def test_hybrid_query(test_store): + query_embedding = [10.25, 10.25] + query_text = "ipsum" + where_filter = {"path": ["id"], "operator": "Equal", "valueString": "1"} + + q = ( + test_store.build_query() + .find(query=query_embedding) + .text_search(query=query_text, search_field="text") + .filter(where_filter) + .build() + ) + + docs = test_store.execute_query(q) + assert len(docs) == 1 + + +def test_hybrid_query_batched(test_store): + query_embeddings = [[10.25, 10.25], [-100, -100]] + query_texts = ["dolor", "elit"] + + q = ( + test_store.build_query() + .find_batched( + queries=query_embeddings, score_name="certainty", score_threshold=0.99 + ) + .text_search_batched(queries=query_texts, search_field="text") + .build() + ) + + docs = test_store.execute_query(q) + assert docs[0][0].id == '1' + assert docs[1][0].id == '2' + + +def test_index_multi_modal_doc(): + class MyMultiModalDoc(BaseDoc): + image: ImageDoc + text: TextDoc + + store = WeaviateDocumentIndex[MyMultiModalDoc]() + + doc = [ + MyMultiModalDoc( + image=ImageDoc(embedding=np.random.randn(128)), text=TextDoc(text='hello') + ) + ] + store.index(doc) + + id_ = doc[0].id + assert store[id_].id == id_ + assert np.all(store[id_].image.embedding == doc[0].image.embedding) + assert store[id_].text.text == doc[0].text.text + + +def test_index_document_with_bytes(weaviate_client): + doc = ImageDoc(id="1", url="www.foo.com", bytes_=b"foo") + + store = WeaviateDocumentIndex[ImageDoc]() + store.index([doc]) + + results = store.filter( + filter_query={"path": ["id"], "operator": "Equal", "valueString": "1"} + ) + + assert doc == results[0] + + +def test_index_document_with_no_embeddings(weaviate_client): + # define a document that does not have any field where is_embedding=True + class Document(BaseDoc): + not_embedding: NdArray[2] = Field(dim=2) + text: str + + doc = Document(not_embedding=[2, 5], text="dolor sit amet", id="1") + + store = WeaviateDocumentIndex[Document]() + + store.index([doc]) + + results = store.filter( + filter_query={"path": ["id"], "operator": "Equal", "valueString": "1"} + ) + + assert doc == results[0] + + +def test_limit_query_builder(test_store): + query_vector = [10.25, 10.25] + q = test_store.build_query().find(query=query_vector).limit(2) + + docs = test_store.execute_query(q) + assert len(docs) == 2 + + +@pytest.mark.linux +def test_embedded_weaviate(): + class Document(BaseDoc): + text: str + + embedded_options = EmbeddedOptions() + db_config = WeaviateDocumentIndex.DBConfig(embedded_options=embedded_options) + store = WeaviateDocumentIndex[Document](db_config=db_config) + + assert store._client._connection.embedded_db