diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9df8e8a06d..23993cc072 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: exclude: ^(docarray/proto/pb/docarray_pb2.py|docarray/proto/pb/docarray_pb2.py|docs/|docarray/resources/) - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.243 + rev: v0.0.250 hooks: - id: ruff diff --git a/README.md b/README.md index 79202079e0..06acc4f516 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ DocArray is a Python library expertly crafted for the [representation](#represen - :fire: Offers native support for **[NumPy](https://github.com/numpy/numpy)**, **[PyTorch](https://github.com/pytorch/pytorch)**, **[TensorFlow](https://github.com/tensorflow/tensorflow)**, and **[JAX](https://github.com/google/jax)**, catering specifically to **model training scenarios**. - :zap: Based on **[Pydantic](https://github.com/pydantic/pydantic)**, and instantly compatible with web and microservice frameworks like **[FastAPI](https://github.com/tiangolo/fastapi/)** and **[Jina](https://github.com/jina-ai/jina/)**. -- :package: Provides support for vector databases such as **[Weaviate](https://weaviate.io/), [Qdrant](https://qdrant.tech/), [ElasticSearch](https://www.elastic.co/de/elasticsearch/), [Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**. +- :package: Provides support for vector databases such as **[Weaviate](https://weaviate.io/), [Qdrant](https://qdrant.tech/), [ElasticSearch](https://www.elastic.co/de/elasticsearch/), **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**. - :chains: Allows data transmission as JSON over **HTTP** or as **[Protobuf](https://protobuf.dev/)** over **[gRPC](https://grpc.io/)**. ## Installation @@ -350,7 +350,7 @@ This is useful for: - :mag: **Neural search** applications - :bulb: **Recommender systems** -Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come! +Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come! The Document Index interface lets you index and retrieve Documents from multiple vector databases, all with the same user interface. @@ -421,7 +421,7 @@ They are now called **Document Indexes** and offer the following improvements (s - **Production-ready:** The new Document Indexes are a much thinner wrapper around the various vector DB libraries, making them more robust and easier to maintain - **Increased flexibility:** We strive to support any configuration or setting that you could perform through the DB's first-party client -For now, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, Exact Nearest Neighbour search and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come. +For now, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, Exact Nearest Neighbour search and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come. @@ -844,6 +844,7 @@ Currently, DocArray supports the following vector databases: - [Milvus](https://milvus.io) - ExactNNMemorySearch as a local alternative with exact kNN search. - [HNSWlib](https://github.com/nmslib/hnswlib) as a local-first ANN alternative +- [Mongo Atlas](https://www.mongodb.com/) An integration of [OpenSearch](https://opensearch.org/) is currently in progress. @@ -874,6 +875,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings() + # Define a document schema class MovieDoc(BaseDoc): title: str @@ -903,6 +905,7 @@ from docarray.index import ( QdrantDocumentIndex, ElasticDocIndex, RedisDocumentIndex, + MongoDBAtlasDocumentIndex, ) # Select a suitable backend and initialize it with data diff --git a/docarray/index/__init__.py b/docarray/index/__init__.py index 72596cd73a..aa20ff5db8 100644 --- a/docarray/index/__init__.py +++ b/docarray/index/__init__.py @@ -13,6 +13,9 @@ from docarray.index.backends.epsilla import EpsillaDocumentIndex # noqa: F401 from docarray.index.backends.hnswlib import HnswDocumentIndex # noqa: F401 from docarray.index.backends.milvus import MilvusDocumentIndex # noqa: F401 + from docarray.index.backends.mongodb_atlas import ( # noqa: F401 + MongoDBAtlasDocumentIndex, + ) from docarray.index.backends.qdrant import QdrantDocumentIndex # noqa: F401 from docarray.index.backends.redis import RedisDocumentIndex # noqa: F401 from docarray.index.backends.weaviate import WeaviateDocumentIndex # noqa: F401 @@ -26,6 +29,7 @@ 'WeaviateDocumentIndex', 'RedisDocumentIndex', 'MilvusDocumentIndex', + 'MongoDBAtlasDocumentIndex', ] @@ -55,6 +59,9 @@ def __getattr__(name: str): elif name == 'RedisDocumentIndex': import_library('redis', raise_error=True) import docarray.index.backends.redis as lib + elif name == 'MongoDBAtlasDocumentIndex': + import_library('pymongo', raise_error=True) + import docarray.index.backends.mongodb_atlas as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py new file mode 100644 index 0000000000..caaa82742f --- /dev/null +++ b/docarray/index/backends/mongodb_atlas.py @@ -0,0 +1,517 @@ +import collections +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from functools import cached_property + +from typing import ( + Any, + Dict, + Generator, + Generic, + List, + Optional, + Sequence, + Type, + TypeVar, + Union, + Tuple, +) + +import bson +import numpy as np +from pymongo import MongoClient + +from docarray import BaseDoc, DocList +from docarray.index.abstract import BaseDocIndex, _raise_not_composable +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal._typing import safe_issubclass +from docarray.utils.find import _FindResult, _FindResultBatched + +MAX_CANDIDATES = 10_000 +OVERSAMPLING_FACTOR = 10 +TSchema = TypeVar('TSchema', bound=BaseDoc) + + +class MongoDBAtlasDocumentIndex(BaseDocIndex, Generic[TSchema]): + def __init__(self, db_config=None, **kwargs): + super().__init__(db_config=db_config, **kwargs) + self._logger = logging.getLogger(__name__) + self._create_indexes() + self._logger.info(f'{self.__class__.__name__} has been initialized') + + @property + def _collection(self): + if self._is_subindex: + return self._db_config.index_name + + if not self._schema: + raise ValueError( + 'A MongoDBAtlasDocumentIndex must be typed with a Document type.' + 'To do so, use the syntax: MongoDBAtlasDocumentIndex[DocumentType]' + ) + + return self._schema.__name__.lower() + + @property + def index_name(self): + """Return the name of the index in the database.""" + return self._collection + + @property + def _database_name(self): + return self._db_config.database_name + + @cached_property + def _client(self): + return self._connect_to_mongodb_atlas( + atlas_connection_uri=self._db_config.mongo_connection_uri + ) + + @property + def _doc_collection(self): + return self._client[self._database_name][self._collection] + + @staticmethod + def _connect_to_mongodb_atlas(atlas_connection_uri: str): + """ + Establish a connection to MongoDB Atlas. + """ + + client = MongoClient( + atlas_connection_uri, + # driver=DriverInfo(name="docarray", version=version("docarray")) + ) + return client + + def _create_indexes(self): + """Create a new index in the MongoDB database if it doesn't already exist.""" + self._logger.warning( + "Search Indexes in MongoDB Atlas must be created manually. " + "Currently, client-side creation of vector indexes is not allowed on free clusters." + "Please follow instructions in docs/API_reference/doc_index/backends/mongodb.md" + ) + + class QueryBuilder(BaseDocIndex.QueryBuilder): + ... + + find = _raise_not_composable('find') + filter = _raise_not_composable('filter') + text_search = _raise_not_composable('text_search') + find_batched = _raise_not_composable('find_batched') + filter_batched = _raise_not_composable('filter_batched') + text_search_batched = _raise_not_composable('text_search_batched') + + def execute_query(self, query: Any, *args, **kwargs) -> _FindResult: + """ + Execute a query on the database. + Can take two kinds of inputs: + 1. A native query of the underlying database. This is meant as a passthrough so that you + can enjoy any functionality that is not available through the Document index API. + 2. The output of this Document index' `QueryBuilder.build()` method. + :param query: the query to execute + :param args: positional arguments to pass to the query + :param kwargs: keyword arguments to pass to the query + :return: the result of the query + """ + ... + + @dataclass + class DBConfig(BaseDocIndex.DBConfig): + mongo_connection_uri: str = 'localhost' + index_name: Optional[str] = None + database_name: Optional[str] = "db" + default_column_config: Dict[Type, Dict[str, Any]] = field( + default_factory=lambda: defaultdict( + dict, + { + bson.BSONARR: { + 'distance': 'COSINE', + 'oversample_factor': OVERSAMPLING_FACTOR, + 'max_candidates': MAX_CANDIDATES, + 'indexed': False, + 'index_name': None, + 'penalty': 1, + }, + bson.BSONSTR: { + 'indexed': False, + 'index_name': None, + 'operator': 'phrase', + 'penalty': 10, + }, + }, + ) + ) + + @dataclass + class RuntimeConfig(BaseDocIndex.RuntimeConfig): + pass + + def python_type_to_db_type(self, python_type: Type) -> Any: + """Map python type to database type. + Takes any python type and returns the corresponding database column type. + + :param python_type: a python type. + :return: the corresponding database column type, + or None if ``python_type`` is not supported. + """ + + type_map = { + int: bson.BSONNUM, + float: bson.BSONDEC, + collections.OrderedDict: bson.BSONOBJ, + str: bson.BSONSTR, + bytes: bson.BSONBIN, + dict: bson.BSONOBJ, + np.ndarray: bson.BSONARR, + AbstractTensor: bson.BSONARR, + } + + for py_type, mongo_types in type_map.items(): + if safe_issubclass(python_type, py_type): + return mongo_types + raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') + + def _doc_to_mongo(self, doc): + result = doc.copy() + + for name in result: + if self._column_infos[name].db_type == bson.BSONARR: + result[name] = list(result[name]) + + result["_id"] = result.pop("id") + return result + + def _docs_to_mongo(self, docs): + return [self._doc_to_mongo(doc) for doc in docs] + + @staticmethod + def _mongo_to_doc(mongo_doc: dict) -> Tuple[dict, float]: + result = mongo_doc.copy() + result["id"] = result.pop("_id") + score = result.pop("score", None) + return result, score + + @staticmethod + def _mongo_to_docs( + mongo_docs: Generator[Dict, None, None] + ) -> Tuple[List[dict], List[float]]: + docs = [] + scores = [] + for mongo_doc in mongo_docs: + doc, score = MongoDBAtlasDocumentIndex._mongo_to_doc(mongo_doc) + docs.append(doc) + scores.append(score) + + return docs, scores + + def _get_oversampling_factor(self, search_field: str) -> int: + return self._column_infos[search_field].config["oversample_factor"] + + def _get_max_candidates(self, search_field: str) -> int: + return self._column_infos[search_field].config["max_candidates"] + + def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): + """index a document into the store""" + # `column_to_data` is a dictionary from column name to a generator + # that yields the data for that column. + # If you want to work directly on documents, you can implement index() instead + # If you implement index(), _index() only needs a dummy implementation. + self._index_subindex(column_to_data) + docs: List[Dict[str, Any]] = [] + while True: + try: + doc = {key: next(column_to_data[key]) for key in column_to_data} + mongo_doc = self._doc_to_mongo(doc) + docs.append(mongo_doc) + except StopIteration: + break + self._doc_collection.insert_many(docs) + + def num_docs(self) -> int: + """Return the number of indexed documents""" + return self._doc_collection.count_documents({}) + + @property + def _is_index_empty(self) -> bool: + """ + Check if index is empty by comparing the number of documents to zero. + :return: True if the index is empty, False otherwise. + """ + return self.num_docs() == 0 + + def _del_items(self, doc_ids: Sequence[str]) -> None: + """Delete Documents from the index. + + :param doc_ids: ids to delete from the Document Store + """ + mg_filter = {"_id": {"$in": doc_ids}} + self._doc_collection.delete_many(mg_filter) + + def _get_items( + self, doc_ids: Sequence[str] + ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: + """Get Documents from the index, by `id`. + If no document is found, a KeyError is raised. + + :param doc_ids: ids to get from the Document index + :return: Sequence of Documents, sorted corresponding to the order of `doc_ids`. Duplicate `doc_ids` can be omitted in the output. + """ + mg_filter = {"_id": {"$in": doc_ids}} + docs = self._doc_collection.find(mg_filter) + docs, _ = self._mongo_to_docs(docs) + + if not docs: + raise KeyError(f'No document with id {doc_ids} found') + return docs + + def _vector_stage_search( + self, + query: np.ndarray, + search_field: str, + limit: int, + filters: List[Dict[str, Any]] = [], + ) -> Dict[str, Any]: + + index_name = self._get_column_db_index(search_field) + oversampling_factor = self._get_oversampling_factor(search_field) + max_candidates = self._get_max_candidates(search_field) + query = query.astype(np.float64).tolist() + + return { + '$vectorSearch': { + 'index': index_name, + 'path': search_field, + 'queryVector': query, + 'numCandidates': min(limit * oversampling_factor, max_candidates), + 'limit': limit, + 'filter': {"$and": filters} if filters else None, + } + } + + def _filter_query( + self, + query: Any, + ) -> Dict[str, Any]: + return query + + def _text_stage_step( + self, + query: str, + search_field: str, + ) -> Dict[str, Any]: + operator = self._column_infos[search_field].config["operator"] + index = self._get_column_db_index(search_field) + return { + "$search": { + "index": index, + operator: {"query": query, "path": search_field}, + } + } + + def _doc_exists(self, doc_id: str) -> bool: + """ + Checks if a given document exists in the index. + + :param doc_id: The id of a document to check. + :return: True if the document exists in the index, False otherwise. + """ + doc = self._doc_collection.find_one({"_id": doc_id}) + return bool(doc) + + def _find( + self, + query: np.ndarray, + limit: int, + search_field: str = '', + ) -> _FindResult: + """Find documents in the index + + :param query: query vector for KNN/ANN search. Has single axis. + :param limit: maximum number of documents to return per query + :param search_field: name of the field to search on + :return: a named NamedTuple containing `documents` and `scores` + """ + # NOTE: in standard implementations, + # `search_field` is equal to the column name to search on + + vector_search_stage = self._vector_stage_search(query, search_field, limit) + + pipeline = [ + vector_search_stage, + { + '$project': self._project_fields( + extra_fields={"score": {'$meta': 'vectorSearchScore'}} + ) + }, + ] + + with self._doc_collection.aggregate(pipeline) as cursor: + documents, scores = self._mongo_to_docs(cursor) + + return _FindResult(documents=documents, scores=scores) + + def _find_batched( + self, queries: np.ndarray, limit: int, search_field: str = '' + ) -> _FindResultBatched: + """Find documents in the index + + :param queries: query vectors for KNN/ANN search. + Has shape (batch_size, vector_dim) + :param limit: maximum number of documents to return + :param search_field: name of the field to search on + :return: a named NamedTuple containing `documents` and `scores` + """ + docs, scores = [], [] + for query in queries: + results = self._find(query=query, search_field=search_field, limit=limit) + docs.append(results.documents) + scores.append(results.scores) + + return _FindResultBatched(documents=docs, scores=scores) + + def _get_column_db_index(self, column_name: str) -> Optional[str]: + """ + Retrieve the index name associated with the specified column name. + + Parameters: + column_name (str): The name of the column. + + Returns: + Optional[str]: The index name associated with the specified column name, or None if not found. + """ + index_name = self._column_infos[column_name].config.get("index_name") + + is_vector_index = safe_issubclass( + self._column_infos[column_name].docarray_type, AbstractTensor + ) + is_text_index = safe_issubclass( + self._column_infos[column_name].docarray_type, str + ) + + if index_name is None or not isinstance(index_name, str): + if is_vector_index: + raise ValueError( + f'The column {column_name} for MongoDBAtlasDocumentIndex should be associated ' + 'with an Atlas Vector Index.' + ) + elif is_text_index: + raise ValueError( + f'The column {column_name} for MongoDBAtlasDocumentIndex should be associated ' + 'with an Atlas Index.' + ) + if not (is_vector_index or is_text_index): + raise ValueError( + f'The column {column_name} for MongoDBAtlasDocumentIndex cannot be associated to an index' + ) + + return index_name + + def _project_fields(self, extra_fields: Dict[str, Any] = None) -> dict: + """ + Create a projection dictionary to include all fields defined in the column information. + + Returns: + dict: A dictionary where each field key from the column information is mapped to the value 1, + indicating that the field should be included in the projection. + """ + + fields = {key: 1 for key in self._column_infos.keys() if key != "id"} + fields["_id"] = 1 + if extra_fields: + fields.update(extra_fields) + return fields + + def _filter( + self, + filter_query: Any, + limit: int, + ) -> Union[DocList, List[Dict]]: + """Find documents in the index based on a filter query + + :param filter_query: the DB specific filter query to execute + :param limit: maximum number of documents to return + :return: a DocList containing the documents that match the filter query + """ + with self._doc_collection.find(filter_query, limit=limit) as cursor: + return self._mongo_to_docs(cursor)[0] + + def _filter_batched( + self, + filter_queries: Any, + limit: int, + ) -> Union[List[DocList], List[List[Dict]]]: + """Find documents in the index based on multiple filter queries. + Each query is considered individually, and results are returned per query. + + :param filter_queries: the DB specific filter queries to execute + :param limit: maximum number of documents to return per query + :return: List of DocLists containing the documents that match the filter + queries + """ + return [self._filter(query, limit) for query in filter_queries] + + def _text_search( + self, + query: str, + limit: int, + search_field: str = '', + ) -> _FindResult: + """Find documents in the index based on a text search query + + :param query: The text to search for + :param limit: maximum number of documents to return + :param search_field: name of the field to search on + :return: a named Tuple containing `documents` and `scores` + """ + text_stage = self._text_stage_step(query=query, search_field=search_field) + + pipeline = [ + text_stage, + { + '$project': self._project_fields( + extra_fields={'score': {'$meta': 'searchScore'}} + ) + }, + {"$limit": limit}, + ] + + with self._doc_collection.aggregate(pipeline) as cursor: + documents, scores = self._mongo_to_docs(cursor) + + return _FindResult(documents=documents, scores=scores) + + def _text_search_batched( + self, + queries: Sequence[str], + limit: int, + search_field: str = '', + ) -> _FindResultBatched: + """Find documents in the index based on a text search query + + :param queries: The texts to search for + :param limit: maximum number of documents to return per query + :param search_field: name of the field to search on + :return: a named Tuple containing `documents` and `scores` + """ + # NOTE: in standard implementations, + # `search_field` is equal to the column name to search on + documents, scores = [], [] + for query in queries: + results = self._text_search( + query=query, search_field=search_field, limit=limit + ) + documents.append(results.documents) + scores.append(results.scores) + return _FindResultBatched(documents=documents, scores=scores) + + def _filter_by_parent_id(self, id: str) -> Optional[List[str]]: + """Filter the ids of the subindex documents given id of root document. + + :param id: the root document id to filter by + :return: a list of ids of the subindex documents + """ + with self._doc_collection.find( + {"parent_id": id}, projection={"_id": 1} + ) as cursor: + return [doc["_id"] for doc in cursor] diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index bb1e4ffe1d..b44da92dc7 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -2,7 +2,7 @@ import os import re import types -from typing import Any, Optional, Literal +from typing import Any, Literal, Optional import numpy as np @@ -50,6 +50,7 @@ 'botocore': '"docarray[aws]"', 'redis': '"docarray[redis]"', 'pymilvus': '"docarray[milvus]"', + "pymongo": '"docarray[mongo]"', } ProtocolType = Literal[ diff --git a/docs/API_reference/doc_index/backends/mongodb.md b/docs/API_reference/doc_index/backends/mongodb.md new file mode 100644 index 0000000000..0a7dc2f6ec --- /dev/null +++ b/docs/API_reference/doc_index/backends/mongodb.md @@ -0,0 +1,134 @@ +# MongoDBAtlasDocumentIndex + +::: docarray.index.backends.mongodb_atlas.MongoDBAtlasDocumentIndex + +# Setting up MongoDB Atlas as the Document Index + +MongoDB Atlas is a multi-cloud database service made by the same people that build MongoDB. +Atlas simplifies deploying and managing your databases while offering the versatility you need +to build resilient and performant global applications on the cloud providers of your choice. + +You can perform semantic search on data in your Atlas cluster running MongoDB v6.0.11 +or later using Atlas Vector Search. You can store vector embeddings for any kind of data along +with other data in your collection on the Atlas cluster. + +In the section, we set up a cluster, a database, test it, and finally create an Atlas Vector Search Index. + +### Deploy a Cluster + +Follow the [Getting-Started](https://www.mongodb.com/basics/mongodb-atlas-tutorial) documentation +to create an account, deploy an Atlas cluster, and connect to a database. + + +### Retrieve the URI used by Python to connect to the Cluster + +When you deploy, this will be stored as the environment variable: `MONGODB_URI` +It will look something like the following. The username and password, if not provided, +can be configured in *Database Access* under Security in the left panel. + +``` +export MONGODB_URI="mongodb+srv://:@cluster0.foo.mongodb.net/?retryWrites=true&w=majority" +``` + +There are a number of ways to navigate the Atlas UI. Keep your eye out for "Connect" and "Driver". + +On the left panel, navigate and click 'Database' under DEPLOYMENT. +Click the Connect button that appears, then Drivers. Select Python. +(Have no concern for the version. This is the PyMongo, not Python, version.) +Once you have got the Connect Window open, you will see an instruction to `pip install pymongo`. +You will also see a **connection string**. +This is the `uri` that a `pymongo.MongoClient` uses to connect to the Database. + + +### Test the connection + +Atlas provides a simple check. Once you have your `uri` and `pymongo` installed, +try the following in a python console. + +```python +from pymongo.mongo_client import MongoClient +client = MongoClient(uri) # Create a new client and connect to the server +try: + client.admin.command('ping') # Send a ping to confirm a successful connection + print("Pinged your deployment. You successfully connected to MongoDB!") +except Exception as e: + print(e) +``` + +**Troubleshooting** +* You can edit a Database's users and passwords on the 'Database Access' page, under Security. +* Remember to add your IP address. (Try `curl -4 ifconfig.co`) + +### Create a Database and Collection + +As mentioned, Vector Databases provide two functions. In addition to being the data store, +they provide very efficient search based on natural language queries. +With Vector Search, one will index and query data with a powerful vector search algorithm +using "Hierarchical Navigable Small World (HNSW) graphs to find vector similarity. + +The indexing runs beside the data as a separate service asynchronously. +The Search index monitors changes to the Collection that it applies to. +Subsequently, one need not upload the data first. +We will create an empty collection now, which will simplify setup in the example notebook. + +Back in the UI, navigate to the Database Deployments page by clicking Database on the left panel. +Click the "Browse Collections" and then "+ Create Database" buttons. +This will open a window where you choose Database and Collection names. (No additional preferences.) +Remember these values as they will be as the environment variables, +`MONGODB_DATABASE`. + +### MongoDBAtlasDocumentIndex + +To connect to the MongoDB Cluster and Database, define the following environment variables. +You can confirm that the required ones have been set like this: `assert "MONGODB_URI" in os.environ` + +**IMPORTANT** It is crucial that the choices are consistent between setup in Atlas and Python environment(s). + +| Name | Description | Example | +|-----------------------|-----------------------------|--------------------------------------------------------------| +| `MONGODB_URI` | Connection String | mongodb+srv://``:``@cluster0.bar.mongodb.net | +| `MONGODB_DATABASE` | Database name | docarray_test_db | + + +```python + +from docarray.index.backends.mongodb_atlas import MongoDBAtlasDocumentIndex +import os + +index = MongoDBAtlasDocumentIndex( + mongo_connection_uri=os.environ["MONGODB_URI"], + database_name=os.environ["MONGODB_DATABASE"]) +``` + + +### Create an Atlas Vector Search Index + +The final step to configure a MongoDBAtlasDocumentIndex is to create a Vector Search Indexes. +The procedure is described [here](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure). + +Under Services on the left panel, choose Atlas Search > Create Search Index > +Atlas Vector Search JSON Editor. An index definition looks like the following. + + +```json +{ + "fields": [ + { + "numDimensions": 1536, + "path": "embedding", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + + +### Running MongoDB Atlas Integration Tests + +Setup is described in detail here `tests/index/mongo_atlas/README.md`. +There are actually a number of different collections and indexes to be created within your cluster's database. + +```bash +MONGODB_URI= MONGODB_DATABASE= py.test tests/index/mongo_atlas/ +``` diff --git a/poetry.lock b/poetry.lock index 161e708cf9..9980ec6627 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiofiles" @@ -884,6 +884,26 @@ files = [ {file = "distlib-0.3.6.tar.gz", hash = "sha256:14bad2d9b04d3a36127ac97f30b12a19268f211063d8f8ee4f47108896e11b46"}, ] +[[package]] +name = "dnspython" +version = "2.6.1" +description = "DNS toolkit" +optional = true +python-versions = ">=3.8" +files = [ + {file = "dnspython-2.6.1-py3-none-any.whl", hash = "sha256:5ef3b9680161f6fa89daf8ad451b5f1a33b18ae8a1c6778cdf4b43f08c0a6e50"}, + {file = "dnspython-2.6.1.tar.gz", hash = "sha256:e8f0f9c23a7b7cb99ded64e6c3a6f3e701d78f50c55e002b839dea7225cff7cc"}, +] + +[package.extras] +dev = ["black (>=23.1.0)", "coverage (>=7.0)", "flake8 (>=7)", "mypy (>=1.8)", "pylint (>=3)", "pytest (>=7.4)", "pytest-cov (>=4.1.0)", "sphinx (>=7.2.0)", "twine (>=4.0.0)", "wheel (>=0.42.0)"] +dnssec = ["cryptography (>=41)"] +doh = ["h2 (>=4.1.0)", "httpcore (>=1.0.0)", "httpx (>=0.26.0)"] +doq = ["aioquic (>=0.9.25)"] +idna = ["idna (>=3.6)"] +trio = ["trio (>=0.23)"] +wmi = ["wmi (>=1.5.1)"] + [[package]] name = "docker" version = "6.0.1" @@ -3583,6 +3603,109 @@ pandas = ">=1.2.4" protobuf = ">=3.20.0" ujson = ">=2.0.0" +[[package]] +name = "pymongo" +version = "4.6.2" +description = "Python driver for MongoDB " +optional = true +python-versions = ">=3.7" +files = [ + {file = "pymongo-4.6.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7640d176ee5b0afec76a1bda3684995cb731b2af7fcfd7c7ef8dc271c5d689af"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux1_i686.whl", hash = "sha256:4e2129ec8f72806751b621470ac5d26aaa18fae4194796621508fa0e6068278a"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c43205e85cbcbdf03cff62ad8f50426dd9d20134a915cfb626d805bab89a1844"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_i686.whl", hash = "sha256:91ddf95cedca12f115fbc5f442b841e81197d85aa3cc30b82aee3635a5208af2"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_ppc64le.whl", hash = "sha256:0fbdbf2fba1b4f5f1522e9f11e21c306e095b59a83340a69e908f8ed9b450070"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_s390x.whl", hash = "sha256:097791d5a8d44e2444e0c8c4d6e14570ac11e22bcb833808885a5db081c3dc2a"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:e0b208ebec3b47ee78a5c836e2e885e8c1e10f8ffd101aaec3d63997a4bdcd04"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1849fd6f1917b4dc5dbf744b2f18e41e0538d08dd8e9ba9efa811c5149d665a3"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa0bbbfbd1f8ebbd5facaa10f9f333b20027b240af012748555148943616fdf3"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4522ad69a4ab0e1b46a8367d62ad3865b8cd54cf77518c157631dac1fdc97584"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397949a9cc85e4a1452f80b7f7f2175d557237177120954eff00bf79553e89d3"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d511db310f43222bc58d811037b176b4b88dc2b4617478c5ef01fea404f8601"}, + {file = "pymongo-4.6.2-cp310-cp310-win32.whl", hash = "sha256:991e406db5da4d89fb220a94d8caaf974ffe14ce6b095957bae9273c609784a0"}, + {file = "pymongo-4.6.2-cp310-cp310-win_amd64.whl", hash = "sha256:94637941fe343000f728e28d3fe04f1f52aec6376b67b85583026ff8dab2a0e0"}, + {file = "pymongo-4.6.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:84593447a5c5fe7a59ba86b72c2c89d813fbac71c07757acdf162fbfd5d005b9"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9aebddb2ec2128d5fc2fe3aee6319afef8697e0374f8a1fcca3449d6f625e7b4"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f706c1a644ed33eaea91df0a8fb687ce572b53eeb4ff9b89270cb0247e5d0e1"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18c422e6b08fa370ed9d8670c67e78d01f50d6517cec4522aa8627014dfa38b6"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d002ae456a15b1d790a78bb84f87af21af1cb716a63efb2c446ab6bcbbc48ca"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f86ba0c781b497a3c9c886765d7b6402a0e3ae079dd517365044c89cd7abb06"}, + {file = "pymongo-4.6.2-cp311-cp311-win32.whl", hash = "sha256:ac20dd0c7b42555837c86f5ea46505f35af20a08b9cf5770cd1834288d8bd1b4"}, + {file = "pymongo-4.6.2-cp311-cp311-win_amd64.whl", hash = "sha256:e78af59fd0eb262c2a5f7c7d7e3b95e8596a75480d31087ca5f02f2d4c6acd19"}, + {file = "pymongo-4.6.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:6125f73503407792c8b3f80165f8ab88a4e448d7d9234c762681a4d0b446fcb4"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba052446a14bd714ec83ca4e77d0d97904f33cd046d7bb60712a6be25eb31dbb"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b65433c90e07dc252b4a55dfd885ca0df94b1cf77c5b8709953ec1983aadc03"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2160d9c8cd20ce1f76a893f0daf7c0d38af093f36f1b5c9f3dcf3e08f7142814"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f251f287e6d42daa3654b686ce1fcb6d74bf13b3907c3ae25954978c70f2cd4"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d7d227a60b00925dd3aeae4675575af89c661a8e89a1f7d1677e57eba4a3693c"}, + {file = "pymongo-4.6.2-cp312-cp312-win32.whl", hash = "sha256:311794ef3ccae374aaef95792c36b0e5c06e8d5cf04a1bdb1b2bf14619ac881f"}, + {file = "pymongo-4.6.2-cp312-cp312-win_amd64.whl", hash = "sha256:f673b64a0884edcc56073bda0b363428dc1bf4eb1b5e7d0b689f7ec6173edad6"}, + {file = "pymongo-4.6.2-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:fe010154dfa9e428bd2fb3e9325eff2216ab20a69ccbd6b5cac6785ca2989161"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:1f5f4cd2969197e25b67e24d5b8aa2452d381861d2791d06c493eaa0b9c9fcfe"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:c9519c9d341983f3a1bd19628fecb1d72a48d8666cf344549879f2e63f54463b"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:c68bf4a399e37798f1b5aa4f6c02886188ef465f4ac0b305a607b7579413e366"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:a509db602462eb736666989739215b4b7d8f4bb8ac31d0bffd4be9eae96c63ef"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:362a5adf6f3f938a8ff220a4c4aaa93e84ef932a409abecd837c617d17a5990f"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:ee30a9d4c27a88042d0636aca0275788af09cc237ae365cd6ebb34524bddb9cc"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:477914e13501bb1d4608339ee5bb618be056d2d0e7267727623516cfa902e652"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebd343ca44982d480f1e39372c48e8e263fc6f32e9af2be456298f146a3db715"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c3797e0a628534e07a36544d2bfa69e251a578c6d013e975e9e3ed2ac41f2d95"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97d81d357e1a2a248b3494d52ebc8bf15d223ee89d59ee63becc434e07438a24"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed694c0d1977cb54281cb808bc2b247c17fb64b678a6352d3b77eb678ebe1bd9"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ceaaff4b812ae368cf9774989dea81b9bbb71e5bed666feca6a9f3087c03e49"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7dd63f7c2b3727541f7f37d0fb78d9942eb12a866180fbeb898714420aad74e2"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e571434633f99a81e081738721bb38e697345281ed2f79c2f290f809ba3fbb2f"}, + {file = "pymongo-4.6.2-cp37-cp37m-win32.whl", hash = "sha256:3e9f6e2f3da0a6af854a3e959a6962b5f8b43bbb8113cd0bff0421c5059b3106"}, + {file = "pymongo-4.6.2-cp37-cp37m-win_amd64.whl", hash = "sha256:3a5280f496297537301e78bde250c96fadf4945e7b2c397d8bb8921861dd236d"}, + {file = "pymongo-4.6.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:5f6bcd2d012d82d25191a911a239fd05a8a72e8c5a7d81d056c0f3520cad14d1"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:4fa30494601a6271a8b416554bd7cde7b2a848230f0ec03e3f08d84565b4bf8c"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bea62f03a50f363265a7a651b4e2a4429b4f138c1864b2d83d4bf6f9851994be"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b2d445f1cf147331947cc35ec10342f898329f29dd1947a3f8aeaf7e0e6878d1"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:5db133d6ec7a4f7fc7e2bd098e4df23d7ad949f7be47b27b515c9fb9301c61e4"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:9eec7140cf7513aa770ea51505d312000c7416626a828de24318fdcc9ac3214c"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:5379ca6fd325387a34cda440aec2bd031b5ef0b0aa2e23b4981945cff1dab84c"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:579508536113dbd4c56e4738955a18847e8a6c41bf3c0b4ab18b51d81a6b7be8"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3bae553ca39ed52db099d76acd5e8566096064dc7614c34c9359bb239ec4081"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0257e0eebb50f242ca28a92ef195889a6ad03dcdde5bf1c7ab9f38b7e810801"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbafe3a1df21eeadb003c38fc02c1abf567648b6477ec50c4a3c042dca205371"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aaecfafb407feb6f562c7f2f5b91f22bfacba6dd739116b1912788cff7124c4a"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e942945e9112075a84d2e2d6e0d0c98833cdcdfe48eb8952b917f996025c7ffa"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2f7b98f8d2cf3eeebde738d080ae9b4276d7250912d9751046a9ac1efc9b1ce2"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:8110b78fc4b37dced85081d56795ecbee6a7937966e918e05e33a3900e8ea07d"}, + {file = "pymongo-4.6.2-cp38-cp38-win32.whl", hash = "sha256:df813f0c2c02281720ccce225edf39dc37855bf72cdfde6f789a1d1cf32ffb4b"}, + {file = "pymongo-4.6.2-cp38-cp38-win_amd64.whl", hash = "sha256:64ec3e2dcab9af61bdbfcb1dd863c70d1b0c220b8e8ac11df8b57f80ee0402b3"}, + {file = "pymongo-4.6.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bff601fbfcecd2166d9a2b70777c2985cb9689e2befb3278d91f7f93a0456cae"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux1_i686.whl", hash = "sha256:f1febca6f79e91feafc572906871805bd9c271b6a2d98a8bb5499b6ace0befed"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:d788cb5cc947d78934be26eef1623c78cec3729dc93a30c23f049b361aa6d835"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5c2f258489de12a65b81e1b803a531ee8cf633fa416ae84de65cd5f82d2ceb37"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:fb24abcd50501b25d33a074c1790a1389b6460d2509e4b240d03fd2e5c79f463"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:4d982c6db1da7cf3018183891883660ad085de97f21490d314385373f775915b"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:b2dd8c874927a27995f64a3b44c890e8a944c98dec1ba79eab50e07f1e3f801b"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4993593de44c741d1e9f230f221fe623179f500765f9855936e4ff6f33571bad"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:658f6c028edaeb02761ebcaca8d44d519c22594b2a51dcbc9bd2432aa93319e3"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:68109c13176749fbbbbbdb94dd4a58dcc604db6ea43ee300b2602154aebdd55f"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:707d28a822b918acf941cff590affaddb42a5d640614d71367c8956623a80cbc"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f251db26c239aec2a4d57fbe869e0a27b7f6b5384ec6bf54aeb4a6a5e7408234"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57c05f2e310701fc17ae358caafd99b1830014e316f0242d13ab6c01db0ab1c2"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2b575fbe6396bbf21e4d0e5fd2e3cdb656dc90c930b6c5532192e9a89814f72d"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:ca5877754f3fa6e4fe5aacf5c404575f04c2d9efc8d22ed39576ed9098d555c8"}, + {file = "pymongo-4.6.2-cp39-cp39-win32.whl", hash = "sha256:8caa73fb19070008e851a589b744aaa38edd1366e2487284c61158c77fdf72af"}, + {file = "pymongo-4.6.2-cp39-cp39-win_amd64.whl", hash = "sha256:3e03c732cb64b96849310e1d8688fb70d75e2571385485bf2f1e7ad1d309fa53"}, + {file = "pymongo-4.6.2.tar.gz", hash = "sha256:ab7d01ac832a1663dad592ccbd92bb0f0775bc8f98a1923c5e1a7d7fead495af"}, +] + +[package.dependencies] +dnspython = ">=1.16.0,<3.0.0" + +[package.extras] +aws = ["pymongo-auth-aws (<2.0.0)"] +encryption = ["certifi", "pymongo[aws]", "pymongocrypt (>=1.6.0,<2.0.0)"] +gssapi = ["pykerberos", "winkerberos (>=0.5.0)"] +ocsp = ["certifi", "cryptography (>=2.5)", "pyopenssl (>=17.2.0)", "requests (<3.0.0)", "service-identity (>=18.1.0)"] +snappy = ["python-snappy"] +test = ["pytest (>=7)"] +zstd = ["zstandard"] + [[package]] name = "pyparsing" version = "3.0.9" @@ -5461,6 +5584,7 @@ jac = ["jina-hubble-sdk"] jax = ["jax"] mesh = ["trimesh"] milvus = ["pymilvus"] +mongo = ["pymongo"] pandas = ["pandas"] proto = ["lz4", "protobuf"] qdrant = ["qdrant-client"] @@ -5473,4 +5597,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "469714891dd7e3e6ddb406402602f0b1bb09215bfbd3fd8d237a061a0f6b3167" +content-hash = "afd26d2453ce8edd6f5021193af4bfd2a449de2719e5fe67bcaea2fbcc98d055" diff --git a/pyproject.toml b/pyproject.toml index 7e9837fe9a..26d1a04766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ pymilvus = {version = "^2.2.12", optional = true } redis = {version = "^4.6.0", optional = true} jax = {version = ">=0.4.10", optional = true} pyepsilla = {version = ">=0.2.3", optional = true} +pymongo = {version = ">=4.6.2", optional = true} [tool.poetry.extras] proto = ["protobuf", "lz4"] @@ -82,6 +83,7 @@ milvus = ["pymilvus"] redis = ['redis'] jax = ["jaxlib","jax"] epsilla = ["pyepsilla"] +mongo = ["pymongo"] # all full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh", "jax"] diff --git a/tests/index/mongo_atlas/README.md b/tests/index/mongo_atlas/README.md new file mode 100644 index 0000000000..fd14ff491f --- /dev/null +++ b/tests/index/mongo_atlas/README.md @@ -0,0 +1,159 @@ +# Setup of Atlas Required + +To run Integration tests, one will first need to create the following **Collections** and **Search Indexes** +with the `MONGODB_DATABASE` in the cluster connected to with your `MONGODB_URI`. + +Instructions of how to accomplish this in your browser are given in +`docs/API_reference/doc_index/backends/mongodb.md`. + + +Below is the mapping of collections to indexes along with their definitions. + +| Collection | Index Name | JSON Definition | Tests +|---------------------------|----------------|--------------------|---------------------------------| +| simpleschema | vector_index | [1] | test_filter,test_find,test_index_get_del, test_persist_data, test_text_search | +| mydoc__docs | vector_index | [2] | test_subindex | +| mydoc__list_docs__docs | vector_index | [3] | test_subindex | +| flatschema | vector_index_1 | [4] | test_find | +| flatschema | vector_index_2 | [5] | test_find | +| nesteddoc | vector_index_1 | [6] | test_find | +| nesteddoc | vector_index | [7] | test_find | +| simpleschema | text_index | [8] | test_text_search | + + +And here are the JSON definition references: + +[1] Collection: `simpleschema` Index name: `vector_index` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "embedding", + "similarity": "cosine", + "type": "vector" + }, + { + "path": "number", + "type": "filter" + }, + { + "path": "text", + "type": "filter" + } + ] +} +``` + +[2] Collection: `mydoc__docs` Index name: `vector_index` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "simple_tens", + "similarity": "euclidean", + "type": "vector" + } + ] +} +``` + +[3] Collection: `mydoc__list_docs__docs` Index name: `vector_index` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "simple_tens", + "similarity": "euclidean", + "type": "vector" + } + ] +} +``` + +[4] Collection: `flatschema` Index name: `vector_index_1` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "embedding1", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + +[5] Collection: `flatschema` Index name: `vector_index_2` +```json +{ + "fields": [ + { + "numDimensions": 50, + "path": "embedding2", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + +[6] Collection: `nesteddoc` Index name: `vector_index_1` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "d__embedding", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + +[7] Collection: `nesteddoc` Index name: `vector_index` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "embedding", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + +[8] Collection: `simpleschema` Index name: `text_index` + +```json +{ + "mappings": { + "dynamic": false, + "fields": { + "text": [ + { + "type": "string" + } + ] + } + } +} +``` + +NOTE: that all but this final one (8) are Vector Search Indexes. 8 is a Text Search Index. + + +With these in place you should be able to successfully run all of the tests as follows. + +```bash +MONGODB_URI= MONGODB_DATABASE= py.test tests/index/mongo_atlas/ +``` + +IMPORTANT: FREE clusters are limited to 3 search indexes. +As such, you may have to (re)create accordingly. \ No newline at end of file diff --git a/tests/index/mongo_atlas/__init__.py b/tests/index/mongo_atlas/__init__.py new file mode 100644 index 0000000000..352060a305 --- /dev/null +++ b/tests/index/mongo_atlas/__init__.py @@ -0,0 +1,46 @@ +import time +from typing import Callable + +from pydantic import Field + +from docarray import BaseDoc +from docarray.typing import NdArray + +N_DIM = 10 + + +class SimpleSchema(BaseDoc): + text: str = Field(index_name='text_index') + number: int + embedding: NdArray[10] = Field(dim=10, index_name="vector_index") + + +class SimpleDoc(BaseDoc): + embedding: NdArray[N_DIM] = Field(dim=N_DIM, index_name="vector_index_1") + + +class NestedDoc(BaseDoc): + d: SimpleDoc + embedding: NdArray[N_DIM] = Field(dim=N_DIM, index_name="vector_index") + + +class FlatSchema(BaseDoc): + embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") + # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim + embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") + + +def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 2): + """ + Retry callable to account for time taken to change data on the cluster + """ + while True: + try: + callable() + except AssertionError: + tries -= 1 + if tries == 0: + raise + time.sleep(interval) + else: + return diff --git a/tests/index/mongo_atlas/conftest.py b/tests/index/mongo_atlas/conftest.py new file mode 100644 index 0000000000..727fabb1f5 --- /dev/null +++ b/tests/index/mongo_atlas/conftest.py @@ -0,0 +1,103 @@ +import os + +import numpy as np +import pytest + +from docarray.index import MongoDBAtlasDocumentIndex + +from . import NestedDoc, SimpleDoc, SimpleSchema + + +@pytest.fixture(scope='session') +def mongodb_index_config(): + return { + "mongo_connection_uri": os.environ["MONGODB_URI"], + "database_name": os.environ["MONGODB_DATABASE"], + } + + +@pytest.fixture +def simple_index(mongodb_index_config): + + index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + return index + + +@pytest.fixture +def nested_index(mongodb_index_config): + index = MongoDBAtlasDocumentIndex[NestedDoc](**mongodb_index_config) + return index + + +@pytest.fixture(scope='module') +def random_simple_documents(): + N_DIM = 10 + docs_text = [ + "Text processing with Python is a valuable skill for data analysis.", + "Gardening tips for a beautiful backyard oasis.", + "Explore the wonders of deep-sea diving in tropical locations.", + "The history and art of classical music compositions.", + "An introduction to the world of gourmet cooking.", + "Integer pharetra, leo quis aliquam hendrerit, arcu ante sagittis massa, nec tincidunt arcu.", + "Sed luctus convallis velit sit amet laoreet. Morbi sit amet magna pellentesque urna tincidunt", + "luctus enim interdum lacinia. Morbi maximus diam id justo egestas pellentesque. Suspendisse", + "id laoreet odio gravida vitae. Vivamus feugiat nisi quis est pellentesque interdum. Integer", + "eleifend eros non, accumsan lectus. Curabitur porta auctor tellus at pharetra. Phasellus ut condimentum", + ] + return [ + SimpleSchema(embedding=np.random.rand(N_DIM), number=i, text=docs_text[i]) + for i in range(10) + ] + + +@pytest.fixture +def nested_documents(): + N_DIM = 10 + docs = [ + NestedDoc( + d=SimpleDoc(embedding=np.random.rand(N_DIM)), + embedding=np.random.rand(N_DIM), + ) + for _ in range(10) + ] + docs.append( + NestedDoc( + d=SimpleDoc(embedding=np.zeros(N_DIM)), + embedding=np.ones(N_DIM), + ) + ) + docs.append( + NestedDoc( + d=SimpleDoc(embedding=np.ones(N_DIM)), + embedding=np.zeros(N_DIM), + ) + ) + docs.append( + NestedDoc( + d=SimpleDoc(embedding=np.zeros(N_DIM)), + embedding=np.ones(N_DIM), + ) + ) + return docs + + +@pytest.fixture +def simple_index_with_docs(simple_index, random_simple_documents): + """ + Setup and teardown of simple_index. Accesses the underlying MongoDB collection directly. + """ + simple_index._doc_collection.delete_many({}) + simple_index.index(random_simple_documents) + yield simple_index, random_simple_documents + simple_index._doc_collection.delete_many({}) + + +@pytest.fixture +def nested_index_with_docs(nested_index, nested_documents): + """ + Setup and teardown of simple_index. Accesses the underlying MongoDB collection directly. + """ + nested_index._doc_collection.delete_many({}) + nested_index.index(nested_documents) + yield nested_index, nested_documents + nested_index._doc_collection.delete_many({}) diff --git a/tests/index/mongo_atlas/test_configurations.py b/tests/index/mongo_atlas/test_configurations.py new file mode 100644 index 0000000000..20b4d5f979 --- /dev/null +++ b/tests/index/mongo_atlas/test_configurations.py @@ -0,0 +1,16 @@ +from . import assert_when_ready + + +# move +def test_num_docs(simple_index_with_docs): # noqa: F811 + index, docs = simple_index_with_docs + + def pred(): + assert index.num_docs() == 10 + + assert_when_ready(pred) + + +# Currently, pymongo cannot create atlas vector search indexes. +def test_configure_index(simple_index): # noqa: F811 + pass diff --git a/tests/index/mongo_atlas/test_filter.py b/tests/index/mongo_atlas/test_filter.py new file mode 100644 index 0000000000..e9ed21bd32 --- /dev/null +++ b/tests/index/mongo_atlas/test_filter.py @@ -0,0 +1,22 @@ +def test_filter(simple_index_with_docs): # noqa: F811 + + db, base_docs = simple_index_with_docs + + docs = db.filter(filter_query={"number": {"$lt": 1}}) + assert len(docs) == 1 + assert docs[0].number == 0 + + docs = db.filter(filter_query={"number": {"$gt": 8}}) + assert len(docs) == 1 + assert docs[0].number == 9 + + docs = db.filter(filter_query={"number": {"$lt": 8, "$gt": 3}}) + assert len(docs) == 4 + + docs = db.filter(filter_query={"text": {"$regex": "introduction"}}) + assert len(docs) == 1 + assert 'introduction' in docs[0].text.lower() + + docs = db.filter(filter_query={"text": {"$not": {"$regex": "Explore"}}}) + assert len(docs) == 9 + assert all("Explore" not in doc.text for doc in docs) diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py new file mode 100644 index 0000000000..aadfacb454 --- /dev/null +++ b/tests/index/mongo_atlas/test_find.py @@ -0,0 +1,147 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import MongoDBAtlasDocumentIndex +from docarray.typing import NdArray + +from . import NestedDoc, SimpleDoc, SimpleSchema, assert_when_ready + +N_DIM = 10 + + +def test_find_simple_schema(simple_index_with_docs): # noqa: F811 + + simple_index, random_simple_documents = simple_index_with_docs # noqa: F811 + query = np.ones(N_DIM) + + # Insert one doc that identically matches query's embedding + expected_matching_document = SimpleSchema(embedding=query, text="other", number=10) + simple_index.index(expected_matching_document) + + def pred(): + docs, scores = simple_index.find(query, search_field='embedding', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].embedding, expected_matching_document.embedding) + + assert_when_ready(pred) + + +def test_find_empty_index(simple_index): # noqa: F811 + query = np.random.rand(N_DIM) + + def pred(): + docs, scores = simple_index.find(query, search_field='embedding', limit=5) + assert len(docs) == 0 + assert len(scores) == 0 + + assert_when_ready(pred) + + +def test_find_limit_larger_than_index(simple_index_with_docs): # noqa: F811 + simple_index, random_simple_documents = simple_index_with_docs # noqa: F811 + + query = np.ones(N_DIM) + new_doc = SimpleSchema(embedding=query, text="other", number=10) + + simple_index.index(new_doc) + + def pred(): + docs, scores = simple_index.find(query, search_field='embedding', limit=20) + assert len(docs) == 11 + assert len(scores) == 11 + + assert_when_ready(pred) + + +def test_find_flat_schema(mongodb_index_config): # noqa: F811 + class FlatSchema(BaseDoc): + embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") + # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim + embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") + + index = MongoDBAtlasDocumentIndex[FlatSchema](**mongodb_index_config) + + index._doc_collection.delete_many({}) + + index_docs = [ + FlatSchema(embedding1=np.random.rand(N_DIM), embedding2=np.random.rand(50)) + for _ in range(10) + ] + + index_docs.append(FlatSchema(embedding1=np.zeros(N_DIM), embedding2=np.ones(50))) + index_docs.append(FlatSchema(embedding1=np.ones(N_DIM), embedding2=np.zeros(50))) + index.index(index_docs) + + def pred1(): + + # find on embedding1 + query = np.ones(N_DIM) + docs, scores = index.find(query, search_field='embedding1', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].embedding1, index_docs[-1].embedding1) + assert np.allclose(docs[0].embedding2, index_docs[-1].embedding2) + + assert_when_ready(pred1) + + def pred2(): + # find on embedding2 + query = np.ones(50) + docs, scores = index.find(query, search_field='embedding2', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].embedding1, index_docs[-2].embedding1) + assert np.allclose(docs[0].embedding2, index_docs[-2].embedding2) + + assert_when_ready(pred2) + + +def test_find_batches(simple_index_with_docs): # noqa: F811 + + simple_index, docs = simple_index_with_docs # noqa: F811 + queries = np.array([np.random.rand(10) for _ in range(3)]) + + def pred(): + resp = simple_index.find_batched( + queries=queries, search_field='embedding', limit=10 + ) + docs_responses = resp.documents + assert len(docs_responses) == 3 + for matches in docs_responses: + assert len(matches) == 10 + + assert_when_ready(pred) + + +def test_find_nested_schema(nested_index_with_docs): # noqa: F811 + db, base_docs = nested_index_with_docs + + query = NestedDoc(d=SimpleDoc(embedding=np.ones(N_DIM)), embedding=np.ones(N_DIM)) + + # find on root level + def pred(): + docs, scores = db.find(query, search_field='embedding', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].embedding, base_docs[-1].embedding) + + # find on first nesting level + docs, scores = db.find(query, search_field='d__embedding', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].d.embedding, base_docs[-2].d.embedding) + + assert_when_ready(pred) + + +def test_find_schema_without_index(mongodb_index_config): # noqa: F811 + class Schema(BaseDoc): + vec: NdArray = Field(dim=N_DIM) + + index = MongoDBAtlasDocumentIndex[Schema](**mongodb_index_config) + query = np.ones(N_DIM) + with pytest.raises(ValueError): + index.find(query, search_field='vec', limit=2) diff --git a/tests/index/mongo_atlas/test_index_get_del.py b/tests/index/mongo_atlas/test_index_get_del.py new file mode 100644 index 0000000000..81935ebd1d --- /dev/null +++ b/tests/index/mongo_atlas/test_index_get_del.py @@ -0,0 +1,109 @@ +import numpy as np +import pytest + +from . import SimpleSchema, assert_when_ready + +N_DIM = 10 + + +def test_num_docs(simple_index_with_docs): # noqa: F811 + index, docs = simple_index_with_docs + query = np.ones(N_DIM) + + def check_n_elements(n): + def pred(): + return index.num_docs() == 10 + + return pred + + assert_when_ready(check_n_elements(10)) + + del index[docs[0].id] + + assert_when_ready(check_n_elements(9)) + + del index[docs[3].id, docs[5].id] + + assert_when_ready(check_n_elements(7)) + + elems = [SimpleSchema(embedding=query, text="other", number=10) for _ in range(3)] + index.index(elems) + + assert_when_ready(check_n_elements(10)) + + del index[elems[0].id, elems[1].id] + + def check_ramaining_ids(): + assert index.num_docs() == 8 + # get everything + elem_ids = set( + doc.id + for doc in index.find(query, search_field='embedding', limit=30).documents + ) + expected_ids = {doc.id for i, doc in enumerate(docs) if i not in (3, 5, 0)} + expected_ids.add(elems[2].id) + assert elem_ids == expected_ids + + assert_when_ready(check_ramaining_ids) + + +def test_get_single(simple_index_with_docs): # noqa: F811 + + index, docs = simple_index_with_docs + + expected_doc = docs[5] + retrieved_doc = index[expected_doc.id] + + assert retrieved_doc.id == expected_doc.id + assert np.allclose(retrieved_doc.embedding, expected_doc.embedding) + + with pytest.raises(KeyError): + index['An id that does not exist'] + + +def test_get_multiple(simple_index_with_docs): # noqa: F811 + index, docs = simple_index_with_docs + + # get the odd documents + docs_to_get = [doc for i, doc in enumerate(docs) if i % 2 == 1] + retrieved_docs = index[[doc.id for doc in docs_to_get]] + assert set(doc.id for doc in docs_to_get) == set(doc.id for doc in retrieved_docs) + + +def test_del_single(simple_index_with_docs): # noqa: F811 + index, docs = simple_index_with_docs + del index[docs[1].id] + + def pred(): + assert index.num_docs() == 9 + + assert_when_ready(pred) + + with pytest.raises(KeyError): + index[docs[1].id] + + +def test_del_multiple(simple_index_with_docs): # noqa: F811 + index, docs = simple_index_with_docs + + # get the odd documents + docs_to_del = [doc for i, doc in enumerate(docs) if i % 2 == 1] + + del index[[d.id for d in docs_to_del]] + for i, doc in enumerate(docs): + if i % 2 == 1: + with pytest.raises(KeyError): + index[doc.id] + else: + assert index[doc.id].id == doc.id + assert np.allclose(index[doc.id].embedding, doc.embedding) + + +def test_contains(simple_index_with_docs): # noqa: F811 + index, docs = simple_index_with_docs + + for doc in docs: + assert doc in index + + other_doc = SimpleSchema(embedding=[1.0] * N_DIM, text="other", number=10) + assert other_doc not in index diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py new file mode 100644 index 0000000000..62ff02348d --- /dev/null +++ b/tests/index/mongo_atlas/test_persist_data.py @@ -0,0 +1,46 @@ +from docarray.index import MongoDBAtlasDocumentIndex + +from . import SimpleSchema, assert_when_ready + + +def test_persist(mongodb_index_config, random_simple_documents): # noqa: F811 + index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + index._doc_collection.delete_many({}) + + def cleaned_database(): + assert index.num_docs() == 0 + + assert_when_ready(cleaned_database) + + index.index(random_simple_documents) + + def pred(): + # check if there are elements in the database and if the index is up to date. + assert index.num_docs() == len(random_simple_documents) + assert ( + len( + index.find( + random_simple_documents[0].embedding, + search_field='embedding', + limit=1, + ).documents + ) + > 0 + ) + + assert_when_ready(pred) + + doc_before = index.find( + random_simple_documents[0].embedding, search_field='embedding', limit=1 + ).documents[0] + del index + + index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + + doc_after = index.find( + random_simple_documents[0].embedding, search_field='embedding', limit=1 + ).documents[0] + + assert index.num_docs() == len(random_simple_documents) + assert doc_before.id == doc_after.id + assert (doc_before.embedding == doc_after.embedding).all() diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py new file mode 100644 index 0000000000..82f8744221 --- /dev/null +++ b/tests/index/mongo_atlas/test_subindex.py @@ -0,0 +1,267 @@ +from typing import Optional + +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc, DocList +from docarray.index import MongoDBAtlasDocumentIndex +from docarray.typing import NdArray +from docarray.typing.tensor import AnyTensor + +from . import assert_when_ready + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class MetaPathDoc(BaseDoc): + path_id: str + level: int + text: str + embedding: Optional[AnyTensor] = Field(space='cosine', dim=128) + + +class MetaCategoryDoc(BaseDoc): + node_id: Optional[str] + node_name: Optional[str] + name: Optional[str] + product_type_definitions: Optional[str] + leaf: bool + paths: Optional[DocList[MetaPathDoc]] + embedding: Optional[AnyTensor] = Field(space='cosine', dim=128) + channel: str + lang: str + + +class SimpleDoc(BaseDoc): + simple_tens: NdArray[10] = Field(index_name='vector_index') + simple_text: str + + +class ListDoc(BaseDoc): + docs: DocList[SimpleDoc] + simple_doc: SimpleDoc + list_tens: NdArray[20] = Field(space='l2') + + +class MyDoc(BaseDoc): + docs: DocList[SimpleDoc] + list_docs: DocList[ListDoc] + my_tens: NdArray[30] = Field(space='l2') + + +def clean_subindex(index): + for subindex in index._subindices.values(): + clean_subindex(subindex) + index._doc_collection.delete_many({}) + + +@pytest.fixture(scope='session') +def index(mongodb_index_config): # noqa: F811 + index = MongoDBAtlasDocumentIndex[MyDoc](**mongodb_index_config) + clean_subindex(index) + + my_docs = [ + MyDoc( + id=f'{i}', + docs=DocList[SimpleDoc]( + [ + SimpleDoc( + id=f'docs-{i}-{j}', + simple_tens=np.ones(10) * (j + 1), + simple_text=f'hello {j}', + ) + for j in range(2) + ] + ), + list_docs=DocList[ListDoc]( + [ + ListDoc( + id=f'list_docs-{i}-{j}', + docs=DocList[SimpleDoc]( + [ + SimpleDoc( + id=f'list_docs-docs-{i}-{j}-{k}', + simple_tens=np.ones(10) * (k + 1), + simple_text=f'hello {k}', + ) + for k in range(2) + ] + ), + simple_doc=SimpleDoc( + id=f'list_docs-simple_doc-{i}-{j}', + simple_tens=np.ones(10) * (j + 1), + simple_text=f'hello {j}', + ), + list_tens=np.ones(20) * (j + 1), + ) + for j in range(2) + ] + ), + my_tens=np.ones((30,)) * (i + 1), + ) + for i in range(2) + ] + + index.index(my_docs) + yield index + clean_subindex(index) + + +def test_subindex_init(index): + assert isinstance(index._subindices['docs'], MongoDBAtlasDocumentIndex) + assert isinstance(index._subindices['list_docs'], MongoDBAtlasDocumentIndex) + assert isinstance( + index._subindices['list_docs']._subindices['docs'], MongoDBAtlasDocumentIndex + ) + + +def test_subindex_index(index): + assert index.num_docs() == 2 + assert index._subindices['docs'].num_docs() == 4 + assert index._subindices['list_docs'].num_docs() == 4 + assert index._subindices['list_docs']._subindices['docs'].num_docs() == 8 + + +def test_subindex_get(index): + doc = index['1'] + assert isinstance(doc, MyDoc) + assert doc.id == '1' + + assert len(doc.docs) == 2 + assert isinstance(doc.docs[0], SimpleDoc) + for d in doc.docs: + i = int(d.id.split('-')[-1]) + assert d.id == f'docs-1-{i}' + assert np.allclose(d.simple_tens, np.ones(10) * (i + 1)) + + assert len(doc.list_docs) == 2 + assert isinstance(doc.list_docs[0], ListDoc) + assert set([d.id for d in doc.list_docs]) == set( + [f'list_docs-1-{i}' for i in range(2)] + ) + assert len(doc.list_docs[0].docs) == 2 + assert isinstance(doc.list_docs[0].docs[0], SimpleDoc) + i = int(doc.list_docs[0].docs[0].id.split('-')[-2]) + j = int(doc.list_docs[0].docs[0].id.split('-')[-1]) + assert doc.list_docs[0].docs[0].id == f'list_docs-docs-1-{i}-{j}' + assert np.allclose(doc.list_docs[0].docs[0].simple_tens, np.ones(10) * (j + 1)) + assert doc.list_docs[0].docs[0].simple_text == f'hello {j}' + assert isinstance(doc.list_docs[0].simple_doc, SimpleDoc) + assert doc.list_docs[0].simple_doc.id == f'list_docs-simple_doc-1-{i}' + assert np.allclose(doc.list_docs[0].simple_doc.simple_tens, np.ones(10) * (i + 1)) + assert doc.list_docs[0].simple_doc.simple_text == f'hello {i}' + assert np.allclose(doc.list_docs[0].list_tens, np.ones(20) * (i + 1)) + + assert np.allclose(doc.my_tens, np.ones(30) * 2) + + +def test_subindex_contain(index, mongodb_index_config): # noqa: F811 + # Checks for individual simple_docs within list_docs + + doc = index['0'] + for simple_doc in doc.list_docs: + assert index.subindex_contains(simple_doc) is True + for nested_doc in simple_doc.docs: + assert index.subindex_contains(nested_doc) is True + + invalid_doc = SimpleDoc( + id='non_existent', + simple_tens=np.zeros(10), + simple_text='invalid', + ) + assert index.subindex_contains(invalid_doc) is False + + # Checks for an empty doc + empty_doc = SimpleDoc( + id='', + simple_tens=np.zeros(10), + simple_text='', + ) + assert index.subindex_contains(empty_doc) is False + + # Empty index + empty_index = MongoDBAtlasDocumentIndex[MyDoc](**mongodb_index_config) + assert (empty_doc in empty_index) is False + + +def test_find_empty_subindex(index): + query = np.ones((30,)) + with pytest.raises(ValueError): + index.find_subindex(query, subindex='', search_field='my_tens', limit=5) + + +def test_find_subindex_sublevel(index): + query = np.ones((10,)) + + def pred(): + root_docs, docs, scores = index.find_subindex( + query, subindex='docs', search_field='simple_tens', limit=4 + ) + assert len(root_docs) == 4 + assert isinstance(root_docs[0], MyDoc) + assert isinstance(docs[0], SimpleDoc) + assert len(scores) == 4 + assert sum(score == 1.0 for score in scores) == 2 + + for root_doc, doc, score in zip(root_docs, docs, scores): + assert root_doc.id == f'{doc.id.split("-")[1]}' + + if score == 1.0: + assert np.allclose(doc.simple_tens, np.ones(10)) + else: + assert np.allclose(doc.simple_tens, np.ones(10) * 2) + + assert_when_ready(pred) + + +def test_find_subindex_subsublevel(index): + # sub sub level + def predicate(): + query = np.ones((10,)) + root_docs, docs, scores = index.find_subindex( + query, subindex='list_docs__docs', search_field='simple_tens', limit=2 + ) + assert len(docs) == 2 + assert isinstance(root_docs[0], MyDoc) + assert isinstance(docs[0], SimpleDoc) + for root_doc, doc, score in zip(root_docs, docs, scores): + assert np.allclose(doc.simple_tens, np.ones(10)) + assert root_doc.id == f'{doc.id.split("-")[2]}' + assert score == 1.0 + + assert_when_ready(predicate) + + +def test_subindex_filter(index): + def predicate(): + query = {"simple_doc__simple_text": {"$eq": "hello 1"}} + docs = index.filter_subindex(query, subindex='list_docs', limit=4) + assert len(docs) == 2 + assert isinstance(docs[0], ListDoc) + for doc in docs: + assert doc.id.split('-')[-1] == '1' + + query = {"simple_text": {"$eq": "hello 0"}} + docs = index.filter_subindex(query, subindex='list_docs__docs', limit=5) + assert len(docs) == 4 + assert isinstance(docs[0], SimpleDoc) + for doc in docs: + assert doc.id.split('-')[-1] == '0' + + assert_when_ready(predicate) + + +def test_subindex_del(index): + del index['0'] + assert index.num_docs() == 1 + assert index._subindices['docs'].num_docs() == 2 + assert index._subindices['list_docs'].num_docs() == 2 + assert index._subindices['list_docs']._subindices['docs'].num_docs() == 4 + + +def test_subindex_collections(mongodb_index_config): # noqa: F811 + doc_index = MongoDBAtlasDocumentIndex[MetaCategoryDoc](**mongodb_index_config) + + assert doc_index._subindices["paths"].index_name == 'metacategorydoc__paths' + assert doc_index._subindices["paths"]._collection == 'metacategorydoc__paths' diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py new file mode 100644 index 0000000000..cbc6db8058 --- /dev/null +++ b/tests/index/mongo_atlas/test_text_search.py @@ -0,0 +1,39 @@ +from . import assert_when_ready + + +def test_text_search(simple_index_with_docs): # noqa: F811 + simple_index, docs = simple_index_with_docs + + query_string = "Python is a valuable skill" + expected_text = docs[0].text + + def pred(): + docs, scores = simple_index.text_search( + query=query_string, search_field='text', limit=1 + ) + assert len(docs) == 1 + assert docs[0].text == expected_text + assert scores[0] > 0 + + assert_when_ready(pred) + + +def test_text_search_batched(simple_index_with_docs): # noqa: F811 + + index, docs = simple_index_with_docs + + queries = ['processing with Python', 'tips', 'for'] + + def pred(): + docs, scores = index.text_search_batched(queries, search_field='text', limit=5) + + assert len(docs) == 3 + assert len(docs[0]) == 1 + assert len(docs[1]) == 1 + assert len(docs[2]) == 2 + assert len(scores) == 3 + assert len(scores[0]) == 1 + assert len(scores[1]) == 1 + assert len(scores[2]) == 2 + + assert_when_ready(pred)