Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions docarray/array/storage/base/backend.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from abc import ABC
from typing import Dict, Optional, TYPE_CHECKING

from docarray.array.storage.base.helper import Offset2ID

if TYPE_CHECKING:
from ....types import (
DocumentArraySourceType,
)
from ....types import DocumentArraySourceType, ArrayType


class BaseBackendMixin(ABC):
Expand All @@ -21,3 +17,11 @@ def _init_storage(

def _get_storage_infos(self) -> Optional[Dict]:
...

def _map_id(self, _id: str) -> str:
return _id

def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType':
from ....math.ndarray import to_numpy_array

return to_numpy_array(embedding)
5 changes: 3 additions & 2 deletions docarray/array/storage/base/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Sequence,
Any,
Iterable,
Dict,
)

from .helper import Offset2ID
Expand Down Expand Up @@ -151,7 +152,7 @@ def _set_doc(self, _id: str, value: 'Document'):
def _set_doc_by_id(self, _id: str, value: 'Document'):
...

def _set_docs_by_ids(self, ids, docs: Iterable['Document']):
def _set_docs_by_ids(self, ids, docs: Iterable['Document'], mismatch_ids: Dict):
"""This function is derived from :meth:`_set_doc_by_id`
Override this function if there is a more efficient logic

Expand All @@ -162,8 +163,8 @@ def _set_docs_by_ids(self, ids, docs: Iterable['Document']):

def _set_docs(self, ids, docs: Iterable['Document']):
docs = list(docs)
self._set_docs_by_ids(ids, docs)
mismatch_ids = {_id: doc.id for _id, doc in zip(ids, docs) if _id != doc.id}
self._set_docs_by_ids(ids, docs, mismatch_ids)
self._offset2ids.update_ids(mismatch_ids)

def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']):
Expand Down
6 changes: 0 additions & 6 deletions docarray/array/storage/pqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,3 @@

class StorageMixins(FindMixin, BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC):
...

def _to_numpy_embedding(self, doc: 'Document'):
if doc.embedding is None:
doc.embedding = np.zeros(self._pqlite.dim, dtype=np.float32)
elif isinstance(doc.embedding, list):
doc.embedding = np.array(doc.embedding, dtype=np.float32)
13 changes: 10 additions & 3 deletions docarray/array/storage/pqlite/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
Generator,
Iterator,
)

import numpy as np
from pqlite import PQLite

from ..base.backend import BaseBackendMixin
from ....helper import dataclass_from_dict

if TYPE_CHECKING:
from ....types import (
DocumentArraySourceType,
)
from ....types import DocumentArraySourceType, ArrayType


@dataclass
Expand Down Expand Up @@ -96,3 +96,10 @@ def _get_storage_infos(self) -> Dict:
'Data Path': self._config.data_path,
'Serialization Protocol': self._config.serialize_config.get('protocol'),
}

def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType':
if embedding is None:
embedding = np.zeros(self._pqlite.dim, dtype=np.float32)
elif isinstance(embedding, list):
embedding = np.array(embedding, dtype=np.float32)
return embedding
11 changes: 7 additions & 4 deletions docarray/array/storage/pqlite/getsetdel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from typing import Iterable, Dict

from .helper import OffsetMapping
from ..base.getsetdel import BaseGetSetDelMixin
Expand All @@ -19,7 +19,9 @@ def _get_doc_by_id(self, _id: str) -> 'Document':
return doc

def _set_doc_by_id(self, _id: str, value: 'Document'):
self._to_numpy_embedding(value)
if _id != value.id:
self._pqlite.delete([_id])
value.embedding = self._map_embedding(value.embedding)
docs = DocumentArrayInMemory([value])
self._pqlite.update(docs)

Expand All @@ -29,10 +31,11 @@ def _del_doc_by_id(self, _id: str):
def _clear_storage(self):
self._pqlite.clear()

def _set_docs_by_ids(self, ids, docs: Iterable['Document']):
def _set_docs_by_ids(self, ids, docs: Iterable['Document'], mismatch_ids: Dict):
self._pqlite.delete(list(mismatch_ids.keys()))
docs = DocumentArrayInMemory(docs)
for doc in docs:
self._to_numpy_embedding(doc)
doc.embedding = self._map_embedding(doc.embedding)
self._pqlite.update(docs)

def _del_docs_by_ids(self, ids):
Expand Down
2 changes: 1 addition & 1 deletion docarray/array/storage/pqlite/seqlike.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def extend(self, values: Iterable['Document']) -> None:
return

for doc in docs:
self._to_numpy_embedding(doc)
doc.embedding = self._map_embedding(doc.embedding)

self._pqlite.index(docs)
self._offset2ids.extend([doc.id for doc in docs])
Expand Down
10 changes: 2 additions & 8 deletions docarray/array/storage/qdrant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Iterable, TYPE_CHECKING
from typing import TYPE_CHECKING

from .backend import BackendMixin, QdrantConfig
from .helper import DISTANCES
from .find import FindMixin
from .getsetdel import GetSetDelMixin
from .helper import DISTANCES
from .seqlike import SequenceLikeMixin

__all__ = ['StorageMixins', 'QdrantConfig']

if TYPE_CHECKING:
from qdrant_client import QdrantClient
from qdrant_openapi_client.models.models import Distance
from docarray import Document


class StorageMixins(FindMixin, BackendMixin, GetSetDelMixin, SequenceLikeMixin):
Expand All @@ -23,11 +22,6 @@ def serialize_config(self) -> dict:
def distance(self) -> 'Distance':
return DISTANCES[self._config.distance]

def extend(self, docs: Iterable['Document']):
docs = list(docs)
self._upload_batch(docs)
self._offset2ids.extend([doc.id for doc in docs])

@property
def serialization_config(self) -> dict:
return self._serialize_config
Expand Down
23 changes: 19 additions & 4 deletions docarray/array/storage/qdrant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
List,
)

import numpy as np
from qdrant_client import QdrantClient
from qdrant_openapi_client.models.models import (
Distance,
Expand All @@ -24,11 +25,10 @@
from docarray.array.storage.base.backend import BaseBackendMixin
from docarray.array.storage.qdrant.helper import DISTANCES
from docarray.helper import dataclass_from_dict, random_identity
from docarray.math.helper import EPSILON

if TYPE_CHECKING:
from docarray.types import (
DocumentArraySourceType,
)
from docarray.types import DocumentArraySourceType, ArrayType


@dataclass
Expand Down Expand Up @@ -125,7 +125,7 @@ def _collection_exists(self, collection_name):
return collection_name in collections

@staticmethod
def _qmap(doc_id: str):
def _map_id(doc_id: str):
# if doc_id is a random ID in hex format, just translate back to UUID str
# otherwise, create UUID5 from doc_id
try:
Expand Down Expand Up @@ -181,3 +181,18 @@ def _get_storage_infos(self) -> Dict:
'Distance': self._config.distance,
'Serialization Protocol': self._config.serialize_config.get('protocol'),
}

def _map_embedding(self, embedding: 'ArrayType') -> List[float]:
if embedding is None:
embedding = np.random.rand(self.n_dim)
else:
from ....math.ndarray import to_numpy_array

embedding = to_numpy_array(embedding)

if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()

if np.all(embedding == 0):
embedding = embedding + EPSILON
return embedding.tolist()
3 changes: 1 addition & 2 deletions docarray/array/storage/qdrant/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
List,
)

from .helper import QdrantStorageHelper
from .... import Document, DocumentArray
from ....math import ndarray
from ....score import NamedScore
Expand Down Expand Up @@ -50,7 +49,7 @@ def distance(self) -> 'Distance':
raise NotImplementedError()

def _find_similar_vectors(self, q: 'QdrantArrayType', limit=10):
query_vector = QdrantStorageHelper.embedding_to_array(q, default_dim=0)
query_vector = self._map_embedding(q)

search_result = self.client.search(
self.collection_name,
Expand Down
9 changes: 4 additions & 5 deletions docarray/array/storage/qdrant/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from docarray import Document
from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin
from docarray.array.storage.base.helper import Offset2ID
from docarray.array.storage.qdrant.helper import QdrantStorageHelper


class GetSetDelMixin(BaseGetSetDelMixin):
Expand Down Expand Up @@ -67,15 +66,15 @@ def _qdrant_to_document(self, qdrant_record: dict) -> 'Document':

def _document_to_qdrant(self, doc: 'Document') -> 'PointStruct':
return PointStruct(
id=self._qmap(doc.id),
id=self._map_id(doc.id),
payload=dict(_serialized=doc.to_base64(**self.serialization_config)),
vector=QdrantStorageHelper.embedding_to_array(doc.embedding, self.n_dim),
vector=self._map_embedding(doc.embedding),
)

def _get_doc_by_id(self, _id: str) -> 'Document':
try:
resp = self.client.http.points_api.get_point(
name=self.collection_name, id=self._qmap(_id)
name=self.collection_name, id=self._map_id(_id)
)
return self._qdrant_to_document(resp.result.payload)
except UnexpectedResponse as response_error:
Expand All @@ -86,7 +85,7 @@ def _del_doc_by_id(self, _id: str):
self.client.http.points_api.delete_points(
name=self.collection_name,
wait=True,
points_selector=PointIdsList(points=[self._qmap(_id)]),
points_selector=PointIdsList(points=[self._map_id(_id)]),
)

def _set_doc_by_id(self, _id: str, value: 'Document'):
Expand Down
30 changes: 0 additions & 30 deletions docarray/array/storage/qdrant/helper.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,5 @@
from typing import List, TYPE_CHECKING

import numpy as np
import scipy.sparse
from qdrant_openapi_client.models.models import Distance

from docarray.math.helper import EPSILON

if TYPE_CHECKING:
from docarray.types import ArrayType


class QdrantStorageHelper:
@classmethod
def embedding_to_array(
cls, embedding: 'ArrayType', default_dim: int
) -> List[float]:
if embedding is None:
embedding = np.random.rand(default_dim)
else:
from ....math.ndarray import to_numpy_array

embedding = to_numpy_array(embedding)

if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()

if np.all(embedding == 0):
embedding = embedding + EPSILON
return embedding.tolist()


DISTANCES = {
'cosine': Distance.COSINE,
'euclidean': Distance.EUCLID,
Expand Down
11 changes: 10 additions & 1 deletion docarray/array/storage/qdrant/seqlike.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import MutableSequence, Iterable, Iterator, Union
from typing import Iterable, Union
from docarray import Document

from qdrant_client import QdrantClient
Expand All @@ -23,6 +23,10 @@ def collection_name(self) -> str:
def config(self):
raise NotImplementedError()

@abstractmethod
def _upload_batch(self, docs: Iterable['Document']):
raise NotImplementedError()

def __eq__(self, other):
"""Compare this object to the other, returns True if and only if other
as the same type as self and other has the same meta information
Expand Down Expand Up @@ -67,3 +71,8 @@ def __bool__(self):
:return: returns true if the length of the array is larger than 0
"""
return len(self) > 0

def extend(self, docs: Iterable['Document']):
docs = list(docs)
self._upload_batch(docs)
self._offset2ids.extend([doc.id for doc in docs])
Loading