From d883a45113225effe9ebdbefb6c221bc497bebff Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Sun, 16 Jan 2022 11:54:05 +0100 Subject: [PATCH 01/23] refactor(array): use extend in add improve docs --- docarray/array/persistence/__init__.py | 0 docarray/array/persistence/sqlite/__init__.py | 0 docarray/array/persistence/sqlite/base.py | 271 ++++++++++++++++++ docarray/array/persistence/sqlite/dict.py | 140 +++++++++ 4 files changed, 411 insertions(+) create mode 100644 docarray/array/persistence/__init__.py create mode 100644 docarray/array/persistence/sqlite/__init__.py create mode 100644 docarray/array/persistence/sqlite/base.py create mode 100644 docarray/array/persistence/sqlite/dict.py diff --git a/docarray/array/persistence/__init__.py b/docarray/array/persistence/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/persistence/sqlite/__init__.py b/docarray/array/persistence/sqlite/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/persistence/sqlite/base.py b/docarray/array/persistence/sqlite/base.py new file mode 100644 index 00000000000..9c85b461a32 --- /dev/null +++ b/docarray/array/persistence/sqlite/base.py @@ -0,0 +1,271 @@ +import sqlite3 +import warnings +from abc import ABCMeta, abstractmethod +from collections.abc import Hashable +from enum import Enum +from pickle import dumps, loads +from tempfile import NamedTemporaryFile +from typing import Callable, Generic, Optional, TypeVar, Union, cast +from uuid import uuid4 + +T = TypeVar('T') +KT = TypeVar('KT') +VT = TypeVar('VT') +_T = TypeVar('_T') +_S = TypeVar('_S') + + +class RebuildStrategy(Enum): + CHECK_WITH_FIRST_ELEMENT = 1 + ALWAYS = 2 + SKIP = 3 + + +def sanitize_table_name(table_name: str) -> str: + ret = ''.join(c for c in table_name if c.isalnum() or c == '_') + if ret != table_name: + warnings.warn(f'The table name is changed to {ret} due to illegal characters') + return ret + + +def create_random_name(suffix: str) -> str: + return f"{suffix}_{str(uuid4()).replace('-', '')}" + + +def is_hashable(x: object) -> bool: + return isinstance(x, Hashable) + + +class _SqliteCollectionBaseDatabaseDriver(metaclass=ABCMeta): + @classmethod + def initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: + if not cls.is_metadata_table_initialized(cur): + cls.do_initialize_metadata_table(cur) + + @classmethod + def is_metadata_table_initialized(cls, cur: sqlite3.Cursor) -> bool: + try: + cur.execute('SELECT 1 FROM metadata LIMIT 1') + _ = list(cur) + return True + except sqlite3.OperationalError as _: + pass + return False + + @classmethod + def do_initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: + cur.execute( + ''' + CREATE TABLE metadata ( + table_name TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + container_type TEXT NOT NULL, + UNIQUE (table_name, container_type) + ) + ''' + ) + + @classmethod + def initialize_table( + cls, + table_name: str, + container_type_name: str, + schema_version: str, + cur: sqlite3.Cursor, + ) -> None: + if not cls.is_table_initialized( + table_name, container_type_name, schema_version, cur + ): + cls.do_create_table(table_name, container_type_name, schema_version, cur) + cls.do_tidy_table_metadata( + table_name, container_type_name, schema_version, cur + ) + + @classmethod + def is_table_initialized( + self, + table_name: str, + container_type_name: str, + schema_version: str, + cur: sqlite3.Cursor, + ) -> bool: + try: + cur.execute( + 'SELECT schema_version FROM metadata WHERE table_name=? AND container_type=?', + (table_name, container_type_name), + ) + buf = cur.fetchone() + if buf is None: + return False + version = buf[0] + if version != schema_version: + return False + cur.execute(f'SELECT 1 FROM {table_name} LIMIT 1') + _ = list(cur) + return True + except sqlite3.OperationalError as _: + pass + return False + + @classmethod + def do_tidy_table_metadata( + cls, + table_name: str, + container_type_name: str, + schema_version: str, + cur: sqlite3.Cursor, + ) -> None: + cur.execute( + 'INSERT INTO metadata (table_name, schema_version, container_type) VALUES (?, ?, ?)', + (table_name, schema_version, container_type_name), + ) + + @classmethod + @abstractmethod + def do_create_table( + cls, + table_name: str, + container_type_name: str, + schema_version: str, + cur: sqlite3.Cursor, + ) -> None: + ... + + @classmethod + def drop_table( + cls, table_name: str, container_type_name: str, cur: sqlite3.Cursor + ) -> None: + cur.execute( + 'DELETE FROM metadata WHERE table_name=? AND container_type=?', + (table_name, container_type_name), + ) + cur.execute(f'DROP TABLE {table_name}') + + @classmethod + def alter_table_name( + cls, table_name: str, new_table_name: str, cur: sqlite3.Cursor + ) -> None: + cur.execute( + 'UPDATE metadata SET table_name=? WHERE table_name=?', + (new_table_name, table_name), + ) + cur.execute(f'ALTER TABLE {table_name} RENAME TO {new_table_name}') + + +class SqliteCollectionBase(Generic[T], metaclass=ABCMeta): + _driver_class = _SqliteCollectionBaseDatabaseDriver + + def __init__( + self, + connection: Optional[Union[str, sqlite3.Connection]] = None, + table_name: Optional[str] = None, + serializer: Optional[Callable[[T], bytes]] = None, + deserializer: Optional[Callable[[bytes], T]] = None, + persist: bool = True, + rebuild_strategy: RebuildStrategy = RebuildStrategy.CHECK_WITH_FIRST_ELEMENT, + ): + super(SqliteCollectionBase, self).__init__() + self._serializer = ( + cast(Callable[[T], bytes], dumps) if serializer is None else serializer + ) + self._deserializer = ( + cast(Callable[[bytes], T], loads) if deserializer is None else deserializer + ) + self._persist = persist + if connection is None: + self._connection = sqlite3.connect(NamedTemporaryFile().name) + elif isinstance(connection, str): + self._connection = sqlite3.connect(connection) + elif isinstance(connection, sqlite3.Connection): + self._connection = connection + else: + raise TypeError( + f'connection argument must be None or a string or a sqlite3.Connection, not `{type(connection)}`' + ) + self._table_name = ( + sanitize_table_name(create_random_name(self.container_type_name)) + if table_name is None + else sanitize_table_name(table_name) + ) + self._initialize(rebuild_strategy=rebuild_strategy) + + def __del__(self) -> None: + if not self.persist: + cur = self.connection.cursor() + self._driver_class.drop_table( + self.table_name, self.container_type_name, cur + ) + self.connection.commit() + + def _initialize(self, rebuild_strategy: RebuildStrategy) -> None: + cur = self.connection.cursor() + self._driver_class.initialize_metadata_table(cur) + self._driver_class.initialize_table( + self.table_name, self.container_type_name, self.schema_version, cur + ) + if self._should_rebuild(rebuild_strategy): + self._do_rebuild() + self.connection.commit() + + def _should_rebuild(self, rebuild_strategy: RebuildStrategy) -> bool: + if rebuild_strategy == RebuildStrategy.ALWAYS: + return True + if rebuild_strategy == RebuildStrategy.SKIP: + return False + return self._rebuild_check_with_first_element() + + @abstractmethod + def _rebuild_check_with_first_element(self) -> bool: + ... + + @abstractmethod + def _do_rebuild(self) -> None: + ... + + @property + def persist(self) -> bool: + return self._persist + + def set_persist(self, persist: bool) -> None: + self._persist = persist + + @property + def serializer(self) -> Callable[[T], bytes]: + return self._serializer + + def serialize(self, x: T) -> bytes: + return self.serializer(x) + + @property + def deserializer(self) -> Callable[[bytes], T]: + return self._deserializer + + def deserialize(self, blob: bytes) -> T: + return self.deserializer(blob) + + @property + def table_name(self) -> str: + return self._table_name + + @table_name.setter + def table_name(self, table_name: str) -> None: + cur = self.connection.cursor() + new_table_name = sanitize_table_name(table_name) + try: + self._driver_class.alter_table_name(self.table_name, new_table_name, cur) + except sqlite3.IntegrityError as e: + raise ValueError(table_name) + self._table_name = new_table_name + + @property + def connection(self) -> sqlite3.Connection: + return self._connection + + @property + def container_type_name(self) -> str: + return self.__class__.__name__ + + @property + @abstractmethod + def schema_version(self) -> str: + ... diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/persistence/sqlite/dict.py new file mode 100644 index 00000000000..4c82bfb709f --- /dev/null +++ b/docarray/array/persistence/sqlite/dict.py @@ -0,0 +1,140 @@ +from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING + +from .base import _SqliteCollectionBaseDatabaseDriver + +if TYPE_CHECKING: + import sqlite3 + + +class _DictDatabaseDriver(_SqliteCollectionBaseDatabaseDriver): + @classmethod + def do_create_table( + cls, + table_name: str, + container_type_nam: str, + schema_version: str, + cur: 'sqlite3.Cursor', + ) -> None: + cur.execute( + f'CREATE TABLE {table_name} (' + 'serialized_key BLOB NOT NULL UNIQUE, ' + 'serialized_value BLOB NOT NULL, ' + 'item_order INTEGER PRIMARY KEY)' + ) + + @classmethod + def delete_single_record_by_serialized_key( + cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + ) -> None: + cur.execute( + f'DELETE FROM {table_name} WHERE serialized_key=?', (serialized_key,) + ) + + @classmethod + def delete_all_records(cls, table_name: str, cur: 'sqlite3.Cursor') -> None: + cur.execute(f'DELETE FROM {table_name}') + + @classmethod + def is_serialized_key_in( + cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + ) -> bool: + cur.execute( + f'SELECT 1 FROM {table_name} WHERE serialized_key=?', (serialized_key,) + ) + return len(list(cur)) > 0 + + @classmethod + def get_serialized_value_by_serialized_key( + cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + ) -> Union[None, bytes]: + cur.execute( + f'SELECT serialized_value FROM {table_name} WHERE serialized_key=?', + (serialized_key,), + ) + res = cur.fetchone() + if res is None: + return None + return cast(bytes, res[0]) + + @classmethod + def get_next_order(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: + cur.execute(f'SELECT MAX(item_order) FROM {table_name}') + res = cur.fetchone()[0] + if res is None: + return 0 + return cast(int, res) + 1 + + @classmethod + def get_count(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: + cur.execute(f'SELECT COUNT(*) FROM {table_name}') + res = cur.fetchone() + return cast(int, res[0]) + + @classmethod + def get_serialized_keys( + cls, table_name: str, cur: 'sqlite3.Cursor' + ) -> Iterable[bytes]: + cur.execute(f'SELECT serialized_key FROM {table_name} ORDER BY item_order') + for res in cur: + yield cast(bytes, res[0]) + + @classmethod + def insert_serialized_value_by_serialized_key( + cls, + table_name: str, + cur: 'sqlite3.Cursor', + serialized_key: bytes, + serialized_value: bytes, + ) -> None: + item_order = cls.get_next_order(table_name, cur) + cur.execute( + f'INSERT INTO {table_name} (serialized_key, serialized_value, item_order) VALUES (?, ?, ?)', + (serialized_key, serialized_value, item_order), + ) + + @classmethod + def update_serialized_value_by_serialized_key( + cls, + table_name: str, + cur: 'sqlite3.Cursor', + serialized_key: bytes, + serialized_value: bytes, + ) -> None: + cur.execute( + f'UPDATE {table_name} SET serialized_value=? WHERE serialized_key=?', + (serialized_value, serialized_key), + ) + + @classmethod + def upsert( + cls, + table_name: str, + cur: 'sqlite3.Cursor', + serialized_key: bytes, + serialized_value: bytes, + ) -> None: + if cls.is_serialized_key_in(table_name, cur, serialized_key): + cls.update_serialized_value_by_serialized_key( + table_name, cur, serialized_key, serialized_value + ) + else: + cls.insert_serialized_value_by_serialized_key( + table_name, cur, serialized_key, serialized_value + ) + + @classmethod + def get_last_serialized_item( + cls, table_name: str, cur: 'sqlite3.Cursor' + ) -> Tuple[bytes, bytes]: + cur.execute( + f'SELECT serialized_key, serialized_value FROM {table_name} ORDER BY item_order DESC LIMIT 1' + ) + return cast(Tuple[bytes, bytes], cur.fetchone()) + + @classmethod + def get_reversed_serialized_keys( + cls, table_name: str, cur: 'sqlite3.Cursor' + ) -> Iterable[bytes]: + cur.execute(f'SELECT serialized_key FROM {table_name} ORDER BY item_order DESC') + for res in cur: + yield cast(bytes, res[0]) From d3706302cca7055c20ac69df59e74703f99bd4ee Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Mon, 17 Jan 2022 15:29:36 +0100 Subject: [PATCH 02/23] chore: bump to 0.2 --- docarray/array/persistence/sqlite/dict.py | 72 +++--- docarray/array/persistence/sqlite/mixin.py | 251 +++++++++++++++++++++ 2 files changed, 287 insertions(+), 36 deletions(-) create mode 100644 docarray/array/persistence/sqlite/mixin.py diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/persistence/sqlite/dict.py index 4c82bfb709f..a0ea8737708 100644 --- a/docarray/array/persistence/sqlite/dict.py +++ b/docarray/array/persistence/sqlite/dict.py @@ -17,17 +17,17 @@ def do_create_table( ) -> None: cur.execute( f'CREATE TABLE {table_name} (' - 'serialized_key BLOB NOT NULL UNIQUE, ' + 'doc_id TEXT NOT NULL UNIQUE, ' 'serialized_value BLOB NOT NULL, ' 'item_order INTEGER PRIMARY KEY)' ) @classmethod - def delete_single_record_by_serialized_key( - cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + def delete_single_record_by_doc_id( + cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> None: cur.execute( - f'DELETE FROM {table_name} WHERE serialized_key=?', (serialized_key,) + f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,) ) @classmethod @@ -35,21 +35,21 @@ def delete_all_records(cls, table_name: str, cur: 'sqlite3.Cursor') -> None: cur.execute(f'DELETE FROM {table_name}') @classmethod - def is_serialized_key_in( - cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + def is_doc_id_in( + cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> bool: cur.execute( - f'SELECT 1 FROM {table_name} WHERE serialized_key=?', (serialized_key,) + f'SELECT 1 FROM {table_name} WHERE doc_id=?', (doc_id,) ) return len(list(cur)) > 0 @classmethod - def get_serialized_value_by_serialized_key( - cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + def get_serialized_value_by_doc_id( + cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> Union[None, bytes]: cur.execute( - f'SELECT serialized_value FROM {table_name} WHERE serialized_key=?', - (serialized_key,), + f'SELECT serialized_value FROM {table_name} WHERE doc_id=?', + (doc_id,), ) res = cur.fetchone() if res is None: @@ -71,38 +71,38 @@ def get_count(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: return cast(int, res[0]) @classmethod - def get_serialized_keys( + def get_doc_ids( cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Iterable[bytes]: - cur.execute(f'SELECT serialized_key FROM {table_name} ORDER BY item_order') + ) -> Iterable[str]: + cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order') for res in cur: - yield cast(bytes, res[0]) + yield cast(str, res[0]) @classmethod - def insert_serialized_value_by_serialized_key( + def insert_serialized_value_by_doc_id( cls, table_name: str, cur: 'sqlite3.Cursor', - serialized_key: bytes, + doc_id: str, serialized_value: bytes, ) -> None: item_order = cls.get_next_order(table_name, cur) cur.execute( - f'INSERT INTO {table_name} (serialized_key, serialized_value, item_order) VALUES (?, ?, ?)', - (serialized_key, serialized_value, item_order), + f'INSERT INTO {table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', + (doc_id, serialized_value, item_order), ) @classmethod - def update_serialized_value_by_serialized_key( + def update_serialized_value_by_doc_id( cls, table_name: str, cur: 'sqlite3.Cursor', - serialized_key: bytes, + doc_id: str, serialized_value: bytes, ) -> None: cur.execute( - f'UPDATE {table_name} SET serialized_value=? WHERE serialized_key=?', - (serialized_value, serialized_key), + f'UPDATE {table_name} SET serialized_value=? WHERE doc_id=?', + (serialized_value, doc_id), ) @classmethod @@ -110,31 +110,31 @@ def upsert( cls, table_name: str, cur: 'sqlite3.Cursor', - serialized_key: bytes, + doc_id: str, serialized_value: bytes, ) -> None: - if cls.is_serialized_key_in(table_name, cur, serialized_key): - cls.update_serialized_value_by_serialized_key( - table_name, cur, serialized_key, serialized_value + if cls.is_doc_id_in(table_name, cur, doc_id): + cls.update_serialized_value_by_doc_id( + table_name, cur, doc_id, serialized_value ) else: - cls.insert_serialized_value_by_serialized_key( - table_name, cur, serialized_key, serialized_value + cls.insert_serialized_value_by_doc_id( + table_name, cur, doc_id, serialized_value ) @classmethod def get_last_serialized_item( cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Tuple[bytes, bytes]: + ) -> Tuple[str, bytes]: cur.execute( - f'SELECT serialized_key, serialized_value FROM {table_name} ORDER BY item_order DESC LIMIT 1' + f'SELECT doc_id, serialized_value FROM {table_name} ORDER BY item_order DESC LIMIT 1' ) - return cast(Tuple[bytes, bytes], cur.fetchone()) + return cast(Tuple[str, bytes], cur.fetchone()) @classmethod - def get_reversed_serialized_keys( + def get_reversed_doc_ids( cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Iterable[bytes]: - cur.execute(f'SELECT serialized_key FROM {table_name} ORDER BY item_order DESC') + ) -> Iterable[str]: + cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order DESC') for res in cur: - yield cast(bytes, res[0]) + yield cast(str, res[0]) diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/persistence/sqlite/mixin.py new file mode 100644 index 00000000000..b58d7889d8e --- /dev/null +++ b/docarray/array/persistence/sqlite/mixin.py @@ -0,0 +1,251 @@ +from dataclasses import dataclass +import dataclasses +from typing import Optional, TYPE_CHECKING, Callable, Union, cast + +from .base import SqliteCollectionBase, RebuildStrategy +from .dict import _DictDatabaseDriver + +if TYPE_CHECKING: + import sqlite3 + + from ....types import ( + T, + Document, + DocumentArraySourceType, + DocumentArrayIndexType, + DocumentArraySingletonIndexType, + DocumentArrayMultipleIndexType, + DocumentArrayMultipleAttributeType, + DocumentArraySingleAttributeType, + ) + +@dataclass +class SqliteConfig: + connection: Optional[Union[str, 'sqlite3.Connection']] = None + table_name: Optional[str] = None + serializer: Optional[Callable[['Document'], bytes]] = None + deserializer: Optional[Callable[[bytes], 'Document']] = None + persist: bool = True + rebuild_strategy: RebuildStrategy = RebuildStrategy.CHECK_WITH_FIRST_ELEMENT + + +class SqliteMixin(SqliteCollectionBase): + """Enable SQLite persistence backend for DocumentArray. + + .. note:: + This has to be put in the first position when use it for subclassing + i.e. `class SqliteDA(SqliteMixin, DA)` not the other way around. + + """ + _driver_class = _DictDatabaseDriver + + def __init__(self, docs: Optional['DocumentArraySourceType'] = None, config: Optional[SqliteConfig] = None): + super().__init__(**(dataclasses.asdict(config) if config else {})) + if docs is not None: + self.clear() + self.update(docs) + + def clear(self) -> None: + cur = self.connection.cursor() + self._driver_class.delete_all_records(self.table_name, cur) + self.connection.commit() + + @property + def schema_version(self) -> str: + return "0" + + def _rebuild_check_with_first_element(self) -> bool: + cur = self.connection.cursor() + cur.execute(f"SELECT doc_id FROM {self.table_name} ORDER BY item_order LIMIT 1") + res = cur.fetchone() + return res is None + + def _do_rebuild(self) -> None: + cur = self.connection.cursor() + last_order = -1 + while last_order is not None: + cur.execute( + f"SELECT item_order FROM {self.table_name} WHERE item_order > ? ORDER BY item_order LIMIT 1", + (last_order,), + ) + res = cur.fetchone() + if res is None: + break + i = res[0] + cur.execute( + f"SELECT doc_id, serialized_value FROM {self.table_name} WHERE item_order=?", + (i,), + ) + doc_id, serialized_value = cur.fetchone() + cur.execute( + f"UPDATE {self.table_name} SET doc_id=?, serialized_value=? WHERE item_order=?", + ( + doc_id, + serialized_value, + i, + ), + ) + last_order = i + + def serialize_value(self, value: VT) -> bytes: + return self.value_serializer(value) + + def deserialize_value(self, value: bytes) -> VT: + return self.value_deserializer(value) + + def __delitem__(self, key: KT) -> None: + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + if not self._driver_class.is_doc_id_in(self.table_name, cur, doc_id): + raise KeyError(key) + self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, doc_id) + self.connection.commit() + + def __getitem__(self, key: KT) -> VT: + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + serialized_value = self._driver_class.get_serialized_value_by_doc_id( + self.table_name, cur, doc_id + ) + if serialized_value is None: + raise KeyError(key) + return self.deserialize_value(serialized_value) + + def __iter__(self) -> Iterator[KT]: + cur = self.connection.cursor() + for doc_id in self._driver_class.get_doc_ids(self.table_name, cur): + yield self.deserialize_key(doc_id) + + def __len__(self) -> int: + cur = self.connection.cursor() + return self._driver_class.get_count(self.table_name, cur) + + def __setitem__(self, key: KT, value: VT) -> None: + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + serialized_value = self.serialize_value(value) + self._driver_class.upsert(self.table_name, cur, doc_id, serialized_value) + self.connection.commit() + + def _create_volatile_copy( + self, + data: Optional[Mapping[KT, VT]] = None, + ) -> "Dict[KT, VT]": + + return Dict[KT, VT]( + connection=self.connection, + key_serializer=self.key_serializer, + key_deserializer=self.key_deserializer, + value_serializer=self.value_serializer, + value_deserializer=self.value_deserializer, + rebuild_strategy=RebuildStrategy.SKIP, + persist=False, + data=(self if data is None else data), + ) + + def copy(self) -> "Dict[KT, VT]": + return self._create_volatile_copy() + + @classmethod + def fromkeys(cls, iterable: Iterable[KT], value: Optional[VT]) -> "Dict[KT, VT]": + raise NotImplementedError + + @overload + def pop(self, k: KT) -> VT: + ... + + @overload + def pop(self, k: KT, default: Union[VT, T] = ...) -> Union[VT, T]: + ... + + def pop(self, k: KT, default: Optional[Union[VT, object]] = None) -> Union[VT, object]: + cur = self.connection.cursor() + doc_id = self.serialize_key(k) + serialized_value = self._driver_class.get_serialized_value_by_doc_id( + self.table_name, cur, doc_id + ) + if serialized_value is None: + if default is None: + raise KeyError(k) + return default + self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, doc_id) + self.connection.commit() + return self.deserialize_value(serialized_value) + + def popitem(self) -> Tuple[KT, VT]: + cur = self.connection.cursor() + serialized_item = self._driver_class.get_last_serialized_item(self.table_name, cur) + if serialized_item is None: + raise KeyError("popitem(): dictionary is empty") + self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, serialized_item[0]) + self.connection.commit() + return ( + self.deserialize_key(serialized_item[0]), + self.deserialize_value(serialized_item[1]), + ) + + @overload + def update(self, __other: Mapping[KT, VT], **kwargs: VT) -> None: + ... + + @overload + def update(self, __other: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: + ... + + @overload + def update(self, **kwargs: VT) -> None: + ... + + def update(self, __other: Optional[Union[Iterable[Tuple[KT, VT]], Mapping[KT, VT]]] = None, **kwargs: VT) -> None: + cur = self.connection.cursor() + for k, v in chain( + tuple() if __other is None else __other.items() if isinstance(__other, Mapping) else __other, + cast(Mapping[KT, VT], kwargs).items(), + ): + self._driver_class.upsert(self.table_name, cur, self.serialize_key(k), self.serialize_value(v)) + self.connection.commit() + + def clear(self) -> None: + cur = self.connection.cursor() + self._driver_class.delete_all_records(self.table_name, cur) + self.connection.commit() + + def __contains__(self, o: object) -> bool: + return self._driver_class.is_doc_id_in( + self.table_name, self.connection.cursor(), self.serialize_key(cast(KT, o)) + ) + + @overload + def get(self, key: KT) -> Union[VT, None]: + ... + + @overload + def get(self, key: KT, default_value: Union[VT, T]) -> Union[VT, T]: + ... + + def get(self, key: KT, default_value: Optional[Union[VT, object]] = None) -> Union[VT, None, object]: + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + serialized_value = self._driver_class.get_serialized_value_by_doc_id( + self.table_name, cur, doc_id + ) + if serialized_value is None: + return default_value + return self.deserialize_value(serialized_value) + + def setdefault(self, key: KT, default: VT = None) -> VT: # type: ignore + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + serialized_value = self._driver_class.get_serialized_value_by_doc_id( + self.table_name, cur, doc_id + ) + if serialized_value is None: + self._driver_class.insert_serialized_value_by_doc_id( + self.table_name, cur, doc_id, self.serialize_value(default) + ) + return default + return self.deserialize_value(serialized_value) + + @property + def schema_version(self) -> str: + return '0' \ No newline at end of file From 109b1d29b91f036281d247f39d6cb38409ae0d19 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 18 Jan 2022 07:54:53 +0100 Subject: [PATCH 03/23] chore: fix typo --- docarray/array/persistence/sqlite/base.py | 68 ++---- docarray/array/persistence/sqlite/dict.py | 42 ++-- docarray/array/persistence/sqlite/mixin.py | 241 ++++----------------- 3 files changed, 87 insertions(+), 264 deletions(-) diff --git a/docarray/array/persistence/sqlite/base.py b/docarray/array/persistence/sqlite/base.py index 9c85b461a32..05c1d52b5ad 100644 --- a/docarray/array/persistence/sqlite/base.py +++ b/docarray/array/persistence/sqlite/base.py @@ -1,11 +1,8 @@ import sqlite3 import warnings from abc import ABCMeta, abstractmethod -from collections.abc import Hashable -from enum import Enum -from pickle import dumps, loads from tempfile import NamedTemporaryFile -from typing import Callable, Generic, Optional, TypeVar, Union, cast +from typing import Callable, Generic, Optional, TypeVar, Union from uuid import uuid4 T = TypeVar('T') @@ -15,12 +12,6 @@ _S = TypeVar('_S') -class RebuildStrategy(Enum): - CHECK_WITH_FIRST_ELEMENT = 1 - ALWAYS = 2 - SKIP = 3 - - def sanitize_table_name(table_name: str) -> str: ret = ''.join(c for c in table_name if c.isalnum() or c == '_') if ret != table_name: @@ -32,10 +23,6 @@ def create_random_name(suffix: str) -> str: return f"{suffix}_{str(uuid4()).replace('-', '')}" -def is_hashable(x: object) -> bool: - return isinstance(x, Hashable) - - class _SqliteCollectionBaseDatabaseDriver(metaclass=ABCMeta): @classmethod def initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: @@ -162,15 +149,12 @@ def __init__( serializer: Optional[Callable[[T], bytes]] = None, deserializer: Optional[Callable[[bytes], T]] = None, persist: bool = True, - rebuild_strategy: RebuildStrategy = RebuildStrategy.CHECK_WITH_FIRST_ELEMENT, ): super(SqliteCollectionBase, self).__init__() - self._serializer = ( - cast(Callable[[T], bytes], dumps) if serializer is None else serializer - ) - self._deserializer = ( - cast(Callable[[bytes], T], loads) if deserializer is None else deserializer - ) + from ....document import Document + + self._serializer = serializer or (lambda d: d.to_bytes()) + self._deserializer = deserializer or (lambda x: Document.from_bytes(x)) self._persist = persist if connection is None: self._connection = sqlite3.connect(NamedTemporaryFile().name) @@ -187,7 +171,12 @@ def __init__( if table_name is None else sanitize_table_name(table_name) ) - self._initialize(rebuild_strategy=rebuild_strategy) + cur = self.connection.cursor() + self._driver_class.initialize_metadata_table(cur) + self._driver_class.initialize_table( + self.table_name, self.container_type_name, self.schema_version, cur + ) + self.connection.commit() def __del__(self) -> None: if not self.persist: @@ -197,31 +186,6 @@ def __del__(self) -> None: ) self.connection.commit() - def _initialize(self, rebuild_strategy: RebuildStrategy) -> None: - cur = self.connection.cursor() - self._driver_class.initialize_metadata_table(cur) - self._driver_class.initialize_table( - self.table_name, self.container_type_name, self.schema_version, cur - ) - if self._should_rebuild(rebuild_strategy): - self._do_rebuild() - self.connection.commit() - - def _should_rebuild(self, rebuild_strategy: RebuildStrategy) -> bool: - if rebuild_strategy == RebuildStrategy.ALWAYS: - return True - if rebuild_strategy == RebuildStrategy.SKIP: - return False - return self._rebuild_check_with_first_element() - - @abstractmethod - def _rebuild_check_with_first_element(self) -> bool: - ... - - @abstractmethod - def _do_rebuild(self) -> None: - ... - @property def persist(self) -> bool: return self._persist @@ -253,14 +217,18 @@ def table_name(self, table_name: str) -> None: new_table_name = sanitize_table_name(table_name) try: self._driver_class.alter_table_name(self.table_name, new_table_name, cur) - except sqlite3.IntegrityError as e: - raise ValueError(table_name) - self._table_name = new_table_name + self._table_name = new_table_name + except sqlite3.IntegrityError as ex: + raise ValueError(table_name) from ex @property def connection(self) -> sqlite3.Connection: return self._connection + @property + def cursor(self) -> sqlite3.Cursor: + return self.connection.cursor() + @property def container_type_name(self) -> str: return self.__class__.__name__ diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/persistence/sqlite/dict.py index a0ea8737708..365df0d00d0 100644 --- a/docarray/array/persistence/sqlite/dict.py +++ b/docarray/array/persistence/sqlite/dict.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING +from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING, Optional from .base import _SqliteCollectionBaseDatabaseDriver @@ -22,25 +22,39 @@ def do_create_table( 'item_order INTEGER PRIMARY KEY)' ) + @classmethod + def get_max_index_plus_one(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: + cur.execute(f'SELECT MAX(item_order) FROM {table_name}') + res = cur.fetchone() + if res[0] is None: + return 0 + return cast(int, res[0]) + 1 + + @classmethod + def increment_indices( + cls, table_name: str, cur: 'sqlite3.Cursor', start: int + ) -> None: + idx = cls.get_max_index_plus_one(table_name, cur) - 1 + while idx >= start: + cur.execute( + f"UPDATE {table_name} SET item_index = ? WHERE item_index = ?", + (idx + 1, idx), + ) + idx -= 1 + @classmethod def delete_single_record_by_doc_id( cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> None: - cur.execute( - f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,) - ) + cur.execute(f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,)) @classmethod def delete_all_records(cls, table_name: str, cur: 'sqlite3.Cursor') -> None: cur.execute(f'DELETE FROM {table_name}') @classmethod - def is_doc_id_in( - cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str - ) -> bool: - cur.execute( - f'SELECT 1 FROM {table_name} WHERE doc_id=?', (doc_id,) - ) + def is_doc_id_in(cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str) -> bool: + cur.execute(f'SELECT 1 FROM {table_name} WHERE doc_id=?', (doc_id,)) return len(list(cur)) > 0 @classmethod @@ -71,9 +85,7 @@ def get_count(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: return cast(int, res[0]) @classmethod - def get_doc_ids( - cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Iterable[str]: + def get_doc_ids(cls, table_name: str, cur: 'sqlite3.Cursor') -> Iterable[str]: cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order') for res in cur: yield cast(str, res[0]) @@ -85,8 +97,10 @@ def insert_serialized_value_by_doc_id( cur: 'sqlite3.Cursor', doc_id: str, serialized_value: bytes, + item_order: Optional[int], ) -> None: - item_order = cls.get_next_order(table_name, cur) + if item_order is None: + item_order = cls.get_next_order(table_name, cur) cur.execute( f'INSERT INTO {table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', (doc_id, serialized_value, item_order), diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/persistence/sqlite/mixin.py index b58d7889d8e..e0f98ab0b5e 100644 --- a/docarray/array/persistence/sqlite/mixin.py +++ b/docarray/array/persistence/sqlite/mixin.py @@ -1,16 +1,16 @@ from dataclasses import dataclass import dataclasses -from typing import Optional, TYPE_CHECKING, Callable, Union, cast +from typing import Optional, TYPE_CHECKING, Callable, Union, cast, Iterable -from .base import SqliteCollectionBase, RebuildStrategy +from .base import SqliteCollectionBase from .dict import _DictDatabaseDriver if TYPE_CHECKING: import sqlite3 from ....types import ( - T, - Document, + T, + Document, DocumentArraySourceType, DocumentArrayIndexType, DocumentArraySingletonIndexType, @@ -19,6 +19,7 @@ DocumentArraySingleAttributeType, ) + @dataclass class SqliteConfig: connection: Optional[Union[str, 'sqlite3.Connection']] = None @@ -26,7 +27,6 @@ class SqliteConfig: serializer: Optional[Callable[['Document'], bytes]] = None deserializer: Optional[Callable[[bytes], 'Document']] = None persist: bool = True - rebuild_strategy: RebuildStrategy = RebuildStrategy.CHECK_WITH_FIRST_ELEMENT class SqliteMixin(SqliteCollectionBase): @@ -37,215 +37,56 @@ class SqliteMixin(SqliteCollectionBase): i.e. `class SqliteDA(SqliteMixin, DA)` not the other way around. """ + _driver_class = _DictDatabaseDriver - def __init__(self, docs: Optional['DocumentArraySourceType'] = None, config: Optional[SqliteConfig] = None): + def __init__( + self, + docs: Optional['DocumentArraySourceType'] = None, + config: Optional[SqliteConfig] = None, + ): super().__init__(**(dataclasses.asdict(config) if config else {})) if docs is not None: self.clear() - self.update(docs) - - def clear(self) -> None: - cur = self.connection.cursor() - self._driver_class.delete_all_records(self.table_name, cur) - self.connection.commit() - - @property - def schema_version(self) -> str: - return "0" - - def _rebuild_check_with_first_element(self) -> bool: - cur = self.connection.cursor() - cur.execute(f"SELECT doc_id FROM {self.table_name} ORDER BY item_order LIMIT 1") - res = cur.fetchone() - return res is None - - def _do_rebuild(self) -> None: - cur = self.connection.cursor() - last_order = -1 - while last_order is not None: - cur.execute( - f"SELECT item_order FROM {self.table_name} WHERE item_order > ? ORDER BY item_order LIMIT 1", - (last_order,), - ) - res = cur.fetchone() - if res is None: - break - i = res[0] - cur.execute( - f"SELECT doc_id, serialized_value FROM {self.table_name} WHERE item_order=?", - (i,), - ) - doc_id, serialized_value = cur.fetchone() - cur.execute( - f"UPDATE {self.table_name} SET doc_id=?, serialized_value=? WHERE item_order=?", - ( - doc_id, - serialized_value, - i, - ), - ) - last_order = i - - def serialize_value(self, value: VT) -> bytes: - return self.value_serializer(value) - - def deserialize_value(self, value: bytes) -> VT: - return self.value_deserializer(value) - - def __delitem__(self, key: KT) -> None: - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - if not self._driver_class.is_doc_id_in(self.table_name, cur, doc_id): - raise KeyError(key) - self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, doc_id) - self.connection.commit() - - def __getitem__(self, key: KT) -> VT: - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - serialized_value = self._driver_class.get_serialized_value_by_doc_id( - self.table_name, cur, doc_id + self.extend(docs) + + def insert(self, index: int, value: 'Document'): + """Insert `doc` at `index`. + + :param index: Position of the insertion. + :param value: The doc needs to be inserted. + """ + length = self._driver_class.get_max_index_plus_one(self.table_name, self.cursor) + if index < 0: + index = length + index + index = max(0, min(length, index)) + self._driver_class.increment_indices(self.table_name, self.cursor, index) + self._driver_class.insert_serialized_value_by_doc_id( + self.table_name, self.cursor, value.id, self.serialize(value), index ) - if serialized_value is None: - raise KeyError(key) - return self.deserialize_value(serialized_value) - - def __iter__(self) -> Iterator[KT]: - cur = self.connection.cursor() - for doc_id in self._driver_class.get_doc_ids(self.table_name, cur): - yield self.deserialize_key(doc_id) - - def __len__(self) -> int: - cur = self.connection.cursor() - return self._driver_class.get_count(self.table_name, cur) - - def __setitem__(self, key: KT, value: VT) -> None: - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - serialized_value = self.serialize_value(value) - self._driver_class.upsert(self.table_name, cur, doc_id, serialized_value) self.connection.commit() - def _create_volatile_copy( - self, - data: Optional[Mapping[KT, VT]] = None, - ) -> "Dict[KT, VT]": - - return Dict[KT, VT]( - connection=self.connection, - key_serializer=self.key_serializer, - key_deserializer=self.key_deserializer, - value_serializer=self.value_serializer, - value_deserializer=self.value_deserializer, - rebuild_strategy=RebuildStrategy.SKIP, - persist=False, - data=(self if data is None else data), - ) - - def copy(self) -> "Dict[KT, VT]": - return self._create_volatile_copy() - - @classmethod - def fromkeys(cls, iterable: Iterable[KT], value: Optional[VT]) -> "Dict[KT, VT]": - raise NotImplementedError - - @overload - def pop(self, k: KT) -> VT: - ... - - @overload - def pop(self, k: KT, default: Union[VT, T] = ...) -> Union[VT, T]: - ... - - def pop(self, k: KT, default: Optional[Union[VT, object]] = None) -> Union[VT, object]: - cur = self.connection.cursor() - doc_id = self.serialize_key(k) - serialized_value = self._driver_class.get_serialized_value_by_doc_id( - self.table_name, cur, doc_id - ) - if serialized_value is None: - if default is None: - raise KeyError(k) - return default - self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, doc_id) - self.connection.commit() - return self.deserialize_value(serialized_value) - - def popitem(self) -> Tuple[KT, VT]: - cur = self.connection.cursor() - serialized_item = self._driver_class.get_last_serialized_item(self.table_name, cur) - if serialized_item is None: - raise KeyError("popitem(): dictionary is empty") - self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, serialized_item[0]) - self.connection.commit() - return ( - self.deserialize_key(serialized_item[0]), - self.deserialize_value(serialized_item[1]), - ) + def extend(self, values: Iterable['Document']) -> None: - @overload - def update(self, __other: Mapping[KT, VT], **kwargs: VT) -> None: - ... - - @overload - def update(self, __other: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: - ... - - @overload - def update(self, **kwargs: VT) -> None: - ... - - def update(self, __other: Optional[Union[Iterable[Tuple[KT, VT]], Mapping[KT, VT]]] = None, **kwargs: VT) -> None: - cur = self.connection.cursor() - for k, v in chain( - tuple() if __other is None else __other.items() if isinstance(__other, Mapping) else __other, - cast(Mapping[KT, VT], kwargs).items(), - ): - self._driver_class.upsert(self.table_name, cur, self.serialize_key(k), self.serialize_value(v)) + idx = self._driver_class.get_max_index_plus_one(self.table_name, self.cursor) + for v in values: + self._driver_class.insert_serialized_value_by_doc_id( + self.table_name, + cur=self.cursor, + doc_id=v.id, + serialized_value=self.serialize(v), + item_order=idx, + ) + idx += 1 self.connection.commit() def clear(self) -> None: - cur = self.connection.cursor() - self._driver_class.delete_all_records(self.table_name, cur) + self._driver_class.delete_all_records(self.table_name, self.cursor) self.connection.commit() - def __contains__(self, o: object) -> bool: - return self._driver_class.is_doc_id_in( - self.table_name, self.connection.cursor(), self.serialize_key(cast(KT, o)) - ) - - @overload - def get(self, key: KT) -> Union[VT, None]: - ... - - @overload - def get(self, key: KT, default_value: Union[VT, T]) -> Union[VT, T]: - ... - - def get(self, key: KT, default_value: Optional[Union[VT, object]] = None) -> Union[VT, None, object]: - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - serialized_value = self._driver_class.get_serialized_value_by_doc_id( - self.table_name, cur, doc_id - ) - if serialized_value is None: - return default_value - return self.deserialize_value(serialized_value) - - def setdefault(self, key: KT, default: VT = None) -> VT: # type: ignore - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - serialized_value = self._driver_class.get_serialized_value_by_doc_id( - self.table_name, cur, doc_id - ) - if serialized_value is None: - self._driver_class.insert_serialized_value_by_doc_id( - self.table_name, cur, doc_id, self.serialize_value(default) - ) - return default - return self.deserialize_value(serialized_value) + def __len__(self) -> int: + return self._driver_class.get_count(self.table_name, self.cursor) @property def schema_version(self) -> str: - return '0' \ No newline at end of file + return '0' From 2b34f0f437441a20d784753d40a38d19470a725b Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 18 Jan 2022 14:29:00 +0100 Subject: [PATCH 04/23] chore: fix typo --- docarray/array/persistence/sqlite/base.py | 52 +++---- docarray/array/persistence/sqlite/dict.py | 70 +--------- docarray/array/persistence/sqlite/mixin.py | 154 ++++++++++++++++++--- 3 files changed, 157 insertions(+), 119 deletions(-) diff --git a/docarray/array/persistence/sqlite/base.py b/docarray/array/persistence/sqlite/base.py index 05c1d52b5ad..80471318de0 100644 --- a/docarray/array/persistence/sqlite/base.py +++ b/docarray/array/persistence/sqlite/base.py @@ -2,7 +2,7 @@ import warnings from abc import ABCMeta, abstractmethod from tempfile import NamedTemporaryFile -from typing import Callable, Generic, Optional, TypeVar, Union +from typing import Callable, Generic, Optional, TypeVar, Union, Dict from uuid import uuid4 T = TypeVar('T') @@ -146,26 +146,35 @@ def __init__( self, connection: Optional[Union[str, sqlite3.Connection]] = None, table_name: Optional[str] = None, - serializer: Optional[Callable[[T], bytes]] = None, - deserializer: Optional[Callable[[bytes], T]] = None, - persist: bool = True, + serialize_config: Optional[Dict] = None, ): super(SqliteCollectionBase, self).__init__() - from ....document import Document + self._serialize_config = serialize_config or {} + self._persist = not table_name - self._serializer = serializer or (lambda d: d.to_bytes()) - self._deserializer = deserializer or (lambda x: Document.from_bytes(x)) - self._persist = persist + from docarray import Document + + sqlite3.register_adapter( + Document, lambda d: d.to_bytes(**self._serialize_config) + ) + sqlite3.register_converter( + 'Document', lambda x: Document.from_bytes(x, **self._serialize_config) + ) + + _conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES) if connection is None: - self._connection = sqlite3.connect(NamedTemporaryFile().name) + self._connection = sqlite3.connect( + NamedTemporaryFile().name, **_conn_kwargs + ) elif isinstance(connection, str): - self._connection = sqlite3.connect(connection) + self._connection = sqlite3.connect(connection, **_conn_kwargs) elif isinstance(connection, sqlite3.Connection): self._connection = connection else: raise TypeError( f'connection argument must be None or a string or a sqlite3.Connection, not `{type(connection)}`' ) + self._table_name = ( sanitize_table_name(create_random_name(self.container_type_name)) if table_name is None @@ -179,34 +188,13 @@ def __init__( self.connection.commit() def __del__(self) -> None: - if not self.persist: + if not self._persist: cur = self.connection.cursor() self._driver_class.drop_table( self.table_name, self.container_type_name, cur ) self.connection.commit() - @property - def persist(self) -> bool: - return self._persist - - def set_persist(self, persist: bool) -> None: - self._persist = persist - - @property - def serializer(self) -> Callable[[T], bytes]: - return self._serializer - - def serialize(self, x: T) -> bytes: - return self.serializer(x) - - @property - def deserializer(self) -> Callable[[bytes], T]: - return self._deserializer - - def deserialize(self, blob: bytes) -> T: - return self.deserializer(blob) - @property def table_name(self) -> str: return self._table_name diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/persistence/sqlite/dict.py index 365df0d00d0..b63fd778b60 100644 --- a/docarray/array/persistence/sqlite/dict.py +++ b/docarray/array/persistence/sqlite/dict.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING, Optional +from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING from .base import _SqliteCollectionBaseDatabaseDriver @@ -18,58 +18,16 @@ def do_create_table( cur.execute( f'CREATE TABLE {table_name} (' 'doc_id TEXT NOT NULL UNIQUE, ' - 'serialized_value BLOB NOT NULL, ' + 'serialized_value Document NOT NULL, ' 'item_order INTEGER PRIMARY KEY)' ) - @classmethod - def get_max_index_plus_one(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: - cur.execute(f'SELECT MAX(item_order) FROM {table_name}') - res = cur.fetchone() - if res[0] is None: - return 0 - return cast(int, res[0]) + 1 - - @classmethod - def increment_indices( - cls, table_name: str, cur: 'sqlite3.Cursor', start: int - ) -> None: - idx = cls.get_max_index_plus_one(table_name, cur) - 1 - while idx >= start: - cur.execute( - f"UPDATE {table_name} SET item_index = ? WHERE item_index = ?", - (idx + 1, idx), - ) - idx -= 1 - @classmethod def delete_single_record_by_doc_id( cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> None: cur.execute(f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,)) - @classmethod - def delete_all_records(cls, table_name: str, cur: 'sqlite3.Cursor') -> None: - cur.execute(f'DELETE FROM {table_name}') - - @classmethod - def is_doc_id_in(cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str) -> bool: - cur.execute(f'SELECT 1 FROM {table_name} WHERE doc_id=?', (doc_id,)) - return len(list(cur)) > 0 - - @classmethod - def get_serialized_value_by_doc_id( - cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str - ) -> Union[None, bytes]: - cur.execute( - f'SELECT serialized_value FROM {table_name} WHERE doc_id=?', - (doc_id,), - ) - res = cur.fetchone() - if res is None: - return None - return cast(bytes, res[0]) - @classmethod def get_next_order(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: cur.execute(f'SELECT MAX(item_order) FROM {table_name}') @@ -78,33 +36,11 @@ def get_next_order(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: return 0 return cast(int, res) + 1 - @classmethod - def get_count(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: - cur.execute(f'SELECT COUNT(*) FROM {table_name}') - res = cur.fetchone() - return cast(int, res[0]) - @classmethod def get_doc_ids(cls, table_name: str, cur: 'sqlite3.Cursor') -> Iterable[str]: cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order') for res in cur: - yield cast(str, res[0]) - - @classmethod - def insert_serialized_value_by_doc_id( - cls, - table_name: str, - cur: 'sqlite3.Cursor', - doc_id: str, - serialized_value: bytes, - item_order: Optional[int], - ) -> None: - if item_order is None: - item_order = cls.get_next_order(table_name, cur) - cur.execute( - f'INSERT INTO {table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', - (doc_id, serialized_value, item_order), - ) + yield res[0] @classmethod def update_serialized_value_by_doc_id( diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/persistence/sqlite/mixin.py index e0f98ab0b5e..43156d184a7 100644 --- a/docarray/array/persistence/sqlite/mixin.py +++ b/docarray/array/persistence/sqlite/mixin.py @@ -1,9 +1,22 @@ +import itertools from dataclasses import dataclass import dataclasses -from typing import Optional, TYPE_CHECKING, Callable, Union, cast, Iterable +from typing import ( + Optional, + TYPE_CHECKING, + Callable, + Union, + cast, + Iterable, + Dict, + Iterator, + Sequence, +) from .base import SqliteCollectionBase from .dict import _DictDatabaseDriver +from ....helper import typename +import numpy as np if TYPE_CHECKING: import sqlite3 @@ -19,14 +32,14 @@ DocumentArraySingleAttributeType, ) + from docarray import DocumentArray + @dataclass class SqliteConfig: connection: Optional[Union[str, 'sqlite3.Connection']] = None table_name: Optional[str] = None - serializer: Optional[Callable[['Document'], bytes]] = None - deserializer: Optional[Callable[[bytes], 'Document']] = None - persist: bool = True + serialize_config: Optional[Dict] = None class SqliteMixin(SqliteCollectionBase): @@ -56,37 +69,138 @@ def insert(self, index: int, value: 'Document'): :param index: Position of the insertion. :param value: The doc needs to be inserted. """ - length = self._driver_class.get_max_index_plus_one(self.table_name, self.cursor) + length = len(self) if index < 0: index = length + index index = max(0, min(length, index)) - self._driver_class.increment_indices(self.table_name, self.cursor, index) - self._driver_class.insert_serialized_value_by_doc_id( - self.table_name, self.cursor, value.id, self.serialize(value), index - ) + self._shift_index_right_backward(index) + self._insert_doc_at_idx(doc=value, idx=index) self.connection.commit() - def extend(self, values: Iterable['Document']) -> None: + def append(self, value: 'Document') -> None: + self._insert_doc_at_idx(value) + self.connection.commit() - idx = self._driver_class.get_max_index_plus_one(self.table_name, self.cursor) + def extend(self, values: Iterable['Document']) -> None: + idx = len(self) for v in values: - self._driver_class.insert_serialized_value_by_doc_id( - self.table_name, - cur=self.cursor, - doc_id=v.id, - serialized_value=self.serialize(v), - item_order=idx, - ) + self._insert_doc_at_idx(v, idx) idx += 1 self.connection.commit() def clear(self) -> None: - self._driver_class.delete_all_records(self.table_name, self.cursor) + self._sql(f'DELETE FROM {self.table_name}') self.connection.commit() + def __contains__(self, item: Union[str, 'Document']): + if isinstance(item, str): + r = self._sql(f"SELECT 1 FROM {self.table_name} WHERE doc_id=?", (item,)) + return len(list(r)) > 0 + elif isinstance(item, Document): + return item.id in self # fall back to str check + else: + return False + def __len__(self) -> int: - return self._driver_class.get_count(self.table_name, self.cursor) + r = self._sql(f'SELECT COUNT(*) FROM {self.table_name}') + return r.fetchone()[0] + + def __iter__(self) -> Iterator['Document']: + r = self._sql( + f'SELECT serialized_value FROM {self.table_name} ORDER BY item_order' + ) + for res in r: + yield res[0] + + def __getitem__( + self, index: 'DocumentArrayIndexType' + ) -> Union['Document', 'DocumentArray']: + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): + return self._get_doc_by_offset(int(index)) + elif isinstance(index, str): + if index.startswith('@'): + return self.traverse_flat(index[1:]) + else: + return self._get_doc_by_id(index) + elif isinstance(index, slice): + from docarray import DocumentArray + + return DocumentArray(self[j] for j in range(len(self))[index]) + elif index is Ellipsis: + return self.flatten() + elif isinstance(index, Sequence): + from docarray import DocumentArray + + if ( + isinstance(index, tuple) + and len(index) == 2 + and isinstance(index[0], (slice, Sequence)) + ): + if isinstance(index[0], str) and isinstance(index[1], str): + # ambiguity only comes from the second string + if index[1] in self: + return DocumentArray([self[index[0]], self[index[1]]]) + else: + return getattr(self[index[0]], index[1]) + elif isinstance(index[0], (slice, Sequence)): + _docs = self[index[0]] + _attrs = index[1] + if isinstance(_attrs, str): + _attrs = (index[1],) + return _docs._get_attributes(*_attrs) + elif isinstance(index[0], bool): + return DocumentArray(itertools.compress(self, index)) + elif isinstance(index[0], (int, str)): + return DocumentArray(self[t] for t in index) + elif isinstance(index, np.ndarray): + index = index.squeeze() + if index.ndim == 1: + return self[index.tolist()] + else: + raise IndexError( + f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' + ) + raise IndexError(f'Unsupported index type {typename(index)}: {index}') + + def _get_doc_by_offset(self, index: int) -> 'Document': + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE item_order = ?", + (index + (len(self) if index < 0 else 0),), + ) + res = r.fetchone() + if res is None: + raise IndexError('index out of range') + return res[0] + + def _get_doc_by_id(self, id: str) -> 'Document': + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE doc_id = ?", (id,) + ) + res = r.fetchone() + if res is None: + raise KeyError(f'Can not find Document with id=`{id}`') + return res[0] @property def schema_version(self) -> str: return '0' + + def _sql(self, *arg, **kwargs) -> 'sqlite3.Cursor': + return self.connection.cursor().execute(*arg, **kwargs) + + def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): + if idx is None: + idx = len(self) + self._sql( + f'INSERT INTO {self.table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', + (doc.id, doc, idx), + ) + + def _shift_index_right_backward(self, start: int): + idx = len(self) - 1 + while idx >= start: + self._sql( + f"UPDATE {self.table_name} SET item_order = ? WHERE item_order = ?", + (idx + 1, idx), + ) + idx -= 1 From 1313bb0f82f8072e66c0cafbdd4dde942542abaa Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 18 Jan 2022 16:26:43 +0100 Subject: [PATCH 05/23] fix(document): serialize tag value in the correct priority --- docarray/array/persistence/sqlite/mixin.py | 61 +++++----------------- 1 file changed, 12 insertions(+), 49 deletions(-) diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/persistence/sqlite/mixin.py index 43156d184a7..53738d96f80 100644 --- a/docarray/array/persistence/sqlite/mixin.py +++ b/docarray/array/persistence/sqlite/mixin.py @@ -112,55 +112,18 @@ def __iter__(self) -> Iterator['Document']: for res in r: yield res[0] - def __getitem__( - self, index: 'DocumentArrayIndexType' - ) -> Union['Document', 'DocumentArray']: - if isinstance(index, (int, np.generic)) and not isinstance(index, bool): - return self._get_doc_by_offset(int(index)) - elif isinstance(index, str): - if index.startswith('@'): - return self.traverse_flat(index[1:]) - else: - return self._get_doc_by_id(index) - elif isinstance(index, slice): - from docarray import DocumentArray - - return DocumentArray(self[j] for j in range(len(self))[index]) - elif index is Ellipsis: - return self.flatten() - elif isinstance(index, Sequence): - from docarray import DocumentArray - - if ( - isinstance(index, tuple) - and len(index) == 2 - and isinstance(index[0], (slice, Sequence)) - ): - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self: - return DocumentArray([self[index[0]], self[index[1]]]) - else: - return getattr(self[index[0]], index[1]) - elif isinstance(index[0], (slice, Sequence)): - _docs = self[index[0]] - _attrs = index[1] - if isinstance(_attrs, str): - _attrs = (index[1],) - return _docs._get_attributes(*_attrs) - elif isinstance(index[0], bool): - return DocumentArray(itertools.compress(self, index)) - elif isinstance(index[0], (int, str)): - return DocumentArray(self[t] for t in index) - elif isinstance(index, np.ndarray): - index = index.squeeze() - if index.ndim == 1: - return self[index.tolist()] - else: - raise IndexError( - f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' - ) - raise IndexError(f'Unsupported index type {typename(index)}: {index}') + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._get_docs_by_offsets(range(len(self))[_slice]) + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + l = len(self) + offsets = [o + (l if o < 0 else 0) for o in offsets] + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE item_order in ({','.join(['?']*len(offsets))})", + offsets, + ) + for rr in r: + yield rr[0] def _get_doc_by_offset(self, index: int) -> 'Document': r = self._sql( From 5a9aa86b8a82a99269059779b1c0a7fb19a93054 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 19 Jan 2022 11:44:42 +0100 Subject: [PATCH 06/23] feat(array): add storage backend --- docarray/array/persistence/__init__.py | 0 docarray/array/persistence/sqlite/__init__.py | 0 docarray/array/{persistence => storage}/sqlite/base.py | 0 docarray/array/{persistence => storage}/sqlite/dict.py | 0 docarray/array/{persistence => storage}/sqlite/mixin.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 docarray/array/persistence/__init__.py delete mode 100644 docarray/array/persistence/sqlite/__init__.py rename docarray/array/{persistence => storage}/sqlite/base.py (100%) rename docarray/array/{persistence => storage}/sqlite/dict.py (100%) rename docarray/array/{persistence => storage}/sqlite/mixin.py (100%) diff --git a/docarray/array/persistence/__init__.py b/docarray/array/persistence/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/docarray/array/persistence/sqlite/__init__.py b/docarray/array/persistence/sqlite/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/docarray/array/persistence/sqlite/base.py b/docarray/array/storage/sqlite/base.py similarity index 100% rename from docarray/array/persistence/sqlite/base.py rename to docarray/array/storage/sqlite/base.py diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/storage/sqlite/dict.py similarity index 100% rename from docarray/array/persistence/sqlite/dict.py rename to docarray/array/storage/sqlite/dict.py diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/storage/sqlite/mixin.py similarity index 100% rename from docarray/array/persistence/sqlite/mixin.py rename to docarray/array/storage/sqlite/mixin.py From 5b353636f10bc54b07e452309c03db631cdfb6b7 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 19 Jan 2022 12:18:39 +0100 Subject: [PATCH 07/23] feat(array): add storage backend --- docarray/array/storage/sqlite/mixin.py | 169 ------------------------- 1 file changed, 169 deletions(-) delete mode 100644 docarray/array/storage/sqlite/mixin.py diff --git a/docarray/array/storage/sqlite/mixin.py b/docarray/array/storage/sqlite/mixin.py deleted file mode 100644 index 53738d96f80..00000000000 --- a/docarray/array/storage/sqlite/mixin.py +++ /dev/null @@ -1,169 +0,0 @@ -import itertools -from dataclasses import dataclass -import dataclasses -from typing import ( - Optional, - TYPE_CHECKING, - Callable, - Union, - cast, - Iterable, - Dict, - Iterator, - Sequence, -) - -from .base import SqliteCollectionBase -from .dict import _DictDatabaseDriver -from ....helper import typename -import numpy as np - -if TYPE_CHECKING: - import sqlite3 - - from ....types import ( - T, - Document, - DocumentArraySourceType, - DocumentArrayIndexType, - DocumentArraySingletonIndexType, - DocumentArrayMultipleIndexType, - DocumentArrayMultipleAttributeType, - DocumentArraySingleAttributeType, - ) - - from docarray import DocumentArray - - -@dataclass -class SqliteConfig: - connection: Optional[Union[str, 'sqlite3.Connection']] = None - table_name: Optional[str] = None - serialize_config: Optional[Dict] = None - - -class SqliteMixin(SqliteCollectionBase): - """Enable SQLite persistence backend for DocumentArray. - - .. note:: - This has to be put in the first position when use it for subclassing - i.e. `class SqliteDA(SqliteMixin, DA)` not the other way around. - - """ - - _driver_class = _DictDatabaseDriver - - def __init__( - self, - docs: Optional['DocumentArraySourceType'] = None, - config: Optional[SqliteConfig] = None, - ): - super().__init__(**(dataclasses.asdict(config) if config else {})) - if docs is not None: - self.clear() - self.extend(docs) - - def insert(self, index: int, value: 'Document'): - """Insert `doc` at `index`. - - :param index: Position of the insertion. - :param value: The doc needs to be inserted. - """ - length = len(self) - if index < 0: - index = length + index - index = max(0, min(length, index)) - self._shift_index_right_backward(index) - self._insert_doc_at_idx(doc=value, idx=index) - self.connection.commit() - - def append(self, value: 'Document') -> None: - self._insert_doc_at_idx(value) - self.connection.commit() - - def extend(self, values: Iterable['Document']) -> None: - idx = len(self) - for v in values: - self._insert_doc_at_idx(v, idx) - idx += 1 - self.connection.commit() - - def clear(self) -> None: - self._sql(f'DELETE FROM {self.table_name}') - self.connection.commit() - - def __contains__(self, item: Union[str, 'Document']): - if isinstance(item, str): - r = self._sql(f"SELECT 1 FROM {self.table_name} WHERE doc_id=?", (item,)) - return len(list(r)) > 0 - elif isinstance(item, Document): - return item.id in self # fall back to str check - else: - return False - - def __len__(self) -> int: - r = self._sql(f'SELECT COUNT(*) FROM {self.table_name}') - return r.fetchone()[0] - - def __iter__(self) -> Iterator['Document']: - r = self._sql( - f'SELECT serialized_value FROM {self.table_name} ORDER BY item_order' - ) - for res in r: - yield res[0] - - def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: - return self._get_docs_by_offsets(range(len(self))[_slice]) - - def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: - l = len(self) - offsets = [o + (l if o < 0 else 0) for o in offsets] - r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE item_order in ({','.join(['?']*len(offsets))})", - offsets, - ) - for rr in r: - yield rr[0] - - def _get_doc_by_offset(self, index: int) -> 'Document': - r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE item_order = ?", - (index + (len(self) if index < 0 else 0),), - ) - res = r.fetchone() - if res is None: - raise IndexError('index out of range') - return res[0] - - def _get_doc_by_id(self, id: str) -> 'Document': - r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE doc_id = ?", (id,) - ) - res = r.fetchone() - if res is None: - raise KeyError(f'Can not find Document with id=`{id}`') - return res[0] - - @property - def schema_version(self) -> str: - return '0' - - def _sql(self, *arg, **kwargs) -> 'sqlite3.Cursor': - return self.connection.cursor().execute(*arg, **kwargs) - - def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): - if idx is None: - idx = len(self) - self._sql( - f'INSERT INTO {self.table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', - (doc.id, doc, idx), - ) - - def _shift_index_right_backward(self, start: int): - idx = len(self) - 1 - while idx >= start: - self._sql( - f"UPDATE {self.table_name} SET item_order = ? WHERE item_order = ?", - (idx + 1, idx), - ) - idx -= 1 From 6a2789cbe146ea03d1cc8a878cb9af5a0b3773bb Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 20 Jan 2022 16:37:04 +0100 Subject: [PATCH 08/23] feat(array): add storage backend --- docarray/array/storage/sqlite/base.py | 227 -------------------------- docarray/array/storage/sqlite/dict.py | 90 ---------- 2 files changed, 317 deletions(-) delete mode 100644 docarray/array/storage/sqlite/base.py delete mode 100644 docarray/array/storage/sqlite/dict.py diff --git a/docarray/array/storage/sqlite/base.py b/docarray/array/storage/sqlite/base.py deleted file mode 100644 index 80471318de0..00000000000 --- a/docarray/array/storage/sqlite/base.py +++ /dev/null @@ -1,227 +0,0 @@ -import sqlite3 -import warnings -from abc import ABCMeta, abstractmethod -from tempfile import NamedTemporaryFile -from typing import Callable, Generic, Optional, TypeVar, Union, Dict -from uuid import uuid4 - -T = TypeVar('T') -KT = TypeVar('KT') -VT = TypeVar('VT') -_T = TypeVar('_T') -_S = TypeVar('_S') - - -def sanitize_table_name(table_name: str) -> str: - ret = ''.join(c for c in table_name if c.isalnum() or c == '_') - if ret != table_name: - warnings.warn(f'The table name is changed to {ret} due to illegal characters') - return ret - - -def create_random_name(suffix: str) -> str: - return f"{suffix}_{str(uuid4()).replace('-', '')}" - - -class _SqliteCollectionBaseDatabaseDriver(metaclass=ABCMeta): - @classmethod - def initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: - if not cls.is_metadata_table_initialized(cur): - cls.do_initialize_metadata_table(cur) - - @classmethod - def is_metadata_table_initialized(cls, cur: sqlite3.Cursor) -> bool: - try: - cur.execute('SELECT 1 FROM metadata LIMIT 1') - _ = list(cur) - return True - except sqlite3.OperationalError as _: - pass - return False - - @classmethod - def do_initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: - cur.execute( - ''' - CREATE TABLE metadata ( - table_name TEXT PRIMARY KEY, - schema_version TEXT NOT NULL, - container_type TEXT NOT NULL, - UNIQUE (table_name, container_type) - ) - ''' - ) - - @classmethod - def initialize_table( - cls, - table_name: str, - container_type_name: str, - schema_version: str, - cur: sqlite3.Cursor, - ) -> None: - if not cls.is_table_initialized( - table_name, container_type_name, schema_version, cur - ): - cls.do_create_table(table_name, container_type_name, schema_version, cur) - cls.do_tidy_table_metadata( - table_name, container_type_name, schema_version, cur - ) - - @classmethod - def is_table_initialized( - self, - table_name: str, - container_type_name: str, - schema_version: str, - cur: sqlite3.Cursor, - ) -> bool: - try: - cur.execute( - 'SELECT schema_version FROM metadata WHERE table_name=? AND container_type=?', - (table_name, container_type_name), - ) - buf = cur.fetchone() - if buf is None: - return False - version = buf[0] - if version != schema_version: - return False - cur.execute(f'SELECT 1 FROM {table_name} LIMIT 1') - _ = list(cur) - return True - except sqlite3.OperationalError as _: - pass - return False - - @classmethod - def do_tidy_table_metadata( - cls, - table_name: str, - container_type_name: str, - schema_version: str, - cur: sqlite3.Cursor, - ) -> None: - cur.execute( - 'INSERT INTO metadata (table_name, schema_version, container_type) VALUES (?, ?, ?)', - (table_name, schema_version, container_type_name), - ) - - @classmethod - @abstractmethod - def do_create_table( - cls, - table_name: str, - container_type_name: str, - schema_version: str, - cur: sqlite3.Cursor, - ) -> None: - ... - - @classmethod - def drop_table( - cls, table_name: str, container_type_name: str, cur: sqlite3.Cursor - ) -> None: - cur.execute( - 'DELETE FROM metadata WHERE table_name=? AND container_type=?', - (table_name, container_type_name), - ) - cur.execute(f'DROP TABLE {table_name}') - - @classmethod - def alter_table_name( - cls, table_name: str, new_table_name: str, cur: sqlite3.Cursor - ) -> None: - cur.execute( - 'UPDATE metadata SET table_name=? WHERE table_name=?', - (new_table_name, table_name), - ) - cur.execute(f'ALTER TABLE {table_name} RENAME TO {new_table_name}') - - -class SqliteCollectionBase(Generic[T], metaclass=ABCMeta): - _driver_class = _SqliteCollectionBaseDatabaseDriver - - def __init__( - self, - connection: Optional[Union[str, sqlite3.Connection]] = None, - table_name: Optional[str] = None, - serialize_config: Optional[Dict] = None, - ): - super(SqliteCollectionBase, self).__init__() - self._serialize_config = serialize_config or {} - self._persist = not table_name - - from docarray import Document - - sqlite3.register_adapter( - Document, lambda d: d.to_bytes(**self._serialize_config) - ) - sqlite3.register_converter( - 'Document', lambda x: Document.from_bytes(x, **self._serialize_config) - ) - - _conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES) - if connection is None: - self._connection = sqlite3.connect( - NamedTemporaryFile().name, **_conn_kwargs - ) - elif isinstance(connection, str): - self._connection = sqlite3.connect(connection, **_conn_kwargs) - elif isinstance(connection, sqlite3.Connection): - self._connection = connection - else: - raise TypeError( - f'connection argument must be None or a string or a sqlite3.Connection, not `{type(connection)}`' - ) - - self._table_name = ( - sanitize_table_name(create_random_name(self.container_type_name)) - if table_name is None - else sanitize_table_name(table_name) - ) - cur = self.connection.cursor() - self._driver_class.initialize_metadata_table(cur) - self._driver_class.initialize_table( - self.table_name, self.container_type_name, self.schema_version, cur - ) - self.connection.commit() - - def __del__(self) -> None: - if not self._persist: - cur = self.connection.cursor() - self._driver_class.drop_table( - self.table_name, self.container_type_name, cur - ) - self.connection.commit() - - @property - def table_name(self) -> str: - return self._table_name - - @table_name.setter - def table_name(self, table_name: str) -> None: - cur = self.connection.cursor() - new_table_name = sanitize_table_name(table_name) - try: - self._driver_class.alter_table_name(self.table_name, new_table_name, cur) - self._table_name = new_table_name - except sqlite3.IntegrityError as ex: - raise ValueError(table_name) from ex - - @property - def connection(self) -> sqlite3.Connection: - return self._connection - - @property - def cursor(self) -> sqlite3.Cursor: - return self.connection.cursor() - - @property - def container_type_name(self) -> str: - return self.__class__.__name__ - - @property - @abstractmethod - def schema_version(self) -> str: - ... diff --git a/docarray/array/storage/sqlite/dict.py b/docarray/array/storage/sqlite/dict.py deleted file mode 100644 index b63fd778b60..00000000000 --- a/docarray/array/storage/sqlite/dict.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING - -from .base import _SqliteCollectionBaseDatabaseDriver - -if TYPE_CHECKING: - import sqlite3 - - -class _DictDatabaseDriver(_SqliteCollectionBaseDatabaseDriver): - @classmethod - def do_create_table( - cls, - table_name: str, - container_type_nam: str, - schema_version: str, - cur: 'sqlite3.Cursor', - ) -> None: - cur.execute( - f'CREATE TABLE {table_name} (' - 'doc_id TEXT NOT NULL UNIQUE, ' - 'serialized_value Document NOT NULL, ' - 'item_order INTEGER PRIMARY KEY)' - ) - - @classmethod - def delete_single_record_by_doc_id( - cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str - ) -> None: - cur.execute(f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,)) - - @classmethod - def get_next_order(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: - cur.execute(f'SELECT MAX(item_order) FROM {table_name}') - res = cur.fetchone()[0] - if res is None: - return 0 - return cast(int, res) + 1 - - @classmethod - def get_doc_ids(cls, table_name: str, cur: 'sqlite3.Cursor') -> Iterable[str]: - cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order') - for res in cur: - yield res[0] - - @classmethod - def update_serialized_value_by_doc_id( - cls, - table_name: str, - cur: 'sqlite3.Cursor', - doc_id: str, - serialized_value: bytes, - ) -> None: - cur.execute( - f'UPDATE {table_name} SET serialized_value=? WHERE doc_id=?', - (serialized_value, doc_id), - ) - - @classmethod - def upsert( - cls, - table_name: str, - cur: 'sqlite3.Cursor', - doc_id: str, - serialized_value: bytes, - ) -> None: - if cls.is_doc_id_in(table_name, cur, doc_id): - cls.update_serialized_value_by_doc_id( - table_name, cur, doc_id, serialized_value - ) - else: - cls.insert_serialized_value_by_doc_id( - table_name, cur, doc_id, serialized_value - ) - - @classmethod - def get_last_serialized_item( - cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Tuple[str, bytes]: - cur.execute( - f'SELECT doc_id, serialized_value FROM {table_name} ORDER BY item_order DESC LIMIT 1' - ) - return cast(Tuple[str, bytes], cur.fetchone()) - - @classmethod - def get_reversed_doc_ids( - cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Iterable[str]: - cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order DESC') - for res in cur: - yield cast(str, res[0]) From ad9deb5004cf8e13039559536102c15073aada9d Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Thu, 20 Jan 2022 15:15:28 +0800 Subject: [PATCH 09/23] feat: init commit pq storage mixins --- docarray/array/storage/pqlite/__init__.py | 10 +++++ docarray/array/storage/pqlite/backend.py | 40 ++++++++++++++++++ docarray/array/storage/pqlite/getsetdel.py | 23 +++++++++++ docarray/array/storage/pqlite/seqlike.py | 47 ++++++++++++++++++++++ 4 files changed, 120 insertions(+) create mode 100644 docarray/array/storage/pqlite/__init__.py create mode 100644 docarray/array/storage/pqlite/backend.py create mode 100644 docarray/array/storage/pqlite/getsetdel.py create mode 100644 docarray/array/storage/pqlite/seqlike.py diff --git a/docarray/array/storage/pqlite/__init__.py b/docarray/array/storage/pqlite/__init__.py new file mode 100644 index 00000000000..c8c79eaaace --- /dev/null +++ b/docarray/array/storage/pqlite/__init__.py @@ -0,0 +1,10 @@ +from .backend import PqliteBackendMixin +from .getsetdel import GetSetDelMixin +from .seqlike import SequenceLikeMixin +from abc import ABC + +__all__ = ['PqliteStorageMixins'] + + +class PqliteStorageMixins(PqliteBackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): + ... diff --git a/docarray/array/storage/pqlite/backend.py b/docarray/array/storage/pqlite/backend.py new file mode 100644 index 00000000000..037bf5c3be3 --- /dev/null +++ b/docarray/array/storage/pqlite/backend.py @@ -0,0 +1,40 @@ +import dataclasses +from dataclasses import dataclass +from typing import ( + Optional, + TYPE_CHECKING, +) + +from ..base.backend import BaseBackendMixin + +if TYPE_CHECKING: + from ....types import ( + DocumentArraySourceType, + ) + + +@dataclass +class PqliteConfig: + dim: int = 256 + metric: str = 'cosine' + data_path: str = 'data' + + +class PqliteBackendMixin(BaseBackendMixin): + """Provide necessary functions to enable this storage backend. """ + + def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): + raise NotImplementedError + + def _shift_index_right_backward(self, start: int): + raise NotImplementedError + + def _init_storage( + self, + docs: Optional['DocumentArraySourceType'] = None, + config: Optional[PqliteConfig] = None, + ): + super().__init__(**(dataclasses.asdict(config) if config else {})) + if docs is not None: + self.clear() + self.extend(docs) diff --git a/docarray/array/storage/pqlite/getsetdel.py b/docarray/array/storage/pqlite/getsetdel.py new file mode 100644 index 00000000000..d257b5c52c3 --- /dev/null +++ b/docarray/array/storage/pqlite/getsetdel.py @@ -0,0 +1,23 @@ +from typing import ( + Sequence, + Iterable, +) + +from ..base.getsetdel import BaseGetSetDelMixin +from .... import Document + + +class GetSetDelMixin(BaseGetSetDelMixin): + """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + + def _get_doc_by_offset(self, index: int) -> 'Document': + ... + + def _get_doc_by_id(self, id: str) -> 'Document': + ... + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + ... + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._get_docs_by_offsets(range(len(self))[_slice]) diff --git a/docarray/array/storage/pqlite/seqlike.py b/docarray/array/storage/pqlite/seqlike.py new file mode 100644 index 00000000000..b78da7fb43c --- /dev/null +++ b/docarray/array/storage/pqlite/seqlike.py @@ -0,0 +1,47 @@ +from typing import Iterator, Union, Iterable, MutableSequence + +from .... import Document + + +class SequenceLikeMixin(MutableSequence[Document]): + """Implement sequence-like methods""" + + def insert(self, index: int, value: 'Document'): + """Insert `doc` at `index`. + + :param index: Position of the insertion. + :param value: The doc needs to be inserted. + """ + length = len(self) + if index < 0: + index = length + index + index = max(0, min(length, index)) + self._shift_index_right_backward(index) + self._insert_doc_at_idx(doc=value, idx=index) + + def append(self, value: 'Document') -> None: + self._insert_doc_at_idx(value) + + def extend(self, values: Iterable['Document']) -> None: + idx = len(self) + for v in values: + self._insert_doc_at_idx(v, idx) + idx += 1 + + def clear(self) -> None: + raise NotImplementedError + + def __contains__(self, item: Union[str, 'Document']): + if isinstance(item, str): + raise NotImplementedError + return len(list(r)) > 0 + elif isinstance(item, Document): + return item.id in self # fall back to str check + else: + return False + + def __len__(self) -> int: + ... + + def __iter__(self) -> Iterator['Document']: + ... From 1b95344ed1acf9e0d85f8544c2f31f667b8842c1 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Tue, 25 Jan 2022 15:57:27 +0800 Subject: [PATCH 10/23] fix: draft commit --- docarray/array/pqlite.py | 7 +++++++ docarray/array/storage/pqlite/__init__.py | 6 +++--- docarray/array/storage/pqlite/backend.py | 10 ++++++++-- docarray/array/storage/pqlite/seqlike.py | 2 +- 4 files changed, 19 insertions(+), 6 deletions(-) create mode 100644 docarray/array/pqlite.py diff --git a/docarray/array/pqlite.py b/docarray/array/pqlite.py new file mode 100644 index 00000000000..1f1a8135914 --- /dev/null +++ b/docarray/array/pqlite.py @@ -0,0 +1,7 @@ +from .base import DocumentArray +from .storage.pqlite import StorageMixins + + +class DocumentArrayPqlite(StorageMixins, DocumentArray): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) diff --git a/docarray/array/storage/pqlite/__init__.py b/docarray/array/storage/pqlite/__init__.py index c8c79eaaace..f07b096a031 100644 --- a/docarray/array/storage/pqlite/__init__.py +++ b/docarray/array/storage/pqlite/__init__.py @@ -1,10 +1,10 @@ -from .backend import PqliteBackendMixin +from .backend import BackendMixin from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin from abc import ABC -__all__ = ['PqliteStorageMixins'] +__all__ = ['StorageMixins'] -class PqliteStorageMixins(PqliteBackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): +class StorageMixins(BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): ... diff --git a/docarray/array/storage/pqlite/backend.py b/docarray/array/storage/pqlite/backend.py index 037bf5c3be3..49af78c9556 100644 --- a/docarray/array/storage/pqlite/backend.py +++ b/docarray/array/storage/pqlite/backend.py @@ -20,7 +20,7 @@ class PqliteConfig: data_path: str = 'data' -class PqliteBackendMixin(BaseBackendMixin): +class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend. """ def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): @@ -34,7 +34,13 @@ def _init_storage( docs: Optional['DocumentArraySourceType'] = None, config: Optional[PqliteConfig] = None, ): - super().__init__(**(dataclasses.asdict(config) if config else {})) + if not config: + config = PqliteConfig() + + from pqlite import PQLite + + self._pqlite = PQLite(**config) + if docs is not None: self.clear() self.extend(docs) diff --git a/docarray/array/storage/pqlite/seqlike.py b/docarray/array/storage/pqlite/seqlike.py index b78da7fb43c..3aab7057ec0 100644 --- a/docarray/array/storage/pqlite/seqlike.py +++ b/docarray/array/storage/pqlite/seqlike.py @@ -41,7 +41,7 @@ def __contains__(self, item: Union[str, 'Document']): return False def __len__(self) -> int: - ... + return self._pqlite.stat['doc_num'] def __iter__(self) -> Iterator['Document']: ... From 4c40b48f05245703a7f5ba65ba2182c221304655 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Thu, 27 Jan 2022 23:11:27 +0800 Subject: [PATCH 11/23] feat(pqlite): init commit --- docarray/array/pqlite.py | 2 +- docarray/array/storage/pqlite/backend.py | 25 +++--- docarray/array/storage/pqlite/getsetdel.py | 35 ++++++++- docarray/array/storage/pqlite/helper.py | 91 ++++++++++++++++++++++ docarray/array/storage/pqlite/seqlike.py | 88 ++++++++++++++------- 5 files changed, 198 insertions(+), 43 deletions(-) create mode 100644 docarray/array/storage/pqlite/helper.py diff --git a/docarray/array/pqlite.py b/docarray/array/pqlite.py index 1f1a8135914..1cd3ef5c2da 100644 --- a/docarray/array/pqlite.py +++ b/docarray/array/pqlite.py @@ -1,4 +1,4 @@ -from .base import DocumentArray +from .document import DocumentArray from .storage.pqlite import StorageMixins diff --git a/docarray/array/storage/pqlite/backend.py b/docarray/array/storage/pqlite/backend.py index 49af78c9556..af83feecd9c 100644 --- a/docarray/array/storage/pqlite/backend.py +++ b/docarray/array/storage/pqlite/backend.py @@ -1,6 +1,7 @@ -import dataclasses -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import ( + Union, + Dict, Optional, TYPE_CHECKING, ) @@ -15,7 +16,7 @@ @dataclass class PqliteConfig: - dim: int = 256 + n_dim: int = 1 metric: str = 'cosine' data_path: str = 'data' @@ -23,23 +24,25 @@ class PqliteConfig: class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend. """ - def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): - raise NotImplementedError - - def _shift_index_right_backward(self, start: int): - raise NotImplementedError - def _init_storage( self, docs: Optional['DocumentArraySourceType'] = None, - config: Optional[PqliteConfig] = None, + config: Optional[Union[PqliteConfig, Dict]] = None, ): if not config: config = PqliteConfig() + self._config = config from pqlite import PQLite + from .helper import OffsetMapping + + config = asdict(config) + n_dim = config.pop('n_dim') - self._pqlite = PQLite(**config) + self._pqlite = PQLite(n_dim, **config) + self._offset2ids = OffsetMapping( + name='offset2ids', data_path=config['data_path'], in_memory=True + ) if docs is not None: self.clear() diff --git a/docarray/array/storage/pqlite/getsetdel.py b/docarray/array/storage/pqlite/getsetdel.py index d257b5c52c3..760da40fd6f 100644 --- a/docarray/array/storage/pqlite/getsetdel.py +++ b/docarray/array/storage/pqlite/getsetdel.py @@ -2,7 +2,8 @@ Sequence, Iterable, ) - +import numpy as np +from ...memory import DocumentArrayInMemory from ..base.getsetdel import BaseGetSetDelMixin from .... import Document @@ -10,14 +11,40 @@ class GetSetDelMixin(BaseGetSetDelMixin): """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + # essential methods start + + def _del_doc_by_id(self, _id: str): + offset = self._offset2ids.get_offset_by_id(_id) + self._offset2ids.del_at_offset(offset, commit=True) + self._pqlite.delete([_id]) + + def _del_doc_by_offset(self, offset: int): + _id = self._offset2ids.get_id_by_offset(offset) + self._pqlite.delete([_id]) + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + self._offset2ids.set_at_offset(offset, value.id) + docs = DocumentArrayInMemory([value]) + if docs.embeddings is None: + docs.embeddings = np.zeros((1, self._pqlite.dim)) + self._pqlite.update(docs) + + def _set_doc_by_id(self, _id: str, value: 'Document'): + docs = DocumentArrayInMemory([value]) + if docs.embeddings is None: + docs.embeddings = np.zeros((1, self._pqlite.dim)) + self._pqlite.update(docs) + def _get_doc_by_offset(self, index: int) -> 'Document': - ... + doc_id = self._offset2ids.get_id_by_offset(index) + if doc_id is not None: + return self._pqlite.get_doc_by_offset(index) def _get_doc_by_id(self, id: str) -> 'Document': - ... + return self._pqlite.get_doc_by_id(id) def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: - ... + return [self._get_doc_by_offset(offset) for offset in offsets] def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: return self._get_docs_by_offsets(range(len(self))[_slice]) diff --git a/docarray/array/storage/pqlite/helper.py b/docarray/array/storage/pqlite/helper.py new file mode 100644 index 00000000000..77b5bad2b55 --- /dev/null +++ b/docarray/array/storage/pqlite/helper.py @@ -0,0 +1,91 @@ +from typing import Optional, List, Tuple + +from pqlite.storage.table import Table + + +class OffsetMapping(Table): + def __init__( + self, + name: str = 'offset2ids', + data_path: Optional[str] = None, + in_memory: bool = True, + ): + super().__init__(name, data_path, in_memory) + self.create_table() + self._size = None + + def create_table(self): + sql = f'''CREATE TABLE {self.name} + (offset INTEGER NOT NULL PRIMARY KEY, + doc_id INTEGER TEXT NOT NULL)''' + + self.execute(sql, commit=True) + + def clear(self): + super().clear() + self._size = None + + @property + def size(self): + if self._size is None: + sql = f'SELECT MAX(offset) from {self.name} LIMIT 1;' + result = self._conn.execute(sql).fetchone() + self._size = result[0] + 1 if result[0] else 0 + + return self._size + + def extend_doc_ids(self, doc_ids: List[str], commit: bool = True): + offsets = [self.size + i for i in range(len(doc_ids))] + self._insert(list(zip(offsets, doc_ids)), commit=commit) + + def _insert(self, offset_ids: List[Tuple[int, str]], commit: bool = True): + sql = f'INSERT INTO {self.name}(offset, doc_id) VALUES (?, ?);' + self.execute_many(sql, offset_ids, commit=commit) + self._size += len(offset_ids) + + def get_id_by_offset(self, offset: int): + sql = f'SELECT doc_id FROM {self.name} WHERE offset = ? LIMIT 1;' + result = self._conn.execute(sql, (offset,)).fetchone() + return result[0] + + def get_offset_by_id(self, doc_id: str): + sql = f'SELECT offset FROM {self.name} WHERE doc_id = ? LIMIT 1;' + result = self._conn.execute(sql, (doc_id,)).fetchone() + return result[0] + + def del_at_offset(self, offset: int, commit: bool = True): + sql = f'DELETE FROM {self.name} WHERE offset=?' + self._conn.execute(sql, (offset,)) + self.shift_offset(offset, shift_step=1, direction='left', commit=False) + if commit: + self.commit() + + self._size -= 1 + + def insert_at_offset(self, offset: int, doc_id: str, commit: bool = True): + self.shift_offset(offset - 1, shift_step=1, direction='right', commit=False) + self._insert([(offset, doc_id)], commit=commit) + + def set_at_offset(self, offset: int, doc_id: str, commit: bool = True): + sql = f'UPDATE {self.name} SET doc_id={doc_id} WHERE offset = ?' + self._conn.execute(sql, (offset,)) + if commit: + self.commit() + + def shift_offset( + self, + shift_from: int, + shift_step: int = 1, + direction: str = 'left', + commit: bool = True, + ): + if direction == 'left': + sql = f'UPDATE {self.name} SET offset=offset-{shift_step} WHERE offset > ?' + elif direction == 'right': + sql = f'UPDATE {self.name} SET offset=offset+{shift_step} WHERE offset > ?' + else: + raise ValueError(f'The shit_offset directory {direction} is not supported!') + + self._conn.execute(sql, (shift_from,)) + if commit: + self._conn.commit() diff --git a/docarray/array/storage/pqlite/seqlike.py b/docarray/array/storage/pqlite/seqlike.py index 3aab7057ec0..fc6017d61c8 100644 --- a/docarray/array/storage/pqlite/seqlike.py +++ b/docarray/array/storage/pqlite/seqlike.py @@ -1,47 +1,81 @@ -from typing import Iterator, Union, Iterable, MutableSequence - +from typing import Iterator, Union, Iterable, Sequence, MutableSequence +import numpy as np from .... import Document +from ...memory import DocumentArrayInMemory class SequenceLikeMixin(MutableSequence[Document]): """Implement sequence-like methods""" + """Implement sequence-like methods""" + def insert(self, index: int, value: 'Document'): """Insert `doc` at `index`. :param index: Position of the insertion. :param value: The doc needs to be inserted. """ - length = len(self) - if index < 0: - index = length + index - index = max(0, min(length, index)) - self._shift_index_right_backward(index) - self._insert_doc_at_idx(doc=value, idx=index) + if value.embedding is None: + value.embedding = np.zeros(self._pqlite.dim, dtype=np.float32) + + self._pqlite.index(DocumentArrayInMemory([value])) + self._offset2ids.insert_at_offset(index, value.id) def append(self, value: 'Document') -> None: - self._insert_doc_at_idx(value) + self._pqlite.index(DocumentArrayInMemory([value])) + self._offset2ids.extend_doc_ids([value.id]) def extend(self, values: Iterable['Document']) -> None: - idx = len(self) - for v in values: - self._insert_doc_at_idx(v, idx) - idx += 1 - - def clear(self) -> None: - raise NotImplementedError - - def __contains__(self, item: Union[str, 'Document']): - if isinstance(item, str): - raise NotImplementedError - return len(list(r)) > 0 - elif isinstance(item, Document): - return item.id in self # fall back to str check + docs = DocumentArrayInMemory(values) + for doc in docs: + if doc.embedding is None: + doc.embedding = np.zeros(self._pqlite.dim, dtype=np.float32) + self._pqlite.index(docs) + self._offset2ids.extend_doc_ids([value.id for value in values]) + + def __del__(self) -> None: + del self._offset2ids + del self._pqlite + + def __eq__(self, other): + """In pqlite backend, data are considered as identical if configs point to the same database source""" + return ( + type(self) is type(other) + and type(self._config) is type(other._config) + and self._config == other._config + ) + + def __len__(self): + return self._offset2ids.size + + def __iter__(self) -> Iterator['Document']: + for i in range(len(self)): + yield self[i] + + def __contains__(self, x: Union[str, 'Document']): + if isinstance(x, str): + return self._offset2id.get_offset_by_id(x) is not None + elif isinstance(x, Document): + return self._offset2id.get_offset_by_id(x.id) is not None else: return False - def __len__(self) -> int: - return self._pqlite.stat['doc_num'] + def clear(self): + """Clear the data of :class:`DocumentArray`""" + self._offset2ids.clear() + self._pqlite.clear() - def __iter__(self) -> Iterator['Document']: - ... + def __bool__(self): + """To simulate ```l = []; if l: ...``` + + :return: returns true if the length of the array is larger than 0 + """ + return len(self) > 0 + + def __repr__(self): + return f'' + + def __add__(self, other: Union['Document', Sequence['Document']]): + v = type(self)(self) + v.extend(other) + return v From 2d84be250d6f6805dd4c5ef0690319babdcc89da Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Thu, 27 Jan 2022 23:15:14 +0800 Subject: [PATCH 12/23] fix: in_memory false --- docarray/array/storage/pqlite/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/storage/pqlite/backend.py b/docarray/array/storage/pqlite/backend.py index af83feecd9c..e8839826759 100644 --- a/docarray/array/storage/pqlite/backend.py +++ b/docarray/array/storage/pqlite/backend.py @@ -41,7 +41,7 @@ def _init_storage( self._pqlite = PQLite(n_dim, **config) self._offset2ids = OffsetMapping( - name='offset2ids', data_path=config['data_path'], in_memory=True + name='offset2ids', data_path=config['data_path'] ) if docs is not None: From 4ada8e9c2baca1e2304d86d91d2137e146f40993 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Thu, 27 Jan 2022 23:54:16 +0800 Subject: [PATCH 13/23] fix: del and clear --- docarray/array/storage/pqlite/backend.py | 12 +++++++++++- docarray/array/storage/pqlite/getsetdel.py | 2 +- docarray/array/storage/pqlite/helper.py | 2 +- docarray/array/storage/pqlite/seqlike.py | 7 +++---- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/docarray/array/storage/pqlite/backend.py b/docarray/array/storage/pqlite/backend.py index e8839826759..dd783604458 100644 --- a/docarray/array/storage/pqlite/backend.py +++ b/docarray/array/storage/pqlite/backend.py @@ -7,6 +7,7 @@ ) from ..base.backend import BaseBackendMixin +from ....helper import dataclass_from_dict if TYPE_CHECKING: from ....types import ( @@ -19,6 +20,7 @@ class PqliteConfig: n_dim: int = 1 metric: str = 'cosine' data_path: str = 'data' + table_name: Optional[str] = None class BackendMixin(BaseBackendMixin): @@ -31,8 +33,14 @@ def _init_storage( ): if not config: config = PqliteConfig() + if isinstance(config, dict): + config = dataclass_from_dict(PqliteConfig, config) + self._config = config + table_name = config.table_name + self._persist = bool(table_name) + from pqlite import PQLite from .helper import OffsetMapping @@ -41,7 +49,9 @@ def _init_storage( self._pqlite = PQLite(n_dim, **config) self._offset2ids = OffsetMapping( - name='offset2ids', data_path=config['data_path'] + name=table_name or 'docarray', + data_path=config['data_path'], + in_memory=False, ) if docs is not None: diff --git a/docarray/array/storage/pqlite/getsetdel.py b/docarray/array/storage/pqlite/getsetdel.py index 760da40fd6f..cdd31302c2e 100644 --- a/docarray/array/storage/pqlite/getsetdel.py +++ b/docarray/array/storage/pqlite/getsetdel.py @@ -38,7 +38,7 @@ def _set_doc_by_id(self, _id: str, value: 'Document'): def _get_doc_by_offset(self, index: int) -> 'Document': doc_id = self._offset2ids.get_id_by_offset(index) if doc_id is not None: - return self._pqlite.get_doc_by_offset(index) + return self._pqlite.get_doc_by_id(doc_id) def _get_doc_by_id(self, id: str) -> 'Document': return self._pqlite.get_doc_by_id(id) diff --git a/docarray/array/storage/pqlite/helper.py b/docarray/array/storage/pqlite/helper.py index 77b5bad2b55..34f4efe2990 100644 --- a/docarray/array/storage/pqlite/helper.py +++ b/docarray/array/storage/pqlite/helper.py @@ -15,7 +15,7 @@ def __init__( self._size = None def create_table(self): - sql = f'''CREATE TABLE {self.name} + sql = f'''CREATE TABLE IF NOT EXISTS {self.name} (offset INTEGER NOT NULL PRIMARY KEY, doc_id INTEGER TEXT NOT NULL)''' diff --git a/docarray/array/storage/pqlite/seqlike.py b/docarray/array/storage/pqlite/seqlike.py index fc6017d61c8..0ebcf7efe57 100644 --- a/docarray/array/storage/pqlite/seqlike.py +++ b/docarray/array/storage/pqlite/seqlike.py @@ -7,8 +7,6 @@ class SequenceLikeMixin(MutableSequence[Document]): """Implement sequence-like methods""" - """Implement sequence-like methods""" - def insert(self, index: int, value: 'Document'): """Insert `doc` at `index`. @@ -34,8 +32,9 @@ def extend(self, values: Iterable['Document']) -> None: self._offset2ids.extend_doc_ids([value.id for value in values]) def __del__(self) -> None: - del self._offset2ids - del self._pqlite + if not self._persist: + self._offset2ids.clear() + self._pqlite.clear() def __eq__(self, other): """In pqlite backend, data are considered as identical if configs point to the same database source""" From 5b40d7b444e2ef79fffe03498bbd376b71196051 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 28 Jan 2022 10:59:03 +0800 Subject: [PATCH 14/23] fix: impl mixin --- docarray/array/storage/pqlite/getsetdel.py | 67 +++++++++++++++------- docarray/array/storage/pqlite/helper.py | 6 ++ docarray/array/storage/pqlite/seqlike.py | 10 ++-- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/docarray/array/storage/pqlite/getsetdel.py b/docarray/array/storage/pqlite/getsetdel.py index cdd31302c2e..b722b9f4867 100644 --- a/docarray/array/storage/pqlite/getsetdel.py +++ b/docarray/array/storage/pqlite/getsetdel.py @@ -13,14 +13,28 @@ class GetSetDelMixin(BaseGetSetDelMixin): # essential methods start - def _del_doc_by_id(self, _id: str): - offset = self._offset2ids.get_offset_by_id(_id) - self._offset2ids.del_at_offset(offset, commit=True) - self._pqlite.delete([_id]) + def _get_doc_by_offset(self, index: int) -> 'Document': + doc_id = self._offset2ids.get_id_by_offset(index) + if doc_id is not None: + return self._pqlite.get_doc_by_id(doc_id) + return None - def _del_doc_by_offset(self, offset: int): - _id = self._offset2ids.get_id_by_offset(offset) - self._pqlite.delete([_id]) + def _get_doc_by_id(self, id: str) -> 'Document': + return self._pqlite.get_doc_by_id(id) + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + ids = self._offset2ids.get_ids_by_offsets(offsets) + self._get_docs_by_ids(ids) + + def _get_docs_by_ids(self, ids: str) -> Iterable['Document']: + return [self._get_doc_by_id(k) for k in ids] + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._get_docs_by_offsets(range(len(self))[_slice]) + + def _get_docs_by_mask(self, mask: Sequence[bool]): + offsets = [i for i, m in enumerate(mask) if m is True] + return self._get_docs_by_offsets(offsets) def _set_doc_by_offset(self, offset: int, value: 'Document'): self._offset2ids.set_at_offset(offset, value.id) @@ -30,21 +44,32 @@ def _set_doc_by_offset(self, offset: int, value: 'Document'): self._pqlite.update(docs) def _set_doc_by_id(self, _id: str, value: 'Document'): - docs = DocumentArrayInMemory([value]) - if docs.embeddings is None: - docs.embeddings = np.zeros((1, self._pqlite.dim)) - self._pqlite.update(docs) + offset = self._offset2ids.get_offset_by_id(_id) + self._set_doc_by_offset(offset, value) - def _get_doc_by_offset(self, index: int) -> 'Document': - doc_id = self._offset2ids.get_id_by_offset(index) - if doc_id is not None: - return self._pqlite.get_doc_by_id(doc_id) + def _set_doc_value_pairs( + self, docs: Iterable['Document'], values: Iterable['Document'] + ): + for _d, _v in zip(docs, values): + self._set_doc_by_id(_d.id, _v) - def _get_doc_by_id(self, id: str) -> 'Document': - return self._pqlite.get_doc_by_id(id) + def _del_doc_by_id(self, _id: str): + offset = self._offset2ids.get_offset_by_id(_id) + self._offset2ids.del_at_offset(offset, commit=True) + self._pqlite.delete([_id]) - def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: - return [self._get_doc_by_offset(offset) for offset in offsets] + def _del_doc_by_offset(self, offset: int): + _id = self._offset2ids.get_id_by_offset(offset) + self._pqlite.delete([_id]) - def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: - return self._get_docs_by_offsets(range(len(self))[_slice]) + def _del_doc_by_offsets(self, offsets: Sequence[int]): + ids = [self._offset2ids.get_id_by_offset(offset) for offset in offsets] + self._pqlite.delete(ids) + + def _del_docs_by_slice(self, _slice: slice): + offsets = range(len(self))[_slice] + self._del_doc_by_offsets(offsets) + + def _del_docs_by_mask(self, mask: Sequence[bool]): + offsets = [i for i, m in enumerate(mask) if m is True] + self._del_doc_by_offsets(offsets) diff --git a/docarray/array/storage/pqlite/helper.py b/docarray/array/storage/pqlite/helper.py index 34f4efe2990..88dfc535a73 100644 --- a/docarray/array/storage/pqlite/helper.py +++ b/docarray/array/storage/pqlite/helper.py @@ -48,6 +48,12 @@ def get_id_by_offset(self, offset: int): result = self._conn.execute(sql, (offset,)).fetchone() return result[0] + def get_ids_by_offsets(self, offsets: List[int]) -> List[str]: + return [self.get_id_by_offset(offset) for offset in offsets] + + def get_offsets_by_ids(self, ids: List[str]) -> List[int]: + return [self.get_offset_by_id(k) for k in ids] + def get_offset_by_id(self, doc_id: str): sql = f'SELECT offset FROM {self.name} WHERE doc_id = ? LIMIT 1;' result = self._conn.execute(sql, (doc_id,)).fetchone() diff --git a/docarray/array/storage/pqlite/seqlike.py b/docarray/array/storage/pqlite/seqlike.py index 0ebcf7efe57..6d50411178f 100644 --- a/docarray/array/storage/pqlite/seqlike.py +++ b/docarray/array/storage/pqlite/seqlike.py @@ -31,6 +31,11 @@ def extend(self, values: Iterable['Document']) -> None: self._pqlite.index(docs) self._offset2ids.extend_doc_ids([value.id for value in values]) + def clear(self): + """Clear the data of :class:`DocumentArray`""" + self._offset2ids.clear() + self._pqlite.clear() + def __del__(self) -> None: if not self._persist: self._offset2ids.clear() @@ -59,11 +64,6 @@ def __contains__(self, x: Union[str, 'Document']): else: return False - def clear(self): - """Clear the data of :class:`DocumentArray`""" - self._offset2ids.clear() - self._pqlite.clear() - def __bool__(self): """To simulate ```l = []; if l: ...``` From f2b0bad695ea516d1506910efd3f2038f88c1aa6 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 28 Jan 2022 15:45:49 +0800 Subject: [PATCH 15/23] fix: bugs --- docarray/array/storage/pqlite/getsetdel.py | 35 +++++++++++++++------- docarray/array/storage/pqlite/helper.py | 23 ++++++++------ docarray/array/storage/pqlite/seqlike.py | 6 ++-- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/docarray/array/storage/pqlite/getsetdel.py b/docarray/array/storage/pqlite/getsetdel.py index b722b9f4867..67efca32331 100644 --- a/docarray/array/storage/pqlite/getsetdel.py +++ b/docarray/array/storage/pqlite/getsetdel.py @@ -13,21 +13,27 @@ class GetSetDelMixin(BaseGetSetDelMixin): # essential methods start - def _get_doc_by_offset(self, index: int) -> 'Document': - doc_id = self._offset2ids.get_id_by_offset(index) - if doc_id is not None: - return self._pqlite.get_doc_by_id(doc_id) - return None - - def _get_doc_by_id(self, id: str) -> 'Document': - return self._pqlite.get_doc_by_id(id) + def _get_doc_by_offset(self, offset: int) -> 'Document': + offset = len(self) + offset if offset < 0 else offset + doc_id = self._offset2ids.get_id_by_offset(offset) + doc = self._pqlite.get_doc_by_id(doc_id) if doc_id else None + if doc is None: + raise IndexError('index out of range') + return doc + + def _get_doc_by_id(self, _id: str) -> 'Document': + doc = self._pqlite.get_doc_by_id(_id) + if doc is None: + raise KeyError(f'Can not find Document with id=`{_id}`') + return doc def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: ids = self._offset2ids.get_ids_by_offsets(offsets) - self._get_docs_by_ids(ids) + return self._get_docs_by_ids(ids) def _get_docs_by_ids(self, ids: str) -> Iterable['Document']: - return [self._get_doc_by_id(k) for k in ids] + for _id in ids: + yield self._get_doc_by_id(_id) def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: return self._get_docs_by_offsets(range(len(self))[_slice]) @@ -37,6 +43,7 @@ def _get_docs_by_mask(self, mask: Sequence[bool]): return self._get_docs_by_offsets(offsets) def _set_doc_by_offset(self, offset: int, value: 'Document'): + offset = len(self) + offset if offset < 0 else offset self._offset2ids.set_at_offset(offset, value.id) docs = DocumentArrayInMemory([value]) if docs.embeddings is None: @@ -59,11 +66,17 @@ def _del_doc_by_id(self, _id: str): self._pqlite.delete([_id]) def _del_doc_by_offset(self, offset: int): + offset = len(self) + offset if offset < 0 else offset _id = self._offset2ids.get_id_by_offset(offset) + self._offset2ids.del_at_offset(offset) self._pqlite.delete([_id]) def _del_doc_by_offsets(self, offsets: Sequence[int]): - ids = [self._offset2ids.get_id_by_offset(offset) for offset in offsets] + ids = [] + for offset in offsets: + _id = self._offset2ids.get_id_by_offset(offset) + ids.append(_id) + self._offset2ids.del_at_offset(offset) self._pqlite.delete(ids) def _del_docs_by_slice(self, _slice: slice): diff --git a/docarray/array/storage/pqlite/helper.py b/docarray/array/storage/pqlite/helper.py index 88dfc535a73..05aa4fe47fb 100644 --- a/docarray/array/storage/pqlite/helper.py +++ b/docarray/array/storage/pqlite/helper.py @@ -36,17 +36,18 @@ def size(self): def extend_doc_ids(self, doc_ids: List[str], commit: bool = True): offsets = [self.size + i for i in range(len(doc_ids))] - self._insert(list(zip(offsets, doc_ids)), commit=commit) + offset_ids = list(zip(offsets, doc_ids)) + self._insert(offset_ids, commit=commit) def _insert(self, offset_ids: List[Tuple[int, str]], commit: bool = True): sql = f'INSERT INTO {self.name}(offset, doc_id) VALUES (?, ?);' self.execute_many(sql, offset_ids, commit=commit) - self._size += len(offset_ids) + self._size = self.size + len(offset_ids) def get_id_by_offset(self, offset: int): sql = f'SELECT doc_id FROM {self.name} WHERE offset = ? LIMIT 1;' result = self._conn.execute(sql, (offset,)).fetchone() - return result[0] + return result[0] if result else None def get_ids_by_offsets(self, offsets: List[int]) -> List[str]: return [self.get_id_by_offset(offset) for offset in offsets] @@ -57,14 +58,12 @@ def get_offsets_by_ids(self, ids: List[str]) -> List[int]: def get_offset_by_id(self, doc_id: str): sql = f'SELECT offset FROM {self.name} WHERE doc_id = ? LIMIT 1;' result = self._conn.execute(sql, (doc_id,)).fetchone() - return result[0] + return result[0] if result else None def del_at_offset(self, offset: int, commit: bool = True): sql = f'DELETE FROM {self.name} WHERE offset=?' self._conn.execute(sql, (offset,)) - self.shift_offset(offset, shift_step=1, direction='left', commit=False) - if commit: - self.commit() + self.shift_offset(offset, shift_step=1, direction='left', commit=commit) self._size -= 1 @@ -73,8 +72,14 @@ def insert_at_offset(self, offset: int, doc_id: str, commit: bool = True): self._insert([(offset, doc_id)], commit=commit) def set_at_offset(self, offset: int, doc_id: str, commit: bool = True): - sql = f'UPDATE {self.name} SET doc_id={doc_id} WHERE offset = ?' - self._conn.execute(sql, (offset,)) + sql = f'UPDATE {self.name} SET doc_id=? WHERE offset = ?' + self._conn.execute( + sql, + ( + doc_id, + offset, + ), + ) if commit: self.commit() diff --git a/docarray/array/storage/pqlite/seqlike.py b/docarray/array/storage/pqlite/seqlike.py index 6d50411178f..17f80d13606 100644 --- a/docarray/array/storage/pqlite/seqlike.py +++ b/docarray/array/storage/pqlite/seqlike.py @@ -29,7 +29,7 @@ def extend(self, values: Iterable['Document']) -> None: if doc.embedding is None: doc.embedding = np.zeros(self._pqlite.dim, dtype=np.float32) self._pqlite.index(docs) - self._offset2ids.extend_doc_ids([value.id for value in values]) + self._offset2ids.extend_doc_ids([doc.id for doc in docs]) def clear(self): """Clear the data of :class:`DocumentArray`""" @@ -58,9 +58,9 @@ def __iter__(self) -> Iterator['Document']: def __contains__(self, x: Union[str, 'Document']): if isinstance(x, str): - return self._offset2id.get_offset_by_id(x) is not None + return self._offset2ids.get_offset_by_id(x) is not None elif isinstance(x, Document): - return self._offset2id.get_offset_by_id(x.id) is not None + return self._offset2ids.get_offset_by_id(x.id) is not None else: return False From 1b9e4c8f187f2866957c101d2bd309927592d706 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 28 Jan 2022 16:53:10 +0800 Subject: [PATCH 16/23] fix: bugs --- docarray/array/storage/pqlite/getsetdel.py | 6 +++--- docarray/array/storage/pqlite/helper.py | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/docarray/array/storage/pqlite/getsetdel.py b/docarray/array/storage/pqlite/getsetdel.py index 67efca32331..bed10a3e56d 100644 --- a/docarray/array/storage/pqlite/getsetdel.py +++ b/docarray/array/storage/pqlite/getsetdel.py @@ -74,9 +74,9 @@ def _del_doc_by_offset(self, offset: int): def _del_doc_by_offsets(self, offsets: Sequence[int]): ids = [] for offset in offsets: - _id = self._offset2ids.get_id_by_offset(offset) - ids.append(_id) - self._offset2ids.del_at_offset(offset) + ids.append(self._offset2ids.get_id_by_offset(offset)) + + self._offset2ids.del_at_offsets(offsets) self._pqlite.delete(ids) def _del_docs_by_slice(self, _slice: slice): diff --git a/docarray/array/storage/pqlite/helper.py b/docarray/array/storage/pqlite/helper.py index 05aa4fe47fb..50ec12981da 100644 --- a/docarray/array/storage/pqlite/helper.py +++ b/docarray/array/storage/pqlite/helper.py @@ -25,6 +25,9 @@ def clear(self): super().clear() self._size = None + def __len__(self): + return self.size + @property def size(self): if self._size is None: @@ -45,6 +48,7 @@ def _insert(self, offset_ids: List[Tuple[int, str]], commit: bool = True): self._size = self.size + len(offset_ids) def get_id_by_offset(self, offset: int): + offset = len(self) + offset if offset < 0 else offset sql = f'SELECT doc_id FROM {self.name} WHERE offset = ? LIMIT 1;' result = self._conn.execute(sql, (offset,)).fetchone() return result[0] if result else None @@ -56,22 +60,31 @@ def get_offsets_by_ids(self, ids: List[str]) -> List[int]: return [self.get_offset_by_id(k) for k in ids] def get_offset_by_id(self, doc_id: str): - sql = f'SELECT offset FROM {self.name} WHERE doc_id = ? LIMIT 1;' + sql = f'SELECT offset FROM {self.name} WHERE doc_id=? LIMIT 1;' result = self._conn.execute(sql, (doc_id,)).fetchone() return result[0] if result else None def del_at_offset(self, offset: int, commit: bool = True): + offset = len(self) + offset if offset < 0 else offset sql = f'DELETE FROM {self.name} WHERE offset=?' self._conn.execute(sql, (offset,)) self.shift_offset(offset, shift_step=1, direction='left', commit=commit) self._size -= 1 + def del_at_offsets(self, offsets: List[int], commit: bool = True): + for offset in sorted(offsets, reverse=True): + self.del_at_offset(offset, commit=False) + if commit: + self.commit() + def insert_at_offset(self, offset: int, doc_id: str, commit: bool = True): + offset = len(self) + offset if offset < 0 else offset self.shift_offset(offset - 1, shift_step=1, direction='right', commit=False) self._insert([(offset, doc_id)], commit=commit) def set_at_offset(self, offset: int, doc_id: str, commit: bool = True): + offset = len(self) + offset if offset < 0 else offset sql = f'UPDATE {self.name} SET doc_id=? WHERE offset = ?' self._conn.execute( sql, From 8fbacfcaf8787bf9ff80e7951bdb9bfc9868b0a5 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 28 Jan 2022 17:02:53 +0800 Subject: [PATCH 17/23] fix: add draft unittest for pqlite --- tests/unit/array/test_pqlite_indexing.py | 326 +++++++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 tests/unit/array/test_pqlite_indexing.py diff --git a/tests/unit/array/test_pqlite_indexing.py b/tests/unit/array/test_pqlite_indexing.py new file mode 100644 index 00000000000..4a71614a123 --- /dev/null +++ b/tests/unit/array/test_pqlite_indexing.py @@ -0,0 +1,326 @@ +import numpy as np +import pytest + +from docarray import DocumentArray, Document + + +@pytest.fixture +def docs(): + yield (Document(text=f'{j}') for j in range(100)) + + +@pytest.fixture +def indices(): + yield (i for i in [-2, 0, 2]) + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_getter_int_str(docs, storage): + docs = DocumentArray(docs, storage=storage) + # getter + assert docs[99].text == '99' + assert docs[np.int(99)].text == '99' + assert docs[-1].text == '99' + assert docs[0].text == '0' + # string index + assert docs[docs[0].id].text == '0' + assert docs[docs[99].id].text == '99' + assert docs[docs[-1].id].text == '99' + + with pytest.raises(IndexError): + r = docs[100] + print(r) + + with pytest.raises(KeyError): + docs['adsad'] + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_setter_int_str(docs, storage): + docs = DocumentArray(docs, storage=storage) + # setter + docs[99] = Document(text='hello') + docs[0] = Document(text='world') + + assert docs[99].text == 'hello' + assert docs[-1].text == 'hello' + assert docs[0].text == 'world' + + docs[docs[2].id] = Document(text='doc2') + # string index + assert docs[docs[2].id].text == 'doc2' + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_del_int_str(docs, storage, indices): + docs = DocumentArray(docs, storage=storage) + initial_len = len(docs) + deleted_elements = 0 + for pos in indices: + pos_id = docs[pos].id + del docs[pos] + deleted_elements += 1 + assert pos_id not in docs + assert len(docs) == initial_len - deleted_elements + + new_pos_id = docs[pos].id + new_doc_zero = docs[pos] + del docs[new_pos_id] + deleted_elements += 1 + assert len(docs) == initial_len - deleted_elements + assert pos_id not in docs + assert new_doc_zero not in docs + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_slice(docs, storage): + docs = DocumentArray(docs, storage=storage) + # getter + assert len(docs[1:5]) == 4 + assert len(docs[1:100:5]) == 20 # 1 to 100, sep with 5 + + # setter + with pytest.raises(TypeError, match='an iterable'): + docs[1:5] = Document(text='repl') + + docs[1:5] = [Document(text=f'repl{j}') for j in range(4)] + for d in docs[1:5]: + assert d.text.startswith('repl') + assert len(docs) == 100 + + # del + zero_doc = docs[0] + twenty_doc = docs[20] + del docs[0:20] + assert len(docs) == 80 + assert zero_doc not in docs + assert twenty_doc in docs + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_sequence_bool_index(docs, storage): + docs = DocumentArray(docs, storage=storage) + # getter + mask = [True, False] * 50 + assert len(docs[mask]) == 50 + assert len(docs[[True, False]]) == 1 + + # setter + mask = [True, False] * 50 + # docs[mask] = [Document(text=f'repl{j}') for j in range(50)] + docs[mask, 'text'] = [f'repl{j}' for j in range(50)] + + for idx, d in enumerate(docs): + if idx % 2 == 0: + # got replaced + assert d.text.startswith('repl') + else: + assert isinstance(d.text, str) + + # del + del docs[mask] + assert len(docs) == 50 + + +@pytest.mark.parametrize('nparray', [lambda x: x, np.array, tuple]) +@pytest.mark.parametrize('storage', ['pqlite']) +def test_sequence_int(docs, nparray, storage): + docs = DocumentArray(docs, storage=storage) + # getter + idx = nparray([1, 3, 5, 7, -1, -2]) + assert len(docs[idx]) == len(idx) + + # setter + docs[idx] = [Document(text='repl') for _ in range(len(idx))] + for _id in idx: + assert docs[_id].text == 'repl' + + # del + idx = [-3, -4, -5, 9, 10, 11] + del docs[idx] + assert len(docs) == 100 - len(idx) + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_sequence_str(docs, storage): + docs = DocumentArray(docs, storage=storage) + # getter + idx = [d.id for d in docs[1, 3, 5, 7, -1, -2]] + + assert len(docs[idx]) == len(idx) + assert len(docs[tuple(idx)]) == len(idx) + + # setter + docs[idx] = [Document(text='repl') for _ in range(len(idx))] + idx = [d.id for d in docs[1, 3, 5, 7, -1, -2]] + for _id in idx: + assert docs[_id].text == 'repl' + + # del + idx = [d.id for d in docs[-3, -4, -5, 9, 10, 11]] + del docs[idx] + assert len(docs) == 100 - len(idx) + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_docarray_list_tuple(docs, storage): + docs = DocumentArray(docs, storage=storage) + assert isinstance(docs[99, 98], DocumentArray) + assert len(docs[99, 98]) == 2 + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_path_syntax_indexing(storage): + da = DocumentArray.empty(3) + for d in da: + d.chunks = DocumentArray.empty(5) + d.matches = DocumentArray.empty(7) + for c in d.chunks: + c.chunks = DocumentArray.empty(3) + + if storage == 'sqlite': + da = DocumentArray(da, storage=storage) + assert len(da['@c']) == 3 * 5 + assert len(da['@c:1']) == 3 + assert len(da['@c-1:']) == 3 + assert len(da['@c1']) == 3 + assert len(da['@c-2:']) == 3 * 2 + assert len(da['@c1:3']) == 3 * 2 + assert len(da['@c1:3c']) == (3 * 2) * 3 + assert len(da['@c1:3,c1:3c']) == (3 * 2) + (3 * 2) * 3 + assert len(da['@c 1:3 , c 1:3 c']) == (3 * 2) + (3 * 2) * 3 + assert len(da['@cc']) == 3 * 5 * 3 + assert len(da['@cc,m']) == 3 * 5 * 3 + 3 * 7 + assert len(da['@r:1cc,m']) == 1 * 5 * 3 + 3 * 7 + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_attribute_indexing(storage): + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(10)) + + for v in da[:, 'id']: + assert v + da[:, 'mime_type'] = [f'type {j}' for j in range(10)] + for v in da[:, 'mime_type']: + assert v + del da[:, 'mime_type'] + for v in da[:, 'mime_type']: + assert not v + + da[:, ['text', 'mime_type']] = [ + [f'hello {j}' for j in range(10)], + [f'type {j}' for j in range(10)], + ] + da.summary() + + for v in da[:, ['mime_type', 'text']]: + for vv in v: + assert vv + + +# TODO: enable weaviate storage test +@pytest.mark.parametrize('storage', ['pqlite']) +def test_tensor_attribute_selector(storage): + import scipy.sparse + + sp_embed = np.random.random([3, 10]) + # sp_embed[sp_embed > 0.1] = 0 + # sp_embed = scipy.sparse.coo_matrix(sp_embed) + + da = DocumentArray(storage=storage, config={'n_dim': 10}) + da.extend(DocumentArray.empty(3)) + + assert len(da) == 3 + + da[:, 'embedding'] = sp_embed + + assert da[:, 'embedding'].shape == (3, 10) + + for d in da: + assert d.embedding.shape == (10,) + + v1, v2 = da[:, ['embedding', 'id']] + # assert isinstance(v1, scipy.sparse.coo_matrix) + assert isinstance(v2, list) + + v1, v2 = da[:, ['id', 'embedding']] + # assert isinstance(v2, scipy.sparse.coo_matrix) + assert isinstance(v1, list) + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_advance_selector_mixed(storage): + da = DocumentArray(storage=storage, config={'n_dim': 3}) + da.extend(DocumentArray.empty(10)) + da.embeddings = np.random.random([10, 3]) + + da.match(da, exclude_self=True) + + assert len(da[:, ('id', 'embedding', 'matches')]) == 3 + assert len(da[:, ('id', 'embedding', 'matches')][0]) == 10 + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_single_boolean_and_padding(storage): + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(3)) + + with pytest.raises(IndexError): + da[True] + + with pytest.raises(IndexError): + da[True] = Document() + + with pytest.raises(IndexError): + del da[True] + + assert len(da[True, False]) == 1 + assert len(da[False, False]) == 0 + + +@pytest.mark.parametrize('storage', ['pqlite']) +def test_edge_case_two_strings(storage): + # getitem + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) + assert da['1', 'id'] == '1' + assert len(da['1', '2']) == 2 + assert isinstance(da['1', '2'], DocumentArray) + with pytest.raises(KeyError): + da['hello', '2'] + with pytest.raises(AttributeError): + da['1', 'hello'] + assert len(da['1', '2', '3']) == 3 + assert isinstance(da['1', '2', '3'], DocumentArray) + + # delitem + del da['1', '2'] + assert len(da) == 1 + + da = DocumentArray( + [Document(id=str(i), text='hey') for i in range(3)], storage=storage + ) + del da['1', 'text'] + assert len(da) == 3 + assert not da[1].text + + del da['2', 'hello'] + + # setitem + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) + da['1', '2'] = DocumentArray.empty(2) + assert da[0].id != '1' + assert da[1].id != '2' + + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) + da['1', 'text'] = 'hello' + assert da['1'].text == 'hello' + + with pytest.raises(ValueError): + da['1', 'hellohello'] = 'hello' From ef983a420fa5fe87fca871a18d4a87a1ac8968f2 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 28 Jan 2022 21:17:50 +0800 Subject: [PATCH 18/23] fix: conflict --- docarray/array/document.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docarray/array/document.py b/docarray/array/document.py index 826e8ec2bf8..fd920fbb5f8 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -41,6 +41,10 @@ def __new__(cls, *args, storage: str = 'memory', **kwargs) -> 'DocumentArrayLike from .sqlite import DocumentArraySqlite instance = super().__new__(DocumentArraySqlite) + elif storage == 'pqlite': + from .pqlite import DocumentArrayPqlite + + instance = super().__new__(DocumentArrayPqlite) elif storage == 'weaviate': from .weaviate import DocumentArrayWeaviate From 9975227df1876fb799f5155934c128d81dadfc5c Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 28 Jan 2022 21:52:33 +0800 Subject: [PATCH 19/23] fix: unittest --- docarray/array/storage/pqlite/backend.py | 15 +++++++++------ docarray/array/storage/pqlite/helper.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docarray/array/storage/pqlite/backend.py b/docarray/array/storage/pqlite/backend.py index dd783604458..d8901eb2be1 100644 --- a/docarray/array/storage/pqlite/backend.py +++ b/docarray/array/storage/pqlite/backend.py @@ -19,8 +19,7 @@ class PqliteConfig: n_dim: int = 1 metric: str = 'cosine' - data_path: str = 'data' - table_name: Optional[str] = None + data_path: Optional[str] = None class BackendMixin(BaseBackendMixin): @@ -36,10 +35,14 @@ def _init_storage( if isinstance(config, dict): config = dataclass_from_dict(PqliteConfig, config) - self._config = config + self._persist = bool(config.data_path) + + if not self._persist: + from tempfile import TemporaryDirectory - table_name = config.table_name - self._persist = bool(table_name) + config.data_path = TemporaryDirectory().name + + self._config = config from pqlite import PQLite from .helper import OffsetMapping @@ -49,7 +52,7 @@ def _init_storage( self._pqlite = PQLite(n_dim, **config) self._offset2ids = OffsetMapping( - name=table_name or 'docarray', + name='docarray', data_path=config['data_path'], in_memory=False, ) diff --git a/docarray/array/storage/pqlite/helper.py b/docarray/array/storage/pqlite/helper.py index 50ec12981da..e3658eea246 100644 --- a/docarray/array/storage/pqlite/helper.py +++ b/docarray/array/storage/pqlite/helper.py @@ -51,7 +51,7 @@ def get_id_by_offset(self, offset: int): offset = len(self) + offset if offset < 0 else offset sql = f'SELECT doc_id FROM {self.name} WHERE offset = ? LIMIT 1;' result = self._conn.execute(sql, (offset,)).fetchone() - return result[0] if result else None + return str(result[0]) if result is not None else None def get_ids_by_offsets(self, offsets: List[int]) -> List[str]: return [self.get_id_by_offset(offset) for offset in offsets] From 513b73df1a7b8192641660d98e5c0ebf684ce608 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 28 Jan 2022 22:08:37 +0800 Subject: [PATCH 20/23] fix: revert unittest --- docarray/array/storage/pqlite/backend.py | 1 + tests/unit/array/test_pqlite_indexing.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/docarray/array/storage/pqlite/backend.py b/docarray/array/storage/pqlite/backend.py index d8901eb2be1..461032b1c7d 100644 --- a/docarray/array/storage/pqlite/backend.py +++ b/docarray/array/storage/pqlite/backend.py @@ -19,6 +19,7 @@ class PqliteConfig: n_dim: int = 1 metric: str = 'cosine' + serialize_protocol: str = 'pickle' data_path: Optional[str] = None diff --git a/tests/unit/array/test_pqlite_indexing.py b/tests/unit/array/test_pqlite_indexing.py index 4a71614a123..fa2ce4f7ee6 100644 --- a/tests/unit/array/test_pqlite_indexing.py +++ b/tests/unit/array/test_pqlite_indexing.py @@ -6,7 +6,7 @@ @pytest.fixture def docs(): - yield (Document(text=f'{j}') for j in range(100)) + yield (Document(text=j) for j in range(100)) @pytest.fixture @@ -18,14 +18,14 @@ def indices(): def test_getter_int_str(docs, storage): docs = DocumentArray(docs, storage=storage) # getter - assert docs[99].text == '99' - assert docs[np.int(99)].text == '99' - assert docs[-1].text == '99' - assert docs[0].text == '0' + assert docs[99].text == 99 + assert docs[np.int(99)].text == 99 + assert docs[-1].text == 99 + assert docs[0].text == 0 # string index - assert docs[docs[0].id].text == '0' - assert docs[docs[99].id].text == '99' - assert docs[docs[-1].id].text == '99' + assert docs[docs[0].id].text == 0 + assert docs[docs[99].id].text == 99 + assert docs[docs[-1].id].text == 99 with pytest.raises(IndexError): r = docs[100] @@ -115,7 +115,7 @@ def test_sequence_bool_index(docs, storage): # got replaced assert d.text.startswith('repl') else: - assert isinstance(d.text, str) + assert isinstance(d.text, int) # del del docs[mask] @@ -225,8 +225,8 @@ def test_tensor_attribute_selector(storage): import scipy.sparse sp_embed = np.random.random([3, 10]) - # sp_embed[sp_embed > 0.1] = 0 - # sp_embed = scipy.sparse.coo_matrix(sp_embed) + sp_embed[sp_embed > 0.1] = 0 + sp_embed = scipy.sparse.coo_matrix(sp_embed) da = DocumentArray(storage=storage, config={'n_dim': 10}) da.extend(DocumentArray.empty(3)) @@ -238,7 +238,7 @@ def test_tensor_attribute_selector(storage): assert da[:, 'embedding'].shape == (3, 10) for d in da: - assert d.embedding.shape == (10,) + assert d.embedding.shape == (1, 10) v1, v2 = da[:, ['embedding', 'id']] # assert isinstance(v1, scipy.sparse.coo_matrix) From 355b66bf57ad2c53fa65dd3c8c2db6454f075f03 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 28 Jan 2022 22:21:32 +0800 Subject: [PATCH 21/23] fix: add pqlite dependence --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index b645f8bd90c..7083bad4da0 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ 'fastapi', 'uvicorn', 'weaviate-client~=3.3.0', + 'pqlite>=0.2.1', ], 'test': [ 'pytest', From 4bdc35c8cbdf778c04cc5698d9de9fdacfc55c78 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Wed, 2 Feb 2022 07:34:11 +0100 Subject: [PATCH 22/23] test: merge pqlite indexing with advance ind --- tests/unit/array/test_advance_indexing.py | 38 ++++++++++++++--------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index e12c5080faf..100f51bf35f 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -15,7 +15,7 @@ def indices(): yield (i for i in [-2, 0, 2]) -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_getter_int_str(docs, storage, start_weaviate): docs = DocumentArray(docs, storage=storage) # getter @@ -35,7 +35,7 @@ def test_getter_int_str(docs, storage, start_weaviate): docs['adsad'] -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_setter_int_str(docs, storage, start_weaviate): docs = DocumentArray(docs, storage=storage) # setter @@ -51,7 +51,7 @@ def test_setter_int_str(docs, storage, start_weaviate): assert docs[docs[2].id].text == 'doc2' -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_del_int_str(docs, storage, indices): docs = DocumentArray(docs, storage=storage) initial_len = len(docs) @@ -72,7 +72,7 @@ def test_del_int_str(docs, storage, indices): assert new_doc_zero not in docs -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_slice(docs, storage, start_weaviate): docs = DocumentArray(docs, storage=storage) # getter @@ -97,7 +97,7 @@ def test_slice(docs, storage, start_weaviate): assert twenty_doc in docs -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_sequence_bool_index(docs, storage, start_weaviate): docs = DocumentArray(docs, storage=storage) # getter @@ -134,7 +134,7 @@ def test_sequence_bool_index(docs, storage, start_weaviate): @pytest.mark.parametrize('nparray', [lambda x: x, np.array, tuple]) -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_sequence_int(docs, nparray, storage, start_weaviate): docs = DocumentArray(docs, storage=storage) # getter @@ -157,7 +157,7 @@ def test_sequence_int(docs, nparray, storage, start_weaviate): assert docs[9].text == 'new' -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_sequence_str(docs, storage, start_weaviate): docs = DocumentArray(docs, storage=storage) # getter @@ -178,14 +178,14 @@ def test_sequence_str(docs, storage, start_weaviate): assert len(docs) == 100 - len(idx) -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_docarray_list_tuple(docs, storage, start_weaviate): docs = DocumentArray(docs, storage=storage) assert isinstance(docs[99, 98], DocumentArray) assert len(docs[99, 98]) == 2 -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_path_syntax_indexing(storage, start_weaviate): da = DocumentArray.empty(3) for d in da: @@ -282,7 +282,7 @@ def test_path_syntax_indexing_set(storage, start_weaviate): @pytest.mark.parametrize('size', [1, 5]) -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_attribute_indexing(storage, start_weaviate, size): da = DocumentArray(storage=storage) da.extend(DocumentArray.empty(size)) @@ -307,7 +307,7 @@ def test_attribute_indexing(storage, start_weaviate, size): assert vv -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_tensor_attribute_selector(storage): import scipy.sparse @@ -315,7 +315,11 @@ def test_tensor_attribute_selector(storage): sp_embed[sp_embed > 0.1] = 0 sp_embed = scipy.sparse.coo_matrix(sp_embed) - da = DocumentArray(storage=storage) + if storage == 'pqlite': + da = DocumentArray(storage=storage, config={'n_dim': 10}) + else: + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(3)) da[:, 'embedding'] = sp_embed @@ -337,9 +341,13 @@ def test_tensor_attribute_selector(storage): # TODO: since match function is not implemented, this test will # not work with weaviate storage atm, will be addressed in # next version -@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'pqlite']) def test_advance_selector_mixed(storage): + da = DocumentArray(storage=storage) + if storage == 'pqlite': + da = DocumentArray(storage=storage, config={'n_dim': 3}) + da.extend(DocumentArray.empty(10)) da.embeddings = np.random.random([10, 3]) @@ -349,7 +357,7 @@ def test_advance_selector_mixed(storage): assert len(da[:, ('id', 'embedding', 'matches')][0]) == 10 -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_single_boolean_and_padding(storage, start_weaviate): da = DocumentArray(storage=storage) da.extend(DocumentArray.empty(3)) @@ -368,7 +376,7 @@ def test_single_boolean_and_padding(storage, start_weaviate): assert len(da[True, False, False]) == 1 -@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) +@pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate', 'pqlite']) def test_edge_case_two_strings(storage, start_weaviate): # getitem da = DocumentArray( From 748542450195f0c28bc902dd9aed848d1bb0593d Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Wed, 2 Feb 2022 12:04:51 +0100 Subject: [PATCH 23/23] test: revert value error --- tests/unit/array/test_pqlite_indexing.py | 312 ----------------------- 1 file changed, 312 deletions(-) diff --git a/tests/unit/array/test_pqlite_indexing.py b/tests/unit/array/test_pqlite_indexing.py index fa2ce4f7ee6..bdd01798ef1 100644 --- a/tests/unit/array/test_pqlite_indexing.py +++ b/tests/unit/array/test_pqlite_indexing.py @@ -12,315 +12,3 @@ def docs(): @pytest.fixture def indices(): yield (i for i in [-2, 0, 2]) - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_getter_int_str(docs, storage): - docs = DocumentArray(docs, storage=storage) - # getter - assert docs[99].text == 99 - assert docs[np.int(99)].text == 99 - assert docs[-1].text == 99 - assert docs[0].text == 0 - # string index - assert docs[docs[0].id].text == 0 - assert docs[docs[99].id].text == 99 - assert docs[docs[-1].id].text == 99 - - with pytest.raises(IndexError): - r = docs[100] - print(r) - - with pytest.raises(KeyError): - docs['adsad'] - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_setter_int_str(docs, storage): - docs = DocumentArray(docs, storage=storage) - # setter - docs[99] = Document(text='hello') - docs[0] = Document(text='world') - - assert docs[99].text == 'hello' - assert docs[-1].text == 'hello' - assert docs[0].text == 'world' - - docs[docs[2].id] = Document(text='doc2') - # string index - assert docs[docs[2].id].text == 'doc2' - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_del_int_str(docs, storage, indices): - docs = DocumentArray(docs, storage=storage) - initial_len = len(docs) - deleted_elements = 0 - for pos in indices: - pos_id = docs[pos].id - del docs[pos] - deleted_elements += 1 - assert pos_id not in docs - assert len(docs) == initial_len - deleted_elements - - new_pos_id = docs[pos].id - new_doc_zero = docs[pos] - del docs[new_pos_id] - deleted_elements += 1 - assert len(docs) == initial_len - deleted_elements - assert pos_id not in docs - assert new_doc_zero not in docs - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_slice(docs, storage): - docs = DocumentArray(docs, storage=storage) - # getter - assert len(docs[1:5]) == 4 - assert len(docs[1:100:5]) == 20 # 1 to 100, sep with 5 - - # setter - with pytest.raises(TypeError, match='an iterable'): - docs[1:5] = Document(text='repl') - - docs[1:5] = [Document(text=f'repl{j}') for j in range(4)] - for d in docs[1:5]: - assert d.text.startswith('repl') - assert len(docs) == 100 - - # del - zero_doc = docs[0] - twenty_doc = docs[20] - del docs[0:20] - assert len(docs) == 80 - assert zero_doc not in docs - assert twenty_doc in docs - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_sequence_bool_index(docs, storage): - docs = DocumentArray(docs, storage=storage) - # getter - mask = [True, False] * 50 - assert len(docs[mask]) == 50 - assert len(docs[[True, False]]) == 1 - - # setter - mask = [True, False] * 50 - # docs[mask] = [Document(text=f'repl{j}') for j in range(50)] - docs[mask, 'text'] = [f'repl{j}' for j in range(50)] - - for idx, d in enumerate(docs): - if idx % 2 == 0: - # got replaced - assert d.text.startswith('repl') - else: - assert isinstance(d.text, int) - - # del - del docs[mask] - assert len(docs) == 50 - - -@pytest.mark.parametrize('nparray', [lambda x: x, np.array, tuple]) -@pytest.mark.parametrize('storage', ['pqlite']) -def test_sequence_int(docs, nparray, storage): - docs = DocumentArray(docs, storage=storage) - # getter - idx = nparray([1, 3, 5, 7, -1, -2]) - assert len(docs[idx]) == len(idx) - - # setter - docs[idx] = [Document(text='repl') for _ in range(len(idx))] - for _id in idx: - assert docs[_id].text == 'repl' - - # del - idx = [-3, -4, -5, 9, 10, 11] - del docs[idx] - assert len(docs) == 100 - len(idx) - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_sequence_str(docs, storage): - docs = DocumentArray(docs, storage=storage) - # getter - idx = [d.id for d in docs[1, 3, 5, 7, -1, -2]] - - assert len(docs[idx]) == len(idx) - assert len(docs[tuple(idx)]) == len(idx) - - # setter - docs[idx] = [Document(text='repl') for _ in range(len(idx))] - idx = [d.id for d in docs[1, 3, 5, 7, -1, -2]] - for _id in idx: - assert docs[_id].text == 'repl' - - # del - idx = [d.id for d in docs[-3, -4, -5, 9, 10, 11]] - del docs[idx] - assert len(docs) == 100 - len(idx) - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_docarray_list_tuple(docs, storage): - docs = DocumentArray(docs, storage=storage) - assert isinstance(docs[99, 98], DocumentArray) - assert len(docs[99, 98]) == 2 - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_path_syntax_indexing(storage): - da = DocumentArray.empty(3) - for d in da: - d.chunks = DocumentArray.empty(5) - d.matches = DocumentArray.empty(7) - for c in d.chunks: - c.chunks = DocumentArray.empty(3) - - if storage == 'sqlite': - da = DocumentArray(da, storage=storage) - assert len(da['@c']) == 3 * 5 - assert len(da['@c:1']) == 3 - assert len(da['@c-1:']) == 3 - assert len(da['@c1']) == 3 - assert len(da['@c-2:']) == 3 * 2 - assert len(da['@c1:3']) == 3 * 2 - assert len(da['@c1:3c']) == (3 * 2) * 3 - assert len(da['@c1:3,c1:3c']) == (3 * 2) + (3 * 2) * 3 - assert len(da['@c 1:3 , c 1:3 c']) == (3 * 2) + (3 * 2) * 3 - assert len(da['@cc']) == 3 * 5 * 3 - assert len(da['@cc,m']) == 3 * 5 * 3 + 3 * 7 - assert len(da['@r:1cc,m']) == 1 * 5 * 3 + 3 * 7 - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_attribute_indexing(storage): - da = DocumentArray(storage=storage) - da.extend(DocumentArray.empty(10)) - - for v in da[:, 'id']: - assert v - da[:, 'mime_type'] = [f'type {j}' for j in range(10)] - for v in da[:, 'mime_type']: - assert v - del da[:, 'mime_type'] - for v in da[:, 'mime_type']: - assert not v - - da[:, ['text', 'mime_type']] = [ - [f'hello {j}' for j in range(10)], - [f'type {j}' for j in range(10)], - ] - da.summary() - - for v in da[:, ['mime_type', 'text']]: - for vv in v: - assert vv - - -# TODO: enable weaviate storage test -@pytest.mark.parametrize('storage', ['pqlite']) -def test_tensor_attribute_selector(storage): - import scipy.sparse - - sp_embed = np.random.random([3, 10]) - sp_embed[sp_embed > 0.1] = 0 - sp_embed = scipy.sparse.coo_matrix(sp_embed) - - da = DocumentArray(storage=storage, config={'n_dim': 10}) - da.extend(DocumentArray.empty(3)) - - assert len(da) == 3 - - da[:, 'embedding'] = sp_embed - - assert da[:, 'embedding'].shape == (3, 10) - - for d in da: - assert d.embedding.shape == (1, 10) - - v1, v2 = da[:, ['embedding', 'id']] - # assert isinstance(v1, scipy.sparse.coo_matrix) - assert isinstance(v2, list) - - v1, v2 = da[:, ['id', 'embedding']] - # assert isinstance(v2, scipy.sparse.coo_matrix) - assert isinstance(v1, list) - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_advance_selector_mixed(storage): - da = DocumentArray(storage=storage, config={'n_dim': 3}) - da.extend(DocumentArray.empty(10)) - da.embeddings = np.random.random([10, 3]) - - da.match(da, exclude_self=True) - - assert len(da[:, ('id', 'embedding', 'matches')]) == 3 - assert len(da[:, ('id', 'embedding', 'matches')][0]) == 10 - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_single_boolean_and_padding(storage): - da = DocumentArray(storage=storage) - da.extend(DocumentArray.empty(3)) - - with pytest.raises(IndexError): - da[True] - - with pytest.raises(IndexError): - da[True] = Document() - - with pytest.raises(IndexError): - del da[True] - - assert len(da[True, False]) == 1 - assert len(da[False, False]) == 0 - - -@pytest.mark.parametrize('storage', ['pqlite']) -def test_edge_case_two_strings(storage): - # getitem - da = DocumentArray( - [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage - ) - assert da['1', 'id'] == '1' - assert len(da['1', '2']) == 2 - assert isinstance(da['1', '2'], DocumentArray) - with pytest.raises(KeyError): - da['hello', '2'] - with pytest.raises(AttributeError): - da['1', 'hello'] - assert len(da['1', '2', '3']) == 3 - assert isinstance(da['1', '2', '3'], DocumentArray) - - # delitem - del da['1', '2'] - assert len(da) == 1 - - da = DocumentArray( - [Document(id=str(i), text='hey') for i in range(3)], storage=storage - ) - del da['1', 'text'] - assert len(da) == 3 - assert not da[1].text - - del da['2', 'hello'] - - # setitem - da = DocumentArray( - [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage - ) - da['1', '2'] = DocumentArray.empty(2) - assert da[0].id != '1' - assert da[1].id != '2' - - da = DocumentArray( - [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage - ) - da['1', 'text'] = 'hello' - assert da['1'].text == 'hello' - - with pytest.raises(ValueError): - da['1', 'hellohello'] = 'hello'