diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 37bab63aaa9..6996e768b45 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -187,7 +187,7 @@ jobs:
pytest --suppress-no-test-exit-code --cov=docarray --cov-report=xml \
-v -s -m "not gpu" ${{ matrix.test-path }}
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
- timeout-minutes: 45
+ timeout-minutes: 60
env:
JINA_AUTH_TOKEN: "${{ secrets.JINA_AUTH_TOKEN }}"
- name: Check codecov file
@@ -238,7 +238,7 @@ jobs:
pytest --suppress-no-test-exit-code --cov=docarray --cov-report=xml \
-v -s -m "not gpu" ${{ matrix.test-path }}
echo "::set-output name=codecov_flag::docarray"
- timeout-minutes: 40
+ timeout-minutes: 60
env:
JINA_AUTH_TOKEN: "${{ secrets.JINA_AUTH_TOKEN }}"
- name: Check codecov file
diff --git a/docarray/array/document.py b/docarray/array/document.py
index 7c535b25330..7ebb0880f89 100644
--- a/docarray/array/document.py
+++ b/docarray/array/document.py
@@ -12,11 +12,13 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic
from docarray.array.redis import DocumentArrayRedis
+ from docarray.array.milvus import DocumentArrayMilvus
from docarray.array.storage.sqlite import SqliteConfig
from docarray.array.storage.annlite import AnnliteConfig
from docarray.array.storage.weaviate import WeaviateConfig
from docarray.array.storage.elastic import ElasticConfig
from docarray.array.storage.redis import RedisConfig
+ from docarray.array.storage.milvus import MilvusConfig
class DocumentArray(AllMixins, BaseDocumentArray):
@@ -140,6 +142,16 @@ def __new__(
"""Create a Redis-powered DocumentArray object."""
...
+ @overload
+ def __new__(
+ cls,
+ _docs: Optional['DocumentArraySourceType'] = None,
+ storage: str = 'milvus',
+ config: Optional[Union['MilvusConfig', Dict]] = None,
+ ) -> 'DocumentArrayMilvus':
+ """Create a Milvus-powered DocumentArray object."""
+ ...
+
def __enter__(self):
self._exit_stack = ExitStack()
# Ensure that we sync the data to the storage backend when exiting the context manager
@@ -184,6 +196,10 @@ def __new__(cls, *args, storage: str = 'memory', **kwargs):
from .redis import DocumentArrayRedis
instance = super().__new__(DocumentArrayRedis)
+ elif storage == 'milvus':
+ from .milvus import DocumentArrayMilvus
+
+ instance = super().__new__(DocumentArrayMilvus)
else:
raise ValueError(f'storage=`{storage}` is not supported.')
diff --git a/docarray/array/milvus.py b/docarray/array/milvus.py
new file mode 100644
index 00000000000..d7924d86f4f
--- /dev/null
+++ b/docarray/array/milvus.py
@@ -0,0 +1,46 @@
+from .document import DocumentArray
+
+from .storage.milvus import StorageMixins, MilvusConfig
+
+__all__ = ['MilvusConfig', 'DocumentArrayMilvus']
+
+
+class DocumentArrayMilvus(StorageMixins, DocumentArray):
+ """
+ DocumentArray that stores Documents in a `Milvus `_ vector search engine.
+
+ .. note::
+ This DocumentArray requires `pymilvus`. You can install it via `pip install "docarray[milvus]"`.
+
+ To use Milvus as storage backend, a Milvus service needs to be running on your machine.
+
+ With this implementation, :meth:`match` and :meth:`find` perform fast (approximate) vector search.
+ Additionally, search with filters is supported.
+
+ Example usage:
+
+ .. code-block:: python
+
+ from docarray import DocumentArray
+
+ # connect to running Milvus service with default configuration (address: http://localhost:19530)
+ da = DocumentArray(storage='milvus', config={'n_dim': 10})
+
+ # connect to a previously persisted DocumentArrayMilvus by specifying collection_name, host, and port
+ da = DocumentArray(
+ storage='milvus',
+ config={
+ 'collection_name': 'persisted',
+ 'host': 'localhost',
+ 'port': '19530',
+ 'n_dim': 10,
+ },
+ )
+
+
+ .. seealso::
+ For further details, see our :ref:`user guide `.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ return super().__new__(cls)
diff --git a/docarray/array/mixins/find.py b/docarray/array/mixins/find.py
index b467e945445..1e0a58f765d 100644
--- a/docarray/array/mixins/find.py
+++ b/docarray/array/mixins/find.py
@@ -96,7 +96,7 @@ def find(
limit: Optional[Union[int, float]] = 20,
metric_name: Optional[str] = None,
exclude_self: bool = False,
- filter: Optional[Dict] = None,
+ filter: Union[Dict, str, None] = None,
only_id: bool = False,
index: str = 'text',
on: Optional[str] = None,
diff --git a/docarray/array/storage/base/seqlike.py b/docarray/array/storage/base/seqlike.py
index 73fa80a00b7..d5ef0ebc50c 100644
--- a/docarray/array/storage/base/seqlike.py
+++ b/docarray/array/storage/base/seqlike.py
@@ -14,13 +14,14 @@ def _update_subindices_append_extend(self, value):
if len(docs_selector) > 0:
da.extend(docs_selector)
- def insert(self, index: int, value: 'Document'):
+ def insert(self, index: int, value: 'Document', **kwargs):
"""Insert `doc` at `index`.
:param index: Position of the insertion.
:param value: The doc needs to be inserted.
+ :param kwargs: Additional Arguments that are passed to the Document Store. This has no effect for in-memory DocumentArray.
"""
- self._set_doc_by_id(value.id, value)
+ self._set_doc_by_id(value.id, value, **kwargs)
self._offset2ids.insert(index, value.id)
def append(self, value: 'Document', **kwargs):
diff --git a/docarray/array/storage/milvus/__init__.py b/docarray/array/storage/milvus/__init__.py
new file mode 100644
index 00000000000..b3e2894e607
--- /dev/null
+++ b/docarray/array/storage/milvus/__init__.py
@@ -0,0 +1,12 @@
+from abc import ABC
+
+from .backend import BackendMixin, MilvusConfig
+from .find import FindMixin
+from .getsetdel import GetSetDelMixin
+from .seqlike import SequenceLikeMixin
+
+__all__ = ['StorageMixins', 'MilvusConfig']
+
+
+class StorageMixins(FindMixin, BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC):
+ ...
diff --git a/docarray/array/storage/milvus/backend.py b/docarray/array/storage/milvus/backend.py
new file mode 100644
index 00000000000..35154ca3c66
--- /dev/null
+++ b/docarray/array/storage/milvus/backend.py
@@ -0,0 +1,360 @@
+import copy
+import uuid
+from typing import Optional, TYPE_CHECKING, Union, Dict, Iterable, List, Tuple
+from dataclasses import dataclass, field
+import re
+
+import numpy as np
+from pymilvus import (
+ connections,
+ Collection,
+ FieldSchema,
+ DataType,
+ CollectionSchema,
+ has_collection,
+ loading_progress,
+)
+
+from docarray import Document, DocumentArray
+from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap
+from docarray.helper import dataclass_from_dict, _safe_cast_int
+
+if TYPE_CHECKING:
+ from docarray.typing import (
+ DocumentArraySourceType,
+ )
+
+
+ID_VARCHAR_LEN = 1024
+SERIALIZED_VARCHAR_LEN = (
+ 65_535 # 65_535 is the maximum that Milvus allows for a VARCHAR field
+)
+COLUMN_VARCHAR_LEN = 1024
+OFFSET_VARCHAR_LEN = 1024
+
+
+def _always_true_expr(primary_key: str) -> str:
+ """
+ Returns a Milvus expression that is always true, thus allowing for the retrieval of all entries in a Collection
+ Assumes that the primary key is of type DataType.VARCHAR
+
+ :param primary_key: the name of the primary key
+ :return: a Milvus expression that is always true for that primary key
+ """
+ return f'({primary_key} in ["1"]) or ({primary_key} not in ["1"])'
+
+
+def _ids_to_milvus_expr(ids):
+ ids = ['"' + _id + '"' for _id in ids]
+ return '[' + ','.join(ids) + ']'
+
+
+def _batch_list(l: List, batch_size: int):
+ """Iterates over a list in batches of size batch_size"""
+ if batch_size < 1:
+ yield l
+ return
+ l_len = len(l)
+ for ndx in range(0, l_len, batch_size):
+ yield l[ndx : min(ndx + batch_size, l_len)]
+
+
+def _sanitize_collection_name(name):
+ """Removes all chars that are not allowed in a Milvus collection name.
+ Thus, it removes all chars that are not alphanumeric or an underscore.
+
+ :param name: the collection name to sanitize
+ :return: the sanitized collection name.
+ """
+ return ''.join(
+ re.findall('[a-zA-Z0-9_]', name)
+ ) # remove everything that is not a letter, number or underscore
+
+
+@dataclass
+class MilvusConfig:
+ n_dim: int
+ collection_name: str = None
+ host: str = 'localhost'
+ port: Optional[Union[str, int]] = 19530 # 19530 for gRPC, 9091 for HTTP
+ distance: str = 'IP' # metric_type in milvus
+ index_type: str = 'HNSW'
+ index_params: Dict = field(
+ default_factory=lambda: {
+ 'M': 4,
+ 'efConstruction': 200,
+ }
+ ) # passed to milvus at index creation time. The default assumes 'HNSW' index type
+ collection_config: Dict = field(
+ default_factory=dict
+ ) # passed to milvus at collection creation time
+ serialize_config: Dict = field(default_factory=dict)
+ consistency_level: str = 'Session'
+ batch_size: int = -1
+ columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None
+ list_like: bool = True
+
+
+class BackendMixin(BaseBackendMixin):
+
+ TYPE_MAP = {
+ 'str': TypeMap(type=DataType.VARCHAR, converter=str),
+ 'float': TypeMap(
+ type=DataType.DOUBLE, converter=float
+ ), # it doesn't like DataType.FLOAT type, perhaps because python floats are double precision?
+ 'double': TypeMap(type=DataType.DOUBLE, converter=float),
+ 'int': TypeMap(type=DataType.INT64, converter=_safe_cast_int),
+ 'bool': TypeMap(type=DataType.BOOL, converter=bool),
+ }
+
+ def _init_storage(
+ self,
+ _docs: Optional['DocumentArraySourceType'] = None,
+ config: Optional[Union[MilvusConfig, Dict]] = None,
+ **kwargs,
+ ):
+ config = copy.deepcopy(config)
+ if not config:
+ raise ValueError('Empty config is not allowed for Milvus storage')
+ elif isinstance(config, dict):
+ config = dataclass_from_dict(MilvusConfig, config)
+
+ if config.collection_name is None:
+ id = uuid.uuid4().hex
+ config.collection_name = 'docarray__' + id
+ self._list_like = config.list_like
+ self._config = config
+ self._config.columns = self._normalize_columns(self._config.columns)
+
+ self._connection_alias = f'docarray_{config.host}_{config.port}'
+ connections.connect(
+ alias=self._connection_alias, host=config.host, port=config.port
+ )
+
+ self._collection = self._create_or_reuse_collection()
+ self._offset2id_collection = self._create_or_reuse_offset2id_collection()
+ self._build_index()
+ super()._init_storage()
+
+ # To align with Sqlite behavior; if `docs` is not `None` and table name
+ # is provided, :class:`DocumentArraySqlite` will clear the existing
+ # table and load the given `docs`
+ if _docs is None:
+ return
+
+ self.clear()
+ if isinstance(_docs, Iterable):
+ self.extend(_docs)
+ else:
+ if isinstance(_docs, Document):
+ self.append(_docs)
+
+ def _create_or_reuse_collection(self):
+ if has_collection(self._config.collection_name, using=self._connection_alias):
+ return Collection(
+ self._config.collection_name, using=self._connection_alias
+ )
+
+ document_id = FieldSchema(
+ name='document_id',
+ dtype=DataType.VARCHAR,
+ max_length=ID_VARCHAR_LEN,
+ is_primary=True,
+ )
+ embedding = FieldSchema(
+ name='embedding', dtype=DataType.FLOAT_VECTOR, dim=self._config.n_dim
+ )
+ serialized = FieldSchema(
+ name='serialized', dtype=DataType.VARCHAR, max_length=SERIALIZED_VARCHAR_LEN
+ )
+
+ additional_columns = []
+ for col, coltype in self._config.columns.items():
+ mapped_type = self._map_type(coltype)
+ if mapped_type == DataType.VARCHAR:
+ field_ = FieldSchema(
+ name=col, dtype=mapped_type, max_length=COLUMN_VARCHAR_LEN
+ )
+ else:
+ field_ = FieldSchema(name=col, dtype=mapped_type)
+ additional_columns.append(field_)
+
+ schema = CollectionSchema(
+ fields=[document_id, embedding, serialized, *additional_columns],
+ description='DocumentArray collection schema',
+ )
+ return Collection(
+ name=self._config.collection_name,
+ schema=schema,
+ using=self._connection_alias,
+ **self._config.collection_config,
+ )
+
+ def _build_index(self):
+ index_params = {
+ 'metric_type': self._config.distance,
+ 'index_type': self._config.index_type,
+ 'params': self._config.index_params,
+ }
+ self._collection.create_index(field_name='embedding', index_params=index_params)
+
+ def _create_or_reuse_offset2id_collection(self):
+ if has_collection(
+ self._config.collection_name + '_offset2id', using=self._connection_alias
+ ):
+ return Collection(
+ self._config.collection_name + '_offset2id',
+ using=self._connection_alias,
+ )
+
+ document_id = FieldSchema(
+ name='document_id', dtype=DataType.VARCHAR, max_length=ID_VARCHAR_LEN
+ )
+ offset = FieldSchema(
+ name='offset',
+ dtype=DataType.VARCHAR,
+ max_length=OFFSET_VARCHAR_LEN,
+ is_primary=True,
+ )
+ dummy_vector = FieldSchema(
+ name='dummy_vector', dtype=DataType.FLOAT_VECTOR, dim=1
+ )
+
+ schema = CollectionSchema(
+ fields=[offset, document_id, dummy_vector],
+ description='offset2id for DocumentArray',
+ )
+
+ return Collection(
+ name=self._config.collection_name + '_offset2id',
+ schema=schema,
+ using=self._connection_alias,
+ # **self._config.collection_config, # we probably don't want to apply the same config here
+ )
+
+ def _ensure_unique_config(
+ self,
+ config_root: dict,
+ config_subindex: dict,
+ config_joined: dict,
+ subindex_name: str,
+ ) -> dict:
+ if 'collection_name' not in config_subindex:
+ config_joined['collection_name'] = _sanitize_collection_name(
+ config_joined['collection_name'] + '_subindex_' + subindex_name
+ )
+ return config_joined
+
+ def _doc_to_milvus_payload(self, doc):
+ return self._docs_to_milvus_payload([doc])
+
+ def _docs_to_milvus_payload(self, docs: 'Iterable[Document]'):
+ extra_columns = [
+ [self._map_column(doc.tags.get(col), col_type) for doc in docs]
+ for col, col_type in self._config.columns.items()
+ ]
+ return [
+ [doc.id for doc in docs],
+ [self._map_embedding(doc.embedding) for doc in docs],
+ [doc.to_base64(**self._config.serialize_config) for doc in docs],
+ *extra_columns,
+ ]
+
+ @staticmethod
+ def _docs_from_query_response(response):
+ return DocumentArray([Document.from_base64(d['serialized']) for d in response])
+
+ @staticmethod
+ def _docs_from_search_response(
+ responses,
+ ) -> 'List[DocumentArray]':
+ das = []
+ for r in responses:
+ das.append(
+ DocumentArray(
+ [Document.from_base64(hit.entity.get('serialized')) for hit in r]
+ )
+ )
+ return das
+
+ def _update_kwargs_from_config(self, field_to_update, **kwargs):
+ kwargs_field_value = kwargs.get(field_to_update, None)
+ config_field_value = getattr(self._config, field_to_update, None)
+
+ if (
+ kwargs_field_value is not None or config_field_value is None
+ ): # no need to update
+ return kwargs
+
+ kwargs[field_to_update] = config_field_value
+ return kwargs
+
+ def _map_embedding(self, embedding):
+ if embedding is not None:
+ from docarray.math.ndarray import to_numpy_array
+
+ embedding = to_numpy_array(embedding)
+
+ if embedding.ndim > 1:
+ embedding = np.asarray(embedding).squeeze()
+ else:
+ embedding = np.zeros(self._config.n_dim)
+ return embedding
+
+ def __getstate__(self):
+ d = dict(self.__dict__)
+ del d['_collection']
+ del d['_offset2id_collection']
+ return d
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+ connections.connect(
+ alias=self._connection_alias, host=self._config.host, port=self._config.port
+ )
+ self._collection = self._create_or_reuse_collection()
+ self._offset2id_collection = self._create_or_reuse_offset2id_collection()
+
+ def __enter__(self):
+ _ = super().__enter__()
+ self._collection.load()
+ self._offset2id_collection.load()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._collection.release()
+ self._offset2id_collection.release()
+ super().__exit__(exc_type, exc_val, exc_tb)
+
+ def loaded_collection(self, collection=None):
+ """
+ Context manager to load a collection and release it after the context is exited.
+ If the collection is already loaded when entering, it will not be released while exiting.
+
+ :param collection: the collection to load. If None, the main collection of this indexer is used.
+ :return: Context manager for the provided collection.
+ """
+
+ class LoadedCollectionManager:
+ def __init__(self, coll, connection_alias):
+ self._collection = coll
+ self._loaded_when_enter = False
+ self._connection_alias = connection_alias
+
+ def __enter__(self):
+ self._loaded_when_enter = (
+ loading_progress(
+ self._collection.name, using=self._connection_alias
+ )['loading_progress']
+ != '0%'
+ )
+ self._collection.load()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if not self._loaded_when_enter:
+ self._collection.release()
+
+ return LoadedCollectionManager(
+ collection if collection else self._collection, self._connection_alias
+ )
diff --git a/docarray/array/storage/milvus/find.py b/docarray/array/storage/milvus/find.py
new file mode 100644
index 00000000000..fa9f42709ad
--- /dev/null
+++ b/docarray/array/storage/milvus/find.py
@@ -0,0 +1,56 @@
+from typing import TYPE_CHECKING, TypeVar, List, Union, Optional, Dict, Sequence
+
+if TYPE_CHECKING:
+ import numpy as np
+ import tensorflow
+ import torch
+
+ # Define the expected input type that your ANN search supports
+ MilvusArrayType = TypeVar(
+ 'MilvusArrayType',
+ np.ndarray,
+ tensorflow.Tensor,
+ torch.Tensor,
+ Sequence[float],
+ )
+ from docarray import Document, DocumentArray
+
+
+class FindMixin:
+ def _find(
+ self,
+ query: 'MilvusArrayType',
+ limit: int = 10,
+ filter: Optional[Dict] = None,
+ param=None,
+ **kwargs
+ ) -> List['DocumentArray']:
+ """Returns `limit` approximate nearest neighbors given a batch of input queries.
+ If the query is a single query, should return a DocumentArray, otherwise a list of DocumentArrays containing
+ the closest Documents for each query.
+ """
+ if param is None:
+ param = dict()
+ kwargs = self._update_kwargs_from_config('consistency_level', **kwargs)
+ with self.loaded_collection():
+ results = self._collection.search(
+ data=query,
+ anns_field='embedding',
+ limit=limit,
+ expr=filter,
+ param=param,
+ output_fields=['serialized'],
+ **kwargs,
+ )
+ return self._docs_from_search_response(results)
+
+ def _filter(self, filter, limit=10, **kwargs):
+ kwargs = self._update_kwargs_from_config('consistency_level', **kwargs)
+ with self.loaded_collection():
+ results = self._collection.query(
+ expr=filter,
+ limit=limit,
+ output_fields=['serialized'],
+ **kwargs,
+ )
+ return self._docs_from_query_response(results)[:limit]
diff --git a/docarray/array/storage/milvus/getsetdel.py b/docarray/array/storage/milvus/getsetdel.py
new file mode 100644
index 00000000000..08757f0f715
--- /dev/null
+++ b/docarray/array/storage/milvus/getsetdel.py
@@ -0,0 +1,109 @@
+from typing import Iterable, Dict, TYPE_CHECKING
+
+import numpy as np
+
+from docarray import DocumentArray
+from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin
+from docarray.array.storage.base.helper import Offset2ID
+from docarray.array.storage.milvus.backend import (
+ _always_true_expr,
+ _ids_to_milvus_expr,
+ _batch_list,
+)
+
+if TYPE_CHECKING:
+ from docarray import Document, DocumentArray
+
+
+class GetSetDelMixin(BaseGetSetDelMixin):
+ def _get_doc_by_id(self, _id: str) -> 'Document':
+ # to be implemented
+ return self._get_docs_by_ids([_id])[0]
+
+ def _del_doc_by_id(self, _id: str):
+ # to be implemented
+ self._del_docs_by_ids([_id])
+
+ def _set_doc_by_id(self, _id: str, value: 'Document', **kwargs):
+ # to be implemented
+ self._set_docs_by_ids([_id], [value], None, **kwargs)
+
+ def _load_offset2ids(self):
+ if self._list_like:
+ collection = self._offset2id_collection
+ kwargs = self._update_kwargs_from_config('consistency_level', **dict())
+ with self.loaded_collection(collection):
+ res = collection.query(
+ expr=_always_true_expr('document_id'),
+ output_fields=['offset', 'document_id'],
+ **kwargs,
+ )
+ sorted_res = sorted(res, key=lambda k: int(k['offset']))
+ self._offset2ids = Offset2ID([r['document_id'] for r in sorted_res])
+
+ def _save_offset2ids(self):
+ if self._list_like:
+ # delete old entries
+ self._clear_offset2ids_milvus()
+ # insert current entries
+ ids = self._offset2ids.ids
+ if not ids:
+ return
+ offsets = [str(i) for i in range(len(ids))]
+ dummy_vectors = [np.zeros(1) for _ in range(len(ids))]
+ collection = self._offset2id_collection
+ collection.insert([offsets, ids, dummy_vectors])
+
+ def _get_docs_by_ids(self, ids: 'Iterable[str]', **kwargs) -> 'DocumentArray':
+ if not ids:
+ return DocumentArray()
+ ids = list(ids)
+ kwargs = self._update_kwargs_from_config('consistency_level', **kwargs)
+ kwargs = self._update_kwargs_from_config('batch_size', **kwargs)
+ with self.loaded_collection():
+ docs = DocumentArray()
+ for id_batch in _batch_list(ids, kwargs['batch_size']):
+ res = self._collection.query(
+ expr=f'document_id in {_ids_to_milvus_expr(id_batch)}',
+ output_fields=['serialized'],
+ **kwargs,
+ )
+ if not res:
+ raise KeyError(f'No documents found for ids {ids}')
+ docs.extend(self._docs_from_query_response(res))
+ # sort output docs according to input id sorting
+ id_to_index = {id_: i for i, id_ in enumerate(ids)}
+ return DocumentArray(sorted(docs, key=lambda d: id_to_index[d.id]))
+
+ def _del_docs_by_ids(self, ids: 'Iterable[str]', **kwargs) -> 'DocumentArray':
+ kwargs = self._update_kwargs_from_config('consistency_level', **kwargs)
+ kwargs = self._update_kwargs_from_config('batch_size', **kwargs)
+ for id_batch in _batch_list(list(ids), kwargs['batch_size']):
+ self._collection.delete(
+ expr=f'document_id in {_ids_to_milvus_expr(id_batch)}', **kwargs
+ )
+
+ def _set_docs_by_ids(
+ self, ids, docs: 'Iterable[Document]', mismatch_ids: 'Dict', **kwargs
+ ):
+ kwargs = self._update_kwargs_from_config('consistency_level', **kwargs)
+ kwargs = self._update_kwargs_from_config('batch_size', **kwargs)
+ # delete old entries
+ for id_batch in _batch_list(list(ids), kwargs['batch_size']):
+ self._collection.delete(
+ expr=f'document_id in {_ids_to_milvus_expr(id_batch)}',
+ **kwargs,
+ )
+ for docs_batch in _batch_list(list(docs), kwargs['batch_size']):
+ # insert new entries
+ payload = self._docs_to_milvus_payload(docs_batch)
+ self._collection.insert(payload, **kwargs)
+
+ def _clear_storage(self):
+ self._collection.drop()
+ self._create_or_reuse_collection()
+ self._clear_offset2ids_milvus()
+
+ def _clear_offset2ids_milvus(self):
+ self._offset2id_collection.drop()
+ self._create_or_reuse_offset2id_collection()
diff --git a/docarray/array/storage/milvus/seqlike.py b/docarray/array/storage/milvus/seqlike.py
new file mode 100644
index 00000000000..1711c5b8080
--- /dev/null
+++ b/docarray/array/storage/milvus/seqlike.py
@@ -0,0 +1,58 @@
+from typing import Iterable, Iterator, Union, TYPE_CHECKING
+from docarray.array.storage.base.seqlike import BaseSequenceLikeMixin
+from docarray.array.storage.milvus.backend import _batch_list
+from docarray import Document
+
+
+class SequenceLikeMixin(BaseSequenceLikeMixin):
+ def __eq__(self, other):
+ """Compare this object to the other, returns True if and only if other
+ as the same type as self and other have the same Milvus Collections for data and offset2id
+
+ :param other: the other object to check for equality
+ :return: `True` if other is equal to self
+ """
+ return (
+ type(self) is type(other)
+ and self._collection.name == other._collection.name
+ and self._offset2id_collection.name == other._offset2id_collection.name
+ and self._config == other._config
+ )
+
+ def __contains__(self, x: Union[str, 'Document']):
+ if isinstance(x, Document):
+ x = x.id
+ try:
+ self._get_doc_by_id(x)
+ return True
+ except:
+ return False
+
+ def __repr__(self):
+ return f''
+
+ def __add__(self, other: Union['Document', Iterable['Document']]):
+ if isinstance(other, Document):
+ self.append(other)
+ else:
+ self.extend(other)
+ return self
+
+ def insert(self, index: int, value: 'Document', **kwargs):
+ self._set_doc_by_id(value.id, value, **kwargs)
+ self._offset2ids.insert(index, value.id)
+
+ def _append(self, value: 'Document', **kwargs):
+ self._set_doc_by_id(value.id, value, **kwargs)
+ self._offset2ids.append(value.id)
+
+ def _extend(self, values: Iterable['Document'], **kwargs):
+ docs = list(values)
+ if not docs:
+ return
+ kwargs = self._update_kwargs_from_config('consistency_level', **kwargs)
+ kwargs = self._update_kwargs_from_config('batch_size', **kwargs)
+ for docs_batch in _batch_list(list(docs), kwargs['batch_size']):
+ payload = self._docs_to_milvus_payload(docs_batch)
+ self._collection.insert(payload, **kwargs)
+ self._offset2ids.extend([doc.id for doc in docs_batch])
diff --git a/docs/advanced/document-store/index.md b/docs/advanced/document-store/index.md
index ee71868fbab..c3d49770e0a 100644
--- a/docs/advanced/document-store/index.md
+++ b/docs/advanced/document-store/index.md
@@ -10,6 +10,7 @@ qdrant
elasticsearch
weaviate
redis
+milvus
extend
benchmark
```
@@ -168,6 +169,7 @@ DocArray supports multiple storage backends with different search features. The
| [`AnnLite`](./annlite.md) | `DocumentArray(storage='annlite')` | ✅ | ✅ | ✅ |
| [`ElasticSearch`](./elasticsearch.md) | `DocumentArray(storage='elasticsearch')` | ✅ | ✅ | ✅ |
| [`Redis`](./redis.md) | `DocumentArray(storage='redis')` | ✅ | ✅ | ✅ |
+| [`Milvus`](./milvus.md) | `DocumentArray(storage='milvus')` | ✅ | ✅ | ✅ |
The right backend choice depends on the scale of your data, the required performance and the desired ease of setup. For most use cases we recommend starting with [`AnnLite`](./annlite.md).
[**Check our One Million Scale Benchmark for more details**](./benchmark#conclusion).
@@ -354,6 +356,7 @@ array([[7., 7., 7.],
[4., 4., 4.]])
```
+(backend-context-mngr)=
## Persistence, mutations and context manager
Having DocumentArrays that are backed by a document store introduces an extra consideration into the way you think about DocumentArrays.
diff --git a/docs/advanced/document-store/milvus.md b/docs/advanced/document-store/milvus.md
new file mode 100644
index 00000000000..a93c9ccaed5
--- /dev/null
+++ b/docs/advanced/document-store/milvus.md
@@ -0,0 +1,496 @@
+(milvus)=
+# Milvus
+
+One can use [Milvus](https://milvus.io/) as the Document store for DocumentArray. It is useful when one wants to have faster Document retrieval on embeddings, i.e. `.match()`, `.find()`.
+
+````{tip}
+This feature requires `pymilvus`. You can install it via `pip install "docarray[milvus]"`.
+````
+
+## Usage
+
+### Start Milvus service
+
+To use Milvus as the storage backend, you need a running Milvus server. You can use the following `docker-compose.yml`
+to start a Milvus server:
+
+`````{dropdown} docker-compose.yml
+
+```yaml
+version: '3.5'
+
+services:
+ etcd:
+ container_name: milvus-etcd
+ image: quay.io/coreos/etcd:v3.5.0
+ environment:
+ - ETCD_AUTO_COMPACTION_MODE=revision
+ - ETCD_AUTO_COMPACTION_RETENTION=1000
+ - ETCD_QUOTA_BACKEND_BYTES=4294967296
+ - ETCD_SNAPSHOT_COUNT=50000
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
+ command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
+
+ minio:
+ container_name: milvus-minio
+ image: minio/minio:RELEASE.2022-03-17T06-34-49Z
+ environment:
+ MINIO_ACCESS_KEY: minioadmin
+ MINIO_SECRET_KEY: minioadmin
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
+ command: minio server /minio_data
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
+ interval: 30s
+ timeout: 20s
+ retries: 3
+
+ standalone:
+ container_name: milvus-standalone
+ image: milvusdb/milvus:v2.1.4
+ command: ["milvus", "run", "standalone"]
+ environment:
+ ETCD_ENDPOINTS: etcd:2379
+ MINIO_ADDRESS: minio:9000
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
+ ports:
+ - "19530:19530"
+ - "9091:9091"
+ depends_on:
+ - "etcd"
+ - "minio"
+
+networks:
+ default:
+ name: milvus
+```
+
+`````
+
+Then
+
+```bash
+docker-compose up
+```
+
+You can find more installation guidance in the [Milvus documentation](https://milvus.io/docs/v2.1.x/install_standalone-docker.md).
+
+### Create DocumentArray with Milvus backend
+
+Assuming the service is started using the default configuration (i.e. the server's gRPC address is `http://localhost:19530`), you can
+instantiate a DocumentArray with Milvus storage like so:
+
+```python
+from docarray import DocumentArray
+
+da = DocumentArray(storage='milvus', config={'n_dim': 10})
+```
+
+Here, `config` is configuration for the new Milvus collection,
+and `n_dim` is a mandatory field that specifies the dimensionality of stored embeddings.
+For more information about the Milvus `config`, refer to the {ref}`config `.
+
+To access a previously persisted DocumentArray, specify the `collection_name`, the `host`, and the `port`.
+
+
+```python
+from docarray import DocumentArray
+
+da = DocumentArray(
+ storage='milvus',
+ config={
+ 'collection_name': 'persisted',
+ 'host': 'localhost',
+ 'port': '19530',
+ 'n_dim': 10,
+ },
+)
+
+da.summary()
+```
+
+(milvus-config)=
+## Config
+
+The following configs can be set:
+
+| Name | Description | Default |
+|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------|
+| `n_dim` | Number of dimensions of embeddings to be stored and retrieved | **This is always required** |
+| `collection_name` | Qdrant collection name client | **Random collection name generated** |
+| `host` | Hostname of the Milvus server | 'localhost' |
+| `port` | Port of the Milvus server | 6333 |
+| `distance` | [Distance metric](https://milvus.io/docs/v2.1.x/metric.md) to be used during search. Can be 'IP', 'L2', 'JACCARD', 'TANIMOTO', 'HAMMING', 'SUPERSTRUCTURE' or 'SUBSTRUCTURE'. | 'IP' (inner product) |
+| `index_type` | Type of the (ANN) search index. Can be 'HNSW', 'FLAT', 'ANNOY', or one of multiple variants of IVF and RHNSW. Refer to the [list of supported index types](https://milvus.io/docs/v2.1.x/build_index.md#Prepare-index-parameter). | 'HNSW' |
+| `index_params` | A dictionary of parameters used for index building. The [allowed parameters](https://milvus.io/docs/v2.1.x/index.md) depend on the index type. | {'M': 4, 'efConstruction': 200} (assumes HNSW index) |
+| `collection_config` | Configuration for the Milvus collection. Passed as **kwargs during collection creation (`Collection(...)`). | {} |
+| `serialize_config` | [Serialization config of each Document](../../../fundamentals/document/serialization.md) | {} |
+ | `consistency_level` | [Consistency level](https://milvus.io/docs/v2.1.x/consistency.md#Consistency-levels) for Milvus database operations. Can be 'Session', 'Strong', 'Bounded' or 'Eventually'. | 'Session' |
+| `batch_size` | Default batch size for CRUD operations. | -1 (no batching) |
+| `columns` | Additional columns to be stored in the datbase, taken from Document `tags`. | None |
+| `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True |
+
+## Minimal example
+
+Download `docker-compose.yml`:
+
+
+```text
+wget https://github.com/milvus-io/milvus/releases/download/v2.1.4/milvus-standalone-docker-compose.yml -O docker-compose.yml
+```
+
+Install DocArray with Milvus and launch the Milvus server:
+
+
+```bash
+pip install -U docarray[milvus]
+docker-compose up
+```
+
+Create a DocumentArray with some random data:
+
+```python
+import numpy as np
+
+from docarray import DocumentArray
+
+N, D = 5, 128
+
+da = DocumentArray.empty(
+ N, storage='milvus', config={'n_dim': D, 'distance': 'IP'}
+) # init
+with da:
+ da.embeddings = np.random.random([N, D])
+```
+
+Perform an approximate nearest neighbor search:
+
+```
+print(da.find(np.random.random(D), limit=10))
+```
+Output:
+
+```bash
+
+```
+
+(milvus-filter)=
+## Vector search with filter
+
+Search with `.find` can be restricted by user-defined filters.
+
+Such filters can be constructed using the [filter expression language defined by Milvus](https://milvus.io/docs/v2.1.x/boolean.md).
+Filters operate on the `tags` of a Document, which are stored as `columns` in the Milvus database.
+
+
+### Example of `.find` with filtered vector search
+
+
+Consider Documents with embeddings `[0,0,0]` up to ` [9,9,9]` where the Document with embedding `[i,i,i]`
+has a tag `price` with value `i`. You can create such an example with the following code:
+
+```python
+from docarray import Document, DocumentArray
+import numpy as np
+
+n_dim = 3
+distance = 'L2'
+
+da = DocumentArray(
+ storage='milvus',
+ config={'n_dim': n_dim, 'columns': {'price': 'float'}, 'distance': distance},
+)
+
+print(f'\nDocumentArray distance: {distance}')
+
+with da:
+ da.extend(
+ [
+ Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i})
+ for i in range(10)
+ ]
+ )
+
+print('\nIndexed Prices:\n')
+for embedding, price in zip(da.embeddings, da[:, 'tags__price']):
+ print(f'\tembedding={embedding},\t price={price}')
+```
+
+Consider you want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that
+prices must follow a filter. As an example, retrieved Documents must have `price` value lower than
+or equal to `max_price`. You can express this information in Milvus using `filter = f'price <= {max_price}'`.
+
+Then you can implement and use the search with the proposed filter:
+
+```python
+max_price = 7
+n_limit = 4
+
+np_query = np.ones(n_dim) * 8
+print(f'\nQuery vector: \t{np_query}')
+
+filter = f'price <= {max_price}'
+results = da.find(np_query, filter=filter, limit=n_limit)
+
+print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n')
+for embedding, price in zip(results.embeddings, results[:, 'tags__price']):
+ print(f'\tembedding={embedding},\t price={price}')
+```
+
+This will print:
+
+```
+Query vector: [8. 8. 8.]
+
+Embeddings Nearest Neighbours with "price" at most 7:
+
+ embedding=[7. 7. 7.], price=7
+ embedding=[6. 6. 6.], price=6
+ embedding=[5. 5. 5.], price=5
+ embedding=[4. 4. 4.], price=4
+```
+### Example of `.find` with only a filter
+
+The following example shows how to use DocArray with Milvus Document Store in order to filter text documents.
+Consider Documents have the tag `price` with a value of `i`. You can create these with the following code:
+
+```python
+from docarray import Document, DocumentArray
+import numpy as np
+
+n_dim = 3
+
+da = DocumentArray(
+ storage='milvus',
+ config={'n_dim': n_dim, 'columns': {'price': 'float'}},
+)
+
+with da:
+ da.extend(
+ [
+ Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i})
+ for i in range(10)
+ ]
+ )
+
+print('\nIndexed Prices:\n')
+for embedding, price in zip(da.embeddings, da[:, 'tags__price']):
+ print(f'\tembedding={embedding},\t price={price}')
+```
+
+Suppose you want to filter results such that
+retrieved Documents must have a `price` value less than or equal to `max_price`. You can express
+this information in Milvus using `filter = f'price <= {max_price}'`.
+
+Then you can implement and use the search with the proposed filter:
+```python
+max_price = 7
+n_limit = 4
+
+filter = f'price <= {max_price}'
+results = da.find(filter=filter, limit=n_limit)
+
+print('\nPoints with "price" at most 7:\n')
+for embedding, price in zip(results.embeddings, results[:, 'tags__price']):
+ print(f'\tembedding={embedding},\t price={price}')
+```
+This prints:
+
+```text
+
+Points with "price" at most 7:
+
+ embedding=[6. 6. 6.], price=6
+ embedding=[7. 7. 7.], price=7
+ embedding=[1. 1. 1.], price=1
+ embedding=[2. 2. 2.], price=2
+```
+
+## Advanced options
+
+The Milvus Document Store allows the user to pass additional parameters to the Milvus server for all main operations.
+
+Currently, the main use cases for this are dynamic setting of a consistency level, and passing of search parameters.
+
+### Setting a consistency level
+
+By default, every operation on the Milvus Document Store is performed with a consistency level passed during intialization
+as part of the {ref}`config `.
+
+When performing a specific operation, you can override this default consistency level by passing a `consistency_level` parameter:
+
+```python
+from docarray import DocumentArray, Document
+import numpy as np
+
+da = DocumentArray(
+ storage='milvus',
+ config={'consistency_level': 'Session', 'n_dim': 3},
+)
+
+da.append(Document(tensor=np.random.rand(3))) # consistency level is 'Session'
+da.append(
+ Document(tensor=np.random.rand(3)), consistency_level='Strong'
+) # consistency level is 'Strong'
+```
+
+Currently, dynamically setting a consistency level is supported for the following operations:
+`.append()`, `.extend()`, `.find()`, and `.insert()`.
+
+### Setting a batch size
+
+You can configure your DocumentArray to, on every relevant operation, send Documents to the Milvus database in batches.
+This default `batch_size` can be specified in the DocumentArray {ref}`config `.
+
+If you do not specify a default batch size, no batching will be performed.
+
+
+When performing a specific operation, you can override this default batch size by passing a `batch_size` parameter:
+
+```python
+from docarray import DocumentArray, Document
+import numpy as np
+
+da = DocumentArray(
+ storage='milvus',
+ config={'batch_size': 100, 'n_dim': 3},
+)
+
+da.append(Document(tensor=np.random.rand(3))) # batch size is 100
+da.append(Document(tensor=np.random.rand(3)), batch_size=5) # batch size is 5
+```
+
+Currently, dynamically setting a consistency level is supported for the following operations:
+`.append()`, `.extend()`, and `.insert()`.
+
+### Passing search parameters
+
+In Milvus you can [pass parameters to the search operation](https://milvus.io/docs/v2.1.x/search.md#Conduct-a-vector-search) which [depend on the used index type](https://milvus.io/docs/v2.1.x/index.md).
+
+In DocumentArray, this ability is exposed through the `param` argument in the `~docarray.array.mixins.find` method:
+
+```python
+import numpy as np
+
+from docarray import DocumentArray
+
+N, D = 5, 128
+
+da = DocumentArray.empty(
+ N, storage='milvus', config={'n_dim': D, 'distance': 'IP'}
+) # init
+with da:
+ da.embeddings = np.random.random([N, D])
+
+da.find(
+ np.random.random(D), limit=10, param={"metric_type": "L2", "params": {"nprobe": 10}}
+)
+```
+
+
+(milvus-limitations)=
+## Known limitations of the Milvus Document Store
+
+The Milvus Document Store implements the entire DocumentArray API, but there are some limitations that you should be aware of.
+
+(milvus-collection-loading)=
+### Collection loading
+
+In Milvus, every search or query operation requires the index to be [loaded into memory](https://milvus.io/api-reference/pymilvus/v2.1.3/Collection/load().md).
+This includes simple Document access through DocArray.
+
+This loading operation can be costly, especially when performing multiple search or query operations in a row.
+
+To mitigate this, you should use the `with da:` context manager whenever you perform multiple reads, searches or queries
+on a Milvus DocumentArray.
+This context manager loads the index into memory only once, and releases it when the context is exited.
+
+```python
+from docarray import Document, DocumentArray
+import numpy as np
+
+da = DocumentArray(
+ [Document(id=f'r{i}', embedding=i * np.ones(3)) for i in range(10)],
+ storage='milvus',
+ config={'n_dim': 3},
+)
+
+with da:
+ # index is loaded into memory
+ for d in da:
+ pass
+# index is released from memory
+
+with da:
+ # index is loaded into memory
+ embs, texts = da.embeddings, da.texts
+# index is released from memory
+```
+
+The `with da:` context manager also {ref}`manages persistence of the list-like interface ` of a DocumentArray,
+which can introduce a small overhead when leaving the context.
+
+If you want to _only_ manage the loading and releasing behavior of your DocumentArray, you can use the `with da.loaded_collection()`
+context manager instead.
+In the example above it can be used as a drop-in replacement.
+
+Not using the `with da:` or `with da.loaded_collection()` context manager will return the same results for the same operations, but will incur significant performance penalties:
+
+````{dropdown} ⚠️ Bad code
+
+```python
+from docarray import Document, DocumentArray
+import numpy as np
+
+da = DocumentArray(
+ [Document(id=f'r{i}', embedding=i * np.ones(3)) for i in range(10)],
+ storage='milvus',
+ config={'n_dim': 3},
+)
+
+for d in da: # index is loaded and released at every iteration
+ pass
+
+embs, texts = (
+ da.embeddings,
+ da.texts,
+) # index is loaded and released for every Document in `da`
+```
+
+````
+
+### Storing large tensors outside of `embedding` field
+
+It is currently not possible to persist Documents with a large `.tensor` field.
+
+A suitable workaround for this is to remove a Document's tensor after computing its embedding and before adding it to the
+Document Store:
+
+```python
+from docarray import Document, DocumentArray
+
+da = DocumentArray(storage='milvus', config={'n_dim': 128})
+
+doc = Document(tensor=np.random.rand(224, 224))
+doc.embed(...)
+doc.tensor = None
+
+da.append(doc)
+```
+
+````{dropdown} Why does this limitation exist?
+By default, DocArray stores three columns in any Document Store: The Document ids, the Document embeddings and
+a serialized (Base64 encoded) representation of the Document itself.
+
+In Milvus, the the serialized Documents are stored in a column of type 'VARCHAR', which imposes a limit of allowed length
+per entry.
+If the Base64 encoded Document exceeds this limit - which is usually the case for Documents with large tensors - the
+Document cannot be stored.
+
+The Milvus team is currently working on a 'STRING' columm type that could solve this issue in the future.
+````
+
+
diff --git a/setup.py b/setup.py
index ae61f6a8a49..0835c266ff8 100644
--- a/setup.py
+++ b/setup.py
@@ -81,6 +81,9 @@
'redis': [
'redis>=4.3.0',
],
+ 'milvus': [
+ 'pymilvus>=2.1.0',
+ ],
'benchmark': [
'pandas',
'matplotlib',
@@ -111,7 +114,9 @@
'annlite',
'elasticsearch>=8.2.0',
'redis>=4.3.0',
+ 'pymilvus>=2.1.0',
'jina',
+ 'pytest-mock',
],
},
classifiers=[
diff --git a/tests/conftest.py b/tests/conftest.py
index 77686a21570..729117fa329 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,6 +9,9 @@
compose_yml = os.path.abspath(
os.path.join(cur_dir, 'unit', 'array', 'docker-compose.yml')
)
+milvus_compose_yml = os.path.abspath(
+ os.path.join(cur_dir, 'unit', 'array', 'milvus-docker-compose.yml')
+)
@pytest.fixture(autouse=True)
@@ -23,17 +26,61 @@ def start_storage():
f"docker-compose -f {compose_yml} --project-directory . up --build -d "
f"--remove-orphans"
)
- from elasticsearch import Elasticsearch
+ os.system(
+ f"docker-compose -f {milvus_compose_yml} --project-directory . up --build -d"
+ )
- es = Elasticsearch(hosts='http://localhost:9200/')
- while not es.ping():
- time.sleep(0.5)
+ _wait_for_es()
+ _wait_for_milvus()
yield
os.system(
f"docker-compose -f {compose_yml} --project-directory . down "
f"--remove-orphans"
)
+ os.system(
+ f"docker-compose -f {milvus_compose_yml} --project-directory . down "
+ f"--remove-orphans"
+ )
+
+
+def restart_milvus():
+ os.system(f"docker-compose -f {milvus_compose_yml} --project-directory . down")
+ os.system(
+ f"docker-compose -f {milvus_compose_yml} --project-directory . up --build -d"
+ )
+ _wait_for_milvus(restart_on_failure=False)
+
+
+def _wait_for_es():
+ from elasticsearch import Elasticsearch
+
+ es = Elasticsearch(hosts='http://localhost:9200/')
+ while not es.ping():
+ time.sleep(0.5)
+
+
+def _wait_for_milvus(restart_on_failure=True):
+ from pymilvus import connections, has_collection
+ from pymilvus.exceptions import MilvusUnavailableException, MilvusException
+
+ milvus_conn_alias = f'pytest_localhost_19530'
+ try:
+ connections.connect(alias=milvus_conn_alias, host='localhost', port=19530)
+ milvus_ready = False
+ while not milvus_ready:
+ try:
+ has_collection('ping', using=milvus_conn_alias)
+ milvus_ready = True
+ except MilvusUnavailableException:
+ # Milvus is not ready yet, just wait
+ time.sleep(0.5)
+ except MilvusException as e:
+ if e.code == 1 and restart_on_failure:
+ # something went wrong with the docker container, restart and retry once
+ restart_milvus()
+ else:
+ raise e
@pytest.fixture(scope='session')
diff --git a/tests/unit/array/docker-compose.yml b/tests/unit/array/docker-compose.yml
index 4e5cd5d30cd..0c384989043 100644
--- a/tests/unit/array/docker-compose.yml
+++ b/tests/unit/array/docker-compose.yml
@@ -31,6 +31,7 @@ services:
image: redislabs/redisearch:2.6.0
ports:
- "6379:6379"
+
networks:
elastic:
diff --git a/tests/unit/array/milvus-docker-compose.yml b/tests/unit/array/milvus-docker-compose.yml
new file mode 100644
index 00000000000..0e2fecfbb84
--- /dev/null
+++ b/tests/unit/array/milvus-docker-compose.yml
@@ -0,0 +1,44 @@
+version: "3.3"
+services:
+ etcd:
+ container_name: milvus-etcd
+ image: quay.io/coreos/etcd:v3.5.0
+ environment:
+ - ETCD_AUTO_COMPACTION_MODE=revision
+ - ETCD_AUTO_COMPACTION_RETENTION=1000
+ - ETCD_QUOTA_BACKEND_BYTES=4294967296
+ - ETCD_SNAPSHOT_COUNT=50000
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
+ command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
+
+ minio:
+ container_name: milvus-minio
+ image: minio/minio:RELEASE.2022-03-17T06-34-49Z
+ environment:
+ MINIO_ACCESS_KEY: minioadmin
+ MINIO_SECRET_KEY: minioadmin
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
+ command: minio server /minio_data
+ healthcheck:
+ test: [ "CMD", "curl", "-f", "http://localhost:9000/minio/health/live" ]
+ interval: 30s
+ timeout: 20s
+ retries: 3
+
+ standalone:
+ container_name: milvus-standalone
+ image: milvusdb/milvus:v2.1.4
+ command: [ "milvus", "run", "standalone" ]
+ environment:
+ ETCD_ENDPOINTS: etcd:2379
+ MINIO_ADDRESS: minio:9000
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
+ ports:
+ - "19530:19530"
+ - "9091:9091"
+ depends_on:
+ - "etcd"
+ - "minio"
\ No newline at end of file
diff --git a/tests/unit/array/mixins/oldproto/test_embed.py b/tests/unit/array/mixins/oldproto/test_embed.py
index e5c1762e925..9e42e622908 100644
--- a/tests/unit/array/mixins/oldproto/test_embed.py
+++ b/tests/unit/array/mixins/oldproto/test_embed.py
@@ -23,6 +23,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic
from docarray.array.redis import DocumentArrayRedis
+from docarray.array.milvus import DocumentArrayMilvus
random_embed_models = {
'keras': lambda: tf.keras.Sequential(
@@ -76,6 +77,7 @@
# DocumentArrayWeaviate, TODO: enable this
DocumentArrayElastic,
DocumentArrayRedis,
+ DocumentArrayMilvus,
],
)
@pytest.mark.parametrize('N', [2, 10])
@@ -97,36 +99,50 @@ def test_embedding_on_random_network(
DocumentArrayQdrant,
DocumentArrayElastic,
DocumentArrayRedis,
+ DocumentArrayMilvus,
]:
da = da_cls.empty(N, config={'n_dim': embedding_shape})
else:
da = da_cls.empty(N, config=None)
- da.tensors = np.random.random([N, *input_shape]).astype(np.float32)
- embed_model = random_embed_models[framework]()
- da.embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
-
- r = da.embeddings
- if hasattr(r, 'numpy'):
- r = r.numpy()
- embed1 = r.copy()
+ embed_model = random_embed_models[framework]()
+ if da_cls == DocumentArrayMilvus and len(input_shape) == 3:
+ input_shape = (3, 12, 12) # Milvus can't handle large tensors
+ if framework.startswith(
+ 'transformers'
+ ): # transformer model expects input shape (3, 224, 224), can't test with Milvus
+ return
+
+ with da: # to speed up milvus by loading the collection
+ da.tensors = np.random.random([N, *input_shape]).astype(np.float32)
+ da.embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
- # reset
- da.embeddings = np.random.random([N, embedding_shape]).astype(np.float32)
+ r = da.embeddings
+ if hasattr(r, 'numpy'):
+ r = r.numpy()
- # docs[a: b].embed is only supported for DocumentArrayInMemory
- if isinstance(da, DocumentArrayInMemory):
- # try it again, it should yield the same result
- da.embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
- np.testing.assert_array_almost_equal(da.embeddings, embed1)
+ embed1 = r.copy()
# reset
da.embeddings = np.random.random([N, embedding_shape]).astype(np.float32)
- # now do this one by one
- da[: int(N / 2)].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
- da[-int(N / 2) :].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
- np.testing.assert_array_almost_equal(da.embeddings, embed1)
+ # docs[a: b].embed is only supported for DocumentArrayInMemory
+ if isinstance(da, DocumentArrayInMemory):
+ # try it again, it should yield the same result
+ da.embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
+ np.testing.assert_array_almost_equal(da.embeddings, embed1)
+
+ # reset
+ da.embeddings = np.random.random([N, embedding_shape]).astype(np.float32)
+
+ # now do this one by one
+ da[: int(N / 2)].embed(
+ embed_model, batch_size=batch_size, to_numpy=to_numpy
+ )
+ da[-int(N / 2) :].embed(
+ embed_model, batch_size=batch_size, to_numpy=to_numpy
+ )
+ np.testing.assert_array_almost_equal(da.embeddings, embed1)
@pytest.fixture
diff --git a/tests/unit/array/mixins/oldproto/test_eval_class.py b/tests/unit/array/mixins/oldproto/test_eval_class.py
index 3eb0b79a3a2..32b6bc2ceef 100644
--- a/tests/unit/array/mixins/oldproto/test_eval_class.py
+++ b/tests/unit/array/mixins/oldproto/test_eval_class.py
@@ -22,6 +22,7 @@
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
@pytest.mark.parametrize(
@@ -61,6 +62,7 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs, storage, config, start_stor
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
def test_eval_mixin_perfect_match_multiple_metrics(storage, config, start_storage):
@@ -139,6 +141,7 @@ def test_eval_mixin_perfect_match_labeled(
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
@pytest.mark.parametrize(
@@ -203,6 +206,7 @@ def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score, label_tag):
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
@pytest.mark.parametrize(
@@ -247,6 +251,7 @@ def test_eval_mixin_zero_match(storage, config, metric_fn, start_storage, kwargs
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
def test_diff_len_should_raise(storage, config, start_storage):
@@ -269,6 +274,7 @@ def test_diff_len_should_raise(storage, config, start_storage):
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
def test_diff_hash_fun_should_raise(storage, config, start_storage):
@@ -291,6 +297,7 @@ def test_diff_hash_fun_should_raise(storage, config, start_storage):
('qdrant', {'n_dim': 3}),
('elasticsearch', {'n_dim': 3}),
('redis', {'n_dim': 3}),
+ ('milvus', {'n_dim': 3}),
],
)
def test_same_hash_same_len_fun_should_work(storage, config, start_storage):
@@ -320,6 +327,7 @@ def test_same_hash_same_len_fun_should_work(storage, config, start_storage):
('qdrant', {'n_dim': 3}),
('elasticsearch', {'n_dim': 3}),
('redis', {'n_dim': 3}),
+ ('milvus', {'n_dim': 3}),
],
)
def test_adding_noise(storage, config, start_storage):
@@ -356,6 +364,7 @@ def test_adding_noise(storage, config, start_storage):
('qdrant', {'n_dim': 128}),
('elasticsearch', {'n_dim': 128}),
('redis', {'n_dim': 128}),
+ ('milvus', {'n_dim': 128}),
],
)
@pytest.mark.parametrize(
@@ -397,6 +406,7 @@ def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs)
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
def test_empty_da_should_raise(storage, config, start_storage):
@@ -415,6 +425,7 @@ def test_empty_da_should_raise(storage, config, start_storage):
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
def test_missing_groundtruth_should_raise(storage, config, start_storage):
@@ -433,6 +444,7 @@ def test_missing_groundtruth_should_raise(storage, config, start_storage):
('qdrant', {'n_dim': 256}),
('elasticsearch', {'n_dim': 256}),
('redis', {'n_dim': 256}),
+ ('milvus', {'n_dim': 256}),
],
)
def test_useless_groundtruth_warning_should_raise(storage, config, start_storage):
diff --git a/tests/unit/array/mixins/oldproto/test_getset.py b/tests/unit/array/mixins/oldproto/test_getset.py
index 5cc8ef9cbc5..c7d9d1455a0 100644
--- a/tests/unit/array/mixins/oldproto/test_getset.py
+++ b/tests/unit/array/mixins/oldproto/test_getset.py
@@ -14,6 +14,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
from tests import random_docs
rand_array = np.random.random([10, 3])
@@ -44,6 +45,7 @@ def nested_docs():
('qdrant', {'n_dim': 3}),
('elasticsearch', {'n_dim': 3}),
('redis', {'n_dim': 3}),
+ ('milvus', {'n_dim': 3}),
],
)
@pytest.mark.parametrize(
@@ -70,6 +72,7 @@ def test_set_embeddings_multi_kind(array, storage, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_da_get_embeddings(docs, config, da_cls, start_storage):
@@ -78,8 +81,9 @@ def test_da_get_embeddings(docs, config, da_cls, start_storage):
else:
da = da_cls()
da.extend(docs)
- np.testing.assert_almost_equal(da._get_attributes('embedding'), da.embeddings)
- np.testing.assert_almost_equal(da[:, 'embedding'], da.embeddings)
+ with da:
+ np.testing.assert_almost_equal(da._get_attributes('embedding'), da.embeddings)
+ np.testing.assert_almost_equal(da[:, 'embedding'], da.embeddings)
@pytest.mark.parametrize(
@@ -92,6 +96,7 @@ def test_da_get_embeddings(docs, config, da_cls, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_embeddings_setter_da(docs, config, da_cls, start_storage):
@@ -102,7 +107,8 @@ def test_embeddings_setter_da(docs, config, da_cls, start_storage):
da.extend(docs)
emb = np.random.random((100, 10))
da[:, 'embedding'] = emb
- np.testing.assert_almost_equal(da.embeddings, emb)
+ with da:
+ np.testing.assert_almost_equal(da.embeddings, emb)
for x, doc in zip(emb, da):
np.testing.assert_almost_equal(x, doc.embedding)
@@ -110,7 +116,8 @@ def test_embeddings_setter_da(docs, config, da_cls, start_storage):
da[:, 'embedding'] = None
if hasattr(da, 'flush'):
da.flush()
- assert da.embeddings is None or not np.any(da.embeddings)
+ with da:
+ assert da.embeddings is None or not np.any(da.embeddings)
@pytest.mark.parametrize(
@@ -123,6 +130,7 @@ def test_embeddings_setter_da(docs, config, da_cls, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_embeddings_wrong_len(docs, config, da_cls, start_storage):
@@ -134,7 +142,8 @@ def test_embeddings_wrong_len(docs, config, da_cls, start_storage):
embeddings = np.ones((2, 10))
with pytest.raises(ValueError):
- da.embeddings = embeddings
+ with da:
+ da.embeddings = embeddings
@pytest.mark.parametrize(
@@ -147,6 +156,7 @@ def test_embeddings_wrong_len(docs, config, da_cls, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_tensors_getter_da(docs, config, da_cls, start_storage):
@@ -156,12 +166,13 @@ def test_tensors_getter_da(docs, config, da_cls, start_storage):
da = da_cls()
da.extend(docs)
tensors = np.random.random((100, 10, 10))
- da.tensors = tensors
- assert len(da) == 100
- np.testing.assert_almost_equal(da.tensors, tensors)
+ with da: # speed up milvus by loading collection
+ da.tensors = tensors
+ assert len(da) == 100
+ np.testing.assert_almost_equal(da.tensors, tensors)
- da.tensors = None
- assert da.tensors is None
+ da.tensors = None
+ assert da.tensors is None
@pytest.mark.parametrize(
@@ -174,6 +185,7 @@ def test_tensors_getter_da(docs, config, da_cls, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_texts_getter_da(docs, config, da_cls, start_storage):
@@ -182,22 +194,23 @@ def test_texts_getter_da(docs, config, da_cls, start_storage):
else:
da = da_cls()
da.extend(docs)
- assert len(da.texts) == 100
- assert da.texts == da[:, 'text']
- texts = ['text' for _ in range(100)]
- da.texts = texts
- assert da.texts == texts
+ with da: # speed up milvus by loading collection
+ assert len(da.texts) == 100
+ assert da.texts == da[:, 'text']
+ texts = ['text' for _ in range(100)]
+ da.texts = texts
+ assert da.texts == texts
- for x, doc in zip(texts, da):
- assert x == doc.text
+ for x, doc in zip(texts, da):
+ assert x == doc.text
- da.texts = None
- if hasattr(da, 'flush'):
- da.flush()
+ da.texts = None
+ if hasattr(da, 'flush'):
+ da.flush()
- # unfortunately protobuf does not distinguish None and '' on string
- # so non-set str field in Pb is ''
- assert set(da.texts) == set([''])
+ # unfortunately protobuf does not distinguish None and '' on string
+ # so non-set str field in Pb is ''
+ assert set(da.texts) == set([''])
@pytest.mark.parametrize(
@@ -210,6 +223,7 @@ def test_texts_getter_da(docs, config, da_cls, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_setter_by_sequences_in_selected_docs_da(docs, config, da_cls, start_storage):
@@ -248,6 +262,7 @@ def test_setter_by_sequences_in_selected_docs_da(docs, config, da_cls, start_sto
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_texts_wrong_len(docs, config, da_cls, start_storage):
@@ -259,7 +274,8 @@ def test_texts_wrong_len(docs, config, da_cls, start_storage):
texts = ['hello']
with pytest.raises(ValueError):
- da.texts = texts
+ with da:
+ da.texts = texts
@pytest.mark.parametrize(
@@ -272,6 +288,7 @@ def test_texts_wrong_len(docs, config, da_cls, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_tensors_wrong_len(docs, config, da_cls, start_storage):
@@ -283,7 +300,8 @@ def test_tensors_wrong_len(docs, config, da_cls, start_storage):
tensors = np.ones((2, 10, 10))
with pytest.raises(ValueError):
- da.tensors = tensors
+ with da: # speed up milvus by loading collection
+ da.tensors = tensors
@pytest.mark.parametrize(
@@ -296,6 +314,7 @@ def test_tensors_wrong_len(docs, config, da_cls, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_blobs_getter_setter(docs, da_cls, config, start_storage):
@@ -304,15 +323,16 @@ def test_blobs_getter_setter(docs, da_cls, config, start_storage):
else:
da = da_cls()
da.extend(docs)
- with pytest.raises(ValueError):
- da.blobs = [b'cc', b'bb', b'aa', b'dd']
+ with da: # speed up milvus by loading collection
+ with pytest.raises(ValueError):
+ da.blobs = [b'cc', b'bb', b'aa', b'dd']
- da.blobs = [b'aa'] * len(da)
- assert da.blobs == [b'aa'] * len(da)
+ da.blobs = [b'aa'] * len(da)
+ assert da.blobs == [b'aa'] * len(da)
- da.blobs = None
- if hasattr(da, 'flush'):
- da.flush()
+ da.blobs = None
+ if hasattr(da, 'flush'):
+ da.flush()
# unfortunately protobuf does not distinguish None and '' on string
# so non-set str field in Pb is ''
@@ -329,6 +349,7 @@ def test_blobs_getter_setter(docs, da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_ellipsis_getter(nested_docs, da_cls, config, start_storage):
@@ -353,6 +374,7 @@ def test_ellipsis_getter(nested_docs, da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_ellipsis_attribute_setter(nested_docs, da_cls, config, start_storage):
@@ -374,6 +396,7 @@ def test_ellipsis_attribute_setter(nested_docs, da_cls, config, start_storage):
(DocumentArrayWeaviate, WeaviateConfig(n_dim=6)),
(DocumentArrayElastic, ElasticConfig(n_dim=6)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=6)),
],
)
def test_zero_embeddings(da_cls, config, start_storage):
@@ -383,29 +406,30 @@ def test_zero_embeddings(da_cls, config, start_storage):
else:
da = da_cls.empty(10)
- # all zero, dense
- da[:, 'embedding'] = a
- np.testing.assert_almost_equal(da.embeddings, a)
- for d in da:
- assert d.embedding.shape == (6,)
-
- # all zero, sparse
- sp_a = scipy.sparse.coo_matrix(a)
- da[:, 'embedding'] = sp_a
- np.testing.assert_almost_equal(da.embeddings.todense(), sp_a.todense())
- for d in da:
- # scipy sparse row-vector can only be a (1, m) not squeezible
- assert d.embedding.shape == (1, 6)
-
- # near zero, sparse
- a = np.random.random([10, 6])
- a[a > 0.1] = 0
- sp_a = scipy.sparse.coo_matrix(a)
- da[:, 'embedding'] = sp_a
- np.testing.assert_almost_equal(da.embeddings.todense(), sp_a.todense())
- for d in da:
- # scipy sparse row-vector can only be a (1, m) not squeezible
- assert d.embedding.shape == (1, 6)
+ with da: # speed up milvus by loading collection
+ # all zero, dense
+ da[:, 'embedding'] = a
+ np.testing.assert_almost_equal(da.embeddings, a)
+ for d in da:
+ assert d.embedding.shape == (6,)
+
+ # all zero, sparse
+ sp_a = scipy.sparse.coo_matrix(a)
+ da[:, 'embedding'] = sp_a
+ np.testing.assert_almost_equal(da.embeddings.todense(), sp_a.todense())
+ for d in da:
+ # scipy sparse row-vector can only be a (1, m) not squeezible
+ assert d.embedding.shape == (1, 6)
+
+ # near zero, sparse
+ a = np.random.random([10, 6])
+ a[a > 0.1] = 0
+ sp_a = scipy.sparse.coo_matrix(a)
+ da[:, 'embedding'] = sp_a
+ np.testing.assert_almost_equal(da.embeddings.todense(), sp_a.todense())
+ for d in da:
+ # scipy sparse row-vector can only be a (1, m) not squeezible
+ assert d.embedding.shape == (1, 6)
def embeddings_eq(emb1, emb2):
@@ -426,6 +450,7 @@ def embeddings_eq(emb1, emb2):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
def test_getset_subindex(storage, config):
@@ -509,6 +534,7 @@ def test_getset_subindex(storage, config):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
def test_init_subindex(storage, config):
@@ -549,6 +575,7 @@ def test_init_subindex(storage, config):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
def test_set_on_subindex(storage, config):
@@ -566,13 +593,15 @@ def test_set_on_subindex(storage, config):
embeddings_to_assign = np.random.random((5 * 3, 2))
with da:
da['@c'].embeddings = embeddings_to_assign
- assert (da['@c'].embeddings == embeddings_to_assign).all()
- assert (da._subindices['@c'].embeddings == embeddings_to_assign).all()
+ with da:
+ assert (da['@c'].embeddings == embeddings_to_assign).all()
+ assert (da._subindices['@c'].embeddings == embeddings_to_assign).all()
with da:
da['@c'].texts = ['hello' for _ in range(5 * 3)]
- assert da['@c'].texts == ['hello' for _ in range(5 * 3)]
- assert da._subindices['@c'].texts == ['hello' for _ in range(5 * 3)]
+ with da:
+ assert da['@c'].texts == ['hello' for _ in range(5 * 3)]
+ assert da._subindices['@c'].texts == ['hello' for _ in range(5 * 3)]
matches = da.find(query=np.random.random(2), on='@c')
assert matches
diff --git a/tests/unit/array/mixins/oldproto/test_match.py b/tests/unit/array/mixins/oldproto/test_match.py
index d7b3811a26a..0fe517da20b 100644
--- a/tests/unit/array/mixins/oldproto/test_match.py
+++ b/tests/unit/array/mixins/oldproto/test_match.py
@@ -77,6 +77,7 @@ def doc_lists_to_doc_arrays(doc_lists, *args, **kwargs):
('qdrant', {'n_dim': 3}),
('weaviate', {'n_dim': 3}),
('redis', {'n_dim': 3}),
+ ('milvus', {'n_dim': 3}),
],
)
@pytest.mark.parametrize('limit', [1, 2, 3])
@@ -777,6 +778,7 @@ def embeddings_eq(emb1, emb2):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
def test_match_subindex(storage, config):
diff --git a/tests/unit/array/mixins/test_content.py b/tests/unit/array/mixins/test_content.py
index ea4535c9d00..362d0c488e9 100644
--- a/tests/unit/array/mixins/test_content.py
+++ b/tests/unit/array/mixins/test_content.py
@@ -10,6 +10,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
@pytest.mark.parametrize(
@@ -22,6 +23,7 @@
DocumentArrayQdrant,
DocumentArrayElastic,
DocumentArrayRedis,
+ DocumentArrayMilvus,
],
)
@pytest.mark.parametrize(
@@ -34,6 +36,7 @@ def test_content_empty_getter_return_none(cls, content_attr, start_storage):
DocumentArrayQdrant,
DocumentArrayElastic,
DocumentArrayRedis,
+ DocumentArrayMilvus,
]:
da = cls(config={'n_dim': 3})
else:
@@ -51,6 +54,7 @@ def test_content_empty_getter_return_none(cls, content_attr, start_storage):
DocumentArrayQdrant,
DocumentArrayElastic,
DocumentArrayRedis,
+ DocumentArrayMilvus,
],
)
@pytest.mark.parametrize(
@@ -70,6 +74,7 @@ def test_content_empty_setter(cls, content_attr, start_storage):
DocumentArrayQdrant,
DocumentArrayElastic,
DocumentArrayRedis,
+ DocumentArrayMilvus,
]:
da = cls(config={'n_dim': 3})
else:
@@ -88,6 +93,7 @@ def test_content_empty_setter(cls, content_attr, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
@pytest.mark.parametrize(
@@ -123,6 +129,7 @@ def test_content_getter_setter(cls, content_attr, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_content_empty(da_len, da_cls, config, start_storage):
@@ -161,6 +168,7 @@ def test_content_empty(da_len, da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=5)),
(DocumentArrayElastic, ElasticConfig(n_dim=5)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=5)),
],
)
def test_embeddings_setter(da_len, da_cls, config, start_storage):
diff --git a/tests/unit/array/mixins/test_del.py b/tests/unit/array/mixins/test_del.py
index 610ca99140b..aee552af9b4 100644
--- a/tests/unit/array/mixins/test_del.py
+++ b/tests/unit/array/mixins/test_del.py
@@ -119,9 +119,10 @@ def test_del_da_attribute():
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
-def test_del_subindex(storage, config):
+def test_del_subindex(storage, config, start_storage):
n_dim = 3
subindex_configs = (
diff --git a/tests/unit/array/mixins/test_empty.py b/tests/unit/array/mixins/test_empty.py
index 0ba3da06e93..c92937cf80d 100644
--- a/tests/unit/array/mixins/test_empty.py
+++ b/tests/unit/array/mixins/test_empty.py
@@ -9,6 +9,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
@pytest.mark.parametrize(
@@ -21,6 +22,7 @@
(DocumentArrayQdrant, QdrantConfig(n_dim=5)),
(DocumentArrayElastic, ElasticConfig(n_dim=5)),
(DocumentArrayRedis, RedisConfig(n_dim=5)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=5)),
],
)
def test_empty_non_zero(da_cls, config, start_storage):
diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py
index 0c3dd165923..47a66b44d76 100644
--- a/tests/unit/array/mixins/test_find.py
+++ b/tests/unit/array/mixins/test_find.py
@@ -34,6 +34,7 @@ def inv_cosine(*args):
('qdrant', {'n_dim': 32}),
('elasticsearch', {'n_dim': 32}),
('redis', {'n_dim': 32}),
+ ('milvus', {'n_dim': 32}),
],
)
@pytest.mark.parametrize('limit', [1, 5, 10])
@@ -340,6 +341,16 @@ def test_find_by_tag(storage, config, start_storage):
}
+numeric_operators_milvus = {
+ '>=': operator.ge,
+ '>': operator.gt,
+ '<=': operator.le,
+ '<': operator.lt,
+ '==': operator.eq,
+ '!=': operator.ne,
+}
+
+
@pytest.mark.parametrize(
'storage,filter_gen,numeric_operators,operator',
[
@@ -457,6 +468,15 @@ def test_find_by_tag(storage, config, start_storage):
'ne',
),
],
+ *[
+ (
+ 'milvus',
+ lambda operator, threshold: f'price {operator} {threshold}',
+ numeric_operators_milvus,
+ operator,
+ )
+ for operator in numeric_operators_milvus.keys()
+ ],
],
)
@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])
@@ -584,6 +604,15 @@ def test_search_pre_filtering(
'ne',
),
],
+ *[
+ (
+ 'milvus',
+ lambda operator, threshold: f'price {operator} {threshold}',
+ numeric_operators_milvus,
+ operator,
+ )
+ for operator in numeric_operators_milvus.keys()
+ ],
],
)
@pytest.mark.parametrize('columns', [[('price', 'float')], {'price': 'float'}])
@@ -823,14 +852,22 @@ def test_elastic_id_filter(storage, config, limit):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
-def test_find_subindex(storage, config):
+def test_find_subindex(storage, config, start_storage):
n_dim = 3
subindex_configs = {'@c': None}
if storage == 'sqlite':
subindex_configs['@c'] = dict()
- elif storage in ['weaviate', 'annlite', 'qdrant', 'elasticsearch', 'redis']:
+ elif storage in [
+ 'weaviate',
+ 'annlite',
+ 'qdrant',
+ 'elasticsearch',
+ 'redis',
+ 'milvus',
+ ]:
subindex_configs['@c'] = {'n_dim': 2}
da = DocumentArray(
@@ -878,9 +915,10 @@ def test_find_subindex(storage, config):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
-def test_find_subindex_multimodal(storage, config):
+def test_find_subindex_multimodal(storage, config, start_storage):
from docarray import dataclass
from docarray.typing import Text
diff --git a/tests/unit/array/mixins/test_io.py b/tests/unit/array/mixins/test_io.py
index 153237deb4f..f8a9bfc32d2 100644
--- a/tests/unit/array/mixins/test_io.py
+++ b/tests/unit/array/mixins/test_io.py
@@ -14,6 +14,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
from docarray.helper import random_identity
from tests import random_docs
@@ -36,6 +37,7 @@ def docs():
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=10)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=10)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=10)),
],
)
def test_document_save_load(
@@ -79,6 +81,7 @@ def test_document_save_load(
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=10)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=10)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=10)),
],
)
def test_da_csv_write(docs, flatten_tags, tmp_path, da_cls, config, start_storage):
@@ -99,6 +102,7 @@ def test_da_csv_write(docs, flatten_tags, tmp_path, da_cls, config, start_storag
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=256)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=256)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=256)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=256)),
],
)
def test_from_ndarray(da_cls, config, start_storage):
@@ -117,6 +121,7 @@ def test_from_ndarray(da_cls, config, start_storage):
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=256)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=256)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=256)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=256)),
],
)
def test_from_files(da_cls, config, start_storage):
@@ -158,6 +163,7 @@ def test_from_files_exclude():
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=256)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=256)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=256)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=256)),
],
)
def test_from_ndjson(da_cls, config, start_storage):
@@ -176,6 +182,7 @@ def test_from_ndjson(da_cls, config, start_storage):
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=3)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=3)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=3)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=3)),
],
)
def test_from_to_pd_dataframe(da_cls, config, start_storage):
@@ -205,6 +212,7 @@ def test_from_to_pd_dataframe(da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=3)),
(DocumentArrayElastic, ElasticConfig(n_dim=3)),
(DocumentArrayRedis, RedisConfig(n_dim=3)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=3)),
],
)
def test_from_to_bytes(da_cls, config, start_storage):
@@ -231,7 +239,7 @@ def test_from_to_bytes(da_cls, config, start_storage):
assert da2.tensors == [[1, 2], [2, 1]]
import numpy as np
- np.testing.assert_array_equal(da2.embeddings, [[1, 2, 3], [4, 5, 6]])
+ np.testing.assert_array_equal(da2[:, 'embedding'], [[1, 2, 3], [4, 5, 6]])
# assert da2.embeddings == [[1, 2, 3], [4, 5, 6]]
assert da2[0].tags == {'hello': 'world'}
assert da2[1].tags == {}
@@ -248,6 +256,7 @@ def test_from_to_bytes(da_cls, config, start_storage):
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=256)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=256)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=256)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=256)),
],
)
def test_push_pull_io(da_cls, config, show_progress, start_storage):
@@ -264,7 +273,7 @@ def test_push_pull_io(da_cls, config, show_progress, start_storage):
da2 = da_cls.pull(name, show_progress=show_progress, config=config())
assert len(da1) == len(da2) == 10
- assert da1.texts == da2.texts == random_texts
+ assert da1[:, 'text'] == da2[:, 'text'] == random_texts
all_names = DocumentArray.cloud_list()
@@ -291,6 +300,7 @@ def test_push_pull_io(da_cls, config, show_progress, start_storage):
# (DocumentArrayQdrant, QdrantConfig(n_dim=3)),
# (DocumentArrayElastic, ElasticConfig(n_dim=3)), # Elastic needs config
# (DocumentArrayRedis, RedisConfig(n_dim=3)), # Redis needs config
+ # (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=3)),
],
)
def test_from_to_base64(protocol, compress, da_cls, config):
diff --git a/tests/unit/array/mixins/test_magic.py b/tests/unit/array/mixins/test_magic.py
index 104c3139b27..bd8f813813e 100644
--- a/tests/unit/array/mixins/test_magic.py
+++ b/tests/unit/array/mixins/test_magic.py
@@ -9,6 +9,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
N = 100
@@ -34,6 +35,7 @@ def docs():
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=1)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_iter_len_bool(da_cls, config, start_storage):
@@ -61,6 +63,7 @@ def test_iter_len_bool(da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_repr(da_cls, config, start_storage):
@@ -81,6 +84,7 @@ def test_repr(da_cls, config, start_storage):
('qdrant', QdrantConfig(n_dim=128)),
('elasticsearch', ElasticConfig(n_dim=128)),
('redis', RedisConfig(n_dim=128)),
+ ('milvus', MilvusConfig(n_dim=128)),
],
)
def test_repr_str(docs, storage, config, start_storage):
@@ -105,6 +109,7 @@ def test_repr_str(docs, storage, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_iadd(da_cls, config, start_storage):
diff --git a/tests/unit/array/mixins/test_parallel.py b/tests/unit/array/mixins/test_parallel.py
index 22ce0a78e3a..f919e0af70e 100644
--- a/tests/unit/array/mixins/test_parallel.py
+++ b/tests/unit/array/mixins/test_parallel.py
@@ -13,6 +13,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
def foo(d: Document):
@@ -54,6 +55,7 @@ def test_parallel_map_apply_external_pool(pytestconfig, pool):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
@pytest.mark.parametrize('backend', ['process', 'thread'])
@@ -111,6 +113,7 @@ def test_parallel_map(
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
@pytest.mark.parametrize('backend', ['thread'])
@@ -183,6 +186,7 @@ def test_parallel_map_batch(
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_map_lambda(pytestconfig, da_cls, config, start_storage):
@@ -212,6 +216,7 @@ def test_map_lambda(pytestconfig, da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
(DocumentArrayRedis, RedisConfig(n_dim=10)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=10)),
],
)
def test_apply_partial(pytestconfig, da_cls, config, start_storage):
@@ -242,6 +247,7 @@ def test_apply_partial(pytestconfig, da_cls, config, start_storage):
('qdrant', QdrantConfig(n_dim=256)),
('elasticsearch', ElasticConfig(n_dim=256)),
('redis', RedisConfig(n_dim=256)),
+ ('milvus', MilvusConfig(n_dim=256)),
],
)
@pytest.mark.parametrize('backend', ['thread', 'process'])
diff --git a/tests/unit/array/mixins/test_plot.py b/tests/unit/array/mixins/test_plot.py
index 818e8f7f9d5..366c33bb0b3 100644
--- a/tests/unit/array/mixins/test_plot.py
+++ b/tests/unit/array/mixins/test_plot.py
@@ -15,6 +15,7 @@
from docarray.array.storage.annlite import AnnliteConfig
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
@pytest.mark.parametrize('keep_aspect_ratio', [True, False])
@@ -29,6 +30,7 @@
(DocumentArrayQdrant, QdrantConfig(n_dim=128, scroll_batch_size=8)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ # (DocumentArrayMilvus, MilvusConfig(n_dim=128)), # tensor is too large to handle
],
)
def test_sprite_fail_tensor_success_uri(
@@ -68,6 +70,7 @@ def test_sprite_fail_tensor_success_uri(
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=128, scroll_batch_size=8)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=128)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=128)),
+ # (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=128)),
],
)
@pytest.mark.parametrize('canvas_size', [50, 512])
@@ -118,6 +121,7 @@ def da_and_dam(start_storage):
(DocumentArrayAnnlite, {'config': {'n_dim': 3}}),
(DocumentArrayQdrant, {'config': {'n_dim': 3}}),
(DocumentArrayRedis, {'config': {'n_dim': 3}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 3}}),
]
]
@@ -135,7 +139,8 @@ def test_plot_sprites(tmpdir):
def _test_plot_embeddings(da):
- p = da.plot_embeddings(start_server=False)
+ with da:
+ p = da.plot_embeddings(start_server=False)
assert os.path.exists(p)
assert os.path.exists(os.path.join(p, 'config.json'))
with open(os.path.join(p, 'config.json')) as fp:
@@ -154,6 +159,7 @@ def _test_plot_embeddings(da):
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=5)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=5)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=5)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=5)),
],
)
def test_plot_embeddings_same_path(tmpdir, da_cls, config_gen, start_storage):
@@ -163,10 +169,12 @@ def test_plot_embeddings_same_path(tmpdir, da_cls, config_gen, start_storage):
else:
da1 = da_cls.empty(100)
da2 = da_cls.empty(768)
- da1.embeddings = np.random.random([100, 5])
- p1 = da1.plot_embeddings(start_server=False, path=tmpdir)
- da2.embeddings = np.random.random([768, 5])
- p2 = da2.plot_embeddings(start_server=False, path=tmpdir)
+ with da1:
+ da1.embeddings = np.random.random([100, 5])
+ p1 = da1.plot_embeddings(start_server=False, path=tmpdir)
+ with da2:
+ da2.embeddings = np.random.random([768, 5])
+ p2 = da2.plot_embeddings(start_server=False, path=tmpdir)
assert p1 == p2
assert os.path.exists(p1)
with open(os.path.join(p1, 'config.json')) as fp:
@@ -184,6 +192,7 @@ def test_plot_embeddings_same_path(tmpdir, da_cls, config_gen, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_summary_homo_hetero(da_cls, config, start_storage):
@@ -191,14 +200,16 @@ def test_summary_homo_hetero(da_cls, config, start_storage):
da = da_cls.empty(100, config=config)
else:
da = da_cls.empty(100)
- da._get_attributes()
- da.summary()
- da._get_raw_summary()
+ with da:
+ da._get_attributes()
+ da.summary()
+ da._get_raw_summary()
da[0].pop('id')
- da.summary()
+ with da:
+ da.summary()
- da._get_raw_summary()
+ da._get_raw_summary()
@pytest.mark.parametrize(
@@ -211,6 +222,7 @@ def test_summary_homo_hetero(da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_empty_get_attributes(da_cls, config, start_storage):
diff --git a/tests/unit/array/mixins/test_sample.py b/tests/unit/array/mixins/test_sample.py
index b4d1b6b4d17..fa75064fdb7 100644
--- a/tests/unit/array/mixins/test_sample.py
+++ b/tests/unit/array/mixins/test_sample.py
@@ -9,6 +9,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
@pytest.mark.parametrize(
@@ -21,6 +22,7 @@
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_sample(da_cls, config, start_storage):
@@ -47,6 +49,7 @@ def test_sample(da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_sample_with_seed(da_cls, config, start_storage):
@@ -72,6 +75,7 @@ def test_sample_with_seed(da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_shuffle(da_cls, config, start_storage):
@@ -98,6 +102,7 @@ def test_shuffle(da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_shuffle_with_seed(da_cls, config, start_storage):
diff --git a/tests/unit/array/mixins/test_text.py b/tests/unit/array/mixins/test_text.py
index 0f7481a7e0d..9d5e42ac2b3 100644
--- a/tests/unit/array/mixins/test_text.py
+++ b/tests/unit/array/mixins/test_text.py
@@ -10,6 +10,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
@pytest.fixture(scope='function')
@@ -32,6 +33,7 @@ def docs():
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_da_vocabulary(da_cls, config, docs, min_freq, start_storage):
@@ -61,6 +63,7 @@ def test_da_vocabulary(da_cls, config, docs, min_freq, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_da_text_to_tensor_non_max_len(docs, da_cls, config, start_storage):
@@ -90,6 +93,7 @@ def test_da_text_to_tensor_non_max_len(docs, da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_da_text_to_tensor_max_len_3(docs, da_cls, config, start_storage):
@@ -121,6 +125,7 @@ def test_da_text_to_tensor_max_len_3(docs, da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_da_text_to_tensor_max_len_1(docs, da_cls, config, start_storage):
@@ -152,6 +157,7 @@ def test_da_text_to_tensor_max_len_1(docs, da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_convert_text_tensor_random_text(da_cls, docs, config, start_storage):
diff --git a/tests/unit/array/mixins/test_traverse.py b/tests/unit/array/mixins/test_traverse.py
index 9dad5475bcc..c68c000e02c 100644
--- a/tests/unit/array/mixins/test_traverse.py
+++ b/tests/unit/array/mixins/test_traverse.py
@@ -10,6 +10,7 @@
from docarray.array.annlite import DocumentArrayAnnlite
from docarray.array.elastic import DocumentArrayElastic
from docarray.array.redis import DocumentArrayRedis
+from docarray.array.milvus import DocumentArrayMilvus
from tests import random_docs
# some random prime number for sanity check
@@ -44,11 +45,13 @@ def doc_req():
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_type(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = doc_req.traverse('r', filter_fn=filter_fn)
+ with doc_req: # speed up milvus by loading collection
+ ds = doc_req.traverse('r', filter_fn=filter_fn)
assert isinstance(ds, types.GeneratorType)
assert isinstance(list(ds)[0], DocumentArray)
@@ -64,11 +67,13 @@ def test_traverse_type(doc_req, filter_fn, da_cls, kwargs, start_storage):
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_root(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse('r', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse('r', filter_fn=filter_fn))
assert len(ds) == 1
assert len(ds[0]) == num_docs
@@ -84,11 +89,13 @@ def test_traverse_root(doc_req, filter_fn, da_cls, kwargs, start_storage):
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_chunk(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse('c', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse('c', filter_fn=filter_fn))
assert len(ds) == num_docs
assert len(ds[0]) == num_chunks_per_doc
@@ -104,11 +111,13 @@ def test_traverse_chunk(doc_req, filter_fn, da_cls, kwargs, start_storage):
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_root_plus_chunk(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse('c,r', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse('c,r', filter_fn=filter_fn))
assert len(ds) == num_docs + 1
assert len(ds[0]) == num_chunks_per_doc
assert len(ds[-1]) == num_docs
@@ -125,11 +134,13 @@ def test_traverse_root_plus_chunk(doc_req, filter_fn, da_cls, kwargs, start_stor
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_chunk_plus_root(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse('r,c', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse('r,c', filter_fn=filter_fn))
assert len(ds) == 1 + num_docs
assert len(ds[-1]) == num_chunks_per_doc
assert len(ds[0]) == num_docs
@@ -146,11 +157,13 @@ def test_traverse_chunk_plus_root(doc_req, filter_fn, da_cls, kwargs, start_stor
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_match(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse('m', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse('m', filter_fn=filter_fn))
assert len(ds) == num_docs
assert len(ds[0]) == num_matches_per_doc
@@ -166,11 +179,13 @@ def test_traverse_match(doc_req, filter_fn, da_cls, kwargs, start_storage):
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_match_chunk(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse('cm', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse('cm', filter_fn=filter_fn))
assert len(ds) == num_docs * num_chunks_per_doc
assert len(ds[0]) == num_matches_per_chunk
@@ -186,11 +201,13 @@ def test_traverse_match_chunk(doc_req, filter_fn, da_cls, kwargs, start_storage)
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_root_match_chunk(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse('r,c,m,cm', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse('r,c,m,cm', filter_fn=filter_fn))
assert len(ds) == 1 + num_docs + num_docs + num_docs * num_chunks_per_doc
@@ -205,11 +222,13 @@ def test_traverse_root_match_chunk(doc_req, filter_fn, da_cls, kwargs, start_sto
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flatten_embedding(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- flattened_results = doc_req.traverse_flat('r,c', filter_fn=filter_fn)
+ with doc_req: # speed up milvus by loading collection
+ flattened_results = doc_req.traverse_flat('r,c', filter_fn=filter_fn)
ds = flattened_results.embeddings
assert ds.shape == (num_docs + num_chunks_per_doc * num_docs, 10)
@@ -225,11 +244,13 @@ def test_traverse_flatten_embedding(doc_req, filter_fn, da_cls, kwargs, start_st
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flatten_root(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat('r', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat('r', filter_fn=filter_fn))
assert len(ds) == num_docs
@@ -244,11 +265,13 @@ def test_traverse_flatten_root(doc_req, filter_fn, da_cls, kwargs, start_storage
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flatten_chunk(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat('c', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat('c', filter_fn=filter_fn))
assert len(ds) == num_docs * num_chunks_per_doc
@@ -263,13 +286,15 @@ def test_traverse_flatten_chunk(doc_req, filter_fn, da_cls, kwargs, start_storag
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flatten_root_plus_chunk(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat('c,r', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat('c,r', filter_fn=filter_fn))
assert len(ds) == num_docs + num_docs * num_chunks_per_doc
@@ -284,11 +309,13 @@ def test_traverse_flatten_root_plus_chunk(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flatten_match(doc_req, filter_fn, da_cls, kwargs, start_storage):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat('m', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat('m', filter_fn=filter_fn))
assert len(ds) == num_docs * num_matches_per_doc
@@ -303,13 +330,15 @@ def test_traverse_flatten_match(doc_req, filter_fn, da_cls, kwargs, start_storag
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flatten_match_chunk(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat('cm', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat('cm', filter_fn=filter_fn))
assert len(ds) == num_docs * num_chunks_per_doc * num_matches_per_chunk
@@ -324,13 +353,15 @@ def test_traverse_flatten_match_chunk(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flatten_root_match_chunk(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat('r,c,m,cm', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat('r,c,m,cm', filter_fn=filter_fn))
assert (
len(ds)
== num_docs
@@ -351,13 +382,17 @@ def test_traverse_flatten_root_match_chunk(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flattened_per_path_embedding(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- flattened_results = list(doc_req.traverse_flat_per_path('r,c', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ flattened_results = list(
+ doc_req.traverse_flat_per_path('r,c', filter_fn=filter_fn)
+ )
ds = flattened_results[0].embeddings
assert ds.shape == (num_docs, 10)
@@ -376,13 +411,15 @@ def test_traverse_flattened_per_path_embedding(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flattened_per_path_root(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat_per_path('r', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat_per_path('r', filter_fn=filter_fn))
assert len(ds[0]) == num_docs
@@ -397,13 +434,15 @@ def test_traverse_flattened_per_path_root(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flattened_per_path_chunk(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat_per_path('c', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat_per_path('c', filter_fn=filter_fn))
assert len(ds[0]) == num_docs * num_chunks_per_doc
@@ -418,13 +457,15 @@ def test_traverse_flattened_per_path_chunk(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flattened_per_path_root_plus_chunk(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat_per_path('c,r', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat_per_path('c,r', filter_fn=filter_fn))
assert len(ds[0]) == num_docs * num_chunks_per_doc
assert len(ds[1]) == num_docs
@@ -440,13 +481,15 @@ def test_traverse_flattened_per_path_root_plus_chunk(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flattened_per_path_match(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat_per_path('m', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat_per_path('m', filter_fn=filter_fn))
assert len(ds[0]) == num_docs * num_matches_per_doc
@@ -461,13 +504,15 @@ def test_traverse_flattened_per_path_match(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flattened_per_path_root_match_chunk(
doc_req, filter_fn, da_cls, kwargs, start_storage
):
doc_req = da_cls(doc_req, **kwargs)
- ds = list(doc_req.traverse_flat_per_path('r,c,m,cm', filter_fn=filter_fn))
+ with doc_req: # speed up milvus by loading collection
+ ds = list(doc_req.traverse_flat_per_path('r,c,m,cm', filter_fn=filter_fn))
assert len(ds[0]) == num_docs
assert len(ds[1]) == num_chunks_per_doc * num_docs
assert len(ds[2]) == num_matches_per_doc * num_docs
@@ -485,18 +530,20 @@ def test_traverse_flattened_per_path_root_match_chunk(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_docuset_traverse_over_iterator_HACKY(da_cls, kwargs, filter_fn):
# HACKY USAGE DO NOT RECOMMEND: can also traverse over "runtime"-documentarray
da = da_cls(random_docs(num_docs, num_chunks_per_doc), **kwargs)
-
- ds = da.traverse('r', filter_fn=filter_fn)
+ with da: # speed up milvus by loading collection
+ ds = da.traverse('r', filter_fn=filter_fn)
assert len(list(list(ds)[0])) == num_docs
- ds = da_cls(random_docs(num_docs, num_chunks_per_doc), **kwargs).traverse(
- 'c', filter_fn=filter_fn
- )
+ ds = da_cls(random_docs(num_docs, num_chunks_per_doc), **kwargs)
+
+ with ds: # speed up milvus by loading collection
+ ds = ds.traverse('c', filter_fn=filter_fn)
ds = list(ds)
assert len(ds) == num_docs
assert len(ds[0]) == num_chunks_per_doc
@@ -513,20 +560,21 @@ def test_docuset_traverse_over_iterator_HACKY(da_cls, kwargs, filter_fn):
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_docuset_traverse_over_iterator_CAVEAT(da_cls, kwargs, filter_fn):
# HACKY USAGE's CAVEAT: but it can not iterate over an iterator twice
- ds = da_cls(random_docs(num_docs, num_chunks_per_doc), **kwargs).traverse(
- 'r,c', filter_fn=filter_fn
- )
+ ds = da_cls(random_docs(num_docs, num_chunks_per_doc), **kwargs)
+ with ds:
+ ds = ds.traverse('r,c', filter_fn=filter_fn)
# note that random_docs is a generator and can be only used once,
# therefore whoever comes first wil get iterated, and then it becomes empty
assert len(list(ds)) == 1 + num_docs
- ds = da_cls(random_docs(num_docs, num_chunks_per_doc), **kwargs).traverse(
- 'c,r', filter_fn=filter_fn
- )
+ ds = da_cls(random_docs(num_docs, num_chunks_per_doc), **kwargs)
+ with ds:
+ ds = ds.traverse('c,r', filter_fn=filter_fn)
assert len(list(ds)) == num_docs + 1
@@ -580,6 +628,7 @@ def test_traverse_chunkarray(filter_fn):
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
@pytest.mark.parametrize(
@@ -611,7 +660,8 @@ def test_filter_fn_traverse_flat(
filter_fn, docs_len, doc_req, da_cls, kwargs, tmp_path
):
docs = da_cls(doc_req, **kwargs)
- ds = list(docs.traverse_flat('r,c,m,cm', filter_fn=filter_fn))
+ with docs:
+ ds = list(docs.traverse_flat('r,c,m,cm', filter_fn=filter_fn))
assert len(ds) == docs_len
assert all(isinstance(d, Document) for d in ds)
@@ -626,6 +676,7 @@ def test_filter_fn_traverse_flat(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
@pytest.mark.parametrize(
@@ -661,7 +712,8 @@ def test_filter_fn_traverse_flat_per_path(
filter_fn, doc_req, docs_len, da_cls, kwargs, tmp_path
):
docs = da_cls(doc_req, **kwargs)
- ds = list(docs.traverse_flat_per_path('r,c,m,cm', filter_fn=filter_fn))
+ with docs:
+ ds = list(docs.traverse_flat_per_path('r,c,m,cm', filter_fn=filter_fn))
assert len(ds) == 4
for seq, length in zip(ds, docs_len):
assert isinstance(seq, DocumentArray)
@@ -678,13 +730,15 @@ def test_filter_fn_traverse_flat_per_path(
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traversal_path(da_cls, kwargs):
da = da_cls([Document() for _ in range(6)], **kwargs)
assert len(da) == 6
- da.traverse_flat('r')
+ with da:
+ da.traverse_flat('r')
@pytest.mark.parametrize(
@@ -697,11 +751,13 @@ def test_traversal_path(da_cls, kwargs):
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_traverse_flat_root_itself(da_cls, kwargs):
da = da_cls([Document() for _ in range(100)], **kwargs)
- res = da.traverse_flat('r')
+ with da:
+ res = da.traverse_flat('r')
assert id(res) == id(da)
@@ -720,11 +776,13 @@ def da_and_dam(N):
(DocumentArrayQdrant, {'config': {'n_dim': 10}}),
(DocumentArrayElastic, {'config': {'n_dim': 10}}),
(DocumentArrayRedis, {'config': {'n_dim': 10}}),
+ (DocumentArrayMilvus, {'config': {'n_dim': 10}}),
],
)
def test_flatten(da_cls, kwargs):
da = da_cls(random_docs(100), **kwargs)
- daf = da.flatten()
+ with da:
+ daf = da.flatten()
assert len(daf) == 600
assert isinstance(daf, DocumentArray)
assert len(set(d.id for d in daf)) == 600
diff --git a/tests/unit/array/storage/milvus/__init__.py b/tests/unit/array/storage/milvus/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tests/unit/array/storage/milvus/test_milvus.py b/tests/unit/array/storage/milvus/test_milvus.py
new file mode 100644
index 00000000000..42ac17173ea
--- /dev/null
+++ b/tests/unit/array/storage/milvus/test_milvus.py
@@ -0,0 +1,169 @@
+import pytest
+from docarray import Document
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
+from pymilvus import loading_progress
+import numpy as np
+
+
+def _is_fully_loaded(da):
+ collections = da._collection, da._offset2id_collection
+ fully_loaded = True
+ for coll in collections:
+ coll_loaded = (
+ loading_progress(coll.name, using=da._connection_alias)['loading_progress']
+ == '100%'
+ )
+ fully_loaded = fully_loaded and coll_loaded
+ return fully_loaded
+
+
+def _is_fully_released(da):
+ collections = da._collection, da._offset2id_collection
+ fully_released = True
+ for coll in collections:
+ coll_released = (
+ loading_progress(coll.name, using=da._connection_alias)['loading_progress']
+ == '0%'
+ )
+ fully_released = fully_released and coll_released
+ return fully_released
+
+
+def test_memory_release(start_storage):
+ da = DocumentArrayMilvus(
+ config={
+ 'n_dim': 10,
+ },
+ )
+ da.extend([Document(embedding=np.random.random([10])) for _ in range(10)])
+ da.find(Document(embedding=np.random.random([10])))
+ assert _is_fully_released(da)
+
+
+def test_memory_cntxt_mngr(start_storage):
+ da = DocumentArrayMilvus(
+ config={
+ 'n_dim': 10,
+ },
+ )
+
+ # `with da` context manager
+ assert _is_fully_released(da)
+ with da:
+ assert _is_fully_loaded(da)
+ pass
+ assert _is_fully_released(da)
+
+ # `da.loaded_collection` context manager
+ with da.loaded_collection(), da.loaded_collection(da._offset2id_collection):
+ assert _is_fully_loaded(da)
+ pass
+ assert _is_fully_released(da)
+
+ # both combined
+ with da:
+ assert _is_fully_loaded(da)
+ with da.loaded_collection(), da.loaded_collection(da._offset2id_collection):
+ assert _is_fully_loaded(da)
+ pass
+ assert _is_fully_loaded(da)
+ assert _is_fully_released(da)
+
+
+@pytest.fixture()
+def mock_response():
+ class MockHit:
+ @property
+ def entity(self):
+ return {'serialized': Document().to_base64()}
+
+ return [[MockHit()]]
+
+
+@pytest.mark.parametrize(
+ 'method,meth_input',
+ [
+ ('append', [Document(embedding=np.random.random([10]))]),
+ ('extend', [[Document(embedding=np.random.random([10]))]]),
+ ('find', [Document(embedding=np.random.random([10]))]),
+ ('insert', [0, Document(embedding=np.random.random([10]))]),
+ ],
+)
+def test_consistency_level(start_storage, mocker, method, meth_input, mock_response):
+ init_consistency = 'Session'
+ da = DocumentArrayMilvus(
+ config={
+ 'n_dim': 10,
+ 'consistency_level': init_consistency,
+ },
+ )
+
+ # patch Milvus collection
+ patch_methods = ['insert', 'search', 'delete', 'query']
+ for m in patch_methods:
+ setattr(da._collection, m, mocker.Mock(return_value=mock_response))
+
+ # test consistency level set in config
+ getattr(da, method)(*meth_input)
+ for m in patch_methods:
+ mock_meth = getattr(da._collection, m)
+ for args, kwargs in mock_meth.call_args_list:
+ if 'consistency_level' in kwargs:
+ assert kwargs['consistency_level'] == init_consistency
+
+ # reset the mocks
+ for m in patch_methods:
+ setattr(da._collection, m, mocker.Mock(return_value=mock_response))
+
+ # test dynamic consistency level
+ new_consistency = 'Strong'
+ getattr(da, method)(*meth_input, consistency_level=new_consistency)
+ for m in patch_methods:
+ mock_meth = getattr(da._collection, m)
+ for args, kwargs in mock_meth.call_args_list:
+ if 'consistency_level' in kwargs:
+ assert kwargs['consistency_level'] == new_consistency
+
+
+@pytest.mark.parametrize(
+ 'method,meth_input',
+ [
+ ('append', [Document(embedding=np.random.random([10]))]),
+ ('extend', [[Document(embedding=np.random.random([10]))]]),
+ ('insert', [0, Document(embedding=np.random.random([10]))]),
+ ],
+)
+def test_batching(start_storage, mocker, method, meth_input, mock_response):
+ init_batch_size = 5
+ da = DocumentArrayMilvus(
+ config={
+ 'n_dim': 10,
+ 'batch_size': init_batch_size,
+ },
+ )
+
+ # patch Milvus collection
+ patch_methods = ['insert', 'search', 'delete', 'query']
+ for m in patch_methods:
+ setattr(da._collection, m, mocker.Mock(return_value=mock_response))
+
+ # test batch_size set in config
+ getattr(da, method)(*meth_input)
+ for m in patch_methods:
+ mock_meth = getattr(da._collection, m)
+ for args, kwargs in mock_meth.call_args_list:
+ if 'batch_size' in kwargs:
+ assert kwargs['batch_size'] == init_batch_size
+
+ # reset the mocks
+ for m in patch_methods:
+ setattr(da._collection, m, mocker.Mock(return_value=mock_response))
+
+ # test dynamic consistency level
+ new_batch_size = 100
+ getattr(da, method)(*meth_input, batch_size=new_batch_size)
+ for m in patch_methods:
+ mock_meth = getattr(da._collection, m)
+ for args, kwargs in mock_meth.call_args_list:
+ if 'batch_size' in kwargs:
+ assert kwargs['batch_size'] == new_batch_size
diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py
index df8e005ddc0..ac27fa8bf6f 100644
--- a/tests/unit/array/test_advance_indexing.py
+++ b/tests/unit/array/test_advance_indexing.py
@@ -7,6 +7,7 @@
from docarray.array.qdrant import QdrantConfig
from docarray.array.elastic import ElasticConfig
from docarray.array.redis import RedisConfig
+from docarray.array.milvus import MilvusConfig
@pytest.fixture
@@ -30,6 +31,7 @@ def indices():
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_getter_int_str(docs, storage, config, start_storage):
@@ -64,6 +66,7 @@ def test_getter_int_str(docs, storage, config, start_storage):
('qdrant', QdrantConfig(n_dim=123)),
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_setter_int_str(docs, storage, config, start_storage):
@@ -95,6 +98,7 @@ def test_setter_int_str(docs, storage, config, start_storage):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_del_int_str(docs, storage, config, start_storage, indices):
@@ -131,6 +135,7 @@ def test_del_int_str(docs, storage, config, start_storage, indices):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_slice(docs, storage, config, start_storage):
@@ -171,6 +176,7 @@ def test_slice(docs, storage, config, start_storage):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_sequence_bool_index(docs, storage, config, start_storage):
@@ -219,6 +225,7 @@ def test_sequence_bool_index(docs, storage, config, start_storage):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_sequence_int(docs, nparray, storage, config, start_storage):
@@ -257,6 +264,7 @@ def test_sequence_int(docs, nparray, storage, config, start_storage):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_sequence_str(docs, storage, config, start_storage):
@@ -293,6 +301,7 @@ def test_sequence_str(docs, storage, config, start_storage):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_docarray_list_tuple(docs, storage, config, start_storage):
@@ -315,6 +324,7 @@ def test_docarray_list_tuple(docs, storage, config, start_storage):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_path_syntax_indexing(storage, config, start_storage):
@@ -330,19 +340,19 @@ def test_path_syntax_indexing(storage, config, start_storage):
da = DocumentArray(da, storage=storage, config=config)
else:
da = DocumentArray(da, storage=storage)
-
- assert len(da['@c']) == 3 * 5
- assert len(da['@c:1']) == 3
- assert len(da['@c-1:']) == 3
- assert len(da['@c1']) == 3
- assert len(da['@c-2:']) == 3 * 2
- assert len(da['@c1:3']) == 3 * 2
- assert len(da['@c1:3c']) == (3 * 2) * 3
- assert len(da['@c1:3,c1:3c']) == (3 * 2) + (3 * 2) * 3
- assert len(da['@c 1:3 , c 1:3 c']) == (3 * 2) + (3 * 2) * 3
- assert len(da['@cc']) == 3 * 5 * 3
- assert len(da['@cc,m']) == 3 * 5 * 3 + 3 * 7
- assert len(da['@r:1cc,m']) == 1 * 5 * 3 + 3 * 7
+ with da:
+ assert len(da['@c']) == 3 * 5
+ assert len(da['@c:1']) == 3
+ assert len(da['@c-1:']) == 3
+ assert len(da['@c1']) == 3
+ assert len(da['@c-2:']) == 3 * 2
+ assert len(da['@c1:3']) == 3 * 2
+ assert len(da['@c1:3c']) == (3 * 2) * 3
+ assert len(da['@c1:3,c1:3c']) == (3 * 2) + (3 * 2) * 3
+ assert len(da['@c 1:3 , c 1:3 c']) == (3 * 2) + (3 * 2) * 3
+ assert len(da['@cc']) == 3 * 5 * 3
+ assert len(da['@cc,m']) == 3 * 5 * 3 + 3 * 7
+ assert len(da['@r:1cc,m']) == 1 * 5 * 3 + 3 * 7
@pytest.mark.parametrize(
@@ -356,6 +366,7 @@ def test_path_syntax_indexing(storage, config, start_storage):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
@pytest.mark.parametrize('use_subindex', [False, True])
@@ -382,44 +393,48 @@ def test_path_syntax_indexing_set(storage, config, use_subindex, start_storage):
da, storage=storage, subindex_configs={'@c': None} if use_subindex else None
)
- assert da['@c'].texts == repeat('a', 3 * 5)
- assert da['@c', 'text'] == repeat('a', 3 * 5)
- if use_subindex:
- assert da._subindices['@c'].texts == repeat('a', 3 * 5)
- assert da['@c:1', 'text'] == repeat('a', 3)
- assert da['@c-1:', 'text'] == repeat('a', 3)
- assert da['@c1', 'text'] == repeat('a', 3)
- assert da['@c-2:', 'text'] == repeat('a', 3 * 2)
- assert da['@c1:3', 'text'] == repeat('a', 3 * 2)
- assert da['@c1:3c', 'text'] == repeat('a', (3 * 2) * 3)
- assert da['@c1:3,c1:3c', 'text'] == repeat('a', (3 * 2) + (3 * 2) * 3)
- assert da['@c 1:3 , c 1:3 c', 'text'] == repeat('a', (3 * 2) + (3 * 2) * 3)
- assert da['@cc', 'text'] == repeat('a', 3 * 5 * 3)
- assert da['@cc,m', 'text'] == repeat('a', 3 * 5 * 3 + 3 * 7)
- assert da['@r:1cc,m', 'text'] == repeat('a', 1 * 5 * 3 + 3 * 7)
- assert da[0, 'text'] == 'a'
- assert da[[True for _ in da], 'text'] == repeat('a', 3)
+ with da:
+ assert da['@c'].texts == repeat('a', 3 * 5)
+ assert da['@c', 'text'] == repeat('a', 3 * 5)
+ if use_subindex:
+ assert da._subindices['@c'].texts == repeat('a', 3 * 5)
+ assert da['@c:1', 'text'] == repeat('a', 3)
+ assert da['@c-1:', 'text'] == repeat('a', 3)
+ assert da['@c1', 'text'] == repeat('a', 3)
+ assert da['@c-2:', 'text'] == repeat('a', 3 * 2)
+ assert da['@c1:3', 'text'] == repeat('a', 3 * 2)
+ assert da['@c1:3c', 'text'] == repeat('a', (3 * 2) * 3)
+ assert da['@c1:3,c1:3c', 'text'] == repeat('a', (3 * 2) + (3 * 2) * 3)
+ assert da['@c 1:3 , c 1:3 c', 'text'] == repeat('a', (3 * 2) + (3 * 2) * 3)
+ assert da['@cc', 'text'] == repeat('a', 3 * 5 * 3)
+ assert da['@cc,m', 'text'] == repeat('a', 3 * 5 * 3 + 3 * 7)
+ assert da['@r:1cc,m', 'text'] == repeat('a', 1 * 5 * 3 + 3 * 7)
+ assert da[0, 'text'] == 'a'
+ assert da[[True for _ in da], 'text'] == repeat('a', 3)
da['@m,cc', 'text'] = repeat('b', 3 + 5 * 3 + 7 * 3 + 3 * 5 * 3)
- assert da['@c', 'text'] == repeat('a', 3 * 5)
- if use_subindex:
- assert da._subindices['@c'].texts == repeat('a', 3 * 5)
- assert da['@c:1', 'text'] == repeat('a', 3)
- assert da['@c-1:', 'text'] == repeat('a', 3)
- assert da['@c1', 'text'] == repeat('a', 3)
- assert da['@c-2:', 'text'] == repeat('a', 3 * 2)
- assert da['@c1:3', 'text'] == repeat('a', 3 * 2)
- assert da['@c1:3c', 'text'] == repeat('b', (3 * 2) * 3)
- assert da['@c1:3,c1:3c', 'text'] == repeat('a', (3 * 2)) + repeat('b', (3 * 2) * 3)
- assert da['@c 1:3 , c 1:3 c', 'text'] == repeat('a', (3 * 2)) + repeat(
- 'b', (3 * 2) * 3
- )
- assert da['@cc', 'text'] == repeat('b', 3 * 5 * 3)
- assert da['@cc,m', 'text'] == repeat('b', 3 * 5 * 3 + 3 * 7)
- assert da['@r:1cc,m', 'text'] == repeat('b', 1 * 5 * 3 + 3 * 7)
- assert da[0, 'text'] == 'a'
- assert da[[True for _ in da], 'text'] == repeat('a', 3)
+ with da:
+ assert da['@c', 'text'] == repeat('a', 3 * 5)
+ if use_subindex:
+ assert da._subindices['@c'].texts == repeat('a', 3 * 5)
+ assert da['@c:1', 'text'] == repeat('a', 3)
+ assert da['@c-1:', 'text'] == repeat('a', 3)
+ assert da['@c1', 'text'] == repeat('a', 3)
+ assert da['@c-2:', 'text'] == repeat('a', 3 * 2)
+ assert da['@c1:3', 'text'] == repeat('a', 3 * 2)
+ assert da['@c1:3c', 'text'] == repeat('b', (3 * 2) * 3)
+ assert da['@c1:3,c1:3c', 'text'] == repeat('a', (3 * 2)) + repeat(
+ 'b', (3 * 2) * 3
+ )
+ assert da['@c 1:3 , c 1:3 c', 'text'] == repeat('a', (3 * 2)) + repeat(
+ 'b', (3 * 2) * 3
+ )
+ assert da['@cc', 'text'] == repeat('b', 3 * 5 * 3)
+ assert da['@cc,m', 'text'] == repeat('b', 3 * 5 * 3 + 3 * 7)
+ assert da['@r:1cc,m', 'text'] == repeat('b', 1 * 5 * 3 + 3 * 7)
+ assert da[0, 'text'] == 'a'
+ assert da[[True for _ in da], 'text'] == repeat('a', 3)
da[1, 'text'] = 'd'
assert da[1, 'text'] == 'd'
@@ -431,12 +446,13 @@ def test_path_syntax_indexing_set(storage, config, use_subindex, start_storage):
assert da[doc_id].text == 'e'
# setting matches is only possible if the IDs are the same
- da['@m'] = [Document(id=f'm{i}', text='c') for i in range(3 * 7)]
- assert da['@m', 'text'] == repeat('c', 3 * 7)
+ with da:
+ da['@m'] = [Document(id=f'm{i}', text='c') for i in range(3 * 7)]
+ assert da['@m', 'text'] == repeat('c', 3 * 7)
- # setting by traversal paths with different IDs is not supported
- with pytest.raises(ValueError):
- da['@m'] = [Document() for _ in range(3 * 7)]
+ # setting by traversal paths with different IDs is not supported
+ with pytest.raises(ValueError):
+ da['@m'] = [Document() for _ in range(3 * 7)]
da[2, ['text', 'id']] = ['new_text', 'new_id']
assert da[2].text == 'new_text'
@@ -454,6 +470,7 @@ def test_path_syntax_indexing_set(storage, config, use_subindex, start_storage):
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_getset_subindex(storage, config, start_storage):
@@ -462,31 +479,34 @@ def test_getset_subindex(storage, config, start_storage):
config=config,
subindex_configs={'@c': {'n_dim': 123}} if config else {'@c': None},
)
- assert len(da['@c']) == 15
- assert len(da._subindices['@c']) == 15
- # set entire subindex
- chunks_ids = [c.id for c in da['@c']]
- new_chunks = [Document(id=cid, text=f'{i}') for i, cid in enumerate(chunks_ids)]
- da['@c'] = new_chunks
- new_chunks = DocumentArray(new_chunks)
- assert da['@c'] == new_chunks
- assert da._subindices['@c'] == new_chunks
- collected_chunks = DocumentArray.empty(0)
- for d in da:
- collected_chunks.extend(d.chunks)
- assert collected_chunks == new_chunks
- # set part of a subindex
- chunks_ids = [c.id for c in da['@c:3']]
- new_chunks = [Document(id=cid, text=f'{2*i}') for i, cid in enumerate(chunks_ids)]
- da['@c:3'] = new_chunks
- new_chunks = DocumentArray(new_chunks)
- assert da['@c:3'] == new_chunks
- for d in new_chunks:
- assert d in da._subindices['@c']
- collected_chunks = DocumentArray.empty(0)
- for d in da:
- collected_chunks.extend(d.chunks[:3])
- assert collected_chunks == new_chunks
+ with da:
+ assert len(da['@c']) == 15
+ assert len(da._subindices['@c']) == 15
+ # set entire subindex
+ chunks_ids = [c.id for c in da['@c']]
+ new_chunks = [Document(id=cid, text=f'{i}') for i, cid in enumerate(chunks_ids)]
+ da['@c'] = new_chunks
+ new_chunks = DocumentArray(new_chunks)
+ assert da['@c'] == new_chunks
+ assert da._subindices['@c'] == new_chunks
+ collected_chunks = DocumentArray.empty(0)
+ for d in da:
+ collected_chunks.extend(d.chunks)
+ assert collected_chunks == new_chunks
+ # set part of a subindex
+ chunks_ids = [c.id for c in da['@c:3']]
+ new_chunks = [
+ Document(id=cid, text=f'{2*i}') for i, cid in enumerate(chunks_ids)
+ ]
+ da['@c:3'] = new_chunks
+ new_chunks = DocumentArray(new_chunks)
+ assert da['@c:3'] == new_chunks
+ for d in new_chunks:
+ assert d in da._subindices['@c']
+ collected_chunks = DocumentArray.empty(0)
+ for d in da:
+ collected_chunks.extend(d.chunks[:3])
+ assert collected_chunks == new_chunks
@pytest.mark.parametrize('size', [1, 5])
@@ -501,6 +521,7 @@ def test_getset_subindex(storage, config, start_storage):
('qdrant', lambda: QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', lambda: ElasticConfig(n_dim=123)),
('redis', lambda: RedisConfig(n_dim=123)),
+ ('milvus', lambda: MilvusConfig(n_dim=123)),
],
)
def test_attribute_indexing(storage, config_gen, start_storage, size):
@@ -541,6 +562,7 @@ def test_attribute_indexing(storage, config_gen, start_storage, size):
('qdrant', lambda: QdrantConfig(n_dim=10, prefer_grpc=True)),
('elasticsearch', lambda: ElasticConfig(n_dim=10)),
('redis', lambda: RedisConfig(n_dim=10)),
+ ('milvus', lambda: MilvusConfig(n_dim=10)),
],
)
def test_tensor_attribute_selector(storage, config_gen, start_storage):
@@ -603,6 +625,7 @@ def test_advance_selector_mixed(storage):
('qdrant', lambda: QdrantConfig(n_dim=10, prefer_grpc=True)),
('elasticsearch', lambda: ElasticConfig(n_dim=10)),
('redis', lambda: RedisConfig(n_dim=10)),
+ ('milvus', lambda: MilvusConfig(n_dim=10)),
],
)
def test_single_boolean_and_padding(storage, config_gen, start_storage):
@@ -637,6 +660,7 @@ def test_single_boolean_and_padding(storage, config_gen, start_storage):
('qdrant', lambda: QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', lambda: ElasticConfig(n_dim=123)),
('redis', lambda: RedisConfig(n_dim=123)),
+ ('milvus', lambda: MilvusConfig(n_dim=123)),
],
)
def test_edge_case_two_strings(storage, config_gen, start_storage):
@@ -710,6 +734,10 @@ def test_edge_case_two_strings(storage, config_gen, start_storage):
'storage,config',
[
('annlite', AnnliteConfig(n_dim=123)),
+ ('qdrant', QdrantConfig(n_dim=123)),
+ ('elasticsearch', ElasticConfig(n_dim=123)),
+ ('redis', RedisConfig(n_dim=123)),
+ ('milvus', MilvusConfig(n_dim=123)),
],
)
def test_offset2ids_persistence(storage, config):
diff --git a/tests/unit/array/test_backend_configuration.py b/tests/unit/array/test_backend_configuration.py
index e10b080d9da..9ef5e2130c5 100644
--- a/tests/unit/array/test_backend_configuration.py
+++ b/tests/unit/array/test_backend_configuration.py
@@ -155,6 +155,30 @@ def test_cast_columns_qdrant(start_storage, type_da, type_column, prefer_grpc, r
assert len(index) == N
+@pytest.mark.parametrize('type_da', [int, float, str, bool])
+@pytest.mark.parametrize('type_column', ['int', 'str', 'float', 'double', 'bool'])
+def test_cast_columns_milvus(start_storage, type_da, type_column, request):
+ test_id = request.node.callspec.id.replace(
+ '-', ''
+ ) # remove '-' from the test id for the milvus name
+ N = 10
+
+ index = DocumentArray(
+ storage='milvus',
+ config={
+ 'collection_name': f'test{test_id}',
+ 'n_dim': 3,
+ 'columns': {'price': type_column},
+ },
+ )
+
+ docs = DocumentArray([Document(tags={'price': type_da(i)}) for i in range(N)])
+
+ index.extend(docs)
+
+ assert len(index) == N
+
+
def test_random_subindices_config():
database_index = random.randint(0, 100)
database_name = "jina" + str(database_index) + ".db"
diff --git a/tests/unit/array/test_construct.py b/tests/unit/array/test_construct.py
index 251e8459b16..e0e68d4b834 100644
--- a/tests/unit/array/test_construct.py
+++ b/tests/unit/array/test_construct.py
@@ -10,6 +10,7 @@
from docarray.array.weaviate import DocumentArrayWeaviate, WeaviateConfig
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
@pytest.mark.parametrize(
@@ -22,6 +23,7 @@
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_construct_docarray(da_cls, config, start_storage):
@@ -71,6 +73,7 @@ def test_construct_docarray(da_cls, config, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
@pytest.mark.parametrize('is_copy', [True, False])
@@ -101,6 +104,7 @@ def test_docarray_copy_singleton(da_cls, config, is_copy, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
@pytest.mark.parametrize('is_copy', [True, False])
@@ -130,6 +134,7 @@ def test_docarray_copy_da(da_cls, config, is_copy, start_storage):
(DocumentArrayQdrant, QdrantConfig(n_dim=1)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
@pytest.mark.parametrize('is_copy', [True, False])
diff --git a/tests/unit/array/test_pull_out.py b/tests/unit/array/test_pull_out.py
index e487c94214e..c36dde294b0 100644
--- a/tests/unit/array/test_pull_out.py
+++ b/tests/unit/array/test_pull_out.py
@@ -23,6 +23,7 @@ def docs():
('qdrant', {'n_dim': 2}),
('elasticsearch', {'n_dim': 2}),
('redis', {'n_dim': 2}),
+ ('milvus', {'n_dim': 2}),
],
)
def test_update_embedding(docs, storage, config, start_storage):
@@ -58,6 +59,7 @@ def test_update_embedding(docs, storage, config, start_storage):
('qdrant', {'n_dim': 2}),
('elasticsearch', {'n_dim': 2}),
('redis', {'n_dim': 2}),
+ ('milvus', {'n_dim': 2}),
],
)
def test_update_doc_embedding(docs, storage, config, start_storage):
@@ -93,6 +95,7 @@ def test_update_doc_embedding(docs, storage, config, start_storage):
('qdrant', {'n_dim': 2}),
('elasticsearch', {'n_dim': 2}),
('redis', {'n_dim': 2}),
+ ('milvus', {'n_dim': 2}),
],
)
def test_batch_update_embedding(docs, storage, config, start_storage):
@@ -126,6 +129,7 @@ def test_batch_update_embedding(docs, storage, config, start_storage):
('qdrant', {'n_dim': 2}),
('elasticsearch', {'n_dim': 2}),
('redis', {'n_dim': 2}),
+ ('milvus', {'n_dim': 2}),
],
)
def test_batch_update_doc_embedding(docs, storage, config, start_storage):
@@ -161,6 +165,7 @@ def test_batch_update_doc_embedding(docs, storage, config, start_storage):
('qdrant', {'n_dim': 2}),
('elasticsearch', {'n_dim': 2}),
('redis', {'n_dim': 2}),
+ ('milvus', {'n_dim': 2}),
],
)
def test_update_id(docs, storage, config, start_storage):
@@ -183,6 +188,7 @@ def test_update_id(docs, storage, config, start_storage):
('qdrant', {'n_dim': 2}),
('elasticsearch', {'n_dim': 2}),
('redis', {'n_dim': 2}),
+ ('milvus', {'n_dim': 2}),
],
)
def test_update_doc_id(docs, storage, config, start_storage):
@@ -204,6 +210,7 @@ def test_update_doc_id(docs, storage, config, start_storage):
('qdrant', {'n_dim': 2}),
('elasticsearch', {'n_dim': 2}),
('redis', {'n_dim': 2}),
+ ('milvus', {'n_dim': 2}),
],
)
def test_batch_update_id(docs, storage, config, start_storage):
@@ -228,6 +235,7 @@ def test_batch_update_id(docs, storage, config, start_storage):
('qdrant', {'n_dim': 2}),
('elasticsearch', {'n_dim': 2}),
('redis', {'n_dim': 2}),
+ ('milvus', {'n_dim': 2}),
],
)
def test_batch_update_doc_id(docs, storage, config, start_storage):
diff --git a/tests/unit/array/test_sequence.py b/tests/unit/array/test_sequence.py
index 92b04995357..5c86237e049 100644
--- a/tests/unit/array/test_sequence.py
+++ b/tests/unit/array/test_sequence.py
@@ -15,6 +15,7 @@
from docarray.array.storage.sqlite import SqliteConfig
from docarray.array.storage.weaviate import WeaviateConfig
from docarray.array.weaviate import DocumentArrayWeaviate
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
from tests.conftest import tmpfile
@@ -27,6 +28,7 @@
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=1)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=1)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=1)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=128)),
],
)
def test_insert(da_cls, config, start_storage):
@@ -50,6 +52,7 @@ def test_insert(da_cls, config, start_storage):
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=1)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=1)),
(DocumentArrayRedis, lambda: RedisConfig(n_dim=1)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=128)),
],
)
def test_append_extend(da_cls, config, start_storage):
@@ -84,6 +87,7 @@ def update_config_inplace(config, tmpdir, tmpfile):
('qdrant', {'n_dim': 3, 'collection_name': 'qdrant'}),
('elasticsearch', {'n_dim': 3, 'index_name': 'elasticsearch'}),
('redis', {'n_dim': 3, 'index_name': 'redis'}),
+ ('milvus', {'n_dim': 3, 'collection_name': 'redis'}),
],
)
def test_context_manager_from_disk(storage, config, start_storage, tmpdir, tmpfile):
@@ -120,9 +124,10 @@ def test_context_manager_from_disk(storage, config, start_storage, tmpdir, tmpfi
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
-def test_extend_subindex(storage, config):
+def test_extend_subindex(storage, config, start_storage):
n_dim = 3
subindex_configs = (
@@ -166,9 +171,10 @@ def test_extend_subindex(storage, config):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
-def test_append_subindex(storage, config):
+def test_append_subindex(storage, config, start_storage):
n_dim = 3
subindex_configs = (
@@ -216,12 +222,13 @@ def embeddings_eq(emb1, emb2):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
@pytest.mark.parametrize(
'index', [1, '1', slice(1, 2), [1], [False, True, False, False, False]]
)
-def test_del_and_append(index, storage, config):
+def test_del_and_append(index, storage, config, start_storage):
da = DocumentArray(storage=storage, config=config)
with da:
@@ -243,12 +250,13 @@ def test_del_and_append(index, storage, config):
('elasticsearch', {'n_dim': 3, 'distance': 'l2_norm'}),
('sqlite', dict()),
('redis', {'n_dim': 3, 'distance': 'L2'}),
+ ('milvus', {'n_dim': 3, 'distance': 'L2'}),
],
)
@pytest.mark.parametrize(
'index', [1, '1', slice(1, 2), [1], [False, True, False, False, False]]
)
-def test_set_and_append(index, storage, config):
+def test_set_and_append(index, storage, config, start_storage):
da = DocumentArray(storage=storage, config=config)
with da:
diff --git a/tests/unit/document/test_plot.py b/tests/unit/document/test_plot.py
index c14d7bbc51c..3e4cb18b2ae 100644
--- a/tests/unit/document/test_plot.py
+++ b/tests/unit/document/test_plot.py
@@ -12,6 +12,7 @@
from docarray.array.storage.qdrant import QdrantConfig
from docarray.array.storage.weaviate import WeaviateConfig
from docarray.array.weaviate import DocumentArrayWeaviate
+from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
@pytest.fixture()
@@ -58,6 +59,7 @@ def test_empty_doc(embed_docs):
(DocumentArrayWeaviate, WeaviateConfig(n_dim=128)),
(DocumentArrayQdrant, QdrantConfig(n_dim=128, scroll_batch_size=8)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
+ (DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_matches_sprites(
@@ -83,6 +85,7 @@ def test_matches_sprites(
(DocumentArrayWeaviate, lambda: WeaviateConfig(n_dim=128)),
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=128, scroll_batch_size=8)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=128)),
+ (DocumentArrayMilvus, lambda: MilvusConfig(n_dim=128)),
],
)
def test_matches_sprite_image_generator(
@@ -95,7 +98,9 @@ def test_matches_sprite_image_generator(
start_storage,
):
da, das = embed_docs
- if image_source == 'tensor':
+ if (
+ image_source == 'tensor' and da_cls != DocumentArrayMilvus
+ ): # Milvus can't handle large tensors
da.apply(lambda d: d.load_uri_to_image_tensor())
das.apply(lambda d: d.load_uri_to_image_tensor())