From 2736dded063ec165325bdbd4ebd7eedcc70071c3 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Sun, 16 Jan 2022 11:54:05 +0100 Subject: [PATCH 01/55] refactor(array): use extend in add improve docs --- docarray/array/persistence/__init__.py | 0 docarray/array/persistence/sqlite/__init__.py | 0 docarray/array/persistence/sqlite/base.py | 271 ++++++++++++++++++ docarray/array/persistence/sqlite/dict.py | 140 +++++++++ 4 files changed, 411 insertions(+) create mode 100644 docarray/array/persistence/__init__.py create mode 100644 docarray/array/persistence/sqlite/__init__.py create mode 100644 docarray/array/persistence/sqlite/base.py create mode 100644 docarray/array/persistence/sqlite/dict.py diff --git a/docarray/array/persistence/__init__.py b/docarray/array/persistence/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/persistence/sqlite/__init__.py b/docarray/array/persistence/sqlite/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/persistence/sqlite/base.py b/docarray/array/persistence/sqlite/base.py new file mode 100644 index 00000000000..9c85b461a32 --- /dev/null +++ b/docarray/array/persistence/sqlite/base.py @@ -0,0 +1,271 @@ +import sqlite3 +import warnings +from abc import ABCMeta, abstractmethod +from collections.abc import Hashable +from enum import Enum +from pickle import dumps, loads +from tempfile import NamedTemporaryFile +from typing import Callable, Generic, Optional, TypeVar, Union, cast +from uuid import uuid4 + +T = TypeVar('T') +KT = TypeVar('KT') +VT = TypeVar('VT') +_T = TypeVar('_T') +_S = TypeVar('_S') + + +class RebuildStrategy(Enum): + CHECK_WITH_FIRST_ELEMENT = 1 + ALWAYS = 2 + SKIP = 3 + + +def sanitize_table_name(table_name: str) -> str: + ret = ''.join(c for c in table_name if c.isalnum() or c == '_') + if ret != table_name: + warnings.warn(f'The table name is changed to {ret} due to illegal characters') + return ret + + +def create_random_name(suffix: str) -> str: + return f"{suffix}_{str(uuid4()).replace('-', '')}" + + +def is_hashable(x: object) -> bool: + return isinstance(x, Hashable) + + +class _SqliteCollectionBaseDatabaseDriver(metaclass=ABCMeta): + @classmethod + def initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: + if not cls.is_metadata_table_initialized(cur): + cls.do_initialize_metadata_table(cur) + + @classmethod + def is_metadata_table_initialized(cls, cur: sqlite3.Cursor) -> bool: + try: + cur.execute('SELECT 1 FROM metadata LIMIT 1') + _ = list(cur) + return True + except sqlite3.OperationalError as _: + pass + return False + + @classmethod + def do_initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: + cur.execute( + ''' + CREATE TABLE metadata ( + table_name TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + container_type TEXT NOT NULL, + UNIQUE (table_name, container_type) + ) + ''' + ) + + @classmethod + def initialize_table( + cls, + table_name: str, + container_type_name: str, + schema_version: str, + cur: sqlite3.Cursor, + ) -> None: + if not cls.is_table_initialized( + table_name, container_type_name, schema_version, cur + ): + cls.do_create_table(table_name, container_type_name, schema_version, cur) + cls.do_tidy_table_metadata( + table_name, container_type_name, schema_version, cur + ) + + @classmethod + def is_table_initialized( + self, + table_name: str, + container_type_name: str, + schema_version: str, + cur: sqlite3.Cursor, + ) -> bool: + try: + cur.execute( + 'SELECT schema_version FROM metadata WHERE table_name=? AND container_type=?', + (table_name, container_type_name), + ) + buf = cur.fetchone() + if buf is None: + return False + version = buf[0] + if version != schema_version: + return False + cur.execute(f'SELECT 1 FROM {table_name} LIMIT 1') + _ = list(cur) + return True + except sqlite3.OperationalError as _: + pass + return False + + @classmethod + def do_tidy_table_metadata( + cls, + table_name: str, + container_type_name: str, + schema_version: str, + cur: sqlite3.Cursor, + ) -> None: + cur.execute( + 'INSERT INTO metadata (table_name, schema_version, container_type) VALUES (?, ?, ?)', + (table_name, schema_version, container_type_name), + ) + + @classmethod + @abstractmethod + def do_create_table( + cls, + table_name: str, + container_type_name: str, + schema_version: str, + cur: sqlite3.Cursor, + ) -> None: + ... + + @classmethod + def drop_table( + cls, table_name: str, container_type_name: str, cur: sqlite3.Cursor + ) -> None: + cur.execute( + 'DELETE FROM metadata WHERE table_name=? AND container_type=?', + (table_name, container_type_name), + ) + cur.execute(f'DROP TABLE {table_name}') + + @classmethod + def alter_table_name( + cls, table_name: str, new_table_name: str, cur: sqlite3.Cursor + ) -> None: + cur.execute( + 'UPDATE metadata SET table_name=? WHERE table_name=?', + (new_table_name, table_name), + ) + cur.execute(f'ALTER TABLE {table_name} RENAME TO {new_table_name}') + + +class SqliteCollectionBase(Generic[T], metaclass=ABCMeta): + _driver_class = _SqliteCollectionBaseDatabaseDriver + + def __init__( + self, + connection: Optional[Union[str, sqlite3.Connection]] = None, + table_name: Optional[str] = None, + serializer: Optional[Callable[[T], bytes]] = None, + deserializer: Optional[Callable[[bytes], T]] = None, + persist: bool = True, + rebuild_strategy: RebuildStrategy = RebuildStrategy.CHECK_WITH_FIRST_ELEMENT, + ): + super(SqliteCollectionBase, self).__init__() + self._serializer = ( + cast(Callable[[T], bytes], dumps) if serializer is None else serializer + ) + self._deserializer = ( + cast(Callable[[bytes], T], loads) if deserializer is None else deserializer + ) + self._persist = persist + if connection is None: + self._connection = sqlite3.connect(NamedTemporaryFile().name) + elif isinstance(connection, str): + self._connection = sqlite3.connect(connection) + elif isinstance(connection, sqlite3.Connection): + self._connection = connection + else: + raise TypeError( + f'connection argument must be None or a string or a sqlite3.Connection, not `{type(connection)}`' + ) + self._table_name = ( + sanitize_table_name(create_random_name(self.container_type_name)) + if table_name is None + else sanitize_table_name(table_name) + ) + self._initialize(rebuild_strategy=rebuild_strategy) + + def __del__(self) -> None: + if not self.persist: + cur = self.connection.cursor() + self._driver_class.drop_table( + self.table_name, self.container_type_name, cur + ) + self.connection.commit() + + def _initialize(self, rebuild_strategy: RebuildStrategy) -> None: + cur = self.connection.cursor() + self._driver_class.initialize_metadata_table(cur) + self._driver_class.initialize_table( + self.table_name, self.container_type_name, self.schema_version, cur + ) + if self._should_rebuild(rebuild_strategy): + self._do_rebuild() + self.connection.commit() + + def _should_rebuild(self, rebuild_strategy: RebuildStrategy) -> bool: + if rebuild_strategy == RebuildStrategy.ALWAYS: + return True + if rebuild_strategy == RebuildStrategy.SKIP: + return False + return self._rebuild_check_with_first_element() + + @abstractmethod + def _rebuild_check_with_first_element(self) -> bool: + ... + + @abstractmethod + def _do_rebuild(self) -> None: + ... + + @property + def persist(self) -> bool: + return self._persist + + def set_persist(self, persist: bool) -> None: + self._persist = persist + + @property + def serializer(self) -> Callable[[T], bytes]: + return self._serializer + + def serialize(self, x: T) -> bytes: + return self.serializer(x) + + @property + def deserializer(self) -> Callable[[bytes], T]: + return self._deserializer + + def deserialize(self, blob: bytes) -> T: + return self.deserializer(blob) + + @property + def table_name(self) -> str: + return self._table_name + + @table_name.setter + def table_name(self, table_name: str) -> None: + cur = self.connection.cursor() + new_table_name = sanitize_table_name(table_name) + try: + self._driver_class.alter_table_name(self.table_name, new_table_name, cur) + except sqlite3.IntegrityError as e: + raise ValueError(table_name) + self._table_name = new_table_name + + @property + def connection(self) -> sqlite3.Connection: + return self._connection + + @property + def container_type_name(self) -> str: + return self.__class__.__name__ + + @property + @abstractmethod + def schema_version(self) -> str: + ... diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/persistence/sqlite/dict.py new file mode 100644 index 00000000000..4c82bfb709f --- /dev/null +++ b/docarray/array/persistence/sqlite/dict.py @@ -0,0 +1,140 @@ +from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING + +from .base import _SqliteCollectionBaseDatabaseDriver + +if TYPE_CHECKING: + import sqlite3 + + +class _DictDatabaseDriver(_SqliteCollectionBaseDatabaseDriver): + @classmethod + def do_create_table( + cls, + table_name: str, + container_type_nam: str, + schema_version: str, + cur: 'sqlite3.Cursor', + ) -> None: + cur.execute( + f'CREATE TABLE {table_name} (' + 'serialized_key BLOB NOT NULL UNIQUE, ' + 'serialized_value BLOB NOT NULL, ' + 'item_order INTEGER PRIMARY KEY)' + ) + + @classmethod + def delete_single_record_by_serialized_key( + cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + ) -> None: + cur.execute( + f'DELETE FROM {table_name} WHERE serialized_key=?', (serialized_key,) + ) + + @classmethod + def delete_all_records(cls, table_name: str, cur: 'sqlite3.Cursor') -> None: + cur.execute(f'DELETE FROM {table_name}') + + @classmethod + def is_serialized_key_in( + cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + ) -> bool: + cur.execute( + f'SELECT 1 FROM {table_name} WHERE serialized_key=?', (serialized_key,) + ) + return len(list(cur)) > 0 + + @classmethod + def get_serialized_value_by_serialized_key( + cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + ) -> Union[None, bytes]: + cur.execute( + f'SELECT serialized_value FROM {table_name} WHERE serialized_key=?', + (serialized_key,), + ) + res = cur.fetchone() + if res is None: + return None + return cast(bytes, res[0]) + + @classmethod + def get_next_order(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: + cur.execute(f'SELECT MAX(item_order) FROM {table_name}') + res = cur.fetchone()[0] + if res is None: + return 0 + return cast(int, res) + 1 + + @classmethod + def get_count(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: + cur.execute(f'SELECT COUNT(*) FROM {table_name}') + res = cur.fetchone() + return cast(int, res[0]) + + @classmethod + def get_serialized_keys( + cls, table_name: str, cur: 'sqlite3.Cursor' + ) -> Iterable[bytes]: + cur.execute(f'SELECT serialized_key FROM {table_name} ORDER BY item_order') + for res in cur: + yield cast(bytes, res[0]) + + @classmethod + def insert_serialized_value_by_serialized_key( + cls, + table_name: str, + cur: 'sqlite3.Cursor', + serialized_key: bytes, + serialized_value: bytes, + ) -> None: + item_order = cls.get_next_order(table_name, cur) + cur.execute( + f'INSERT INTO {table_name} (serialized_key, serialized_value, item_order) VALUES (?, ?, ?)', + (serialized_key, serialized_value, item_order), + ) + + @classmethod + def update_serialized_value_by_serialized_key( + cls, + table_name: str, + cur: 'sqlite3.Cursor', + serialized_key: bytes, + serialized_value: bytes, + ) -> None: + cur.execute( + f'UPDATE {table_name} SET serialized_value=? WHERE serialized_key=?', + (serialized_value, serialized_key), + ) + + @classmethod + def upsert( + cls, + table_name: str, + cur: 'sqlite3.Cursor', + serialized_key: bytes, + serialized_value: bytes, + ) -> None: + if cls.is_serialized_key_in(table_name, cur, serialized_key): + cls.update_serialized_value_by_serialized_key( + table_name, cur, serialized_key, serialized_value + ) + else: + cls.insert_serialized_value_by_serialized_key( + table_name, cur, serialized_key, serialized_value + ) + + @classmethod + def get_last_serialized_item( + cls, table_name: str, cur: 'sqlite3.Cursor' + ) -> Tuple[bytes, bytes]: + cur.execute( + f'SELECT serialized_key, serialized_value FROM {table_name} ORDER BY item_order DESC LIMIT 1' + ) + return cast(Tuple[bytes, bytes], cur.fetchone()) + + @classmethod + def get_reversed_serialized_keys( + cls, table_name: str, cur: 'sqlite3.Cursor' + ) -> Iterable[bytes]: + cur.execute(f'SELECT serialized_key FROM {table_name} ORDER BY item_order DESC') + for res in cur: + yield cast(bytes, res[0]) From 590961f133a99ab34d3fc9d8d63c23d7ac5460e3 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Mon, 17 Jan 2022 15:29:36 +0100 Subject: [PATCH 02/55] chore: bump to 0.2 --- docarray/array/persistence/sqlite/dict.py | 72 +++--- docarray/array/persistence/sqlite/mixin.py | 251 +++++++++++++++++++++ 2 files changed, 287 insertions(+), 36 deletions(-) create mode 100644 docarray/array/persistence/sqlite/mixin.py diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/persistence/sqlite/dict.py index 4c82bfb709f..a0ea8737708 100644 --- a/docarray/array/persistence/sqlite/dict.py +++ b/docarray/array/persistence/sqlite/dict.py @@ -17,17 +17,17 @@ def do_create_table( ) -> None: cur.execute( f'CREATE TABLE {table_name} (' - 'serialized_key BLOB NOT NULL UNIQUE, ' + 'doc_id TEXT NOT NULL UNIQUE, ' 'serialized_value BLOB NOT NULL, ' 'item_order INTEGER PRIMARY KEY)' ) @classmethod - def delete_single_record_by_serialized_key( - cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + def delete_single_record_by_doc_id( + cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> None: cur.execute( - f'DELETE FROM {table_name} WHERE serialized_key=?', (serialized_key,) + f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,) ) @classmethod @@ -35,21 +35,21 @@ def delete_all_records(cls, table_name: str, cur: 'sqlite3.Cursor') -> None: cur.execute(f'DELETE FROM {table_name}') @classmethod - def is_serialized_key_in( - cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + def is_doc_id_in( + cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> bool: cur.execute( - f'SELECT 1 FROM {table_name} WHERE serialized_key=?', (serialized_key,) + f'SELECT 1 FROM {table_name} WHERE doc_id=?', (doc_id,) ) return len(list(cur)) > 0 @classmethod - def get_serialized_value_by_serialized_key( - cls, table_name: str, cur: 'sqlite3.Cursor', serialized_key: bytes + def get_serialized_value_by_doc_id( + cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> Union[None, bytes]: cur.execute( - f'SELECT serialized_value FROM {table_name} WHERE serialized_key=?', - (serialized_key,), + f'SELECT serialized_value FROM {table_name} WHERE doc_id=?', + (doc_id,), ) res = cur.fetchone() if res is None: @@ -71,38 +71,38 @@ def get_count(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: return cast(int, res[0]) @classmethod - def get_serialized_keys( + def get_doc_ids( cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Iterable[bytes]: - cur.execute(f'SELECT serialized_key FROM {table_name} ORDER BY item_order') + ) -> Iterable[str]: + cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order') for res in cur: - yield cast(bytes, res[0]) + yield cast(str, res[0]) @classmethod - def insert_serialized_value_by_serialized_key( + def insert_serialized_value_by_doc_id( cls, table_name: str, cur: 'sqlite3.Cursor', - serialized_key: bytes, + doc_id: str, serialized_value: bytes, ) -> None: item_order = cls.get_next_order(table_name, cur) cur.execute( - f'INSERT INTO {table_name} (serialized_key, serialized_value, item_order) VALUES (?, ?, ?)', - (serialized_key, serialized_value, item_order), + f'INSERT INTO {table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', + (doc_id, serialized_value, item_order), ) @classmethod - def update_serialized_value_by_serialized_key( + def update_serialized_value_by_doc_id( cls, table_name: str, cur: 'sqlite3.Cursor', - serialized_key: bytes, + doc_id: str, serialized_value: bytes, ) -> None: cur.execute( - f'UPDATE {table_name} SET serialized_value=? WHERE serialized_key=?', - (serialized_value, serialized_key), + f'UPDATE {table_name} SET serialized_value=? WHERE doc_id=?', + (serialized_value, doc_id), ) @classmethod @@ -110,31 +110,31 @@ def upsert( cls, table_name: str, cur: 'sqlite3.Cursor', - serialized_key: bytes, + doc_id: str, serialized_value: bytes, ) -> None: - if cls.is_serialized_key_in(table_name, cur, serialized_key): - cls.update_serialized_value_by_serialized_key( - table_name, cur, serialized_key, serialized_value + if cls.is_doc_id_in(table_name, cur, doc_id): + cls.update_serialized_value_by_doc_id( + table_name, cur, doc_id, serialized_value ) else: - cls.insert_serialized_value_by_serialized_key( - table_name, cur, serialized_key, serialized_value + cls.insert_serialized_value_by_doc_id( + table_name, cur, doc_id, serialized_value ) @classmethod def get_last_serialized_item( cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Tuple[bytes, bytes]: + ) -> Tuple[str, bytes]: cur.execute( - f'SELECT serialized_key, serialized_value FROM {table_name} ORDER BY item_order DESC LIMIT 1' + f'SELECT doc_id, serialized_value FROM {table_name} ORDER BY item_order DESC LIMIT 1' ) - return cast(Tuple[bytes, bytes], cur.fetchone()) + return cast(Tuple[str, bytes], cur.fetchone()) @classmethod - def get_reversed_serialized_keys( + def get_reversed_doc_ids( cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Iterable[bytes]: - cur.execute(f'SELECT serialized_key FROM {table_name} ORDER BY item_order DESC') + ) -> Iterable[str]: + cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order DESC') for res in cur: - yield cast(bytes, res[0]) + yield cast(str, res[0]) diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/persistence/sqlite/mixin.py new file mode 100644 index 00000000000..b58d7889d8e --- /dev/null +++ b/docarray/array/persistence/sqlite/mixin.py @@ -0,0 +1,251 @@ +from dataclasses import dataclass +import dataclasses +from typing import Optional, TYPE_CHECKING, Callable, Union, cast + +from .base import SqliteCollectionBase, RebuildStrategy +from .dict import _DictDatabaseDriver + +if TYPE_CHECKING: + import sqlite3 + + from ....types import ( + T, + Document, + DocumentArraySourceType, + DocumentArrayIndexType, + DocumentArraySingletonIndexType, + DocumentArrayMultipleIndexType, + DocumentArrayMultipleAttributeType, + DocumentArraySingleAttributeType, + ) + +@dataclass +class SqliteConfig: + connection: Optional[Union[str, 'sqlite3.Connection']] = None + table_name: Optional[str] = None + serializer: Optional[Callable[['Document'], bytes]] = None + deserializer: Optional[Callable[[bytes], 'Document']] = None + persist: bool = True + rebuild_strategy: RebuildStrategy = RebuildStrategy.CHECK_WITH_FIRST_ELEMENT + + +class SqliteMixin(SqliteCollectionBase): + """Enable SQLite persistence backend for DocumentArray. + + .. note:: + This has to be put in the first position when use it for subclassing + i.e. `class SqliteDA(SqliteMixin, DA)` not the other way around. + + """ + _driver_class = _DictDatabaseDriver + + def __init__(self, docs: Optional['DocumentArraySourceType'] = None, config: Optional[SqliteConfig] = None): + super().__init__(**(dataclasses.asdict(config) if config else {})) + if docs is not None: + self.clear() + self.update(docs) + + def clear(self) -> None: + cur = self.connection.cursor() + self._driver_class.delete_all_records(self.table_name, cur) + self.connection.commit() + + @property + def schema_version(self) -> str: + return "0" + + def _rebuild_check_with_first_element(self) -> bool: + cur = self.connection.cursor() + cur.execute(f"SELECT doc_id FROM {self.table_name} ORDER BY item_order LIMIT 1") + res = cur.fetchone() + return res is None + + def _do_rebuild(self) -> None: + cur = self.connection.cursor() + last_order = -1 + while last_order is not None: + cur.execute( + f"SELECT item_order FROM {self.table_name} WHERE item_order > ? ORDER BY item_order LIMIT 1", + (last_order,), + ) + res = cur.fetchone() + if res is None: + break + i = res[0] + cur.execute( + f"SELECT doc_id, serialized_value FROM {self.table_name} WHERE item_order=?", + (i,), + ) + doc_id, serialized_value = cur.fetchone() + cur.execute( + f"UPDATE {self.table_name} SET doc_id=?, serialized_value=? WHERE item_order=?", + ( + doc_id, + serialized_value, + i, + ), + ) + last_order = i + + def serialize_value(self, value: VT) -> bytes: + return self.value_serializer(value) + + def deserialize_value(self, value: bytes) -> VT: + return self.value_deserializer(value) + + def __delitem__(self, key: KT) -> None: + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + if not self._driver_class.is_doc_id_in(self.table_name, cur, doc_id): + raise KeyError(key) + self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, doc_id) + self.connection.commit() + + def __getitem__(self, key: KT) -> VT: + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + serialized_value = self._driver_class.get_serialized_value_by_doc_id( + self.table_name, cur, doc_id + ) + if serialized_value is None: + raise KeyError(key) + return self.deserialize_value(serialized_value) + + def __iter__(self) -> Iterator[KT]: + cur = self.connection.cursor() + for doc_id in self._driver_class.get_doc_ids(self.table_name, cur): + yield self.deserialize_key(doc_id) + + def __len__(self) -> int: + cur = self.connection.cursor() + return self._driver_class.get_count(self.table_name, cur) + + def __setitem__(self, key: KT, value: VT) -> None: + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + serialized_value = self.serialize_value(value) + self._driver_class.upsert(self.table_name, cur, doc_id, serialized_value) + self.connection.commit() + + def _create_volatile_copy( + self, + data: Optional[Mapping[KT, VT]] = None, + ) -> "Dict[KT, VT]": + + return Dict[KT, VT]( + connection=self.connection, + key_serializer=self.key_serializer, + key_deserializer=self.key_deserializer, + value_serializer=self.value_serializer, + value_deserializer=self.value_deserializer, + rebuild_strategy=RebuildStrategy.SKIP, + persist=False, + data=(self if data is None else data), + ) + + def copy(self) -> "Dict[KT, VT]": + return self._create_volatile_copy() + + @classmethod + def fromkeys(cls, iterable: Iterable[KT], value: Optional[VT]) -> "Dict[KT, VT]": + raise NotImplementedError + + @overload + def pop(self, k: KT) -> VT: + ... + + @overload + def pop(self, k: KT, default: Union[VT, T] = ...) -> Union[VT, T]: + ... + + def pop(self, k: KT, default: Optional[Union[VT, object]] = None) -> Union[VT, object]: + cur = self.connection.cursor() + doc_id = self.serialize_key(k) + serialized_value = self._driver_class.get_serialized_value_by_doc_id( + self.table_name, cur, doc_id + ) + if serialized_value is None: + if default is None: + raise KeyError(k) + return default + self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, doc_id) + self.connection.commit() + return self.deserialize_value(serialized_value) + + def popitem(self) -> Tuple[KT, VT]: + cur = self.connection.cursor() + serialized_item = self._driver_class.get_last_serialized_item(self.table_name, cur) + if serialized_item is None: + raise KeyError("popitem(): dictionary is empty") + self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, serialized_item[0]) + self.connection.commit() + return ( + self.deserialize_key(serialized_item[0]), + self.deserialize_value(serialized_item[1]), + ) + + @overload + def update(self, __other: Mapping[KT, VT], **kwargs: VT) -> None: + ... + + @overload + def update(self, __other: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: + ... + + @overload + def update(self, **kwargs: VT) -> None: + ... + + def update(self, __other: Optional[Union[Iterable[Tuple[KT, VT]], Mapping[KT, VT]]] = None, **kwargs: VT) -> None: + cur = self.connection.cursor() + for k, v in chain( + tuple() if __other is None else __other.items() if isinstance(__other, Mapping) else __other, + cast(Mapping[KT, VT], kwargs).items(), + ): + self._driver_class.upsert(self.table_name, cur, self.serialize_key(k), self.serialize_value(v)) + self.connection.commit() + + def clear(self) -> None: + cur = self.connection.cursor() + self._driver_class.delete_all_records(self.table_name, cur) + self.connection.commit() + + def __contains__(self, o: object) -> bool: + return self._driver_class.is_doc_id_in( + self.table_name, self.connection.cursor(), self.serialize_key(cast(KT, o)) + ) + + @overload + def get(self, key: KT) -> Union[VT, None]: + ... + + @overload + def get(self, key: KT, default_value: Union[VT, T]) -> Union[VT, T]: + ... + + def get(self, key: KT, default_value: Optional[Union[VT, object]] = None) -> Union[VT, None, object]: + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + serialized_value = self._driver_class.get_serialized_value_by_doc_id( + self.table_name, cur, doc_id + ) + if serialized_value is None: + return default_value + return self.deserialize_value(serialized_value) + + def setdefault(self, key: KT, default: VT = None) -> VT: # type: ignore + doc_id = self.serialize_key(key) + cur = self.connection.cursor() + serialized_value = self._driver_class.get_serialized_value_by_doc_id( + self.table_name, cur, doc_id + ) + if serialized_value is None: + self._driver_class.insert_serialized_value_by_doc_id( + self.table_name, cur, doc_id, self.serialize_value(default) + ) + return default + return self.deserialize_value(serialized_value) + + @property + def schema_version(self) -> str: + return '0' \ No newline at end of file From 7369901a9e7be7fca2c0a7ad87b621fc97a6e84a Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 18 Jan 2022 07:54:53 +0100 Subject: [PATCH 03/55] chore: fix typo --- docarray/array/persistence/sqlite/base.py | 68 ++---- docarray/array/persistence/sqlite/dict.py | 42 ++-- docarray/array/persistence/sqlite/mixin.py | 241 ++++----------------- 3 files changed, 87 insertions(+), 264 deletions(-) diff --git a/docarray/array/persistence/sqlite/base.py b/docarray/array/persistence/sqlite/base.py index 9c85b461a32..05c1d52b5ad 100644 --- a/docarray/array/persistence/sqlite/base.py +++ b/docarray/array/persistence/sqlite/base.py @@ -1,11 +1,8 @@ import sqlite3 import warnings from abc import ABCMeta, abstractmethod -from collections.abc import Hashable -from enum import Enum -from pickle import dumps, loads from tempfile import NamedTemporaryFile -from typing import Callable, Generic, Optional, TypeVar, Union, cast +from typing import Callable, Generic, Optional, TypeVar, Union from uuid import uuid4 T = TypeVar('T') @@ -15,12 +12,6 @@ _S = TypeVar('_S') -class RebuildStrategy(Enum): - CHECK_WITH_FIRST_ELEMENT = 1 - ALWAYS = 2 - SKIP = 3 - - def sanitize_table_name(table_name: str) -> str: ret = ''.join(c for c in table_name if c.isalnum() or c == '_') if ret != table_name: @@ -32,10 +23,6 @@ def create_random_name(suffix: str) -> str: return f"{suffix}_{str(uuid4()).replace('-', '')}" -def is_hashable(x: object) -> bool: - return isinstance(x, Hashable) - - class _SqliteCollectionBaseDatabaseDriver(metaclass=ABCMeta): @classmethod def initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: @@ -162,15 +149,12 @@ def __init__( serializer: Optional[Callable[[T], bytes]] = None, deserializer: Optional[Callable[[bytes], T]] = None, persist: bool = True, - rebuild_strategy: RebuildStrategy = RebuildStrategy.CHECK_WITH_FIRST_ELEMENT, ): super(SqliteCollectionBase, self).__init__() - self._serializer = ( - cast(Callable[[T], bytes], dumps) if serializer is None else serializer - ) - self._deserializer = ( - cast(Callable[[bytes], T], loads) if deserializer is None else deserializer - ) + from ....document import Document + + self._serializer = serializer or (lambda d: d.to_bytes()) + self._deserializer = deserializer or (lambda x: Document.from_bytes(x)) self._persist = persist if connection is None: self._connection = sqlite3.connect(NamedTemporaryFile().name) @@ -187,7 +171,12 @@ def __init__( if table_name is None else sanitize_table_name(table_name) ) - self._initialize(rebuild_strategy=rebuild_strategy) + cur = self.connection.cursor() + self._driver_class.initialize_metadata_table(cur) + self._driver_class.initialize_table( + self.table_name, self.container_type_name, self.schema_version, cur + ) + self.connection.commit() def __del__(self) -> None: if not self.persist: @@ -197,31 +186,6 @@ def __del__(self) -> None: ) self.connection.commit() - def _initialize(self, rebuild_strategy: RebuildStrategy) -> None: - cur = self.connection.cursor() - self._driver_class.initialize_metadata_table(cur) - self._driver_class.initialize_table( - self.table_name, self.container_type_name, self.schema_version, cur - ) - if self._should_rebuild(rebuild_strategy): - self._do_rebuild() - self.connection.commit() - - def _should_rebuild(self, rebuild_strategy: RebuildStrategy) -> bool: - if rebuild_strategy == RebuildStrategy.ALWAYS: - return True - if rebuild_strategy == RebuildStrategy.SKIP: - return False - return self._rebuild_check_with_first_element() - - @abstractmethod - def _rebuild_check_with_first_element(self) -> bool: - ... - - @abstractmethod - def _do_rebuild(self) -> None: - ... - @property def persist(self) -> bool: return self._persist @@ -253,14 +217,18 @@ def table_name(self, table_name: str) -> None: new_table_name = sanitize_table_name(table_name) try: self._driver_class.alter_table_name(self.table_name, new_table_name, cur) - except sqlite3.IntegrityError as e: - raise ValueError(table_name) - self._table_name = new_table_name + self._table_name = new_table_name + except sqlite3.IntegrityError as ex: + raise ValueError(table_name) from ex @property def connection(self) -> sqlite3.Connection: return self._connection + @property + def cursor(self) -> sqlite3.Cursor: + return self.connection.cursor() + @property def container_type_name(self) -> str: return self.__class__.__name__ diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/persistence/sqlite/dict.py index a0ea8737708..365df0d00d0 100644 --- a/docarray/array/persistence/sqlite/dict.py +++ b/docarray/array/persistence/sqlite/dict.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING +from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING, Optional from .base import _SqliteCollectionBaseDatabaseDriver @@ -22,25 +22,39 @@ def do_create_table( 'item_order INTEGER PRIMARY KEY)' ) + @classmethod + def get_max_index_plus_one(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: + cur.execute(f'SELECT MAX(item_order) FROM {table_name}') + res = cur.fetchone() + if res[0] is None: + return 0 + return cast(int, res[0]) + 1 + + @classmethod + def increment_indices( + cls, table_name: str, cur: 'sqlite3.Cursor', start: int + ) -> None: + idx = cls.get_max_index_plus_one(table_name, cur) - 1 + while idx >= start: + cur.execute( + f"UPDATE {table_name} SET item_index = ? WHERE item_index = ?", + (idx + 1, idx), + ) + idx -= 1 + @classmethod def delete_single_record_by_doc_id( cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> None: - cur.execute( - f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,) - ) + cur.execute(f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,)) @classmethod def delete_all_records(cls, table_name: str, cur: 'sqlite3.Cursor') -> None: cur.execute(f'DELETE FROM {table_name}') @classmethod - def is_doc_id_in( - cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str - ) -> bool: - cur.execute( - f'SELECT 1 FROM {table_name} WHERE doc_id=?', (doc_id,) - ) + def is_doc_id_in(cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str) -> bool: + cur.execute(f'SELECT 1 FROM {table_name} WHERE doc_id=?', (doc_id,)) return len(list(cur)) > 0 @classmethod @@ -71,9 +85,7 @@ def get_count(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: return cast(int, res[0]) @classmethod - def get_doc_ids( - cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Iterable[str]: + def get_doc_ids(cls, table_name: str, cur: 'sqlite3.Cursor') -> Iterable[str]: cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order') for res in cur: yield cast(str, res[0]) @@ -85,8 +97,10 @@ def insert_serialized_value_by_doc_id( cur: 'sqlite3.Cursor', doc_id: str, serialized_value: bytes, + item_order: Optional[int], ) -> None: - item_order = cls.get_next_order(table_name, cur) + if item_order is None: + item_order = cls.get_next_order(table_name, cur) cur.execute( f'INSERT INTO {table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', (doc_id, serialized_value, item_order), diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/persistence/sqlite/mixin.py index b58d7889d8e..e0f98ab0b5e 100644 --- a/docarray/array/persistence/sqlite/mixin.py +++ b/docarray/array/persistence/sqlite/mixin.py @@ -1,16 +1,16 @@ from dataclasses import dataclass import dataclasses -from typing import Optional, TYPE_CHECKING, Callable, Union, cast +from typing import Optional, TYPE_CHECKING, Callable, Union, cast, Iterable -from .base import SqliteCollectionBase, RebuildStrategy +from .base import SqliteCollectionBase from .dict import _DictDatabaseDriver if TYPE_CHECKING: import sqlite3 from ....types import ( - T, - Document, + T, + Document, DocumentArraySourceType, DocumentArrayIndexType, DocumentArraySingletonIndexType, @@ -19,6 +19,7 @@ DocumentArraySingleAttributeType, ) + @dataclass class SqliteConfig: connection: Optional[Union[str, 'sqlite3.Connection']] = None @@ -26,7 +27,6 @@ class SqliteConfig: serializer: Optional[Callable[['Document'], bytes]] = None deserializer: Optional[Callable[[bytes], 'Document']] = None persist: bool = True - rebuild_strategy: RebuildStrategy = RebuildStrategy.CHECK_WITH_FIRST_ELEMENT class SqliteMixin(SqliteCollectionBase): @@ -37,215 +37,56 @@ class SqliteMixin(SqliteCollectionBase): i.e. `class SqliteDA(SqliteMixin, DA)` not the other way around. """ + _driver_class = _DictDatabaseDriver - def __init__(self, docs: Optional['DocumentArraySourceType'] = None, config: Optional[SqliteConfig] = None): + def __init__( + self, + docs: Optional['DocumentArraySourceType'] = None, + config: Optional[SqliteConfig] = None, + ): super().__init__(**(dataclasses.asdict(config) if config else {})) if docs is not None: self.clear() - self.update(docs) - - def clear(self) -> None: - cur = self.connection.cursor() - self._driver_class.delete_all_records(self.table_name, cur) - self.connection.commit() - - @property - def schema_version(self) -> str: - return "0" - - def _rebuild_check_with_first_element(self) -> bool: - cur = self.connection.cursor() - cur.execute(f"SELECT doc_id FROM {self.table_name} ORDER BY item_order LIMIT 1") - res = cur.fetchone() - return res is None - - def _do_rebuild(self) -> None: - cur = self.connection.cursor() - last_order = -1 - while last_order is not None: - cur.execute( - f"SELECT item_order FROM {self.table_name} WHERE item_order > ? ORDER BY item_order LIMIT 1", - (last_order,), - ) - res = cur.fetchone() - if res is None: - break - i = res[0] - cur.execute( - f"SELECT doc_id, serialized_value FROM {self.table_name} WHERE item_order=?", - (i,), - ) - doc_id, serialized_value = cur.fetchone() - cur.execute( - f"UPDATE {self.table_name} SET doc_id=?, serialized_value=? WHERE item_order=?", - ( - doc_id, - serialized_value, - i, - ), - ) - last_order = i - - def serialize_value(self, value: VT) -> bytes: - return self.value_serializer(value) - - def deserialize_value(self, value: bytes) -> VT: - return self.value_deserializer(value) - - def __delitem__(self, key: KT) -> None: - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - if not self._driver_class.is_doc_id_in(self.table_name, cur, doc_id): - raise KeyError(key) - self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, doc_id) - self.connection.commit() - - def __getitem__(self, key: KT) -> VT: - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - serialized_value = self._driver_class.get_serialized_value_by_doc_id( - self.table_name, cur, doc_id + self.extend(docs) + + def insert(self, index: int, value: 'Document'): + """Insert `doc` at `index`. + + :param index: Position of the insertion. + :param value: The doc needs to be inserted. + """ + length = self._driver_class.get_max_index_plus_one(self.table_name, self.cursor) + if index < 0: + index = length + index + index = max(0, min(length, index)) + self._driver_class.increment_indices(self.table_name, self.cursor, index) + self._driver_class.insert_serialized_value_by_doc_id( + self.table_name, self.cursor, value.id, self.serialize(value), index ) - if serialized_value is None: - raise KeyError(key) - return self.deserialize_value(serialized_value) - - def __iter__(self) -> Iterator[KT]: - cur = self.connection.cursor() - for doc_id in self._driver_class.get_doc_ids(self.table_name, cur): - yield self.deserialize_key(doc_id) - - def __len__(self) -> int: - cur = self.connection.cursor() - return self._driver_class.get_count(self.table_name, cur) - - def __setitem__(self, key: KT, value: VT) -> None: - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - serialized_value = self.serialize_value(value) - self._driver_class.upsert(self.table_name, cur, doc_id, serialized_value) self.connection.commit() - def _create_volatile_copy( - self, - data: Optional[Mapping[KT, VT]] = None, - ) -> "Dict[KT, VT]": - - return Dict[KT, VT]( - connection=self.connection, - key_serializer=self.key_serializer, - key_deserializer=self.key_deserializer, - value_serializer=self.value_serializer, - value_deserializer=self.value_deserializer, - rebuild_strategy=RebuildStrategy.SKIP, - persist=False, - data=(self if data is None else data), - ) - - def copy(self) -> "Dict[KT, VT]": - return self._create_volatile_copy() - - @classmethod - def fromkeys(cls, iterable: Iterable[KT], value: Optional[VT]) -> "Dict[KT, VT]": - raise NotImplementedError - - @overload - def pop(self, k: KT) -> VT: - ... - - @overload - def pop(self, k: KT, default: Union[VT, T] = ...) -> Union[VT, T]: - ... - - def pop(self, k: KT, default: Optional[Union[VT, object]] = None) -> Union[VT, object]: - cur = self.connection.cursor() - doc_id = self.serialize_key(k) - serialized_value = self._driver_class.get_serialized_value_by_doc_id( - self.table_name, cur, doc_id - ) - if serialized_value is None: - if default is None: - raise KeyError(k) - return default - self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, doc_id) - self.connection.commit() - return self.deserialize_value(serialized_value) - - def popitem(self) -> Tuple[KT, VT]: - cur = self.connection.cursor() - serialized_item = self._driver_class.get_last_serialized_item(self.table_name, cur) - if serialized_item is None: - raise KeyError("popitem(): dictionary is empty") - self._driver_class.delete_single_record_by_doc_id(self.table_name, cur, serialized_item[0]) - self.connection.commit() - return ( - self.deserialize_key(serialized_item[0]), - self.deserialize_value(serialized_item[1]), - ) + def extend(self, values: Iterable['Document']) -> None: - @overload - def update(self, __other: Mapping[KT, VT], **kwargs: VT) -> None: - ... - - @overload - def update(self, __other: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: - ... - - @overload - def update(self, **kwargs: VT) -> None: - ... - - def update(self, __other: Optional[Union[Iterable[Tuple[KT, VT]], Mapping[KT, VT]]] = None, **kwargs: VT) -> None: - cur = self.connection.cursor() - for k, v in chain( - tuple() if __other is None else __other.items() if isinstance(__other, Mapping) else __other, - cast(Mapping[KT, VT], kwargs).items(), - ): - self._driver_class.upsert(self.table_name, cur, self.serialize_key(k), self.serialize_value(v)) + idx = self._driver_class.get_max_index_plus_one(self.table_name, self.cursor) + for v in values: + self._driver_class.insert_serialized_value_by_doc_id( + self.table_name, + cur=self.cursor, + doc_id=v.id, + serialized_value=self.serialize(v), + item_order=idx, + ) + idx += 1 self.connection.commit() def clear(self) -> None: - cur = self.connection.cursor() - self._driver_class.delete_all_records(self.table_name, cur) + self._driver_class.delete_all_records(self.table_name, self.cursor) self.connection.commit() - def __contains__(self, o: object) -> bool: - return self._driver_class.is_doc_id_in( - self.table_name, self.connection.cursor(), self.serialize_key(cast(KT, o)) - ) - - @overload - def get(self, key: KT) -> Union[VT, None]: - ... - - @overload - def get(self, key: KT, default_value: Union[VT, T]) -> Union[VT, T]: - ... - - def get(self, key: KT, default_value: Optional[Union[VT, object]] = None) -> Union[VT, None, object]: - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - serialized_value = self._driver_class.get_serialized_value_by_doc_id( - self.table_name, cur, doc_id - ) - if serialized_value is None: - return default_value - return self.deserialize_value(serialized_value) - - def setdefault(self, key: KT, default: VT = None) -> VT: # type: ignore - doc_id = self.serialize_key(key) - cur = self.connection.cursor() - serialized_value = self._driver_class.get_serialized_value_by_doc_id( - self.table_name, cur, doc_id - ) - if serialized_value is None: - self._driver_class.insert_serialized_value_by_doc_id( - self.table_name, cur, doc_id, self.serialize_value(default) - ) - return default - return self.deserialize_value(serialized_value) + def __len__(self) -> int: + return self._driver_class.get_count(self.table_name, self.cursor) @property def schema_version(self) -> str: - return '0' \ No newline at end of file + return '0' From 8cb7baa4870e8d182cd1a5d616b1b7cd8c424b8a Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 18 Jan 2022 14:29:00 +0100 Subject: [PATCH 04/55] chore: fix typo --- docarray/array/document.py | 6 +- docarray/array/persistence/sqlite/base.py | 52 +++---- docarray/array/persistence/sqlite/dict.py | 70 +--------- docarray/array/persistence/sqlite/mixin.py | 154 ++++++++++++++++++--- 4 files changed, 160 insertions(+), 122 deletions(-) diff --git a/docarray/array/document.py b/docarray/array/document.py index 951a86c9c37..1caf5b76c9c 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -148,7 +148,7 @@ def __getitem__( ): if isinstance(index[0], str) and isinstance(index[1], str): # ambiguity only comes from the second string - if index[1] in self._id2offset: + if index[1] in self: return DocumentArray([self[index[0]], self[index[1]]]) else: return getattr(self[index[0]], index[1]) @@ -161,9 +161,9 @@ def __getitem__( elif isinstance(index[0], bool): return DocumentArray(itertools.compress(self._data, index)) elif isinstance(index[0], int): - return DocumentArray(self._data[t] for t in index) + return DocumentArray(self[t] for t in index) elif isinstance(index[0], str): - return DocumentArray(self._data[self._id2offset[t]] for t in index) + return DocumentArray(self[t] for t in index) elif isinstance(index, np.ndarray): index = index.squeeze() if index.ndim == 1: diff --git a/docarray/array/persistence/sqlite/base.py b/docarray/array/persistence/sqlite/base.py index 05c1d52b5ad..80471318de0 100644 --- a/docarray/array/persistence/sqlite/base.py +++ b/docarray/array/persistence/sqlite/base.py @@ -2,7 +2,7 @@ import warnings from abc import ABCMeta, abstractmethod from tempfile import NamedTemporaryFile -from typing import Callable, Generic, Optional, TypeVar, Union +from typing import Callable, Generic, Optional, TypeVar, Union, Dict from uuid import uuid4 T = TypeVar('T') @@ -146,26 +146,35 @@ def __init__( self, connection: Optional[Union[str, sqlite3.Connection]] = None, table_name: Optional[str] = None, - serializer: Optional[Callable[[T], bytes]] = None, - deserializer: Optional[Callable[[bytes], T]] = None, - persist: bool = True, + serialize_config: Optional[Dict] = None, ): super(SqliteCollectionBase, self).__init__() - from ....document import Document + self._serialize_config = serialize_config or {} + self._persist = not table_name - self._serializer = serializer or (lambda d: d.to_bytes()) - self._deserializer = deserializer or (lambda x: Document.from_bytes(x)) - self._persist = persist + from docarray import Document + + sqlite3.register_adapter( + Document, lambda d: d.to_bytes(**self._serialize_config) + ) + sqlite3.register_converter( + 'Document', lambda x: Document.from_bytes(x, **self._serialize_config) + ) + + _conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES) if connection is None: - self._connection = sqlite3.connect(NamedTemporaryFile().name) + self._connection = sqlite3.connect( + NamedTemporaryFile().name, **_conn_kwargs + ) elif isinstance(connection, str): - self._connection = sqlite3.connect(connection) + self._connection = sqlite3.connect(connection, **_conn_kwargs) elif isinstance(connection, sqlite3.Connection): self._connection = connection else: raise TypeError( f'connection argument must be None or a string or a sqlite3.Connection, not `{type(connection)}`' ) + self._table_name = ( sanitize_table_name(create_random_name(self.container_type_name)) if table_name is None @@ -179,34 +188,13 @@ def __init__( self.connection.commit() def __del__(self) -> None: - if not self.persist: + if not self._persist: cur = self.connection.cursor() self._driver_class.drop_table( self.table_name, self.container_type_name, cur ) self.connection.commit() - @property - def persist(self) -> bool: - return self._persist - - def set_persist(self, persist: bool) -> None: - self._persist = persist - - @property - def serializer(self) -> Callable[[T], bytes]: - return self._serializer - - def serialize(self, x: T) -> bytes: - return self.serializer(x) - - @property - def deserializer(self) -> Callable[[bytes], T]: - return self._deserializer - - def deserialize(self, blob: bytes) -> T: - return self.deserializer(blob) - @property def table_name(self) -> str: return self._table_name diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/persistence/sqlite/dict.py index 365df0d00d0..b63fd778b60 100644 --- a/docarray/array/persistence/sqlite/dict.py +++ b/docarray/array/persistence/sqlite/dict.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING, Optional +from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING from .base import _SqliteCollectionBaseDatabaseDriver @@ -18,58 +18,16 @@ def do_create_table( cur.execute( f'CREATE TABLE {table_name} (' 'doc_id TEXT NOT NULL UNIQUE, ' - 'serialized_value BLOB NOT NULL, ' + 'serialized_value Document NOT NULL, ' 'item_order INTEGER PRIMARY KEY)' ) - @classmethod - def get_max_index_plus_one(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: - cur.execute(f'SELECT MAX(item_order) FROM {table_name}') - res = cur.fetchone() - if res[0] is None: - return 0 - return cast(int, res[0]) + 1 - - @classmethod - def increment_indices( - cls, table_name: str, cur: 'sqlite3.Cursor', start: int - ) -> None: - idx = cls.get_max_index_plus_one(table_name, cur) - 1 - while idx >= start: - cur.execute( - f"UPDATE {table_name} SET item_index = ? WHERE item_index = ?", - (idx + 1, idx), - ) - idx -= 1 - @classmethod def delete_single_record_by_doc_id( cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str ) -> None: cur.execute(f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,)) - @classmethod - def delete_all_records(cls, table_name: str, cur: 'sqlite3.Cursor') -> None: - cur.execute(f'DELETE FROM {table_name}') - - @classmethod - def is_doc_id_in(cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str) -> bool: - cur.execute(f'SELECT 1 FROM {table_name} WHERE doc_id=?', (doc_id,)) - return len(list(cur)) > 0 - - @classmethod - def get_serialized_value_by_doc_id( - cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str - ) -> Union[None, bytes]: - cur.execute( - f'SELECT serialized_value FROM {table_name} WHERE doc_id=?', - (doc_id,), - ) - res = cur.fetchone() - if res is None: - return None - return cast(bytes, res[0]) - @classmethod def get_next_order(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: cur.execute(f'SELECT MAX(item_order) FROM {table_name}') @@ -78,33 +36,11 @@ def get_next_order(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: return 0 return cast(int, res) + 1 - @classmethod - def get_count(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: - cur.execute(f'SELECT COUNT(*) FROM {table_name}') - res = cur.fetchone() - return cast(int, res[0]) - @classmethod def get_doc_ids(cls, table_name: str, cur: 'sqlite3.Cursor') -> Iterable[str]: cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order') for res in cur: - yield cast(str, res[0]) - - @classmethod - def insert_serialized_value_by_doc_id( - cls, - table_name: str, - cur: 'sqlite3.Cursor', - doc_id: str, - serialized_value: bytes, - item_order: Optional[int], - ) -> None: - if item_order is None: - item_order = cls.get_next_order(table_name, cur) - cur.execute( - f'INSERT INTO {table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', - (doc_id, serialized_value, item_order), - ) + yield res[0] @classmethod def update_serialized_value_by_doc_id( diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/persistence/sqlite/mixin.py index e0f98ab0b5e..43156d184a7 100644 --- a/docarray/array/persistence/sqlite/mixin.py +++ b/docarray/array/persistence/sqlite/mixin.py @@ -1,9 +1,22 @@ +import itertools from dataclasses import dataclass import dataclasses -from typing import Optional, TYPE_CHECKING, Callable, Union, cast, Iterable +from typing import ( + Optional, + TYPE_CHECKING, + Callable, + Union, + cast, + Iterable, + Dict, + Iterator, + Sequence, +) from .base import SqliteCollectionBase from .dict import _DictDatabaseDriver +from ....helper import typename +import numpy as np if TYPE_CHECKING: import sqlite3 @@ -19,14 +32,14 @@ DocumentArraySingleAttributeType, ) + from docarray import DocumentArray + @dataclass class SqliteConfig: connection: Optional[Union[str, 'sqlite3.Connection']] = None table_name: Optional[str] = None - serializer: Optional[Callable[['Document'], bytes]] = None - deserializer: Optional[Callable[[bytes], 'Document']] = None - persist: bool = True + serialize_config: Optional[Dict] = None class SqliteMixin(SqliteCollectionBase): @@ -56,37 +69,138 @@ def insert(self, index: int, value: 'Document'): :param index: Position of the insertion. :param value: The doc needs to be inserted. """ - length = self._driver_class.get_max_index_plus_one(self.table_name, self.cursor) + length = len(self) if index < 0: index = length + index index = max(0, min(length, index)) - self._driver_class.increment_indices(self.table_name, self.cursor, index) - self._driver_class.insert_serialized_value_by_doc_id( - self.table_name, self.cursor, value.id, self.serialize(value), index - ) + self._shift_index_right_backward(index) + self._insert_doc_at_idx(doc=value, idx=index) self.connection.commit() - def extend(self, values: Iterable['Document']) -> None: + def append(self, value: 'Document') -> None: + self._insert_doc_at_idx(value) + self.connection.commit() - idx = self._driver_class.get_max_index_plus_one(self.table_name, self.cursor) + def extend(self, values: Iterable['Document']) -> None: + idx = len(self) for v in values: - self._driver_class.insert_serialized_value_by_doc_id( - self.table_name, - cur=self.cursor, - doc_id=v.id, - serialized_value=self.serialize(v), - item_order=idx, - ) + self._insert_doc_at_idx(v, idx) idx += 1 self.connection.commit() def clear(self) -> None: - self._driver_class.delete_all_records(self.table_name, self.cursor) + self._sql(f'DELETE FROM {self.table_name}') self.connection.commit() + def __contains__(self, item: Union[str, 'Document']): + if isinstance(item, str): + r = self._sql(f"SELECT 1 FROM {self.table_name} WHERE doc_id=?", (item,)) + return len(list(r)) > 0 + elif isinstance(item, Document): + return item.id in self # fall back to str check + else: + return False + def __len__(self) -> int: - return self._driver_class.get_count(self.table_name, self.cursor) + r = self._sql(f'SELECT COUNT(*) FROM {self.table_name}') + return r.fetchone()[0] + + def __iter__(self) -> Iterator['Document']: + r = self._sql( + f'SELECT serialized_value FROM {self.table_name} ORDER BY item_order' + ) + for res in r: + yield res[0] + + def __getitem__( + self, index: 'DocumentArrayIndexType' + ) -> Union['Document', 'DocumentArray']: + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): + return self._get_doc_by_offset(int(index)) + elif isinstance(index, str): + if index.startswith('@'): + return self.traverse_flat(index[1:]) + else: + return self._get_doc_by_id(index) + elif isinstance(index, slice): + from docarray import DocumentArray + + return DocumentArray(self[j] for j in range(len(self))[index]) + elif index is Ellipsis: + return self.flatten() + elif isinstance(index, Sequence): + from docarray import DocumentArray + + if ( + isinstance(index, tuple) + and len(index) == 2 + and isinstance(index[0], (slice, Sequence)) + ): + if isinstance(index[0], str) and isinstance(index[1], str): + # ambiguity only comes from the second string + if index[1] in self: + return DocumentArray([self[index[0]], self[index[1]]]) + else: + return getattr(self[index[0]], index[1]) + elif isinstance(index[0], (slice, Sequence)): + _docs = self[index[0]] + _attrs = index[1] + if isinstance(_attrs, str): + _attrs = (index[1],) + return _docs._get_attributes(*_attrs) + elif isinstance(index[0], bool): + return DocumentArray(itertools.compress(self, index)) + elif isinstance(index[0], (int, str)): + return DocumentArray(self[t] for t in index) + elif isinstance(index, np.ndarray): + index = index.squeeze() + if index.ndim == 1: + return self[index.tolist()] + else: + raise IndexError( + f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' + ) + raise IndexError(f'Unsupported index type {typename(index)}: {index}') + + def _get_doc_by_offset(self, index: int) -> 'Document': + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE item_order = ?", + (index + (len(self) if index < 0 else 0),), + ) + res = r.fetchone() + if res is None: + raise IndexError('index out of range') + return res[0] + + def _get_doc_by_id(self, id: str) -> 'Document': + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE doc_id = ?", (id,) + ) + res = r.fetchone() + if res is None: + raise KeyError(f'Can not find Document with id=`{id}`') + return res[0] @property def schema_version(self) -> str: return '0' + + def _sql(self, *arg, **kwargs) -> 'sqlite3.Cursor': + return self.connection.cursor().execute(*arg, **kwargs) + + def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): + if idx is None: + idx = len(self) + self._sql( + f'INSERT INTO {self.table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', + (doc.id, doc, idx), + ) + + def _shift_index_right_backward(self, start: int): + idx = len(self) - 1 + while idx >= start: + self._sql( + f"UPDATE {self.table_name} SET item_order = ? WHERE item_order = ?", + (idx + 1, idx), + ) + idx -= 1 From 95b4a37fe204101be32de9545e689030202ed3a1 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 18 Jan 2022 16:26:43 +0100 Subject: [PATCH 05/55] fix(document): serialize tag value in the correct priority --- docarray/array/document.py | 81 +++---------- docarray/array/mixins/__init__.py | 4 +- docarray/array/mixins/getitem.py | 127 +++++++++++++++++++++ docarray/array/persistence/sqlite/mixin.py | 61 ++-------- 4 files changed, 157 insertions(+), 116 deletions(-) create mode 100644 docarray/array/mixins/getitem.py diff --git a/docarray/array/document.py b/docarray/array/document.py index 1caf5b76c9c..4dbe4c3714f 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -32,6 +32,21 @@ class DocumentArray(AllMixins, MutableSequence[Document]): + def _get_doc_by_offset(self, offset: int) -> 'Document': + return self._data[offset] + + def _get_doc_by_id(self, _id: str) -> 'Document': + return self._data[self._id2offset[_id]] + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._data[_slice] + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + return (self._data[t] for t in offsets) + + def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: + return (self._data[self._id2offset[t]] for t in ids) + def __init__( self, docs: Optional['DocumentArraySourceType'] = None, copy: bool = False ): @@ -108,72 +123,6 @@ def __contains__(self, x: Union[str, 'Document']): else: return False - @overload - def __getitem__(self, index: 'DocumentArraySingletonIndexType') -> 'Document': - ... - - @overload - def __getitem__(self, index: 'DocumentArrayMultipleIndexType') -> 'DocumentArray': - ... - - @overload - def __getitem__(self, index: 'DocumentArraySingleAttributeType') -> List[Any]: - ... - - @overload - def __getitem__( - self, index: 'DocumentArrayMultipleAttributeType' - ) -> List[List[Any]]: - ... - - def __getitem__( - self, index: 'DocumentArrayIndexType' - ) -> Union['Document', 'DocumentArray']: - if isinstance(index, (int, np.generic)) and not isinstance(index, bool): - return self._data[int(index)] - elif isinstance(index, str): - if index.startswith('@'): - return self.traverse_flat(index[1:]) - else: - return self._data[self._id2offset[index]] - elif isinstance(index, slice): - return DocumentArray(self._data[index]) - elif index is Ellipsis: - return self.flatten() - elif isinstance(index, Sequence): - if ( - isinstance(index, tuple) - and len(index) == 2 - and isinstance(index[0], (slice, Sequence)) - ): - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self: - return DocumentArray([self[index[0]], self[index[1]]]) - else: - return getattr(self[index[0]], index[1]) - elif isinstance(index[0], (slice, Sequence)): - _docs = self[index[0]] - _attrs = index[1] - if isinstance(_attrs, str): - _attrs = (index[1],) - return _docs._get_attributes(*_attrs) - elif isinstance(index[0], bool): - return DocumentArray(itertools.compress(self._data, index)) - elif isinstance(index[0], int): - return DocumentArray(self[t] for t in index) - elif isinstance(index[0], str): - return DocumentArray(self[t] for t in index) - elif isinstance(index, np.ndarray): - index = index.squeeze() - if index.ndim == 1: - return self[index.tolist()] - else: - raise IndexError( - f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' - ) - raise IndexError(f'Unsupported index type {typename(index)}: {index}') - @overload def __setitem__( self, diff --git a/docarray/array/mixins/__init__.py b/docarray/array/mixins/__init__.py index 61b24a546a0..a50ca69a0e3 100644 --- a/docarray/array/mixins/__init__.py +++ b/docarray/array/mixins/__init__.py @@ -5,6 +5,7 @@ from .empty import EmptyMixin from .evaluation import EvaluationMixin from .getattr import GetAttributeMixin +from .getitem import GetItemMixin from .group import GroupMixin from .io.binary import BinaryIOMixin from .io.common import CommonIOMixin @@ -16,15 +17,16 @@ from .match import MatchMixin from .parallel import ParallelMixin from .plot import PlotMixin +from .pydantic import PydanticMixin from .reduce import ReduceMixin from .sample import SampleMixin from .text import TextToolsMixin from .traverse import TraverseMixin -from .pydantic import PydanticMixin class AllMixins( GetAttributeMixin, + GetItemMixin, ContentPropertyMixin, PydanticMixin, GroupMixin, diff --git a/docarray/array/mixins/getitem.py b/docarray/array/mixins/getitem.py new file mode 100644 index 00000000000..f893d0b996b --- /dev/null +++ b/docarray/array/mixins/getitem.py @@ -0,0 +1,127 @@ +import itertools +from abc import abstractmethod +from typing import ( + TYPE_CHECKING, + Union, + Sequence, + overload, + Any, + List, + Iterable, +) + +import numpy as np + +from ... import Document +from ...helper import typename + +if TYPE_CHECKING: + from ...types import ( + DocumentArrayIndexType, + DocumentArraySingletonIndexType, + DocumentArrayMultipleIndexType, + DocumentArrayMultipleAttributeType, + DocumentArraySingleAttributeType, + ) + from ... import DocumentArray + + +class GetItemMixin: + """Provide helper functions to enable advance indexing in `__getitem__`""" + + @abstractmethod + def _get_doc_by_offset(self, offset: int) -> 'Document': + ... + + @abstractmethod + def _get_doc_by_id(self, _id: str) -> 'Document': + ... + + @abstractmethod + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_offset` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_offset(j) for j in range(len(self))[_slice]) + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_offset` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_offset(d) for d in offsets) + + def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_id` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_id(d) for d in ids) + + @overload + def __getitem__(self, index: 'DocumentArraySingletonIndexType') -> 'Document': + ... + + @overload + def __getitem__(self, index: 'DocumentArrayMultipleIndexType') -> 'DocumentArray': + ... + + @overload + def __getitem__(self, index: 'DocumentArraySingleAttributeType') -> List[Any]: + ... + + @overload + def __getitem__( + self, index: 'DocumentArrayMultipleAttributeType' + ) -> List[List[Any]]: + ... + + def __getitem__( + self, index: 'DocumentArrayIndexType' + ) -> Union['Document', 'DocumentArray']: + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): + return self._get_doc_by_offset(int(index)) + elif isinstance(index, str): + if index.startswith('@'): + return self.traverse_flat(index[1:]) + else: + return self._get_doc_by_id(index) + elif isinstance(index, slice): + from ... import DocumentArray + + return DocumentArray(self._get_docs_by_slice(index)) + elif index is Ellipsis: + return self.flatten() + elif isinstance(index, Sequence): + from ... import DocumentArray + + if ( + isinstance(index, tuple) + and len(index) == 2 + and isinstance(index[0], (slice, Sequence)) + ): + if isinstance(index[0], str) and isinstance(index[1], str): + # ambiguity only comes from the second string + if index[1] in self: + return DocumentArray([self[index[0]], self[index[1]]]) + else: + return getattr(self[index[0]], index[1]) + elif isinstance(index[0], (slice, Sequence)): + _docs = self[index[0]] + _attrs = index[1] + if isinstance(_attrs, str): + _attrs = (index[1],) + return _docs._get_attributes(*_attrs) + elif isinstance(index[0], bool): + return DocumentArray(itertools.compress(self, index)) + elif isinstance(index[0], int): + return DocumentArray(self._get_docs_by_offsets(index)) + elif isinstance(index[0], str): + return DocumentArray(self._get_docs_by_ids(index)) + elif isinstance(index, np.ndarray): + index = index.squeeze() + if index.ndim == 1: + return self[index.tolist()] + else: + raise IndexError( + f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' + ) + raise IndexError(f'Unsupported index type {typename(index)}: {index}') diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/persistence/sqlite/mixin.py index 43156d184a7..53738d96f80 100644 --- a/docarray/array/persistence/sqlite/mixin.py +++ b/docarray/array/persistence/sqlite/mixin.py @@ -112,55 +112,18 @@ def __iter__(self) -> Iterator['Document']: for res in r: yield res[0] - def __getitem__( - self, index: 'DocumentArrayIndexType' - ) -> Union['Document', 'DocumentArray']: - if isinstance(index, (int, np.generic)) and not isinstance(index, bool): - return self._get_doc_by_offset(int(index)) - elif isinstance(index, str): - if index.startswith('@'): - return self.traverse_flat(index[1:]) - else: - return self._get_doc_by_id(index) - elif isinstance(index, slice): - from docarray import DocumentArray - - return DocumentArray(self[j] for j in range(len(self))[index]) - elif index is Ellipsis: - return self.flatten() - elif isinstance(index, Sequence): - from docarray import DocumentArray - - if ( - isinstance(index, tuple) - and len(index) == 2 - and isinstance(index[0], (slice, Sequence)) - ): - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self: - return DocumentArray([self[index[0]], self[index[1]]]) - else: - return getattr(self[index[0]], index[1]) - elif isinstance(index[0], (slice, Sequence)): - _docs = self[index[0]] - _attrs = index[1] - if isinstance(_attrs, str): - _attrs = (index[1],) - return _docs._get_attributes(*_attrs) - elif isinstance(index[0], bool): - return DocumentArray(itertools.compress(self, index)) - elif isinstance(index[0], (int, str)): - return DocumentArray(self[t] for t in index) - elif isinstance(index, np.ndarray): - index = index.squeeze() - if index.ndim == 1: - return self[index.tolist()] - else: - raise IndexError( - f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' - ) - raise IndexError(f'Unsupported index type {typename(index)}: {index}') + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._get_docs_by_offsets(range(len(self))[_slice]) + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + l = len(self) + offsets = [o + (l if o < 0 else 0) for o in offsets] + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE item_order in ({','.join(['?']*len(offsets))})", + offsets, + ) + for rr in r: + yield rr[0] def _get_doc_by_offset(self, index: int) -> 'Document': r = self._sql( From b3b378fb55b267385bb848bba884d549117d7470 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 18 Jan 2022 22:43:31 +0100 Subject: [PATCH 06/55] fix(document): complete the schema for namedscore --- docarray/array/document.py | 257 ++++++------------------------ docarray/array/mixins/__init__.py | 4 + docarray/array/mixins/delitem.py | 92 +++++++++++ docarray/array/mixins/setitem.py | 186 +++++++++++++++++++++ 4 files changed, 332 insertions(+), 207 deletions(-) create mode 100644 docarray/array/mixins/delitem.py create mode 100644 docarray/array/mixins/setitem.py diff --git a/docarray/array/document.py b/docarray/array/document.py index 4dbe4c3714f..3efcfbe9c37 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -9,29 +9,71 @@ MutableSequence, Sequence, Iterable, - overload, Any, - List, ) -import numpy as np - from .mixins import AllMixins from .. import Document -from ..helper import typename if TYPE_CHECKING: from ..types import ( DocumentArraySourceType, - DocumentArrayIndexType, DocumentArraySingletonIndexType, DocumentArrayMultipleIndexType, - DocumentArrayMultipleAttributeType, - DocumentArraySingleAttributeType, ) class DocumentArray(AllMixins, MutableSequence[Document]): + def _del_docs_by_mask(self, mask: Sequence[bool]): + self._data = list(itertools.compress(self._data, (not _i for _i in mask))) + self._rebuild_id2offset() + + def _del_all_docs(self): + self._data.clear() + self._id2offset.clear() + + def _del_docs_by_slice(self, _slice: slice): + del self._data[_slice] + self._rebuild_id2offset() + + def _del_doc_by_id(self, _id: str): + del self._data[self._id2offset[_id]] + self._id2offset.pop(_id) + + def _del_doc_by_offset(self, offset: int): + self._id2offset.pop(self._data[offset].id) + del self._data[offset] + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + self._data[offset] = value + self._id2offset[value.id] = offset + + def _set_doc_by_id(self, _id: str, value: 'Document'): + old_idx = self._id2offset.pop(_id) + self._data[old_idx] = value + self._id2offset[value.id] = old_idx + + def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): + self._data[_slice] = value + self._rebuild_id2offset() + + def _set_doc_value_pairs( + self, docs: Iterable['Document'], values: Iterable['Document'] + ): + for _d, _v in zip(docs, values): + _d._data = _v._data + self._rebuild_id2offset() + + def _set_doc_attr_by_index( + self, + _index: Union[ + 'DocumentArraySingletonIndexType', 'DocumentArrayMultipleIndexType' + ], + attr: str, + value: Any, + ): + setattr(self[_index], attr, value) + def _get_doc_by_offset(self, offset: int) -> 'Document': return self._data[offset] @@ -123,205 +165,6 @@ def __contains__(self, x: Union[str, 'Document']): else: return False - @overload - def __setitem__( - self, - index: 'DocumentArrayMultipleAttributeType', - value: List[List['Any']], - ): - ... - - @overload - def __setitem__( - self, - index: 'DocumentArraySingleAttributeType', - value: List['Any'], - ): - ... - - @overload - def __setitem__( - self, - index: 'DocumentArraySingletonIndexType', - value: 'Document', - ): - ... - - @overload - def __setitem__( - self, - index: 'DocumentArrayMultipleIndexType', - value: Sequence['Document'], - ): - ... - - def __setitem__( - self, - index: 'DocumentArrayIndexType', - value: Union['Document', Sequence['Document']], - ): - - if isinstance(index, (int, np.generic)) and not isinstance(index, bool): - index = int(index) - self._data[index] = value - self._id2offset[value.id] = index - elif isinstance(index, str): - if index.startswith('@'): - for _d, _v in zip(self.traverse_flat(index[1:]), value): - _d._data = _v._data - self._rebuild_id2offset() - else: - old_idx = self._id2offset.pop(index) - self._data[old_idx] = value - self._id2offset[value.id] = old_idx - elif isinstance(index, slice): - self._data[index] = value - self._rebuild_id2offset() - elif index is Ellipsis: - for _d, _v in zip(self.flatten(), value): - _d._data = _v._data - self._rebuild_id2offset() - elif isinstance(index, Sequence): - if ( - isinstance(index, tuple) - and len(index) == 2 - and isinstance(index[0], (slice, Sequence)) - ): - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self._id2offset: - for _d, _v in zip((self[index[0]], self[index[1]]), value): - _d._data = _v._data - self._rebuild_id2offset() - elif hasattr(self[index[0]], index[1]): - setattr(self[index[0]], index[1], value) - else: - # to avoid accidentally add new unsupport attribute - raise ValueError( - f'`{index[1]}` is neither a valid id nor attribute name' - ) - elif isinstance(index[0], (slice, Sequence)): - _docs = self[index[0]] - _attrs = index[1] - - if isinstance(_attrs, str): - # a -> [a] - # [a, a] -> [a, a] - _attrs = (index[1],) - if isinstance(value, (list, tuple)) and not any( - isinstance(el, (tuple, list)) for el in value - ): - # [x] -> [[x]] - # [[x], [y]] -> [[x], [y]] - value = (value,) - if not isinstance(value, (list, tuple)): - # x -> [x] - value = (value,) - - for _a, _v in zip(_attrs, value): - if _a == 'tensor': - _docs.tensors = _v - elif _a == 'embedding': - _docs.embeddings = _v - else: - if len(_docs) == 1: - setattr(_docs[0], _a, _v) - else: - for _d, _vv in zip(_docs, _v): - setattr(_d, _a, _vv) - elif isinstance(index[0], bool): - if len(index) != len(self._data): - raise IndexError( - f'Boolean mask index is required to have the same length as {len(self._data)}, ' - f'but receiving {len(index)}' - ) - _selected = itertools.compress(self._data, index) - for _idx, _val in zip(_selected, value): - self[_idx.id] = _val - elif isinstance(index[0], (int, str)): - if not isinstance(value, Sequence) or len(index) != len(value): - raise ValueError( - f'Number of elements for assigning must be ' - f'the same as the index length: {len(index)}' - ) - if isinstance(value, Document): - for si in index: - self[si] = value - else: - for si, _val in zip(index, value): - self[si] = _val - elif isinstance(index, np.ndarray): - index = index.squeeze() - if index.ndim == 1: - self[index.tolist()] = value - else: - raise IndexError( - f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' - ) - else: - raise IndexError(f'Unsupported index type {typename(index)}: {index}') - - def __delitem__(self, index: 'DocumentArrayIndexType'): - if isinstance(index, (int, np.generic)) and not isinstance(index, bool): - index = int(index) - self._id2offset.pop(self._data[index].id) - del self._data[index] - elif isinstance(index, str): - if index.startswith('@'): - raise NotImplementedError( - 'Delete elements along traversal paths is not implemented' - ) - else: - del self._data[self._id2offset[index]] - self._id2offset.pop(index) - elif isinstance(index, slice): - del self._data[index] - self._rebuild_id2offset() - elif index is Ellipsis: - self._data.clear() - self._id2offset.clear() - elif isinstance(index, Sequence): - if ( - isinstance(index, tuple) - and len(index) == 2 - and isinstance(index[0], (slice, Sequence)) - ): - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self._id2offset: - del self[index[0]] - del self[index[1]] - else: - self[index[0]].pop(index[1]) - elif isinstance(index[0], (slice, Sequence)): - _docs = self[index[0]] - _attrs = index[1] - if isinstance(_attrs, str): - _attrs = (index[1],) - for _d in _docs: - _d.pop(*_attrs) - elif isinstance(index[0], bool): - self._data = list( - itertools.compress(self._data, (not _i for _i in index)) - ) - self._rebuild_id2offset() - elif isinstance(index[0], int): - for t in sorted(index, reverse=True): - del self[t] - elif isinstance(index[0], str): - for t in index: - del self[t] - elif isinstance(index, np.ndarray): - index = index.squeeze() - if index.ndim == 1: - del self[index.tolist()] - else: - raise IndexError( - f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' - ) - else: - raise IndexError(f'Unsupported index type {typename(index)}: {index}') - def clear(self): """Clear the data of :class:`DocumentArray`""" self._data.clear() diff --git a/docarray/array/mixins/__init__.py b/docarray/array/mixins/__init__.py index a50ca69a0e3..639dc59263b 100644 --- a/docarray/array/mixins/__init__.py +++ b/docarray/array/mixins/__init__.py @@ -20,6 +20,8 @@ from .pydantic import PydanticMixin from .reduce import ReduceMixin from .sample import SampleMixin +from .setitem import SetItemMixin +from .delitem import DelItemMixin from .text import TextToolsMixin from .traverse import TraverseMixin @@ -27,6 +29,8 @@ class AllMixins( GetAttributeMixin, GetItemMixin, + SetItemMixin, + DelItemMixin, ContentPropertyMixin, PydanticMixin, GroupMixin, diff --git a/docarray/array/mixins/delitem.py b/docarray/array/mixins/delitem.py new file mode 100644 index 00000000000..ee0cf73b5e0 --- /dev/null +++ b/docarray/array/mixins/delitem.py @@ -0,0 +1,92 @@ +from abc import abstractmethod +from typing import ( + TYPE_CHECKING, + Sequence, +) + +import numpy as np + +from ...helper import typename + +if TYPE_CHECKING: + from ...types import ( + DocumentArrayIndexType, + ) + + +class DelItemMixin: + """Provide help function to enable advanced indexing in `__delitem__`""" + + @abstractmethod + def _del_docs_by_mask(self, mask: Sequence[bool]): + ... + + @abstractmethod + def _del_doc_by_offset(self, offset: int): + ... + + @abstractmethod + def _del_doc_by_id(self, _id: str): + ... + + @abstractmethod + def _del_docs_by_slice(self, _slice: slice): + ... + + @abstractmethod + def _del_all_docs(self): + ... + + def __delitem__(self, index: 'DocumentArrayIndexType'): + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): + self._del_doc_by_offset(int(index)) + + elif isinstance(index, str): + if index.startswith('@'): + raise NotImplementedError( + 'Delete elements along traversal paths is not implemented' + ) + else: + self._del_doc_by_id(index) + elif isinstance(index, slice): + self._del_docs_by_slice(index) + elif index is Ellipsis: + self._del_all_docs() + elif isinstance(index, Sequence): + if ( + isinstance(index, tuple) + and len(index) == 2 + and isinstance(index[0], (slice, Sequence)) + ): + if isinstance(index[0], str) and isinstance(index[1], str): + # ambiguity only comes from the second string + if index[1] in self._id2offset: + del self[index[0]] + del self[index[1]] + else: + self._set_doc_attr_by_index(index[0], index[1], None) + elif isinstance(index[0], (slice, Sequence)): + _attrs = index[1] + if isinstance(_attrs, str): + _attrs = (index[1],) + for _d in self[index[0]]: + for _aa in _attrs: + self._set_doc_attr_by_index(_d.id, _aa, None) + elif isinstance(index[0], bool): + self._del_docs_by_mask(index) + elif isinstance(index[0], int): + for t in sorted(index, reverse=True): + del self[t] + elif isinstance(index[0], str): + for t in index: + del self[t] + elif isinstance(index, np.ndarray): + index = index.squeeze() + if index.ndim == 1: + del self[index.tolist()] + else: + raise IndexError( + f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' + ) + else: + raise IndexError(f'Unsupported index type {typename(index)}: {index}') diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py new file mode 100644 index 00000000000..89359355c49 --- /dev/null +++ b/docarray/array/mixins/setitem.py @@ -0,0 +1,186 @@ +import itertools +from abc import abstractmethod +from typing import ( + TYPE_CHECKING, + Union, + Sequence, + overload, + Any, + List, + Iterable, +) + +import numpy as np + +from ... import Document +from ...helper import typename + +if TYPE_CHECKING: + from ...types import ( + DocumentArrayIndexType, + DocumentArraySingletonIndexType, + DocumentArrayMultipleIndexType, + DocumentArrayMultipleAttributeType, + DocumentArraySingleAttributeType, + ) + + +class SetItemMixin: + """Provides helper function to allow advanced indexing for `__setitem__`""" + + @abstractmethod + def _set_doc_by_offset(self, offset: int, value: 'Document'): + ... + + @abstractmethod + def _set_doc_by_id(self, _id: str, value: 'Document'): + ... + + @abstractmethod + def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): + ... + + @abstractmethod + def _set_doc_value_pairs( + self, docs: Iterable['Document'], values: Iterable['Document'] + ): + ... + + @abstractmethod + def _set_doc_attr_by_index( + self, + _index: Union[ + 'DocumentArraySingletonIndexType', 'DocumentArrayMultipleIndexType' + ], + attr: str, + value: Any, + ): + ... + + @overload + def __setitem__( + self, + index: 'DocumentArrayMultipleAttributeType', + value: List[List['Any']], + ): + ... + + @overload + def __setitem__( + self, + index: 'DocumentArraySingleAttributeType', + value: List['Any'], + ): + ... + + @overload + def __setitem__( + self, + index: 'DocumentArraySingletonIndexType', + value: 'Document', + ): + ... + + @overload + def __setitem__( + self, + index: 'DocumentArrayMultipleIndexType', + value: Sequence['Document'], + ): + ... + + def __setitem__( + self, + index: 'DocumentArrayIndexType', + value: Union['Document', Sequence['Document']], + ): + + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): + self._set_doc_by_offset(int(index), value) + elif isinstance(index, str): + if index.startswith('@'): + self._set_doc_value_pairs(self.traverse_flat(index[1:]), value) + else: + self._set_doc_by_id(index, value) + elif isinstance(index, slice): + self._set_docs_by_slice(index, value) + elif index is Ellipsis: + self._set_doc_value_pairs(self.flatten(), value) + elif isinstance(index, Sequence): + if ( + isinstance(index, tuple) + and len(index) == 2 + and isinstance(index[0], (slice, Sequence)) + ): + if isinstance(index[0], str) and isinstance(index[1], str): + # ambiguity only comes from the second string + if index[1] in self: + self._set_doc_value_pairs( + (self[index[0]], self[index[1]]), value + ) + elif hasattr(self[index[0]], index[1]): + self._set_doc_attr_by_index(index[0], index[1], value) + else: + # to avoid accidentally add new unsupport attribute + raise ValueError( + f'`{index[1]}` is neither a valid id nor attribute name' + ) + elif isinstance(index[0], (slice, Sequence)): + _attrs = index[1] + + if isinstance(_attrs, str): + # a -> [a] + # [a, a] -> [a, a] + _attrs = (index[1],) + if isinstance(value, (list, tuple)) and not any( + isinstance(el, (tuple, list)) for el in value + ): + # [x] -> [[x]] + # [[x], [y]] -> [[x], [y]] + value = (value,) + if not isinstance(value, (list, tuple)): + # x -> [x] + value = (value,) + + _docs = self[index[0]] + for _a, _v in zip(_attrs, value): + if _a == 'tensor': + _docs.tensors = _v + elif _a == 'embedding': + _docs.embeddings = _v + else: + if len(_docs) == 1: + self._set_doc_attr_by_index(_docs[0].id, _a, _v) + else: + for _d, _vv in zip(_docs, _v): + self._set_doc_attr_by_index(_d.id, _a, _vv) + elif isinstance(index[0], bool): + if len(index) != len(self): + raise IndexError( + f'Boolean mask index is required to have the same length as {len(self._data)}, ' + f'but receiving {len(index)}' + ) + _selected = itertools.compress(self, index) + self._set_doc_value_pairs(_selected, value) + elif isinstance(index[0], (int, str)): + if not isinstance(value, Sequence) or len(index) != len(value): + raise ValueError( + f'Number of elements for assigning must be ' + f'the same as the index length: {len(index)}' + ) + if isinstance(value, Document): + for si in index: + self[si] = value # leverage existing setter + else: + for si, _val in zip(index, value): + self[si] = _val # leverage existing setter + elif isinstance(index, np.ndarray): + index = index.squeeze() + if index.ndim == 1: + self[index.tolist()] = value # leverage existing setter + else: + raise IndexError( + f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' + ) + else: + raise IndexError(f'Unsupported index type {typename(index)}: {index}') From c1d00841e10e03c5db8d5894ee9a114c85273723 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 19 Jan 2022 11:44:42 +0100 Subject: [PATCH 07/55] feat(array): add storage backend --- docarray/array/document.py | 194 ++---------------- docarray/array/mixins/__init__.py | 4 +- docarray/array/mixins/delitem.py | 25 +-- docarray/array/mixins/getitem.py | 29 --- docarray/array/mixins/setitem.py | 37 +--- .../{persistence => storage}/__init__.py | 0 .../sqlite => storage/base}/__init__.py | 0 docarray/array/storage/base/getsetdel.py | 133 ++++++++++++ docarray/array/storage/memory/__init__.py | 7 + docarray/array/storage/memory/backend.py | 65 ++++++ docarray/array/storage/memory/getsetdel.py | 74 +++++++ docarray/array/storage/memory/seqlike.py | 62 ++++++ docarray/array/storage/sqlite/__init__.py | 0 .../{persistence => storage}/sqlite/base.py | 0 .../{persistence => storage}/sqlite/dict.py | 0 .../{persistence => storage}/sqlite/mixin.py | 0 16 files changed, 360 insertions(+), 270 deletions(-) rename docarray/array/{persistence => storage}/__init__.py (100%) rename docarray/array/{persistence/sqlite => storage/base}/__init__.py (100%) create mode 100644 docarray/array/storage/base/getsetdel.py create mode 100644 docarray/array/storage/memory/__init__.py create mode 100644 docarray/array/storage/memory/backend.py create mode 100644 docarray/array/storage/memory/getsetdel.py create mode 100644 docarray/array/storage/memory/seqlike.py create mode 100644 docarray/array/storage/sqlite/__init__.py rename docarray/array/{persistence => storage}/sqlite/base.py (100%) rename docarray/array/{persistence => storage}/sqlite/dict.py (100%) rename docarray/array/{persistence => storage}/sqlite/mixin.py (100%) diff --git a/docarray/array/document.py b/docarray/array/document.py index 3efcfbe9c37..ac3c9b25f83 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,191 +1,21 @@ -import itertools -from typing import ( - Optional, - TYPE_CHECKING, - Generator, - Iterator, - Dict, - Union, - MutableSequence, - Sequence, - Iterable, - Any, -) - from .mixins import AllMixins -from .. import Document - -if TYPE_CHECKING: - from ..types import ( - DocumentArraySourceType, - DocumentArraySingletonIndexType, - DocumentArrayMultipleIndexType, - ) - - -class DocumentArray(AllMixins, MutableSequence[Document]): - def _del_docs_by_mask(self, mask: Sequence[bool]): - self._data = list(itertools.compress(self._data, (not _i for _i in mask))) - self._rebuild_id2offset() - - def _del_all_docs(self): - self._data.clear() - self._id2offset.clear() - - def _del_docs_by_slice(self, _slice: slice): - del self._data[_slice] - self._rebuild_id2offset() - - def _del_doc_by_id(self, _id: str): - del self._data[self._id2offset[_id]] - self._id2offset.pop(_id) - - def _del_doc_by_offset(self, offset: int): - self._id2offset.pop(self._data[offset].id) - del self._data[offset] - def _set_doc_by_offset(self, offset: int, value: 'Document'): - self._data[offset] = value - self._id2offset[value.id] = offset - def _set_doc_by_id(self, _id: str, value: 'Document'): - old_idx = self._id2offset.pop(_id) - self._data[old_idx] = value - self._id2offset[value.id] = old_idx +def _extend_instance(obj, cls): + """Apply mixins to a class instance after creation""" + base_cls = obj.__class__ + base_cls_name = obj.__class__.__name__ + obj.__class__ = type(base_cls_name, (base_cls, cls), {}) - def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): - self._data[_slice] = value - self._rebuild_id2offset() - def _set_doc_value_pairs( - self, docs: Iterable['Document'], values: Iterable['Document'] - ): - for _d, _v in zip(docs, values): - _d._data = _v._data - self._rebuild_id2offset() - - def _set_doc_attr_by_index( - self, - _index: Union[ - 'DocumentArraySingletonIndexType', 'DocumentArrayMultipleIndexType' - ], - attr: str, - value: Any, - ): - setattr(self[_index], attr, value) - - def _get_doc_by_offset(self, offset: int) -> 'Document': - return self._data[offset] - - def _get_doc_by_id(self, _id: str) -> 'Document': - return self._data[self._id2offset[_id]] - - def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: - return self._data[_slice] - - def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: - return (self._data[t] for t in offsets) - - def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: - return (self._data[self._id2offset[t]] for t in ids) - - def __init__( - self, docs: Optional['DocumentArraySourceType'] = None, copy: bool = False - ): +class DocumentArray(AllMixins): + def __init__(self, *args, storage: str = 'memory', **kwargs): super().__init__() - self._data = [] - if docs is None: - return - elif isinstance( - docs, (DocumentArray, Sequence, Generator, Iterator, itertools.chain) - ): - if copy: - self._data = [Document(d, copy=True) for d in docs] - self._rebuild_id2offset() - elif isinstance(docs, DocumentArray): - self._data = docs._data - self._id_to_index = docs._id2offset - else: - self._data = list(docs) - self._rebuild_id2offset() - else: - if isinstance(docs, Document): - if copy: - self.append(Document(docs, copy=True)) - else: - self.append(docs) + if storage == 'memory': + from .storage.memory import MemoryStorageMixins - @property - def _id2offset(self) -> Dict[str, int]: - """Return the `_id_to_index` map - - :return: a Python dict. - """ - if not hasattr(self, '_id_to_index'): - self._rebuild_id2offset() - return self._id_to_index - - def _rebuild_id2offset(self) -> None: - """Update the id_to_index map by enumerating all Documents in self._data. - - Very costy! Only use this function when self._data is dramtically changed. - """ - - self._id_to_index = { - d.id: i for i, d in enumerate(self._data) - } # type: Dict[str, int] - - def insert(self, index: int, value: 'Document'): - """Insert `doc` at `index`. - - :param index: Position of the insertion. - :param value: The doc needs to be inserted. - """ - self._data.insert(index, value) - self._id2offset[value.id] = index - - def __eq__(self, other): - return ( - type(self) is type(other) - and type(self._data) is type(other._data) - and self._data == other._data - ) - - def __len__(self): - return len(self._data) - - def __iter__(self) -> Iterator['Document']: - yield from self._data - - def __contains__(self, x: Union[str, 'Document']): - if isinstance(x, str): - return x in self._id2offset - elif isinstance(x, Document): - return x.id in self._id2offset + _extend_instance(self, MemoryStorageMixins) else: - return False - - def clear(self): - """Clear the data of :class:`DocumentArray`""" - self._data.clear() - self._id2offset.clear() - - def __bool__(self): - """To simulate ```l = []; if l: ...``` - - :return: returns true if the length of the array is larger than 0 - """ - return len(self) > 0 - - def __repr__(self): - return f'<{self.__class__.__name__} (length={len(self)}) at {id(self)}>' - - def __add__(self, other: Union['Document', Sequence['Document']]): - v = type(self)() - v.extend(self) - v.extend(other) - return v + raise ValueError(f'storage=`{storage}` is not supported.') - def extend(self, values: Iterable['Document']) -> None: - self._data.extend(values) - self._rebuild_id2offset() + self._init_storage(*args, **kwargs) diff --git a/docarray/array/mixins/__init__.py b/docarray/array/mixins/__init__.py index 639dc59263b..f5dbd8df4d6 100644 --- a/docarray/array/mixins/__init__.py +++ b/docarray/array/mixins/__init__.py @@ -1,6 +1,7 @@ from abc import ABC from .content import ContentPropertyMixin +from .delitem import DelItemMixin from .embed import EmbedMixin from .empty import EmptyMixin from .evaluation import EvaluationMixin @@ -21,7 +22,6 @@ from .reduce import ReduceMixin from .sample import SampleMixin from .setitem import SetItemMixin -from .delitem import DelItemMixin from .text import TextToolsMixin from .traverse import TraverseMixin @@ -53,6 +53,6 @@ class AllMixins( DataframeIOMixin, ABC, ): - """All plugins that can be used in :class:`DocumentArray`. """ + """All plugins that can be used in :class:`DocumentArray`.""" ... diff --git a/docarray/array/mixins/delitem.py b/docarray/array/mixins/delitem.py index ee0cf73b5e0..6b9a15fb10a 100644 --- a/docarray/array/mixins/delitem.py +++ b/docarray/array/mixins/delitem.py @@ -1,4 +1,3 @@ -from abc import abstractmethod from typing import ( TYPE_CHECKING, Sequence, @@ -17,26 +16,6 @@ class DelItemMixin: """Provide help function to enable advanced indexing in `__delitem__`""" - @abstractmethod - def _del_docs_by_mask(self, mask: Sequence[bool]): - ... - - @abstractmethod - def _del_doc_by_offset(self, offset: int): - ... - - @abstractmethod - def _del_doc_by_id(self, _id: str): - ... - - @abstractmethod - def _del_docs_by_slice(self, _slice: slice): - ... - - @abstractmethod - def _del_all_docs(self): - ... - def __delitem__(self, index: 'DocumentArrayIndexType'): if isinstance(index, (int, np.generic)) and not isinstance(index, bool): self._del_doc_by_offset(int(index)) @@ -64,14 +43,14 @@ def __delitem__(self, index: 'DocumentArrayIndexType'): del self[index[0]] del self[index[1]] else: - self._set_doc_attr_by_index(index[0], index[1], None) + self._set_doc_attr_by_id(index[0], index[1], None) elif isinstance(index[0], (slice, Sequence)): _attrs = index[1] if isinstance(_attrs, str): _attrs = (index[1],) for _d in self[index[0]]: for _aa in _attrs: - self._set_doc_attr_by_index(_d.id, _aa, None) + self._set_doc_attr_by_id(_d.id, _aa, None) elif isinstance(index[0], bool): self._del_docs_by_mask(index) elif isinstance(index[0], int): diff --git a/docarray/array/mixins/getitem.py b/docarray/array/mixins/getitem.py index f893d0b996b..308d1731f48 100644 --- a/docarray/array/mixins/getitem.py +++ b/docarray/array/mixins/getitem.py @@ -1,5 +1,4 @@ import itertools -from abc import abstractmethod from typing import ( TYPE_CHECKING, Union, @@ -7,7 +6,6 @@ overload, Any, List, - Iterable, ) import numpy as np @@ -29,33 +27,6 @@ class GetItemMixin: """Provide helper functions to enable advance indexing in `__getitem__`""" - @abstractmethod - def _get_doc_by_offset(self, offset: int) -> 'Document': - ... - - @abstractmethod - def _get_doc_by_id(self, _id: str) -> 'Document': - ... - - @abstractmethod - def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: - """This function is derived from :meth:`_get_doc_by_offset` - - Override this function if there is a more efficient logic""" - return (self._get_doc_by_offset(j) for j in range(len(self))[_slice]) - - def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: - """This function is derived from :meth:`_get_doc_by_offset` - - Override this function if there is a more efficient logic""" - return (self._get_doc_by_offset(d) for d in offsets) - - def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: - """This function is derived from :meth:`_get_doc_by_id` - - Override this function if there is a more efficient logic""" - return (self._get_doc_by_id(d) for d in ids) - @overload def __getitem__(self, index: 'DocumentArraySingletonIndexType') -> 'Document': ... diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 89359355c49..44bb8291131 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -1,5 +1,4 @@ import itertools -from abc import abstractmethod from typing import ( TYPE_CHECKING, Union, @@ -7,7 +6,6 @@ overload, Any, List, - Iterable, ) import numpy as np @@ -28,35 +26,6 @@ class SetItemMixin: """Provides helper function to allow advanced indexing for `__setitem__`""" - @abstractmethod - def _set_doc_by_offset(self, offset: int, value: 'Document'): - ... - - @abstractmethod - def _set_doc_by_id(self, _id: str, value: 'Document'): - ... - - @abstractmethod - def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): - ... - - @abstractmethod - def _set_doc_value_pairs( - self, docs: Iterable['Document'], values: Iterable['Document'] - ): - ... - - @abstractmethod - def _set_doc_attr_by_index( - self, - _index: Union[ - 'DocumentArraySingletonIndexType', 'DocumentArrayMultipleIndexType' - ], - attr: str, - value: Any, - ): - ... - @overload def __setitem__( self, @@ -119,7 +88,7 @@ def __setitem__( (self[index[0]], self[index[1]]), value ) elif hasattr(self[index[0]], index[1]): - self._set_doc_attr_by_index(index[0], index[1], value) + self._set_doc_attr_by_id(index[0], index[1], value) else: # to avoid accidentally add new unsupport attribute raise ValueError( @@ -150,10 +119,10 @@ def __setitem__( _docs.embeddings = _v else: if len(_docs) == 1: - self._set_doc_attr_by_index(_docs[0].id, _a, _v) + self._set_doc_attr_by_id(_docs[0].id, _a, _v) else: for _d, _vv in zip(_docs, _v): - self._set_doc_attr_by_index(_d.id, _a, _vv) + self._set_doc_attr_by_id(_d.id, _a, _vv) elif isinstance(index[0], bool): if len(index) != len(self): raise IndexError( diff --git a/docarray/array/persistence/__init__.py b/docarray/array/storage/__init__.py similarity index 100% rename from docarray/array/persistence/__init__.py rename to docarray/array/storage/__init__.py diff --git a/docarray/array/persistence/sqlite/__init__.py b/docarray/array/storage/base/__init__.py similarity index 100% rename from docarray/array/persistence/sqlite/__init__.py rename to docarray/array/storage/base/__init__.py diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py new file mode 100644 index 00000000000..22bb8b849d9 --- /dev/null +++ b/docarray/array/storage/base/getsetdel.py @@ -0,0 +1,133 @@ +from abc import abstractmethod, ABC +from typing import ( + Sequence, + Any, + Iterable, +) + +from .... import Document + + +class BaseGetSetDelMixin(ABC): + """Provide abstract methods and derived methods for ``__getitem__``, ``__setitem__`` and ``__delitem__`` + + .. note:: + The following methods must be implemented: + - :meth:`._get_doc_by_offset` + - :meth:`._get_doc_by_id` + - :meth:`._set_doc_by_offset` + - :meth:`._set_doc_by_id` + - :meth:`._del_doc_by_offset` + - :meth:`._del_doc_by_id` + + Other methods implemented a generic-but-slow version that leverage the methods above. + Please override those methods in the subclass whenever a more efficient implementation is available. + """ + + # Getitem APIs + + @abstractmethod + def _get_doc_by_offset(self, offset: int) -> 'Document': + ... + + @abstractmethod + def _get_doc_by_id(self, _id: str) -> 'Document': + ... + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_offset` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_offset(j) for j in range(len(self))[_slice]) + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_offset` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_offset(d) for d in offsets) + + def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_id` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_id(d) for d in ids) + + # Delitem APIs + + @abstractmethod + def _del_doc_by_offset(self, offset: int): + ... + + @abstractmethod + def _del_doc_by_id(self, _id: str): + ... + + def _del_docs_by_slice(self, _slice: slice): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic""" + for j in range(len(self))[_slice]: + self._del_doc_by_offset(j) + + def _del_docs_by_mask(self, mask: Sequence[bool]): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic""" + for idx, m in enumerate(mask): + if not m: + self._del_doc_by_offset(idx) + + def _del_all_docs(self): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic""" + for j in range(len(self)): + self._del_doc_by_offset(j) + + # Setitem API + + @abstractmethod + def _set_doc_by_offset(self, offset: int, value: 'Document'): + ... + + @abstractmethod + def _set_doc_by_id(self, _id: str, value: 'Document'): + ... + + def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic + """ + for _offset, val in zip(range(len(self))[_slice], value): + self._set_doc_by_offset(_offset, val) + + def _set_doc_value_pairs( + self, docs: Iterable['Document'], values: Iterable['Document'] + ): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic + """ + for _d, _v in zip(docs, values): + self._set_doc_by_id(_d.id, _v) + + def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic + """ + d = self._get_doc_by_offset(offset) + if hasattr(d, attr): + setattr(d, attr, value) + self._set_doc_by_offset(offset, d) + + def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic + """ + d = self._get_doc_by_id(_id) + if hasattr(d, attr): + setattr(d, attr, value) + self._set_doc_by_id(_id, d) diff --git a/docarray/array/storage/memory/__init__.py b/docarray/array/storage/memory/__init__.py new file mode 100644 index 00000000000..e26a79127d7 --- /dev/null +++ b/docarray/array/storage/memory/__init__.py @@ -0,0 +1,7 @@ +from .backend import MemoryBackendMixin +from .getsetdel import GetSetDelMixin +from .seqlike import SequenceLikeMixin + + +class MemoryStorageMixins(MemoryBackendMixin, GetSetDelMixin, SequenceLikeMixin): + ... diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py new file mode 100644 index 00000000000..45b44afd778 --- /dev/null +++ b/docarray/array/storage/memory/backend.py @@ -0,0 +1,65 @@ +import itertools +from typing import ( + Generator, + Iterator, + Dict, + Sequence, + Optional, + TYPE_CHECKING, +) + +from .... import Document + +if TYPE_CHECKING: + from ....types import ( + DocumentArraySourceType, + ) + + +class MemoryBackendMixin: + @property + def _id2offset(self) -> Dict[str, int]: + """Return the `_id_to_index` map + + :return: a Python dict. + """ + if not hasattr(self, '_id_to_index'): + self._rebuild_id2offset() + return self._id_to_index + + def _rebuild_id2offset(self) -> None: + """Update the id_to_index map by enumerating all Documents in self._data. + + Very costy! Only use this function when self._data is dramtically changed. + """ + + self._id_to_index = { + d.id: i for i, d in enumerate(self._data) + } # type: Dict[str, int] + + def _init_storage( + self, docs: Optional['DocumentArraySourceType'] = None, copy: bool = False + ): + from ... import DocumentArray + + self._data = [] + if docs is None: + return + elif isinstance( + docs, (DocumentArray, Sequence, Generator, Iterator, itertools.chain) + ): + if copy: + self._data = [Document(d, copy=True) for d in docs] + self._rebuild_id2offset() + elif isinstance(docs, DocumentArray): + self._data = docs._data + self._id_to_index = docs._id2offset + else: + self._data = list(docs) + self._rebuild_id2offset() + else: + if isinstance(docs, Document): + if copy: + self.append(Document(docs, copy=True)) + else: + self.append(docs) diff --git a/docarray/array/storage/memory/getsetdel.py b/docarray/array/storage/memory/getsetdel.py new file mode 100644 index 00000000000..5adf2cefb67 --- /dev/null +++ b/docarray/array/storage/memory/getsetdel.py @@ -0,0 +1,74 @@ +import itertools +from typing import ( + Sequence, + Iterable, + Any, +) + +from ..base.getsetdel import BaseGetSetDelMixin +from .... import Document + + +class GetSetDelMixin(BaseGetSetDelMixin): + """Implement `getitem`, `setitem`, `delitem`""" + + def _del_docs_by_mask(self, mask: Sequence[bool]): + self._data = list(itertools.compress(self._data, (not _i for _i in mask))) + self._rebuild_id2offset() + + def _del_all_docs(self): + self._data.clear() + self._id2offset.clear() + + def _del_docs_by_slice(self, _slice: slice): + del self._data[_slice] + self._rebuild_id2offset() + + def _del_doc_by_id(self, _id: str): + del self._data[self._id2offset[_id]] + self._id2offset.pop(_id) + + def _del_doc_by_offset(self, offset: int): + self._id2offset.pop(self._data[offset].id) + del self._data[offset] + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + self._data[offset] = value + self._id2offset[value.id] = offset + + def _set_doc_by_id(self, _id: str, value: 'Document'): + old_idx = self._id2offset.pop(_id) + self._data[old_idx] = value + self._id2offset[value.id] = old_idx + + def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): + self._data[_slice] = value + self._rebuild_id2offset() + + def _set_doc_value_pairs( + self, docs: Iterable['Document'], values: Iterable['Document'] + ): + for _d, _v in zip(docs, values): + _d._data = _v._data + self._rebuild_id2offset() + + def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): + setattr(self._data[offset], attr, value) + + def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): + setattr(self._data[self._id2offset[_id]], attr, value) + + def _get_doc_by_offset(self, offset: int) -> 'Document': + return self._data[offset] + + def _get_doc_by_id(self, _id: str) -> 'Document': + return self._data[self._id2offset[_id]] + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._data[_slice] + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + return (self._data[t] for t in offsets) + + def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: + return (self._data[self._id2offset[t]] for t in ids) diff --git a/docarray/array/storage/memory/seqlike.py b/docarray/array/storage/memory/seqlike.py new file mode 100644 index 00000000000..3429ec7caae --- /dev/null +++ b/docarray/array/storage/memory/seqlike.py @@ -0,0 +1,62 @@ +from typing import Iterator, Union, Sequence, Iterable, MutableSequence + +from .... import Document + + +class SequenceLikeMixin(MutableSequence[Document]): + """Implement sequence-like methods""" + + def insert(self, index: int, value: 'Document'): + """Insert `doc` at `index`. + + :param index: Position of the insertion. + :param value: The doc needs to be inserted. + """ + self._data.insert(index, value) + self._id2offset[value.id] = index + + def __eq__(self, other): + return ( + type(self) is type(other) + and type(self._data) is type(other._data) + and self._data == other._data + ) + + def __len__(self): + return len(self._data) + + def __iter__(self) -> Iterator['Document']: + yield from self._data + + def __contains__(self, x: Union[str, 'Document']): + if isinstance(x, str): + return x in self._id2offset + elif isinstance(x, Document): + return x.id in self._id2offset + else: + return False + + def clear(self): + """Clear the data of :class:`DocumentArray`""" + self._data.clear() + self._id2offset.clear() + + def __bool__(self): + """To simulate ```l = []; if l: ...``` + + :return: returns true if the length of the array is larger than 0 + """ + return len(self) > 0 + + def __repr__(self): + return f'<{self.__class__.__name__} (length={len(self)}) at {id(self)}>' + + def __add__(self, other: Union['Document', Sequence['Document']]): + v = type(self)() + v.extend(self) + v.extend(other) + return v + + def extend(self, values: Iterable['Document']) -> None: + self._data.extend(values) + self._rebuild_id2offset() diff --git a/docarray/array/storage/sqlite/__init__.py b/docarray/array/storage/sqlite/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/persistence/sqlite/base.py b/docarray/array/storage/sqlite/base.py similarity index 100% rename from docarray/array/persistence/sqlite/base.py rename to docarray/array/storage/sqlite/base.py diff --git a/docarray/array/persistence/sqlite/dict.py b/docarray/array/storage/sqlite/dict.py similarity index 100% rename from docarray/array/persistence/sqlite/dict.py rename to docarray/array/storage/sqlite/dict.py diff --git a/docarray/array/persistence/sqlite/mixin.py b/docarray/array/storage/sqlite/mixin.py similarity index 100% rename from docarray/array/persistence/sqlite/mixin.py rename to docarray/array/storage/sqlite/mixin.py From dee84574104b29a36f2074f915ba4bb816199fee Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 19 Jan 2022 12:18:39 +0100 Subject: [PATCH 08/55] feat(array): add storage backend --- docarray/array/document.py | 4 + docarray/array/storage/base/backend.py | 7 + docarray/array/storage/base/getsetdel.py | 6 +- docarray/array/storage/memory/__init__.py | 2 + docarray/array/storage/memory/backend.py | 5 +- docarray/array/storage/memory/getsetdel.py | 2 +- docarray/array/storage/sqlite/__init__.py | 9 ++ docarray/array/storage/sqlite/backend.py | 62 ++++++++ docarray/array/storage/sqlite/getsetdel.py | 45 ++++++ docarray/array/storage/sqlite/mixin.py | 169 --------------------- docarray/array/storage/sqlite/seqlike.py | 56 +++++++ 11 files changed, 193 insertions(+), 174 deletions(-) create mode 100644 docarray/array/storage/base/backend.py create mode 100644 docarray/array/storage/sqlite/backend.py create mode 100644 docarray/array/storage/sqlite/getsetdel.py delete mode 100644 docarray/array/storage/sqlite/mixin.py create mode 100644 docarray/array/storage/sqlite/seqlike.py diff --git a/docarray/array/document.py b/docarray/array/document.py index ac3c9b25f83..19b24fd66b9 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -15,6 +15,10 @@ def __init__(self, *args, storage: str = 'memory', **kwargs): from .storage.memory import MemoryStorageMixins _extend_instance(self, MemoryStorageMixins) + elif storage == 'sqlite': + from .storage.sqlite import SqliteStorageMixins + + _extend_instance(self, SqliteStorageMixins) else: raise ValueError(f'storage=`{storage}` is not supported.') diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py new file mode 100644 index 00000000000..06ece2b12e6 --- /dev/null +++ b/docarray/array/storage/base/backend.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod + + +class BaseBackendMixin(ABC): + @abstractmethod + def _init_storage(self, *args, **kwargs): + ... diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 22bb8b849d9..66f74f0abec 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -38,19 +38,19 @@ def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: """This function is derived from :meth:`_get_doc_by_offset` Override this function if there is a more efficient logic""" - return (self._get_doc_by_offset(j) for j in range(len(self))[_slice]) + return (self._get_doc_by_offset(o) for o in range(len(self))[_slice]) def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: """This function is derived from :meth:`_get_doc_by_offset` Override this function if there is a more efficient logic""" - return (self._get_doc_by_offset(d) for d in offsets) + return (self._get_doc_by_offset(o) for o in offsets) def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: """This function is derived from :meth:`_get_doc_by_id` Override this function if there is a more efficient logic""" - return (self._get_doc_by_id(d) for d in ids) + return (self._get_doc_by_id(_id) for _id in ids) # Delitem APIs diff --git a/docarray/array/storage/memory/__init__.py b/docarray/array/storage/memory/__init__.py index e26a79127d7..772b4b04881 100644 --- a/docarray/array/storage/memory/__init__.py +++ b/docarray/array/storage/memory/__init__.py @@ -2,6 +2,8 @@ from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin +__all__ = ['MemoryStorageMixins'] + class MemoryStorageMixins(MemoryBackendMixin, GetSetDelMixin, SequenceLikeMixin): ... diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py index 45b44afd778..b1db1904e26 100644 --- a/docarray/array/storage/memory/backend.py +++ b/docarray/array/storage/memory/backend.py @@ -9,6 +9,7 @@ ) from .... import Document +from ..base.backend import BaseBackendMixin if TYPE_CHECKING: from ....types import ( @@ -16,7 +17,9 @@ ) -class MemoryBackendMixin: +class MemoryBackendMixin(BaseBackendMixin): + """Provide necessary functions to enable this storage backend. """ + @property def _id2offset(self) -> Dict[str, int]: """Return the `_id_to_index` map diff --git a/docarray/array/storage/memory/getsetdel.py b/docarray/array/storage/memory/getsetdel.py index 5adf2cefb67..8ea62471cce 100644 --- a/docarray/array/storage/memory/getsetdel.py +++ b/docarray/array/storage/memory/getsetdel.py @@ -10,7 +10,7 @@ class GetSetDelMixin(BaseGetSetDelMixin): - """Implement `getitem`, `setitem`, `delitem`""" + """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" def _del_docs_by_mask(self, mask: Sequence[bool]): self._data = list(itertools.compress(self._data, (not _i for _i in mask))) diff --git a/docarray/array/storage/sqlite/__init__.py b/docarray/array/storage/sqlite/__init__.py index e69de29bb2d..48e846a7e3f 100644 --- a/docarray/array/storage/sqlite/__init__.py +++ b/docarray/array/storage/sqlite/__init__.py @@ -0,0 +1,9 @@ +from .backend import SqliteBackendMixin +from .getsetdel import GetSetDelMixin +from .seqlike import SequenceLikeMixin + +__all__ = ['SqliteStorageMixins'] + + +class SqliteStorageMixins(SqliteBackendMixin, GetSetDelMixin, SequenceLikeMixin): + ... diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py new file mode 100644 index 00000000000..91b89250bd3 --- /dev/null +++ b/docarray/array/storage/sqlite/backend.py @@ -0,0 +1,62 @@ +import dataclasses +from dataclasses import dataclass +from typing import ( + Optional, + TYPE_CHECKING, + Union, + Dict, +) + +from ..base.backend import BaseBackendMixin + +if TYPE_CHECKING: + import sqlite3 + + from ....types import ( + DocumentArraySourceType, + ) + + +@dataclass +class SqliteConfig: + connection: Optional[Union[str, 'sqlite3.Connection']] = None + table_name: Optional[str] = None + serialize_config: Optional[Dict] = None + + +class SqliteBackendMixin(BaseBackendMixin): + """Provide necessary functions to enable this storage backend.""" + + @property + def schema_version(self) -> str: + return '0' + + def _sql(self, *arg, **kwargs) -> 'sqlite3.Cursor': + return self.connection.cursor().execute(*arg, **kwargs) + + def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): + if idx is None: + idx = len(self) + self._sql( + f'INSERT INTO {self.table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', + (doc.id, doc, idx), + ) + + def _shift_index_right_backward(self, start: int): + idx = len(self) - 1 + while idx >= start: + self._sql( + f"UPDATE {self.table_name} SET item_order = ? WHERE item_order = ?", + (idx + 1, idx), + ) + idx -= 1 + + def _init_storage( + self, + docs: Optional['DocumentArraySourceType'] = None, + config: Optional[SqliteConfig] = None, + ): + super().__init__(**(dataclasses.asdict(config) if config else {})) + if docs is not None: + self.clear() + self.extend(docs) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py new file mode 100644 index 00000000000..80b936ecc4e --- /dev/null +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -0,0 +1,45 @@ +import itertools +from typing import ( + Sequence, + Iterable, + Any, +) + +from ..base.getsetdel import BaseGetSetDelMixin +from .... import Document + + +class GetSetDelMixin(BaseGetSetDelMixin): + """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + + def _get_doc_by_offset(self, index: int) -> 'Document': + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE item_order = ?", + (index + (len(self) if index < 0 else 0),), + ) + res = r.fetchone() + if res is None: + raise IndexError('index out of range') + return res[0] + + def _get_doc_by_id(self, id: str) -> 'Document': + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE doc_id = ?", (id,) + ) + res = r.fetchone() + if res is None: + raise KeyError(f'Can not find Document with id=`{id}`') + return res[0] + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + l = len(self) + offsets = [o + (l if o < 0 else 0) for o in offsets] + r = self._sql( + f"SELECT serialized_value FROM {self.table_name} WHERE item_order in ({','.join(['?']*len(offsets))})", + offsets, + ) + for rr in r: + yield rr[0] + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._get_docs_by_offsets(range(len(self))[_slice]) diff --git a/docarray/array/storage/sqlite/mixin.py b/docarray/array/storage/sqlite/mixin.py deleted file mode 100644 index 53738d96f80..00000000000 --- a/docarray/array/storage/sqlite/mixin.py +++ /dev/null @@ -1,169 +0,0 @@ -import itertools -from dataclasses import dataclass -import dataclasses -from typing import ( - Optional, - TYPE_CHECKING, - Callable, - Union, - cast, - Iterable, - Dict, - Iterator, - Sequence, -) - -from .base import SqliteCollectionBase -from .dict import _DictDatabaseDriver -from ....helper import typename -import numpy as np - -if TYPE_CHECKING: - import sqlite3 - - from ....types import ( - T, - Document, - DocumentArraySourceType, - DocumentArrayIndexType, - DocumentArraySingletonIndexType, - DocumentArrayMultipleIndexType, - DocumentArrayMultipleAttributeType, - DocumentArraySingleAttributeType, - ) - - from docarray import DocumentArray - - -@dataclass -class SqliteConfig: - connection: Optional[Union[str, 'sqlite3.Connection']] = None - table_name: Optional[str] = None - serialize_config: Optional[Dict] = None - - -class SqliteMixin(SqliteCollectionBase): - """Enable SQLite persistence backend for DocumentArray. - - .. note:: - This has to be put in the first position when use it for subclassing - i.e. `class SqliteDA(SqliteMixin, DA)` not the other way around. - - """ - - _driver_class = _DictDatabaseDriver - - def __init__( - self, - docs: Optional['DocumentArraySourceType'] = None, - config: Optional[SqliteConfig] = None, - ): - super().__init__(**(dataclasses.asdict(config) if config else {})) - if docs is not None: - self.clear() - self.extend(docs) - - def insert(self, index: int, value: 'Document'): - """Insert `doc` at `index`. - - :param index: Position of the insertion. - :param value: The doc needs to be inserted. - """ - length = len(self) - if index < 0: - index = length + index - index = max(0, min(length, index)) - self._shift_index_right_backward(index) - self._insert_doc_at_idx(doc=value, idx=index) - self.connection.commit() - - def append(self, value: 'Document') -> None: - self._insert_doc_at_idx(value) - self.connection.commit() - - def extend(self, values: Iterable['Document']) -> None: - idx = len(self) - for v in values: - self._insert_doc_at_idx(v, idx) - idx += 1 - self.connection.commit() - - def clear(self) -> None: - self._sql(f'DELETE FROM {self.table_name}') - self.connection.commit() - - def __contains__(self, item: Union[str, 'Document']): - if isinstance(item, str): - r = self._sql(f"SELECT 1 FROM {self.table_name} WHERE doc_id=?", (item,)) - return len(list(r)) > 0 - elif isinstance(item, Document): - return item.id in self # fall back to str check - else: - return False - - def __len__(self) -> int: - r = self._sql(f'SELECT COUNT(*) FROM {self.table_name}') - return r.fetchone()[0] - - def __iter__(self) -> Iterator['Document']: - r = self._sql( - f'SELECT serialized_value FROM {self.table_name} ORDER BY item_order' - ) - for res in r: - yield res[0] - - def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: - return self._get_docs_by_offsets(range(len(self))[_slice]) - - def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: - l = len(self) - offsets = [o + (l if o < 0 else 0) for o in offsets] - r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE item_order in ({','.join(['?']*len(offsets))})", - offsets, - ) - for rr in r: - yield rr[0] - - def _get_doc_by_offset(self, index: int) -> 'Document': - r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE item_order = ?", - (index + (len(self) if index < 0 else 0),), - ) - res = r.fetchone() - if res is None: - raise IndexError('index out of range') - return res[0] - - def _get_doc_by_id(self, id: str) -> 'Document': - r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE doc_id = ?", (id,) - ) - res = r.fetchone() - if res is None: - raise KeyError(f'Can not find Document with id=`{id}`') - return res[0] - - @property - def schema_version(self) -> str: - return '0' - - def _sql(self, *arg, **kwargs) -> 'sqlite3.Cursor': - return self.connection.cursor().execute(*arg, **kwargs) - - def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): - if idx is None: - idx = len(self) - self._sql( - f'INSERT INTO {self.table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', - (doc.id, doc, idx), - ) - - def _shift_index_right_backward(self, start: int): - idx = len(self) - 1 - while idx >= start: - self._sql( - f"UPDATE {self.table_name} SET item_order = ? WHERE item_order = ?", - (idx + 1, idx), - ) - idx -= 1 diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py new file mode 100644 index 00000000000..17d0954967a --- /dev/null +++ b/docarray/array/storage/sqlite/seqlike.py @@ -0,0 +1,56 @@ +from typing import Iterator, Union, Sequence, Iterable, MutableSequence + +from .... import Document + + +class SequenceLikeMixin(MutableSequence[Document]): + """Implement sequence-like methods""" + + def insert(self, index: int, value: 'Document'): + """Insert `doc` at `index`. + + :param index: Position of the insertion. + :param value: The doc needs to be inserted. + """ + length = len(self) + if index < 0: + index = length + index + index = max(0, min(length, index)) + self._shift_index_right_backward(index) + self._insert_doc_at_idx(doc=value, idx=index) + self.connection.commit() + + def append(self, value: 'Document') -> None: + self._insert_doc_at_idx(value) + self.connection.commit() + + def extend(self, values: Iterable['Document']) -> None: + idx = len(self) + for v in values: + self._insert_doc_at_idx(v, idx) + idx += 1 + self.connection.commit() + + def clear(self) -> None: + self._sql(f'DELETE FROM {self.table_name}') + self.connection.commit() + + def __contains__(self, item: Union[str, 'Document']): + if isinstance(item, str): + r = self._sql(f"SELECT 1 FROM {self.table_name} WHERE doc_id=?", (item,)) + return len(list(r)) > 0 + elif isinstance(item, Document): + return item.id in self # fall back to str check + else: + return False + + def __len__(self) -> int: + r = self._sql(f'SELECT COUNT(*) FROM {self.table_name}') + return r.fetchone()[0] + + def __iter__(self) -> Iterator['Document']: + r = self._sql( + f'SELECT serialized_value FROM {self.table_name} ORDER BY item_order' + ) + for res in r: + yield res[0] From 34ad9e150bfa19a3a469cbc74c99de453edebe01 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 19 Jan 2022 12:20:19 +0100 Subject: [PATCH 09/55] feat(array): add storage backend --- docarray/array/storage/memory/__init__.py | 3 ++- docarray/array/storage/sqlite/__init__.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docarray/array/storage/memory/__init__.py b/docarray/array/storage/memory/__init__.py index 772b4b04881..418a1dc5719 100644 --- a/docarray/array/storage/memory/__init__.py +++ b/docarray/array/storage/memory/__init__.py @@ -1,9 +1,10 @@ from .backend import MemoryBackendMixin from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin +from abc import ABC __all__ = ['MemoryStorageMixins'] -class MemoryStorageMixins(MemoryBackendMixin, GetSetDelMixin, SequenceLikeMixin): +class MemoryStorageMixins(MemoryBackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): ... diff --git a/docarray/array/storage/sqlite/__init__.py b/docarray/array/storage/sqlite/__init__.py index 48e846a7e3f..37a98289f39 100644 --- a/docarray/array/storage/sqlite/__init__.py +++ b/docarray/array/storage/sqlite/__init__.py @@ -1,9 +1,10 @@ from .backend import SqliteBackendMixin from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin +from abc import ABC __all__ = ['SqliteStorageMixins'] -class SqliteStorageMixins(SqliteBackendMixin, GetSetDelMixin, SequenceLikeMixin): +class SqliteStorageMixins(SqliteBackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): ... From 1621b94b357e879bf02cd9079019f3292b57e1d5 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 20 Jan 2022 08:47:54 +0100 Subject: [PATCH 10/55] feat(array): add storage backend --- docarray/array/base.py | 7 ++++++ docarray/array/document.py | 26 ++++------------------- docarray/array/sqlite.py | 7 ++++++ docarray/array/storage/memory/__init__.py | 6 +++--- docarray/array/storage/memory/backend.py | 2 +- docarray/array/storage/sqlite/__init__.py | 6 +++--- docarray/array/storage/sqlite/backend.py | 2 +- 7 files changed, 26 insertions(+), 30 deletions(-) create mode 100644 docarray/array/base.py create mode 100644 docarray/array/sqlite.py diff --git a/docarray/array/base.py b/docarray/array/base.py new file mode 100644 index 00000000000..cf0e85f3564 --- /dev/null +++ b/docarray/array/base.py @@ -0,0 +1,7 @@ +from abc import ABC + + +class BaseDocumentArray(ABC): + def __init__(self, *args, **kwargs): + super().__init__() + self._init_storage(*args, **kwargs) diff --git a/docarray/array/document.py b/docarray/array/document.py index 19b24fd66b9..02dbd6b36a7 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,25 +1,7 @@ +from .base import BaseDocumentArray from .mixins import AllMixins +from .storage.memory import StorageMixins -def _extend_instance(obj, cls): - """Apply mixins to a class instance after creation""" - base_cls = obj.__class__ - base_cls_name = obj.__class__.__name__ - obj.__class__ = type(base_cls_name, (base_cls, cls), {}) - - -class DocumentArray(AllMixins): - def __init__(self, *args, storage: str = 'memory', **kwargs): - super().__init__() - if storage == 'memory': - from .storage.memory import MemoryStorageMixins - - _extend_instance(self, MemoryStorageMixins) - elif storage == 'sqlite': - from .storage.sqlite import SqliteStorageMixins - - _extend_instance(self, SqliteStorageMixins) - else: - raise ValueError(f'storage=`{storage}` is not supported.') - - self._init_storage(*args, **kwargs) +class DocumentArray(StorageMixins, AllMixins, BaseDocumentArray): + ... diff --git a/docarray/array/sqlite.py b/docarray/array/sqlite.py new file mode 100644 index 00000000000..2c6c9512951 --- /dev/null +++ b/docarray/array/sqlite.py @@ -0,0 +1,7 @@ +from .base import BaseDocumentArray +from .mixins import AllMixins +from .storage.sqlite import StorageMixins + + +class DocumentArraySqlite(StorageMixins, AllMixins, BaseDocumentArray): + ... diff --git a/docarray/array/storage/memory/__init__.py b/docarray/array/storage/memory/__init__.py index 418a1dc5719..f07b096a031 100644 --- a/docarray/array/storage/memory/__init__.py +++ b/docarray/array/storage/memory/__init__.py @@ -1,10 +1,10 @@ -from .backend import MemoryBackendMixin +from .backend import BackendMixin from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin from abc import ABC -__all__ = ['MemoryStorageMixins'] +__all__ = ['StorageMixins'] -class MemoryStorageMixins(MemoryBackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): +class StorageMixins(BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): ... diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py index b1db1904e26..30252ccb658 100644 --- a/docarray/array/storage/memory/backend.py +++ b/docarray/array/storage/memory/backend.py @@ -17,7 +17,7 @@ ) -class MemoryBackendMixin(BaseBackendMixin): +class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend. """ @property diff --git a/docarray/array/storage/sqlite/__init__.py b/docarray/array/storage/sqlite/__init__.py index 37a98289f39..f07b096a031 100644 --- a/docarray/array/storage/sqlite/__init__.py +++ b/docarray/array/storage/sqlite/__init__.py @@ -1,10 +1,10 @@ -from .backend import SqliteBackendMixin +from .backend import BackendMixin from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin from abc import ABC -__all__ = ['SqliteStorageMixins'] +__all__ = ['StorageMixins'] -class SqliteStorageMixins(SqliteBackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): +class StorageMixins(BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): ... diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 91b89250bd3..6a719181675 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -24,7 +24,7 @@ class SqliteConfig: serialize_config: Optional[Dict] = None -class SqliteBackendMixin(BaseBackendMixin): +class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" @property From c8800b479433b92a3a6a9925e28f38839b5c0690 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 20 Jan 2022 16:37:04 +0100 Subject: [PATCH 11/55] feat(array): add storage backend --- docarray/array/storage/sqlite/backend.py | 88 +++++--- docarray/array/storage/sqlite/base.py | 227 --------------------- docarray/array/storage/sqlite/dict.py | 90 -------- docarray/array/storage/sqlite/getsetdel.py | 35 +++- docarray/array/storage/sqlite/helper.py | 80 ++++++++ docarray/array/storage/sqlite/seqlike.py | 43 +++- 6 files changed, 199 insertions(+), 364 deletions(-) delete mode 100644 docarray/array/storage/sqlite/base.py delete mode 100644 docarray/array/storage/sqlite/dict.py create mode 100644 docarray/array/storage/sqlite/helper.py diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 6a719181675..3e41558081a 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -1,5 +1,7 @@ -import dataclasses -from dataclasses import dataclass +import sqlite3 +import warnings +from dataclasses import dataclass, field +from tempfile import NamedTemporaryFile from typing import ( Optional, TYPE_CHECKING, @@ -7,56 +9,84 @@ Dict, ) +from .helper import initialize_table from ..base.backend import BaseBackendMixin +from ....helper import random_identity if TYPE_CHECKING: - import sqlite3 - from ....types import ( DocumentArraySourceType, ) +def _sanitize_table_name(table_name: str) -> str: + ret = ''.join(c for c in table_name if c.isalnum() or c == '_') + if ret != table_name: + warnings.warn(f'The table name is changed to {ret} due to illegal characters') + return ret + + @dataclass class SqliteConfig: connection: Optional[Union[str, 'sqlite3.Connection']] = None table_name: Optional[str] = None - serialize_config: Optional[Dict] = None + serialize_config: Dict = field(default_factory=dict) class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" - @property - def schema_version(self) -> str: - return '0' + schema_version = '0' + + def _sql(self, *args, **kwargs) -> 'sqlite3.Cursor': + return self._cursor.execute(*args, **kwargs) + + def _commit(self): + self._connection.commit() + + def _init_storage( + self, + docs: Optional['DocumentArraySourceType'] = None, + config: Optional[SqliteConfig] = None, + ): + if not config: + config = SqliteConfig() - def _sql(self, *arg, **kwargs) -> 'sqlite3.Cursor': - return self.connection.cursor().execute(*arg, **kwargs) + from docarray import Document - def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): - if idx is None: - idx = len(self) - self._sql( - f'INSERT INTO {self.table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', - (doc.id, doc, idx), + sqlite3.register_adapter( + Document, lambda d: d.to_bytes(**config.serialize_config) + ) + sqlite3.register_converter( + 'Document', lambda x: Document.from_bytes(x, **config.serialize_config) ) - def _shift_index_right_backward(self, start: int): - idx = len(self) - 1 - while idx >= start: - self._sql( - f"UPDATE {self.table_name} SET item_order = ? WHERE item_order = ?", - (idx + 1, idx), + _conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES) + if config.connection is None: + self._connection = sqlite3.connect( + NamedTemporaryFile().name, **_conn_kwargs + ) + elif isinstance(config.connection, str): + self._connection = sqlite3.connect(config.connection, **_conn_kwargs) + elif isinstance(config.connection, sqlite3.Connection): + self._connection = config.connection + else: + raise TypeError( + f'connection argument must be None or a string or a sqlite3.Connection, not `{type(connection)}`' ) - idx -= 1 - def _init_storage( - self, - docs: Optional['DocumentArraySourceType'] = None, - config: Optional[SqliteConfig] = None, - ): - super().__init__(**(dataclasses.asdict(config) if config else {})) + self._table_name = ( + _sanitize_table_name(self.__class__.__name__ + random_identity()) + if config.table_name is None + else _sanitize_table_name(config.table_name) + ) + self._cursor = self._connection.cursor() + self._persist = not config.table_name + initialize_table( + self._table_name, self.__class__.__name__, self.schema_version, self._cursor + ) + self._connection.commit() + if docs is not None: self.clear() self.extend(docs) diff --git a/docarray/array/storage/sqlite/base.py b/docarray/array/storage/sqlite/base.py deleted file mode 100644 index 80471318de0..00000000000 --- a/docarray/array/storage/sqlite/base.py +++ /dev/null @@ -1,227 +0,0 @@ -import sqlite3 -import warnings -from abc import ABCMeta, abstractmethod -from tempfile import NamedTemporaryFile -from typing import Callable, Generic, Optional, TypeVar, Union, Dict -from uuid import uuid4 - -T = TypeVar('T') -KT = TypeVar('KT') -VT = TypeVar('VT') -_T = TypeVar('_T') -_S = TypeVar('_S') - - -def sanitize_table_name(table_name: str) -> str: - ret = ''.join(c for c in table_name if c.isalnum() or c == '_') - if ret != table_name: - warnings.warn(f'The table name is changed to {ret} due to illegal characters') - return ret - - -def create_random_name(suffix: str) -> str: - return f"{suffix}_{str(uuid4()).replace('-', '')}" - - -class _SqliteCollectionBaseDatabaseDriver(metaclass=ABCMeta): - @classmethod - def initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: - if not cls.is_metadata_table_initialized(cur): - cls.do_initialize_metadata_table(cur) - - @classmethod - def is_metadata_table_initialized(cls, cur: sqlite3.Cursor) -> bool: - try: - cur.execute('SELECT 1 FROM metadata LIMIT 1') - _ = list(cur) - return True - except sqlite3.OperationalError as _: - pass - return False - - @classmethod - def do_initialize_metadata_table(cls, cur: sqlite3.Cursor) -> None: - cur.execute( - ''' - CREATE TABLE metadata ( - table_name TEXT PRIMARY KEY, - schema_version TEXT NOT NULL, - container_type TEXT NOT NULL, - UNIQUE (table_name, container_type) - ) - ''' - ) - - @classmethod - def initialize_table( - cls, - table_name: str, - container_type_name: str, - schema_version: str, - cur: sqlite3.Cursor, - ) -> None: - if not cls.is_table_initialized( - table_name, container_type_name, schema_version, cur - ): - cls.do_create_table(table_name, container_type_name, schema_version, cur) - cls.do_tidy_table_metadata( - table_name, container_type_name, schema_version, cur - ) - - @classmethod - def is_table_initialized( - self, - table_name: str, - container_type_name: str, - schema_version: str, - cur: sqlite3.Cursor, - ) -> bool: - try: - cur.execute( - 'SELECT schema_version FROM metadata WHERE table_name=? AND container_type=?', - (table_name, container_type_name), - ) - buf = cur.fetchone() - if buf is None: - return False - version = buf[0] - if version != schema_version: - return False - cur.execute(f'SELECT 1 FROM {table_name} LIMIT 1') - _ = list(cur) - return True - except sqlite3.OperationalError as _: - pass - return False - - @classmethod - def do_tidy_table_metadata( - cls, - table_name: str, - container_type_name: str, - schema_version: str, - cur: sqlite3.Cursor, - ) -> None: - cur.execute( - 'INSERT INTO metadata (table_name, schema_version, container_type) VALUES (?, ?, ?)', - (table_name, schema_version, container_type_name), - ) - - @classmethod - @abstractmethod - def do_create_table( - cls, - table_name: str, - container_type_name: str, - schema_version: str, - cur: sqlite3.Cursor, - ) -> None: - ... - - @classmethod - def drop_table( - cls, table_name: str, container_type_name: str, cur: sqlite3.Cursor - ) -> None: - cur.execute( - 'DELETE FROM metadata WHERE table_name=? AND container_type=?', - (table_name, container_type_name), - ) - cur.execute(f'DROP TABLE {table_name}') - - @classmethod - def alter_table_name( - cls, table_name: str, new_table_name: str, cur: sqlite3.Cursor - ) -> None: - cur.execute( - 'UPDATE metadata SET table_name=? WHERE table_name=?', - (new_table_name, table_name), - ) - cur.execute(f'ALTER TABLE {table_name} RENAME TO {new_table_name}') - - -class SqliteCollectionBase(Generic[T], metaclass=ABCMeta): - _driver_class = _SqliteCollectionBaseDatabaseDriver - - def __init__( - self, - connection: Optional[Union[str, sqlite3.Connection]] = None, - table_name: Optional[str] = None, - serialize_config: Optional[Dict] = None, - ): - super(SqliteCollectionBase, self).__init__() - self._serialize_config = serialize_config or {} - self._persist = not table_name - - from docarray import Document - - sqlite3.register_adapter( - Document, lambda d: d.to_bytes(**self._serialize_config) - ) - sqlite3.register_converter( - 'Document', lambda x: Document.from_bytes(x, **self._serialize_config) - ) - - _conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES) - if connection is None: - self._connection = sqlite3.connect( - NamedTemporaryFile().name, **_conn_kwargs - ) - elif isinstance(connection, str): - self._connection = sqlite3.connect(connection, **_conn_kwargs) - elif isinstance(connection, sqlite3.Connection): - self._connection = connection - else: - raise TypeError( - f'connection argument must be None or a string or a sqlite3.Connection, not `{type(connection)}`' - ) - - self._table_name = ( - sanitize_table_name(create_random_name(self.container_type_name)) - if table_name is None - else sanitize_table_name(table_name) - ) - cur = self.connection.cursor() - self._driver_class.initialize_metadata_table(cur) - self._driver_class.initialize_table( - self.table_name, self.container_type_name, self.schema_version, cur - ) - self.connection.commit() - - def __del__(self) -> None: - if not self._persist: - cur = self.connection.cursor() - self._driver_class.drop_table( - self.table_name, self.container_type_name, cur - ) - self.connection.commit() - - @property - def table_name(self) -> str: - return self._table_name - - @table_name.setter - def table_name(self, table_name: str) -> None: - cur = self.connection.cursor() - new_table_name = sanitize_table_name(table_name) - try: - self._driver_class.alter_table_name(self.table_name, new_table_name, cur) - self._table_name = new_table_name - except sqlite3.IntegrityError as ex: - raise ValueError(table_name) from ex - - @property - def connection(self) -> sqlite3.Connection: - return self._connection - - @property - def cursor(self) -> sqlite3.Cursor: - return self.connection.cursor() - - @property - def container_type_name(self) -> str: - return self.__class__.__name__ - - @property - @abstractmethod - def schema_version(self) -> str: - ... diff --git a/docarray/array/storage/sqlite/dict.py b/docarray/array/storage/sqlite/dict.py deleted file mode 100644 index b63fd778b60..00000000000 --- a/docarray/array/storage/sqlite/dict.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Tuple, Union, cast, Iterable, TYPE_CHECKING - -from .base import _SqliteCollectionBaseDatabaseDriver - -if TYPE_CHECKING: - import sqlite3 - - -class _DictDatabaseDriver(_SqliteCollectionBaseDatabaseDriver): - @classmethod - def do_create_table( - cls, - table_name: str, - container_type_nam: str, - schema_version: str, - cur: 'sqlite3.Cursor', - ) -> None: - cur.execute( - f'CREATE TABLE {table_name} (' - 'doc_id TEXT NOT NULL UNIQUE, ' - 'serialized_value Document NOT NULL, ' - 'item_order INTEGER PRIMARY KEY)' - ) - - @classmethod - def delete_single_record_by_doc_id( - cls, table_name: str, cur: 'sqlite3.Cursor', doc_id: str - ) -> None: - cur.execute(f'DELETE FROM {table_name} WHERE doc_id=?', (doc_id,)) - - @classmethod - def get_next_order(cls, table_name: str, cur: 'sqlite3.Cursor') -> int: - cur.execute(f'SELECT MAX(item_order) FROM {table_name}') - res = cur.fetchone()[0] - if res is None: - return 0 - return cast(int, res) + 1 - - @classmethod - def get_doc_ids(cls, table_name: str, cur: 'sqlite3.Cursor') -> Iterable[str]: - cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order') - for res in cur: - yield res[0] - - @classmethod - def update_serialized_value_by_doc_id( - cls, - table_name: str, - cur: 'sqlite3.Cursor', - doc_id: str, - serialized_value: bytes, - ) -> None: - cur.execute( - f'UPDATE {table_name} SET serialized_value=? WHERE doc_id=?', - (serialized_value, doc_id), - ) - - @classmethod - def upsert( - cls, - table_name: str, - cur: 'sqlite3.Cursor', - doc_id: str, - serialized_value: bytes, - ) -> None: - if cls.is_doc_id_in(table_name, cur, doc_id): - cls.update_serialized_value_by_doc_id( - table_name, cur, doc_id, serialized_value - ) - else: - cls.insert_serialized_value_by_doc_id( - table_name, cur, doc_id, serialized_value - ) - - @classmethod - def get_last_serialized_item( - cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Tuple[str, bytes]: - cur.execute( - f'SELECT doc_id, serialized_value FROM {table_name} ORDER BY item_order DESC LIMIT 1' - ) - return cast(Tuple[str, bytes], cur.fetchone()) - - @classmethod - def get_reversed_doc_ids( - cls, table_name: str, cur: 'sqlite3.Cursor' - ) -> Iterable[str]: - cur.execute(f'SELECT doc_id FROM {table_name} ORDER BY item_order DESC') - for res in cur: - yield cast(str, res[0]) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 80b936ecc4e..c4a5dacd0f2 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -1,9 +1,4 @@ -import itertools -from typing import ( - Sequence, - Iterable, - Any, -) +from typing import Sequence, Iterable from ..base.getsetdel import BaseGetSetDelMixin from .... import Document @@ -12,9 +7,31 @@ class GetSetDelMixin(BaseGetSetDelMixin): """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + def _del_doc_by_id(self, _id: str): + self._sql(f'DELETE FROM {self._table_name} WHERE doc_id=?', (_id,)) + self._commit() + + def _del_doc_by_offset(self, offset: int): + self._sql(f'DELETE FROM {self._table_name} WHERE item_order=?', (offset,)) + self._commit() + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + self._sql( + f'UPDATE {self._table_name} SET serialized_value=? WHERE item_order=?', + (offset, value), + ) + self._commit() + + def _set_doc_by_id(self, _id: str, value: 'Document'): + self._sql( + f'UPDATE {self._table_name} SET serialized_value=? WHERE doc_id=?', + (_id, value), + ) + self._commit() + def _get_doc_by_offset(self, index: int) -> 'Document': r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE item_order = ?", + f'SELECT serialized_value FROM {self._table_name} WHERE item_order = ?', (index + (len(self) if index < 0 else 0),), ) res = r.fetchone() @@ -24,7 +41,7 @@ def _get_doc_by_offset(self, index: int) -> 'Document': def _get_doc_by_id(self, id: str) -> 'Document': r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE doc_id = ?", (id,) + f'SELECT serialized_value FROM {self._table_name} WHERE doc_id = ?', (id,) ) res = r.fetchone() if res is None: @@ -35,7 +52,7 @@ def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: l = len(self) offsets = [o + (l if o < 0 else 0) for o in offsets] r = self._sql( - f"SELECT serialized_value FROM {self.table_name} WHERE item_order in ({','.join(['?']*len(offsets))})", + f"SELECT serialized_value FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", offsets, ) for rr in r: diff --git a/docarray/array/storage/sqlite/helper.py b/docarray/array/storage/sqlite/helper.py new file mode 100644 index 00000000000..5eb05a09c47 --- /dev/null +++ b/docarray/array/storage/sqlite/helper.py @@ -0,0 +1,80 @@ +import sqlite3 + + +def initialize_table( + table_name: str, container_type_name: str, schema_version: str, cur: sqlite3.Cursor +) -> None: + if not _is_metadata_table_initialized(cur): + _do_initialize_metadata_table(cur) + + if not _is_table_initialized(table_name, container_type_name, schema_version, cur): + _do_create_table(table_name, cur) + _do_tidy_table_metadata(table_name, container_type_name, schema_version, cur) + + +def _is_metadata_table_initialized(cur: sqlite3.Cursor) -> bool: + try: + cur.execute('SELECT 1 FROM metadata LIMIT 1') + _ = list(cur) + return True + except sqlite3.OperationalError as _: + pass + return False + + +def _do_initialize_metadata_table(cur: sqlite3.Cursor) -> None: + cur.execute( + ''' + CREATE TABLE metadata ( + table_name TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + container_type TEXT NOT NULL, + UNIQUE (table_name, container_type) + ) + ''' + ) + + +def _do_create_table( + table_name: str, + cur: 'sqlite3.Cursor', +) -> None: + cur.execute( + f''' + CREATE TABLE {table_name} ( + doc_id TEXT NOT NULL UNIQUE, + serialized_value Document NOT NULL, + item_order INTEGER PRIMARY KEY) + ''' + ) + + +def _is_table_initialized( + table_name: str, container_type_name: str, schema_version: str, cur: sqlite3.Cursor +) -> bool: + try: + cur.execute( + 'SELECT schema_version FROM metadata WHERE table_name=? AND container_type=?', + (table_name, container_type_name), + ) + buf = cur.fetchone() + if buf is None: + return False + version = buf[0] + if version != schema_version: + return False + cur.execute(f'SELECT 1 FROM {table_name} LIMIT 1') + _ = list(cur) + return True + except sqlite3.OperationalError as _: + pass + return False + + +def _do_tidy_table_metadata( + table_name: str, container_type_name: str, schema_version: str, cur: sqlite3.Cursor +) -> None: + cur.execute( + 'INSERT INTO metadata (table_name, schema_version, container_type) VALUES (?, ?, ?)', + (table_name, schema_version, container_type_name), + ) diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index 17d0954967a..4edf1e523c2 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -1,4 +1,4 @@ -from typing import Iterator, Union, Sequence, Iterable, MutableSequence +from typing import Iterator, Union, Iterable, MutableSequence, Optional from .... import Document @@ -6,6 +6,23 @@ class SequenceLikeMixin(MutableSequence[Document]): """Implement sequence-like methods""" + def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): + if idx is None: + idx = len(self) + self._sql( + f'INSERT INTO {self._table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', + (doc.id, doc, idx), + ) + + def _shift_index_right_backward(self, start: int): + idx = len(self) - 1 + while idx >= start: + self._sql( + f'UPDATE {self._table_name} SET item_order = ? WHERE item_order = ?', + (idx + 1, idx), + ) + idx -= 1 + def insert(self, index: int, value: 'Document'): """Insert `doc` at `index`. @@ -18,26 +35,34 @@ def insert(self, index: int, value: 'Document'): index = max(0, min(length, index)) self._shift_index_right_backward(index) self._insert_doc_at_idx(doc=value, idx=index) - self.connection.commit() + self._commit() def append(self, value: 'Document') -> None: self._insert_doc_at_idx(value) - self.connection.commit() + self._commit() def extend(self, values: Iterable['Document']) -> None: idx = len(self) for v in values: self._insert_doc_at_idx(v, idx) idx += 1 - self.connection.commit() + self._commit() def clear(self) -> None: - self._sql(f'DELETE FROM {self.table_name}') - self.connection.commit() + self._sql(f'DELETE FROM {self._table_name}') + self._commit() + + def __del__(self) -> None: + if not self._persist: + self._sql( + 'DELETE FROM metadata WHERE table_name=? AND container_type=?', + (self._table_name, self.__class__.__name__), + ) + self._sql(f'DROP TABLE {self._table_name}') def __contains__(self, item: Union[str, 'Document']): if isinstance(item, str): - r = self._sql(f"SELECT 1 FROM {self.table_name} WHERE doc_id=?", (item,)) + r = self._sql(f'SELECT 1 FROM {self._table_name} WHERE doc_id=?', (item,)) return len(list(r)) > 0 elif isinstance(item, Document): return item.id in self # fall back to str check @@ -45,12 +70,12 @@ def __contains__(self, item: Union[str, 'Document']): return False def __len__(self) -> int: - r = self._sql(f'SELECT COUNT(*) FROM {self.table_name}') + r = self._sql(f'SELECT COUNT(*) FROM {self._table_name}') return r.fetchone()[0] def __iter__(self) -> Iterator['Document']: r = self._sql( - f'SELECT serialized_value FROM {self.table_name} ORDER BY item_order' + f'SELECT serialized_value FROM {self._table_name} ORDER BY item_order' ) for res in r: yield res[0] From 6029a1436ecaf6b975f2e82656c93eea3bd50f31 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 20 Jan 2022 16:37:37 +0100 Subject: [PATCH 12/55] feat(array): add storage backend --- docarray/array/storage/sqlite/backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 3e41558081a..2773a0d81e8 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -45,9 +45,9 @@ def _commit(self): self._connection.commit() def _init_storage( - self, - docs: Optional['DocumentArraySourceType'] = None, - config: Optional[SqliteConfig] = None, + self, + docs: Optional['DocumentArraySourceType'] = None, + config: Optional[SqliteConfig] = None, ): if not config: config = SqliteConfig() From 8936d776eafbd71c5dfb99521b170da215a7ad19 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 20 Jan 2022 16:58:53 +0100 Subject: [PATCH 13/55] feat(array): add storage backend --- docarray/array/document.py | 8 ++------ docarray/array/memory.py | 7 +++++++ docarray/array/storage/memory/seqlike.py | 3 +-- docarray/array/storage/sqlite/getsetdel.py | 18 ++++++++++++++++++ docarray/array/storage/sqlite/seqlike.py | 3 +-- 5 files changed, 29 insertions(+), 10 deletions(-) create mode 100644 docarray/array/memory.py diff --git a/docarray/array/document.py b/docarray/array/document.py index 02dbd6b36a7..f39e9ad4099 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,7 +1,3 @@ -from .base import BaseDocumentArray -from .mixins import AllMixins -from .storage.memory import StorageMixins +from .memory import DocumentArrayMemory as DocumentArray - -class DocumentArray(StorageMixins, AllMixins, BaseDocumentArray): - ... +__all__ = ['DocumentArray'] diff --git a/docarray/array/memory.py b/docarray/array/memory.py new file mode 100644 index 00000000000..96d2d0d788a --- /dev/null +++ b/docarray/array/memory.py @@ -0,0 +1,7 @@ +from .base import BaseDocumentArray +from .mixins import AllMixins +from .storage.memory import StorageMixins + + +class DocumentArrayMemory(StorageMixins, AllMixins, BaseDocumentArray): + ... diff --git a/docarray/array/storage/memory/seqlike.py b/docarray/array/storage/memory/seqlike.py index 3429ec7caae..bc625d1b201 100644 --- a/docarray/array/storage/memory/seqlike.py +++ b/docarray/array/storage/memory/seqlike.py @@ -38,8 +38,7 @@ def __contains__(self, x: Union[str, 'Document']): def clear(self): """Clear the data of :class:`DocumentArray`""" - self._data.clear() - self._id2offset.clear() + self._del_all_docs() def __bool__(self): """To simulate ```l = []; if l: ...``` diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index c4a5dacd0f2..5dfd55bb929 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -7,6 +7,8 @@ class GetSetDelMixin(BaseGetSetDelMixin): """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + # essential methods start + def _del_doc_by_id(self, _id: str): self._sql(f'DELETE FROM {self._table_name} WHERE doc_id=?', (_id,)) self._commit() @@ -48,6 +50,10 @@ def _get_doc_by_id(self, id: str) -> 'Document': raise KeyError(f'Can not find Document with id=`{id}`') return res[0] + # essentials end here + + # now start the optimized bulk methods + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: l = len(self) offsets = [o + (l if o < 0 else 0) for o in offsets] @@ -60,3 +66,15 @@ def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: return self._get_docs_by_offsets(range(len(self))[_slice]) + + def _del_all_docs(self): + self._sql(f'DELETE FROM {self._table_name}') + self._commit() + + def _del_docs_by_slice(self, _slice: slice): + offsets = range(len(self))[_slice] + self._sql( + f"DELETE FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", + offsets, + ) + self._commit() diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index 4edf1e523c2..9558cd369f3 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -49,8 +49,7 @@ def extend(self, values: Iterable['Document']) -> None: self._commit() def clear(self) -> None: - self._sql(f'DELETE FROM {self._table_name}') - self._commit() + self._del_all_docs() def __del__(self) -> None: if not self._persist: From 58da76e56067e5552f9113d68c7e4df19e52a52c Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 20 Jan 2022 17:16:26 +0100 Subject: [PATCH 14/55] feat(array): add storage backend --- docarray/array/document.py | 2 +- docarray/array/memory.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/document.py b/docarray/array/document.py index f39e9ad4099..616a1ef1524 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,3 +1,3 @@ -from .memory import DocumentArrayMemory as DocumentArray +from .memory import DocumentArrayInMemory as DocumentArray __all__ = ['DocumentArray'] diff --git a/docarray/array/memory.py b/docarray/array/memory.py index 96d2d0d788a..349fb7093fe 100644 --- a/docarray/array/memory.py +++ b/docarray/array/memory.py @@ -3,5 +3,5 @@ from .storage.memory import StorageMixins -class DocumentArrayMemory(StorageMixins, AllMixins, BaseDocumentArray): +class DocumentArrayInMemory(StorageMixins, AllMixins, BaseDocumentArray): ... From 80193cdcc17b35e90863cab77340c8c64eda27eb Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 21 Jan 2022 10:36:05 +0100 Subject: [PATCH 15/55] feat: implement sugar API --- docarray/array/base.py | 19 ++++++++++++++++++- docarray/array/chunk.py | 3 ++- docarray/array/document.py | 2 +- docarray/array/match.py | 6 ++++-- docarray/array/memory.py | 16 ++++++++++++---- docarray/array/sqlite.py | 8 ++++---- 6 files changed, 41 insertions(+), 13 deletions(-) diff --git a/docarray/array/base.py b/docarray/array/base.py index cf0e85f3564..37340caee62 100644 --- a/docarray/array/base.py +++ b/docarray/array/base.py @@ -1,7 +1,24 @@ from abc import ABC +from docarray.array.mixins import AllMixins class BaseDocumentArray(ABC): - def __init__(self, *args, **kwargs): + def __init__(self, *args, storage: str = 'memory', **kwargs): super().__init__() self._init_storage(*args, **kwargs) + + +class DocumentArray(AllMixins, BaseDocumentArray): + def __new__(cls, *args, storage: str = 'memory', **kwargs): + if cls is DocumentArray: + if storage == 'memory': + from docarray.array.memory import DocumentArrayInMemory + instance = super().__new__(DocumentArrayInMemory) + elif storage == 'sqlite': + from docarray.array.sqlite import DocumentArraySqlite + instance = super().__new__(DocumentArraySqlite) + else: + raise ValueError(f'storage=`{storage}` is not supported.') + else: + instance = super().__new__(cls) + return instance diff --git a/docarray/array/chunk.py b/docarray/array/chunk.py index ea445702f62..246b162152d 100644 --- a/docarray/array/chunk.py +++ b/docarray/array/chunk.py @@ -7,12 +7,13 @@ ) from .document import DocumentArray +from .memory import DocumentArrayInMemory if TYPE_CHECKING: from ..document import Document -class ChunkArray(DocumentArray): +class ChunkArray(DocumentArrayInMemory): """ :class:`ChunkArray` inherits from :class:`DocumentArray`. It's a subset of Documents. diff --git a/docarray/array/document.py b/docarray/array/document.py index 616a1ef1524..e275c50d71a 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,3 +1,3 @@ -from .memory import DocumentArrayInMemory as DocumentArray +from .base import DocumentArray __all__ = ['DocumentArray'] diff --git a/docarray/array/match.py b/docarray/array/match.py index 2b33828f3fb..fca154d66c6 100644 --- a/docarray/array/match.py +++ b/docarray/array/match.py @@ -6,13 +6,15 @@ Sequence, ) -from .. import DocumentArray +from docarray import DocumentArray + +from .memory import DocumentArrayInMemory if TYPE_CHECKING: from ..document import Document -class MatchArray(DocumentArray): +class MatchArray(DocumentArrayInMemory): """ :class:`MatchArray` inherits from :class:`DocumentArray`. It's a subset of Documents that represents the matches diff --git a/docarray/array/memory.py b/docarray/array/memory.py index 349fb7093fe..7c02c8cc244 100644 --- a/docarray/array/memory.py +++ b/docarray/array/memory.py @@ -1,7 +1,15 @@ -from .base import BaseDocumentArray -from .mixins import AllMixins +from .base import DocumentArray from .storage.memory import StorageMixins -class DocumentArrayInMemory(StorageMixins, AllMixins, BaseDocumentArray): - ... +class DocumentArrayInMemory(StorageMixins, DocumentArray): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) +""" + def __delitem__(self): + pass + def __getitem__(self, item): + print('getting item') + pass + def __setitem__(self): + pass""" \ No newline at end of file diff --git a/docarray/array/sqlite.py b/docarray/array/sqlite.py index 2c6c9512951..0ca77ec5ba1 100644 --- a/docarray/array/sqlite.py +++ b/docarray/array/sqlite.py @@ -1,7 +1,7 @@ -from .base import BaseDocumentArray -from .mixins import AllMixins +from .base import DocumentArray from .storage.sqlite import StorageMixins -class DocumentArraySqlite(StorageMixins, AllMixins, BaseDocumentArray): - ... +class DocumentArraySqlite(StorageMixins, DocumentArray): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) From 8ea8935b9afaa9b6152252afe147a30fdb03d1fd Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 21 Jan 2022 11:32:35 +0100 Subject: [PATCH 16/55] test: add sqlite tests --- tests/unit/array/test_advance_indexing.py | 208 +++++++++++++--------- 1 file changed, 121 insertions(+), 87 deletions(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 77c72a8a9fe..93efda18900 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -1,3 +1,6 @@ +import os +import time + import numpy as np import pytest @@ -5,90 +8,100 @@ @pytest.fixture -def docarray100(): - yield DocumentArray(Document(text=j) for j in range(100)) +def docs(): + yield (Document(text=j) for j in range(100)) -def test_getter_int_str(docarray100): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_getter_int_str(docs, storage): + docs = DocumentArray(docs, storage=storage) # getter - assert docarray100[99].text == 99 - assert docarray100[np.int(99)].text == 99 - assert docarray100[-1].text == 99 - assert docarray100[0].text == 0 + assert docs[99].text == 99 + assert docs[np.int(99)].text == 99 + assert docs[-1].text == 99 + assert docs[0].text == 0 # string index - assert docarray100[docarray100[0].id].text == 0 - assert docarray100[docarray100[99].id].text == 99 - assert docarray100[docarray100[-1].id].text == 99 + assert docs[docs[0].id].text == 0 + assert docs[docs[99].id].text == 99 + assert docs[docs[-1].id].text == 99 with pytest.raises(IndexError): - docarray100[100] + docs[100] with pytest.raises(KeyError): - docarray100['adsad'] + docs['adsad'] -def test_setter_int_str(docarray100): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_setter_int_str(docs, storage): + docs = DocumentArray(docs, storage=storage) # setter - docarray100[99] = Document(text='hello') - docarray100[0] = Document(text='world') + docs[99] = Document(text='hello') + docs[0] = Document(text='world') - assert docarray100[99].text == 'hello' - assert docarray100[-1].text == 'hello' - assert docarray100[0].text == 'world' + assert docs[99].text == 'hello' + assert docs[-1].text == 'hello' + assert docs[0].text == 'world' - docarray100[docarray100[2].id] = Document(text='doc2') + docs[docs[2].id] = Document(text='doc2') # string index - assert docarray100[docarray100[2].id].text == 'doc2' + assert docs[docs[2].id].text == 'doc2' -def test_del_int_str(docarray100): - zero_id = docarray100[0].id - del docarray100[0] - assert len(docarray100) == 99 - assert zero_id not in docarray100 +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_del_int_str(docs, storage): + docs = DocumentArray(docs, storage=storage) + zero_id = docs[0].id + del docs[0] + assert len(docs) == 99 + assert zero_id not in docs - new_zero_id = docarray100[0].id - new_doc_zero = docarray100[0] - del docarray100[new_zero_id] - assert len(docarray100) == 98 - assert zero_id not in docarray100 - assert new_doc_zero not in docarray100 + new_zero_id = docs[0].id + new_doc_zero = docs[0] + del docs[new_zero_id] + assert len(docs) == 98 + assert zero_id not in docs + assert new_doc_zero not in docs -def test_slice(docarray100): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_slice(docs, storage): + docs = DocumentArray(docs, storage=storage) # getter - assert len(docarray100[1:5]) == 4 - assert len(docarray100[1:100:5]) == 20 # 1 to 100, sep with 5 + assert len(docs[1:5]) == 4 + assert len(docs[1:100:5]) == 20 # 1 to 100, sep with 5 # setter with pytest.raises(TypeError, match='can only assign an iterable'): - docarray100[1:5] = Document(text='repl') + docs[1:5] = Document(text='repl') - docarray100[1:5] = [Document(text=f'repl{j}') for j in range(4)] - for d in docarray100[1:5]: + docs[1:5] = [Document(text=f'repl{j}') for j in range(4)] + for d in docs[1:5]: assert d.text.startswith('repl') - assert len(docarray100) == 100 + assert len(docs) == 100 # del - zero_doc = docarray100[0] - twenty_doc = docarray100[20] - del docarray100[0:20] - assert len(docarray100) == 80 - assert zero_doc not in docarray100 - assert twenty_doc in docarray100 + zero_doc = docs[0] + twenty_doc = docs[20] + del docs[0:20] + assert len(docs) == 80 + assert zero_doc not in docs + assert twenty_doc in docs -def test_sequence_bool_index(docarray100): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_sequence_bool_index(docs, storage): + docs = DocumentArray(docs, storage=storage) # getter mask = [True, False] * 50 - assert len(docarray100[mask]) == 50 - assert len(docarray100[[True, False]]) == 1 + assert len(docs[mask]) == 50 + assert len(docs[[True, False]]) == 1 # setter mask = [True, False] * 50 - docarray100[mask] = [Document(text=f'repl{j}') for j in range(50)] + docs[mask] = [Document(text=f'repl{j}') for j in range(50)] - for idx, d in enumerate(docarray100): + for idx, d in enumerate(docs): if idx % 2 == 0: # got replaced assert d.text.startswith('repl') @@ -96,61 +109,71 @@ def test_sequence_bool_index(docarray100): assert isinstance(d.text, int) # del - del docarray100[mask] - assert len(docarray100) == 50 + del docs[mask] + assert len(docs) == 50 - del docarray100[mask] - assert len(docarray100) == 25 + del docs[mask] + assert len(docs) == 25 @pytest.mark.parametrize('nparray', [lambda x: x, np.array, tuple]) -def test_sequence_int(docarray100, nparray): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_sequence_int(docs, nparray, storage): + docs = DocumentArray(docs, storage=storage) # getter idx = nparray([1, 3, 5, 7, -1, -2]) - assert len(docarray100[idx]) == len(idx) + assert len(docs[idx]) == len(idx) # setter - docarray100[idx] = [Document(text='repl') for _ in range(len(idx))] + docs[idx] = [Document(text='repl') for _ in range(len(idx))] for _id in idx: - assert docarray100[_id].text == 'repl' + assert docs[_id].text == 'repl' # del idx = [-3, -4, -5, 9, 10, 11] - del docarray100[idx] - assert len(docarray100) == 100 - len(idx) + del docs[idx] + assert len(docs) == 100 - len(idx) -def test_sequence_str(docarray100): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_sequence_str(docs, storage): + docs = DocumentArray(docs, storage=storage) # getter - idx = [d.id for d in docarray100[1, 3, 5, 7, -1, -2]] + idx = [d.id for d in docs[1, 3, 5, 7, -1, -2]] - assert len(docarray100[idx]) == len(idx) - assert len(docarray100[tuple(idx)]) == len(idx) + assert len(docs[idx]) == len(idx) + assert len(docs[tuple(idx)]) == len(idx) # setter - docarray100[idx] = [Document(text='repl') for _ in range(len(idx))] - idx = [d.id for d in docarray100[1, 3, 5, 7, -1, -2]] + docs[idx] = [Document(text='repl') for _ in range(len(idx))] + idx = [d.id for d in docs[1, 3, 5, 7, -1, -2]] for _id in idx: - assert docarray100[_id].text == 'repl' + assert docs[_id].text == 'repl' # del - idx = [d.id for d in docarray100[-3, -4, -5, 9, 10, 11]] - del docarray100[idx] - assert len(docarray100) == 100 - len(idx) + idx = [d.id for d in docs[-3, -4, -5, 9, 10, 11]] + del docs[idx] + assert len(docs) == 100 - len(idx) -def test_docarray_list_tuple(docarray100): - assert isinstance(docarray100[99, 98], DocumentArray) - assert len(docarray100[99, 98]) == 2 +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_docarray_list_tuple(docs, storage): + docs = DocumentArray(docs, storage=storage) + assert isinstance(docs[99, 98], DocumentArray) + assert len(docs[99, 98]) == 2 -def test_path_syntax_indexing(): - da = DocumentArray().empty(3) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_path_syntax_indexing(storage): + da = DocumentArray.empty(3) for d in da: d.chunks = DocumentArray.empty(5) d.matches = DocumentArray.empty(7) for c in d.chunks: c.chunks = DocumentArray.empty(3) + + if storage == 'sqlite': + da = DocumentArray(da, storage=storage) assert len(da['@c']) == 3 * 5 assert len(da['@c:1']) == 3 assert len(da['@c-1:']) == 3 @@ -165,8 +188,11 @@ def test_path_syntax_indexing(): assert len(da['@r:1cc,m']) == 1 * 5 * 3 + 3 * 7 -def test_attribute_indexing(): - da = DocumentArray.empty(10) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_attribute_indexing(storage): + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(10)) + for v in da[:, 'id']: assert v da[:, 'mime_type'] = [f'type {j}' for j in range(10)] @@ -187,14 +213,17 @@ def test_attribute_indexing(): assert vv -def test_tensor_attribute_selector(): +# TODO: enable weaviate storage test +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_tensor_attribute_selector(storage): import scipy.sparse sp_embed = np.random.random([3, 10]) sp_embed[sp_embed > 0.1] = 0 sp_embed = scipy.sparse.coo_matrix(sp_embed) - da = DocumentArray.empty(3) + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(3)) da[:, 'embedding'] = sp_embed @@ -212,8 +241,10 @@ def test_tensor_attribute_selector(): assert isinstance(v1, list) -def test_advance_selector_mixed(): - da = DocumentArray.empty(10) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_advance_selector_mixed(storage): + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(10)) da.embeddings = np.random.random([10, 3]) da.match(da, exclude_self=True) @@ -221,8 +252,10 @@ def test_advance_selector_mixed(): assert len(da[:, ('id', 'embedding', 'matches')][0]) == 10 -def test_single_boolean_and_padding(): - da = DocumentArray.empty(3) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_single_boolean_and_padding(storage): + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(3)) with pytest.raises(IndexError): da[True] @@ -237,9 +270,10 @@ def test_single_boolean_and_padding(): assert len(da[False, False]) == 0 -def test_edge_case_two_strings(): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_edge_case_two_strings(storage): # getitem - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')]) + da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) assert da['1', 'id'] == '1' assert len(da['1', '2']) == 2 assert isinstance(da['1', '2'], DocumentArray) @@ -254,7 +288,7 @@ def test_edge_case_two_strings(): del da['1', '2'] assert len(da) == 1 - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')]) + da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) del da['1', 'id'] assert len(da) == 3 assert not da[0].id @@ -262,12 +296,12 @@ def test_edge_case_two_strings(): del da['2', 'hello'] # setitem - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')]) + da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) da['1', '2'] = DocumentArray.empty(2) assert da[0].id != '1' assert da[1].id != '2' - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')]) + da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) da['1', 'text'] = 'hello' assert da['1'].text == 'hello' From aa7e4caa7a8b33042568e1e234fe4e65b6ce5831 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 21 Jan 2022 11:43:26 +0100 Subject: [PATCH 17/55] fix: fix sqlite set_doc_by_id --- docarray/array/storage/sqlite/getsetdel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 5dfd55bb929..f5727eee5ae 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -20,7 +20,7 @@ def _del_doc_by_offset(self, offset: int): def _set_doc_by_offset(self, offset: int, value: 'Document'): self._sql( f'UPDATE {self._table_name} SET serialized_value=? WHERE item_order=?', - (offset, value), + (value, offset), ) self._commit() From c6a6d6f20fcc2993779dd7d827770e580acbb7e6 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 21 Jan 2022 12:09:36 +0100 Subject: [PATCH 18/55] fix: fix set_doc_by_id --- docarray/array/memory.py | 8 -------- docarray/array/storage/sqlite/getsetdel.py | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/docarray/array/memory.py b/docarray/array/memory.py index 7c02c8cc244..823f281ebe0 100644 --- a/docarray/array/memory.py +++ b/docarray/array/memory.py @@ -5,11 +5,3 @@ class DocumentArrayInMemory(StorageMixins, DocumentArray): def __new__(cls, *args, **kwargs): return super().__new__(cls) -""" - def __delitem__(self): - pass - def __getitem__(self, item): - print('getting item') - pass - def __setitem__(self): - pass""" \ No newline at end of file diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index f5727eee5ae..6cddd6d7094 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -27,7 +27,7 @@ def _set_doc_by_offset(self, offset: int, value: 'Document'): def _set_doc_by_id(self, _id: str, value: 'Document'): self._sql( f'UPDATE {self._table_name} SET serialized_value=? WHERE doc_id=?', - (_id, value), + (value, _id), ) self._commit() From c1423b71009c9c360b09990e634faf00d74afd7d Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 21 Jan 2022 14:13:32 +0100 Subject: [PATCH 19/55] fix: DocumentArraySqlit._set_doc_by_id should behave like DocumentArrayInMemory --- docarray/array/storage/sqlite/getsetdel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 6cddd6d7094..3d1895b25cd 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -26,8 +26,8 @@ def _set_doc_by_offset(self, offset: int, value: 'Document'): def _set_doc_by_id(self, _id: str, value: 'Document'): self._sql( - f'UPDATE {self._table_name} SET serialized_value=? WHERE doc_id=?', - (value, _id), + f'UPDATE {self._table_name} SET serialized_value=?, doc_id=? WHERE doc_id=?', + (value, value.id, _id), ) self._commit() From 6397a467c29fe30099e6abdf485a5b64d1bcebd5 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 21 Jan 2022 14:49:32 +0100 Subject: [PATCH 20/55] fix: raise TypeError when assigning non iterable with slice --- docarray/array/storage/base/getsetdel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 66f74f0abec..3a9e90abb3f 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -99,6 +99,8 @@ def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): Override this function if there is a more efficient logic """ + if not isinstance(value, Iterable): + raise TypeError('You can only assign an iterable') for _offset, val in zip(range(len(self))[_slice], value): self._set_doc_by_offset(_offset, val) From f15fd77f03a16b0df36c2906f4c35c48b7d2c85c Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 24 Jan 2022 13:09:05 +0100 Subject: [PATCH 21/55] fix: fix del attribute --- docarray/array/mixins/delitem.py | 2 +- tests/unit/array/test_advance_indexing.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docarray/array/mixins/delitem.py b/docarray/array/mixins/delitem.py index 6b9a15fb10a..b7f796aff98 100644 --- a/docarray/array/mixins/delitem.py +++ b/docarray/array/mixins/delitem.py @@ -39,7 +39,7 @@ def __delitem__(self, index: 'DocumentArrayIndexType'): ): if isinstance(index[0], str) and isinstance(index[1], str): # ambiguity only comes from the second string - if index[1] in self._id2offset: + if index[1] in self: del self[index[0]] del self[index[1]] else: diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 93efda18900..2ca9189f30e 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -288,10 +288,11 @@ def test_edge_case_two_strings(storage): del da['1', '2'] assert len(da) == 1 - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) - del da['1', 'id'] + da = DocumentArray( + [Document(id=str(i), text='hey') for i in range(3)], storage=storage) + del da['1', 'text'] assert len(da) == 3 - assert not da[0].id + assert not da[1].text del da['2', 'hello'] From 5eb3881c72ca359d45c90417dbfaf4ab33d0513f Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 24 Jan 2022 14:04:44 +0100 Subject: [PATCH 22/55] test: fix test_tensor_attribute_selector --- docarray/array/mixins/setitem.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 44bb8291131..245631a1ca3 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -113,10 +113,13 @@ def __setitem__( _docs = self[index[0]] for _a, _v in zip(_attrs, value): - if _a == 'tensor': - _docs.tensors = _v - elif _a == 'embedding': - _docs.embeddings = _v + if _a in ('tensor', 'embedding'): + if _a == 'tensor': + _docs.tensors = _v + elif _a == 'embedding': + _docs.embeddings = _v + for _d in _docs: + self._set_doc_by_id(_d.id, _d) else: if len(_docs) == 1: self._set_doc_attr_by_id(_docs[0].id, _a, _v) From 8deaacd968c01d3359397eca2e12c3013a3dacac Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Mon, 24 Jan 2022 15:44:48 +0100 Subject: [PATCH 23/55] fix: delete by offset was missing shift of ids --- docarray/array/storage/sqlite/getsetdel.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 3d1895b25cd..02084019242 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -15,13 +15,22 @@ def _del_doc_by_id(self, _id: str): def _del_doc_by_offset(self, offset: int): self._sql(f'DELETE FROM {self._table_name} WHERE item_order=?', (offset,)) + # shift the offset of every value on the right position of the deleted item + for i in range(offset, len(self) + 1): + # doc_id values should be also changed + self._sql( + f'UPDATE {self._table_name} SET item_order=? WHERE item_order=?', + (i - 1, i), + ) self._commit() + def _set_doc_by_offset(self, offset: int, value: 'Document'): self._sql( f'UPDATE {self._table_name} SET serialized_value=? WHERE item_order=?', (value, offset), ) + self._commit() def _set_doc_by_id(self, _id: str, value: 'Document'): @@ -37,6 +46,7 @@ def _get_doc_by_offset(self, index: int) -> 'Document': (index + (len(self) if index < 0 else 0),), ) res = r.fetchone() + #import pdb;pdb.set_trace() if res is None: raise IndexError('index out of range') return res[0] From 3aa538bf4dd07558957100317be19278a297559d Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 21 Jan 2022 14:13:32 +0100 Subject: [PATCH 24/55] fix: _set_doc_by_id should behave like DocumentArrayInMemory --- docarray/array/storage/sqlite/getsetdel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 6cddd6d7094..3d1895b25cd 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -26,8 +26,8 @@ def _set_doc_by_offset(self, offset: int, value: 'Document'): def _set_doc_by_id(self, _id: str, value: 'Document'): self._sql( - f'UPDATE {self._table_name} SET serialized_value=? WHERE doc_id=?', - (value, _id), + f'UPDATE {self._table_name} SET serialized_value=?, doc_id=? WHERE doc_id=?', + (value, value.id, _id), ) self._commit() From 6b0b567ddd6dcb3bedf5de08ee91822375197c73 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 21 Jan 2022 14:49:32 +0100 Subject: [PATCH 25/55] fix: raise TypeError when assigning non iterable with slice --- docarray/array/storage/base/getsetdel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 66f74f0abec..3a9e90abb3f 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -99,6 +99,8 @@ def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): Override this function if there is a more efficient logic """ + if not isinstance(value, Iterable): + raise TypeError('You can only assign an iterable') for _offset, val in zip(range(len(self))[_slice], value): self._set_doc_by_offset(_offset, val) From 27c34df4c390449349c82db940296f6a8e0d0b54 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 24 Jan 2022 13:09:05 +0100 Subject: [PATCH 26/55] fix: fix del attribute --- docarray/array/mixins/delitem.py | 2 +- tests/unit/array/test_advance_indexing.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docarray/array/mixins/delitem.py b/docarray/array/mixins/delitem.py index 6b9a15fb10a..b7f796aff98 100644 --- a/docarray/array/mixins/delitem.py +++ b/docarray/array/mixins/delitem.py @@ -39,7 +39,7 @@ def __delitem__(self, index: 'DocumentArrayIndexType'): ): if isinstance(index[0], str) and isinstance(index[1], str): # ambiguity only comes from the second string - if index[1] in self._id2offset: + if index[1] in self: del self[index[0]] del self[index[1]] else: diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 93efda18900..2ca9189f30e 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -288,10 +288,11 @@ def test_edge_case_two_strings(storage): del da['1', '2'] assert len(da) == 1 - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) - del da['1', 'id'] + da = DocumentArray( + [Document(id=str(i), text='hey') for i in range(3)], storage=storage) + del da['1', 'text'] assert len(da) == 3 - assert not da[0].id + assert not da[1].text del da['2', 'hello'] From 9b9c2d5e2f3a3fbcfa82ed9d470393ecdf16dd60 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 24 Jan 2022 14:04:44 +0100 Subject: [PATCH 27/55] test: fix test_tensor_attribute_selector --- docarray/array/mixins/setitem.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 44bb8291131..245631a1ca3 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -113,10 +113,13 @@ def __setitem__( _docs = self[index[0]] for _a, _v in zip(_attrs, value): - if _a == 'tensor': - _docs.tensors = _v - elif _a == 'embedding': - _docs.embeddings = _v + if _a in ('tensor', 'embedding'): + if _a == 'tensor': + _docs.tensors = _v + elif _a == 'embedding': + _docs.embeddings = _v + for _d in _docs: + self._set_doc_by_id(_d.id, _d) else: if len(_docs) == 1: self._set_doc_attr_by_id(_docs[0].id, _a, _v) From f6f60b91c1faae174b9b1365fa8fc64f147f34b0 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Mon, 24 Jan 2022 15:44:48 +0100 Subject: [PATCH 28/55] fix: delete by offset was missing shift of ids --- docarray/array/storage/sqlite/getsetdel.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 3d1895b25cd..02084019242 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -15,13 +15,22 @@ def _del_doc_by_id(self, _id: str): def _del_doc_by_offset(self, offset: int): self._sql(f'DELETE FROM {self._table_name} WHERE item_order=?', (offset,)) + # shift the offset of every value on the right position of the deleted item + for i in range(offset, len(self) + 1): + # doc_id values should be also changed + self._sql( + f'UPDATE {self._table_name} SET item_order=? WHERE item_order=?', + (i - 1, i), + ) self._commit() + def _set_doc_by_offset(self, offset: int, value: 'Document'): self._sql( f'UPDATE {self._table_name} SET serialized_value=? WHERE item_order=?', (value, offset), ) + self._commit() def _set_doc_by_id(self, _id: str, value: 'Document'): @@ -37,6 +46,7 @@ def _get_doc_by_offset(self, index: int) -> 'Document': (index + (len(self) if index < 0 else 0),), ) res = r.fetchone() + #import pdb;pdb.set_trace() if res is None: raise IndexError('index out of range') return res[0] From 160118f2e3ceff99a0850f8d677cbcceca03a689 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 24 Jan 2022 15:58:46 +0100 Subject: [PATCH 29/55] fix: fix linting --- docarray/array/storage/sqlite/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 2773a0d81e8..745961d4fdf 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -72,7 +72,7 @@ def _init_storage( self._connection = config.connection else: raise TypeError( - f'connection argument must be None or a string or a sqlite3.Connection, not `{type(connection)}`' + f'connection argument must be None or a string or a sqlite3.Connection, not `{type(config.connection)}`' ) self._table_name = ( From 51401d4022adc4c7fe906b7eb0b1af0d9f37f03b Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 24 Jan 2022 16:14:12 +0100 Subject: [PATCH 30/55] fix: black --- docarray/array/base.py | 2 ++ docarray/array/storage/sqlite/getsetdel.py | 3 +-- tests/unit/array/test_advance_indexing.py | 15 +++++++++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/docarray/array/base.py b/docarray/array/base.py index 37340caee62..14dce727ddf 100644 --- a/docarray/array/base.py +++ b/docarray/array/base.py @@ -13,9 +13,11 @@ def __new__(cls, *args, storage: str = 'memory', **kwargs): if cls is DocumentArray: if storage == 'memory': from docarray.array.memory import DocumentArrayInMemory + instance = super().__new__(DocumentArrayInMemory) elif storage == 'sqlite': from docarray.array.sqlite import DocumentArraySqlite + instance = super().__new__(DocumentArraySqlite) else: raise ValueError(f'storage=`{storage}` is not supported.') diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 02084019242..152eb30a2c4 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -24,7 +24,6 @@ def _del_doc_by_offset(self, offset: int): ) self._commit() - def _set_doc_by_offset(self, offset: int, value: 'Document'): self._sql( f'UPDATE {self._table_name} SET serialized_value=? WHERE item_order=?', @@ -46,7 +45,7 @@ def _get_doc_by_offset(self, index: int) -> 'Document': (index + (len(self) if index < 0 else 0),), ) res = r.fetchone() - #import pdb;pdb.set_trace() + # import pdb;pdb.set_trace() if res is None: raise IndexError('index out of range') return res[0] diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 2ca9189f30e..9370f1eddbe 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -273,7 +273,9 @@ def test_single_boolean_and_padding(storage): @pytest.mark.parametrize('storage', ['memory', 'sqlite']) def test_edge_case_two_strings(storage): # getitem - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) assert da['1', 'id'] == '1' assert len(da['1', '2']) == 2 assert isinstance(da['1', '2'], DocumentArray) @@ -289,7 +291,8 @@ def test_edge_case_two_strings(storage): assert len(da) == 1 da = DocumentArray( - [Document(id=str(i), text='hey') for i in range(3)], storage=storage) + [Document(id=str(i), text='hey') for i in range(3)], storage=storage + ) del da['1', 'text'] assert len(da) == 3 assert not da[1].text @@ -297,12 +300,16 @@ def test_edge_case_two_strings(storage): del da['2', 'hello'] # setitem - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) da['1', '2'] = DocumentArray.empty(2) assert da[0].id != '1' assert da[1].id != '2' - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')], storage=storage) + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) da['1', 'text'] = 'hello' assert da['1'].text == 'hello' From 8a893962baae396fcdf8f46f05af8d91cefda1f6 Mon Sep 17 00:00:00 2001 From: David Buchaca Date: Mon, 24 Jan 2022 16:34:06 +0100 Subject: [PATCH 31/55] fix: cover negative offset --- docarray/array/storage/sqlite/getsetdel.py | 5 ++++ tests/unit/array/test_advance_indexing.py | 34 ++++++++++++++-------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 02084019242..87b00a05861 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -14,6 +14,11 @@ def _del_doc_by_id(self, _id: str): self._commit() def _del_doc_by_offset(self, offset: int): + + # if offset = -2 and len(self)= 100 use offset = 98 + if offset < 0: + offset = len(self) + offset + self._sql(f'DELETE FROM {self._table_name} WHERE item_order=?', (offset,)) # shift the offset of every value on the right position of the deleted item for i in range(offset, len(self) + 1): diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 2ca9189f30e..ec3ac38501a 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -11,6 +11,10 @@ def docs(): yield (Document(text=j) for j in range(100)) +@pytest.fixture +def indices(): + yield (i for i in [-2,0,2]) + @pytest.mark.parametrize('storage', ['memory', 'sqlite']) def test_getter_int_str(docs, storage): @@ -49,19 +53,25 @@ def test_setter_int_str(docs, storage): @pytest.mark.parametrize('storage', ['memory', 'sqlite']) -def test_del_int_str(docs, storage): +def test_del_int_str(docs, storage, indices): + docs = DocumentArray(docs, storage=storage) - zero_id = docs[0].id - del docs[0] - assert len(docs) == 99 - assert zero_id not in docs - - new_zero_id = docs[0].id - new_doc_zero = docs[0] - del docs[new_zero_id] - assert len(docs) == 98 - assert zero_id not in docs - assert new_doc_zero not in docs + initial_len = len(docs) + deleted_elements = 0 + for pos in indices: + pos_id = docs[pos].id + del docs[pos] + deleted_elements += 1 + assert pos_id not in docs + assert len(docs) == initial_len - deleted_elements + + new_pos_id = docs[pos].id + new_doc_zero = docs[pos] + del docs[new_pos_id] + deleted_elements += 1 + assert len(docs) == initial_len - deleted_elements + assert pos_id not in docs + assert new_doc_zero not in docs @pytest.mark.parametrize('storage', ['memory', 'sqlite']) From 2215d81f6bf6f68298f4d43fcc94a5aa4bd5a446 Mon Sep 17 00:00:00 2001 From: David Buchaca Date: Mon, 24 Jan 2022 16:43:28 +0100 Subject: [PATCH 32/55] test: cover negative indices --- tests/unit/array/test_advance_indexing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index f3b410de531..107c61724f7 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -11,9 +11,10 @@ def docs(): yield (Document(text=j) for j in range(100)) + @pytest.fixture def indices(): - yield (i for i in [-2,0,2]) + yield (i for i in [-2, 0, 2]) @pytest.mark.parametrize('storage', ['memory', 'sqlite']) From 183c0791cbc94de8f11736115820073ae6019098 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Mon, 24 Jan 2022 17:17:29 +0100 Subject: [PATCH 33/55] fix(storage): fix doc id usage in _set_doc_value_pairs --- docarray/array/storage/base/getsetdel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 3a9e90abb3f..45fe6decedb 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -100,7 +100,9 @@ def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): Override this function if there is a more efficient logic """ if not isinstance(value, Iterable): - raise TypeError('You can only assign an iterable') + raise TypeError( + f'You right-hand assignment must be an iterable, receiving {type(value)}' + ) for _offset, val in zip(range(len(self))[_slice], value): self._set_doc_by_offset(_offset, val) @@ -132,4 +134,4 @@ def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): d = self._get_doc_by_id(_id) if hasattr(d, attr): setattr(d, attr, value) - self._set_doc_by_id(_id, d) + self._set_doc_by_id(d.id, d) From 829c5b45d2c53eca259e9e68fc876160b55cd94d Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Mon, 24 Jan 2022 17:32:17 +0100 Subject: [PATCH 34/55] fix(storage): fix doc id usage in _set_doc_value_pairs --- docarray/array/storage/sqlite/getsetdel.py | 4 ++-- tests/unit/array/test_advance_indexing.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index a6f7c8efeb5..b648d4f757b 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -31,8 +31,8 @@ def _del_doc_by_offset(self, offset: int): def _set_doc_by_offset(self, offset: int, value: 'Document'): self._sql( - f'UPDATE {self._table_name} SET serialized_value=? WHERE item_order=?', - (value, offset), + f'UPDATE {self._table_name} SET serialized_value=?, doc_id=? WHERE item_order=?', + (value, value.id, offset), ) self._commit() diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 107c61724f7..dd11adac2d3 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -1,6 +1,3 @@ -import os -import time - import numpy as np import pytest @@ -55,7 +52,6 @@ def test_setter_int_str(docs, storage): @pytest.mark.parametrize('storage', ['memory', 'sqlite']) def test_del_int_str(docs, storage, indices): - docs = DocumentArray(docs, storage=storage) initial_len = len(docs) deleted_elements = 0 From 23cd905aace1e114ba5505ee7c4cf19eff856c35 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Mon, 24 Jan 2022 18:02:36 +0100 Subject: [PATCH 35/55] fix(storage): fix _set_doc_value_pairs --- docarray/array/storage/base/getsetdel.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 45fe6decedb..7bf75cea1b1 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -114,7 +114,17 @@ def _set_doc_value_pairs( Override this function if there is a more efficient logic """ for _d, _v in zip(docs, values): - self._set_doc_by_id(_d.id, _v) + _d._data = _v._data + + for _d in docs: + if _d not in docs: + root_d = self._find_root_doc(_d) + else: + # _d is already on the root-level + root_d = _d + + if root_d: + self._set_doc_by_id(root_d.id, root_d) def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): """This function is derived and may not have the most efficient implementation. @@ -135,3 +145,12 @@ def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): if hasattr(d, attr): setattr(d, attr, value) self._set_doc_by_id(d.id, d) + + def _find_root_doc(self, d: Document): + """Find `d`'s root Document in an exhaustive manner """ + from docarray import DocumentArray + + for _d in self: + _all_ids = set(DocumentArray(d)[...][:, 'id']) + if d.id in _all_ids: + return _d From e1f30df6e0d20f09d2f6c6c779576a7063388dd6 Mon Sep 17 00:00:00 2001 From: David Buchaca Date: Mon, 24 Jan 2022 18:17:32 +0100 Subject: [PATCH 36/55] fix: offset set by id --- docarray/array/storage/sqlite/getsetdel.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index a6f7c8efeb5..382ace8bbd7 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -16,13 +16,11 @@ def _del_doc_by_id(self, _id: str): def _del_doc_by_offset(self, offset: int): # if offset = -2 and len(self)= 100 use offset = 98 - if offset < 0: - offset = len(self) + offset + offset = len(self) + offset if offset < 0 else offset self._sql(f'DELETE FROM {self._table_name} WHERE item_order=?', (offset,)) # shift the offset of every value on the right position of the deleted item for i in range(offset, len(self) + 1): - # doc_id values should be also changed self._sql( f'UPDATE {self._table_name} SET item_order=? WHERE item_order=?', (i - 1, i), @@ -30,6 +28,10 @@ def _del_doc_by_offset(self, offset: int): self._commit() def _set_doc_by_offset(self, offset: int, value: 'Document'): + + # if offset = -2 and len(self)= 100 use offset = 98 + offset = len(self) + offset if offset < 0 else offset + self._sql( f'UPDATE {self._table_name} SET serialized_value=? WHERE item_order=?', (value, offset), @@ -50,7 +52,6 @@ def _get_doc_by_offset(self, index: int) -> 'Document': (index + (len(self) if index < 0 else 0),), ) res = r.fetchone() - # import pdb;pdb.set_trace() if res is None: raise IndexError('index out of range') return res[0] From 901e0d0e1eab821ba9e100641cc17d657e226d7f Mon Sep 17 00:00:00 2001 From: David Buchaca Date: Mon, 24 Jan 2022 22:05:46 +0100 Subject: [PATCH 37/55] refactor: add boolean slicing in sqlite --- docarray/array/storage/sqlite/getsetdel.py | 10 +++++++++- tests/unit/array/test_advance_indexing.py | 7 +++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 019c1a451f3..c4124c54603 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -68,7 +68,6 @@ def _get_doc_by_id(self, id: str) -> 'Document': # essentials end here # now start the optimized bulk methods - def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: l = len(self) offsets = [o + (l if o < 0 else 0) for o in offsets] @@ -93,3 +92,12 @@ def _del_docs_by_slice(self, _slice: slice): offsets, ) self._commit() + + def _del_docs_by_mask(self, mask: Sequence[bool]): + + offsets = [i for i,m in enumerate(mask) if m==True] + self._sql( + f"DELETE FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", + offsets, + ) + self._commit() diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index dd11adac2d3..4950f0523e6 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -79,7 +79,7 @@ def test_slice(docs, storage): assert len(docs[1:100:5]) == 20 # 1 to 100, sep with 5 # setter - with pytest.raises(TypeError, match='can only assign an iterable'): + with pytest.raises(TypeError, match='an iterable'): docs[1:5] = Document(text='repl') docs[1:5] = [Document(text=f'repl{j}') for j in range(4)] @@ -106,7 +106,8 @@ def test_sequence_bool_index(docs, storage): # setter mask = [True, False] * 50 - docs[mask] = [Document(text=f'repl{j}') for j in range(50)] + #docs[mask] = [Document(text=f'repl{j}') for j in range(50)] + docs[mask,'text'] = [f'repl{j}' for j in range(50)] for idx, d in enumerate(docs): if idx % 2 == 0: @@ -119,8 +120,6 @@ def test_sequence_bool_index(docs, storage): del docs[mask] assert len(docs) == 50 - del docs[mask] - assert len(docs) == 25 @pytest.mark.parametrize('nparray', [lambda x: x, np.array, tuple]) From 906ff8b23188b9402378c9bf0acfb302f7b55547 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 25 Jan 2022 12:13:06 +0100 Subject: [PATCH 38/55] test: fix test_advance_selector_mixed --- docarray/array/storage/sqlite/seqlike.py | 9 ++++++--- docarray/math/ndarray.py | 6 +++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index 9558cd369f3..6c4c68dfba8 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -73,8 +73,11 @@ def __len__(self) -> int: return r.fetchone()[0] def __iter__(self) -> Iterator['Document']: - r = self._sql( + # TODO: make this memory efficient + self._sql( f'SELECT serialized_value FROM {self._table_name} ORDER BY item_order' ) - for res in r: - yield res[0] + result = self._cursor.fetchall() + for r in result: + yield r[0] + diff --git a/docarray/math/ndarray.py b/docarray/math/ndarray.py index 1dc354f61fc..a4ef24a9f03 100644 --- a/docarray/math/ndarray.py +++ b/docarray/math/ndarray.py @@ -55,6 +55,7 @@ def ravel(value: 'ArrayType', docs: Sequence['Document'], field: str) -> None: :param field: the field of the doc to set :param value: the value to be set on ``doc.field`` """ + from .. import DocumentArray use_get_row = False if hasattr(value, 'getformat'): @@ -76,9 +77,12 @@ def ravel(value: 'ArrayType', docs: Sequence['Document'], field: str) -> None: for d, j in zip(docs, value): setattr(d, field, j) else: + emb_shape0 = value.shape[0] - for d, j in zip(docs, range(emb_shape0)): + for i, (d, j) in enumerate(zip(docs, range(emb_shape0))): setattr(d, field, value[j, ...]) + if isinstance(docs, DocumentArray): + docs._set_doc_by_id(d.id, d) def get_array_type(array: 'ArrayType') -> Tuple[str, bool]: From f99766bad7a6275f268156f5827e6cf566ceaea4 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 12:26:47 +0100 Subject: [PATCH 39/55] style: fix black style --- .pre-commit-config.yaml | 19 ------------------- docarray/array/storage/base/getsetdel.py | 2 +- docarray/array/storage/memory/__init__.py | 3 ++- docarray/array/storage/memory/backend.py | 4 ++-- docarray/array/storage/sqlite/__init__.py | 3 ++- docarray/array/storage/sqlite/getsetdel.py | 14 +++++++------- docarray/array/storage/sqlite/seqlike.py | 1 - tests/unit/array/test_advance_indexing.py | 10 +++++----- 8 files changed, 19 insertions(+), 37 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20c8fcdc750..e7c5cd3687c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,4 @@ repos: -- repo: https://github.com/terrencepreilly/darglint - rev: v1.5.8 - hooks: - - id: darglint - files: docarray/ - exclude: ^(docarray/proto/docarray_pb2.py|docs/|docarray/resources/) - args: - - --message-template={path}:{line} {msg_id} {msg} - - -s=sphinx - - -z=full - - -v=2 -- repo: https://github.com/pycqa/pydocstyle - rev: 5.1.1 # pick a git hash / tag to point to - hooks: - - id: pydocstyle - files: docarray/ - exclude: ^(docarray/proto/docarray_pb2.py|docs/|docarray/resources/) - args: - - --select=D101,D102,D103 - repo: https://github.com/ambv/black rev: 20.8b1 hooks: diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 7bf75cea1b1..72ee83e200c 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -147,7 +147,7 @@ def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): self._set_doc_by_id(d.id, d) def _find_root_doc(self, d: Document): - """Find `d`'s root Document in an exhaustive manner """ + """Find `d`'s root Document in an exhaustive manner""" from docarray import DocumentArray for _d in self: diff --git a/docarray/array/storage/memory/__init__.py b/docarray/array/storage/memory/__init__.py index f07b096a031..cf9d9e7f97d 100644 --- a/docarray/array/storage/memory/__init__.py +++ b/docarray/array/storage/memory/__init__.py @@ -1,7 +1,8 @@ +from abc import ABC + from .backend import BackendMixin from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin -from abc import ABC __all__ = ['StorageMixins'] diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py index 30252ccb658..6146a3136fa 100644 --- a/docarray/array/storage/memory/backend.py +++ b/docarray/array/storage/memory/backend.py @@ -8,8 +8,8 @@ TYPE_CHECKING, ) -from .... import Document from ..base.backend import BaseBackendMixin +from .... import Document if TYPE_CHECKING: from ....types import ( @@ -18,7 +18,7 @@ class BackendMixin(BaseBackendMixin): - """Provide necessary functions to enable this storage backend. """ + """Provide necessary functions to enable this storage backend.""" @property def _id2offset(self) -> Dict[str, int]: diff --git a/docarray/array/storage/sqlite/__init__.py b/docarray/array/storage/sqlite/__init__.py index f07b096a031..cf9d9e7f97d 100644 --- a/docarray/array/storage/sqlite/__init__.py +++ b/docarray/array/storage/sqlite/__init__.py @@ -1,7 +1,8 @@ +from abc import ABC + from .backend import BackendMixin from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin -from abc import ABC __all__ = ['StorageMixins'] diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 38755186a87..0b33987c495 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -13,7 +13,6 @@ def _del_doc_by_id(self, _id: str): self._sql(f'DELETE FROM {self._table_name} WHERE doc_id=?', (_id,)) self._commit() - def _del_doc_by_offset(self, offset: int): # if offset = -2 and len(self)= 100 use offset = 98 @@ -22,7 +21,9 @@ def _del_doc_by_offset(self, offset: int): self._sql(f'DELETE FROM {self._table_name} WHERE item_order=?', (offset,)) # shift the offset of every value on the right position of the deleted item - self._sql(f'UPDATE {self._table_name} SET item_order=item_order-1 WHERE item_order>={offset}') + self._sql( + f'UPDATE {self._table_name} SET item_order=item_order-1 WHERE item_order>={offset}' + ) # Code above line is equivalent to """ @@ -57,7 +58,7 @@ def _get_doc_by_offset(self, index: int) -> 'Document': (index + (len(self) if index < 0 else 0),), ) res = r.fetchone() - #import pdb;pdb.set_trace() + # import pdb;pdb.set_trace() if res is None: raise IndexError('index out of range') return res[0] @@ -89,7 +90,8 @@ def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: def _get_doc_by_ids(self, ids: str) -> 'Document': r = self._sql( - f"SELECT serialized_value FROM {self._table_name} WHERE doc_id in ({','.join(['?'] * len(ids))})", ids + f"SELECT serialized_value FROM {self._table_name} WHERE doc_id in ({','.join(['?'] * len(ids))})", + ids, ) res = r.fetchall() if not res: @@ -110,17 +112,15 @@ def _del_docs_by_slice(self, _slice: slice): def _del_docs_by_mask(self, mask: Sequence[bool]): - offsets = [i for i,m in enumerate(mask) if m==True] + offsets = [i for i, m in enumerate(mask) if m == True] self._sql( f"DELETE FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", offsets, ) self._commit() - def _set_doc_value_pairs( self, docs: Iterable['Document'], values: Iterable['Document'] ): for _d, _v in zip(docs, values): self._set_doc_by_id(_d.id, _v) - diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index 6c4c68dfba8..c770761ed52 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -80,4 +80,3 @@ def __iter__(self) -> Iterator['Document']: result = self._cursor.fetchall() for r in result: yield r[0] - diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index a72a5baea9b..741d3da3ab7 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -1,4 +1,3 @@ - import numpy as np import pytest @@ -107,8 +106,8 @@ def test_sequence_bool_index(docs, storage): # setter mask = [True, False] * 50 - #docs[mask] = [Document(text=f'repl{j}') for j in range(50)] - docs[mask,'text'] = [f'repl{j}' for j in range(50)] + # docs[mask] = [Document(text=f'repl{j}') for j in range(50)] + docs[mask, 'text'] = [f'repl{j}' for j in range(50)] for idx, d in enumerate(docs): if idx % 2 == 0: @@ -122,7 +121,6 @@ def test_sequence_bool_index(docs, storage): assert len(docs) == 50 - @pytest.mark.parametrize('nparray', [lambda x: x, np.array, tuple]) @pytest.mark.parametrize('storage', ['memory', 'sqlite']) def test_sequence_int(docs, nparray, storage): @@ -298,7 +296,9 @@ def test_edge_case_two_strings(storage): del da['1', '2'] assert len(da) == 1 - da = DocumentArray([Document(id=str(i), text='hey') for i in range(3)], storage=storage) + da = DocumentArray( + [Document(id=str(i), text='hey') for i in range(3)], storage=storage + ) del da['1', 'text'] assert len(da) == 3 assert not da[1].text From d9f1c6fdd13c2ee6468aa028fed1005afd9fdce5 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 12:55:52 +0100 Subject: [PATCH 40/55] style: fix black style --- docarray/array/base.py | 21 +-------------------- docarray/array/document.py | 21 +++++++++++++++++++-- docarray/array/match.py | 3 +-- docarray/array/memory.py | 2 +- docarray/array/sqlite.py | 2 +- 5 files changed, 23 insertions(+), 26 deletions(-) diff --git a/docarray/array/base.py b/docarray/array/base.py index 14dce727ddf..cf0e85f3564 100644 --- a/docarray/array/base.py +++ b/docarray/array/base.py @@ -1,26 +1,7 @@ from abc import ABC -from docarray.array.mixins import AllMixins class BaseDocumentArray(ABC): - def __init__(self, *args, storage: str = 'memory', **kwargs): + def __init__(self, *args, **kwargs): super().__init__() self._init_storage(*args, **kwargs) - - -class DocumentArray(AllMixins, BaseDocumentArray): - def __new__(cls, *args, storage: str = 'memory', **kwargs): - if cls is DocumentArray: - if storage == 'memory': - from docarray.array.memory import DocumentArrayInMemory - - instance = super().__new__(DocumentArrayInMemory) - elif storage == 'sqlite': - from docarray.array.sqlite import DocumentArraySqlite - - instance = super().__new__(DocumentArraySqlite) - else: - raise ValueError(f'storage=`{storage}` is not supported.') - else: - instance = super().__new__(cls) - return instance diff --git a/docarray/array/document.py b/docarray/array/document.py index e275c50d71a..817372110ae 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,3 +1,20 @@ -from .base import DocumentArray +from .base import BaseDocumentArray +from .mixins import AllMixins -__all__ = ['DocumentArray'] + +class DocumentArray(AllMixins, BaseDocumentArray): + def __new__(cls, *args, storage: str = 'memory', **kwargs): + if cls is DocumentArray: + if storage == 'memory': + from .memory import DocumentArrayInMemory + + instance = super().__new__(DocumentArrayInMemory) + elif storage == 'sqlite': + from .sqlite import DocumentArraySqlite + + instance = super().__new__(DocumentArraySqlite) + else: + raise ValueError(f'storage=`{storage}` is not supported.') + else: + instance = super().__new__(cls) + return instance diff --git a/docarray/array/match.py b/docarray/array/match.py index fca154d66c6..07d2e8b2ced 100644 --- a/docarray/array/match.py +++ b/docarray/array/match.py @@ -6,8 +6,7 @@ Sequence, ) -from docarray import DocumentArray - +from .document import DocumentArray from .memory import DocumentArrayInMemory if TYPE_CHECKING: diff --git a/docarray/array/memory.py b/docarray/array/memory.py index 823f281ebe0..d97ff54656f 100644 --- a/docarray/array/memory.py +++ b/docarray/array/memory.py @@ -1,4 +1,4 @@ -from .base import DocumentArray +from .document import DocumentArray from .storage.memory import StorageMixins diff --git a/docarray/array/sqlite.py b/docarray/array/sqlite.py index 0ca77ec5ba1..c203ef5dc53 100644 --- a/docarray/array/sqlite.py +++ b/docarray/array/sqlite.py @@ -1,4 +1,4 @@ -from .base import DocumentArray +from .document import DocumentArray from .storage.sqlite import StorageMixins From e7b5a893916f8573cdd837cbf9df31886b252ebd Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 13:08:14 +0100 Subject: [PATCH 41/55] fix(sqlite): fix iter method in sqlite --- docarray/array/base.py | 2 +- docarray/array/storage/sqlite/seqlike.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docarray/array/base.py b/docarray/array/base.py index cf0e85f3564..87d849206aa 100644 --- a/docarray/array/base.py +++ b/docarray/array/base.py @@ -2,6 +2,6 @@ class BaseDocumentArray(ABC): - def __init__(self, *args, **kwargs): + def __init__(self, *args, storage: str, **kwargs): super().__init__() self._init_storage(*args, **kwargs) diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index c770761ed52..05cdeb6a7b6 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -73,10 +73,8 @@ def __len__(self) -> int: return r.fetchone()[0] def __iter__(self) -> Iterator['Document']: - # TODO: make this memory efficient self._sql( f'SELECT serialized_value FROM {self._table_name} ORDER BY item_order' ) - result = self._cursor.fetchall() - for r in result: + for r in self._cursor: yield r[0] From 1c9cb0cb26507bc3711174c946d8eeceea1d79b9 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 13:32:23 +0100 Subject: [PATCH 42/55] fix(sqlite): fix iter method in sqlite --- docarray/array/base.py | 2 +- docarray/array/storage/sqlite/seqlike.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/docarray/array/base.py b/docarray/array/base.py index 87d849206aa..a64fa37602e 100644 --- a/docarray/array/base.py +++ b/docarray/array/base.py @@ -2,6 +2,6 @@ class BaseDocumentArray(ABC): - def __init__(self, *args, storage: str, **kwargs): + def __init__(self, *args, storage: str = 'memory', **kwargs): super().__init__() self._init_storage(*args, **kwargs) diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index 05cdeb6a7b6..e1e41b7e782 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -15,13 +15,9 @@ def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): ) def _shift_index_right_backward(self, start: int): - idx = len(self) - 1 - while idx >= start: - self._sql( - f'UPDATE {self._table_name} SET item_order = ? WHERE item_order = ?', - (idx + 1, idx), - ) - idx -= 1 + self._sql( + f'UPDATE {self._table_name} SET item_order=item_order+1 WHERE item_order>={start}' + ) def insert(self, index: int, value: 'Document'): """Insert `doc` at `index`. @@ -58,6 +54,7 @@ def __del__(self) -> None: (self._table_name, self.__class__.__name__), ) self._sql(f'DROP TABLE {self._table_name}') + self._commit() def __contains__(self, item: Union[str, 'Document']): if isinstance(item, str): From 3c3808499b40e445ada28b8b1f2b77304d0ba3a8 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 14:03:37 +0100 Subject: [PATCH 43/55] fix(sqlite): revert shift index right backward --- docarray/array/storage/sqlite/getsetdel.py | 3 ++- docarray/array/storage/sqlite/seqlike.py | 10 +++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 0b33987c495..db2aea09ba9 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -22,7 +22,8 @@ def _del_doc_by_offset(self, offset: int): # shift the offset of every value on the right position of the deleted item self._sql( - f'UPDATE {self._table_name} SET item_order=item_order-1 WHERE item_order>={offset}' + f'UPDATE {self._table_name} SET item_order=item_order-1 WHERE item_order>?', + (offset,), ) # Code above line is equivalent to diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index e1e41b7e782..6bcc9dec670 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -15,9 +15,13 @@ def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): ) def _shift_index_right_backward(self, start: int): - self._sql( - f'UPDATE {self._table_name} SET item_order=item_order+1 WHERE item_order>={start}' - ) + idx = len(self) - 1 + while idx >= start: + self._sql( + f'UPDATE {self._table_name} SET item_order = ? WHERE item_order = ?', + (idx + 1, idx), + ) + idx -= 1 def insert(self, index: int, value: 'Document'): """Insert `doc` at `index`. From 30e05cf802cc3ff9fe6a7278e23ccbc8b1031036 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 25 Jan 2022 14:09:47 +0100 Subject: [PATCH 44/55] fix: use separate cursor for iterator --- docarray/array/storage/sqlite/seqlike.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index 6c4c68dfba8..c30ac2a22b1 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -73,11 +73,9 @@ def __len__(self) -> int: return r.fetchone()[0] def __iter__(self) -> Iterator['Document']: - # TODO: make this memory efficient - self._sql( + cursor = self._connection.cursor() + r = cursor.execute( f'SELECT serialized_value FROM {self._table_name} ORDER BY item_order' ) - result = self._cursor.fetchall() - for r in result: - yield r[0] - + for res in r: + yield res[0] From baa25f34c89bdba60c2aadb5bc06255b08e13872 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 25 Jan 2022 14:49:46 +0100 Subject: [PATCH 45/55] fix: use separate cursor --- docarray/array/storage/sqlite/getsetdel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 38755186a87..c0bcaf2b61e 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -57,7 +57,6 @@ def _get_doc_by_offset(self, index: int) -> 'Document': (index + (len(self) if index < 0 else 0),), ) res = r.fetchone() - #import pdb;pdb.set_trace() if res is None: raise IndexError('index out of range') return res[0] @@ -77,7 +76,8 @@ def _get_doc_by_id(self, id: str) -> 'Document': def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: l = len(self) offsets = [o + (l if o < 0 else 0) for o in offsets] - r = self._sql( + cursor = self._connection.cursor() + r = cursor.execute( f"SELECT serialized_value FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", offsets, ) From 63cb55ec5d512ca9c9668dcd47e510f90434d28b Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 25 Jan 2022 16:22:29 +0100 Subject: [PATCH 46/55] fix: persist when table_name is specified --- docarray/array/storage/sqlite/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 745961d4fdf..1762917019f 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -81,7 +81,7 @@ def _init_storage( else _sanitize_table_name(config.table_name) ) self._cursor = self._connection.cursor() - self._persist = not config.table_name + self._persist = bool(config.table_name) initialize_table( self._table_name, self.__class__.__name__, self.schema_version, self._cursor ) From 681aa8465bf0e51b365ba6b24dadd5b988731f6f Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 25 Jan 2022 17:08:39 +0100 Subject: [PATCH 47/55] test: add base tests --- tests/unit/array/test_base_getsetdel.py | 72 +++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/unit/array/test_base_getsetdel.py diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py new file mode 100644 index 00000000000..552be342388 --- /dev/null +++ b/tests/unit/array/test_base_getsetdel.py @@ -0,0 +1,72 @@ +from abc import ABC + +import pytest + +from docarray import DocumentArray, Document +from docarray.array.mixins import GetItemMixin +from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin +from docarray.array.storage.memory import BackendMixin, SequenceLikeMixin + + +class DummyGetSetDelMixin(BaseGetSetDelMixin): + """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + + # essentials + + def _del_doc_by_id(self, _id: str): + del self._data[self._id2offset[_id]] + self._id2offset.pop(_id) + + def _del_doc_by_offset(self, offset: int): + self._id2offset.pop(self._data[offset].id) + del self._data[offset] + + def _set_doc_by_id(self, _id: str, value: 'Document'): + old_idx = self._id2offset.pop(_id) + self._data[old_idx] = value + self._id2offset[value.id] = old_idx + + + def _get_doc_by_offset(self, offset: int) -> 'Document': + return self._data[offset] + + def _get_doc_by_id(self, _id: str) -> 'Document': + return self._data[self._id2offset[_id]] + + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + self._data[offset] = value + self._id2offset[value.id] = offset + + +class StorageMixins(BackendMixin, DummyGetSetDelMixin, SequenceLikeMixin, ABC): + ... + + +class DocumentArrayDummy(StorageMixins, DocumentArray): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + +@pytest.fixture(scope='function') +def docs(): + return DocumentArrayDummy([Document(text=j) for j in range(100)]) + + +def test_index_by_int_str(docs): + # getter + assert len(docs[[1]]) == 1 + assert len(docs[1, 2]) == 2 + assert len(docs[1, 2, 3]) == 3 + assert len(docs[1:5]) == 4 + assert len(docs[1:100:5]) == 20 # 1 to 100, sep with 5 + + # setter + with pytest.raises(TypeError, match='an iterable'): + docs[1:5] = Document(text='repl') + + docs[1:5] = [Document(text=f'repl{j}') for j in range(4)] + for d in docs[1:5]: + assert d.text.startswith('repl') + assert len(docs) == 100 + From 41f38adf74673241864155aa34d7609dc6cea767 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 17:14:33 +0100 Subject: [PATCH 48/55] fix(sqlite): fix _get_docs_by_offsets --- docarray/array/storage/sqlite/backend.py | 6 +++++- docarray/array/storage/sqlite/getsetdel.py | 11 ++++------- docarray/array/storage/sqlite/seqlike.py | 3 +-- tests/unit/array/test_base_getsetdel.py | 3 --- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 1762917019f..be55ee40fb5 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -44,6 +44,10 @@ def _sql(self, *args, **kwargs) -> 'sqlite3.Cursor': def _commit(self): self._connection.commit() + @property + def _cursor(self) -> 'sqlite3.Cursor': + return self._connection.cursor() + def _init_storage( self, docs: Optional['DocumentArraySourceType'] = None, @@ -80,7 +84,7 @@ def _init_storage( if config.table_name is None else _sanitize_table_name(config.table_name) ) - self._cursor = self._connection.cursor() + self._persist = bool(config.table_name) initialize_table( self._table_name, self.__class__.__name__, self.schema_version, self._cursor diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 6b3f90e6aaf..89318b114df 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -78,8 +78,7 @@ def _get_doc_by_id(self, id: str) -> 'Document': def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: l = len(self) offsets = [o + (l if o < 0 else 0) for o in offsets] - cursor = self._connection.cursor() - r = cursor.execute( + r = self._sql( f"SELECT serialized_value FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", offsets, ) @@ -89,15 +88,13 @@ def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: return self._get_docs_by_offsets(range(len(self))[_slice]) - def _get_doc_by_ids(self, ids: str) -> 'Document': + def _get_docs_by_ids(self, ids: str) -> Iterable['Document']: r = self._sql( f"SELECT serialized_value FROM {self._table_name} WHERE doc_id in ({','.join(['?'] * len(ids))})", ids, ) - res = r.fetchall() - if not res: - raise KeyError(f'Cannot find any Documents with ids from {ids}') - return res + for rr in r: + yield rr[0] def _del_all_docs(self): self._sql(f'DELETE FROM {self._table_name}') diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index f75122cb69e..17c5b718fd7 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -74,8 +74,7 @@ def __len__(self) -> int: return r.fetchone()[0] def __iter__(self) -> Iterator['Document']: - cursor = self._connection.cursor() - r = cursor.execute( + r = self._sql( f'SELECT serialized_value FROM {self._table_name} ORDER BY item_order' ) for res in r: diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py index 552be342388..eda3ac070d1 100644 --- a/tests/unit/array/test_base_getsetdel.py +++ b/tests/unit/array/test_base_getsetdel.py @@ -26,14 +26,12 @@ def _set_doc_by_id(self, _id: str, value: 'Document'): self._data[old_idx] = value self._id2offset[value.id] = old_idx - def _get_doc_by_offset(self, offset: int) -> 'Document': return self._data[offset] def _get_doc_by_id(self, _id: str) -> 'Document': return self._data[self._id2offset[_id]] - def _set_doc_by_offset(self, offset: int, value: 'Document'): self._data[offset] = value self._id2offset[value.id] = offset @@ -69,4 +67,3 @@ def test_index_by_int_str(docs): for d in docs[1:5]: assert d.text.startswith('repl') assert len(docs) == 100 - From cb6ce426c1edc57f71bfb46e5b707617207c3ec3 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 25 Jan 2022 17:15:08 +0100 Subject: [PATCH 49/55] fix: linting --- tests/unit/array/test_base_getsetdel.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py index 552be342388..b7911047315 100644 --- a/tests/unit/array/test_base_getsetdel.py +++ b/tests/unit/array/test_base_getsetdel.py @@ -3,7 +3,6 @@ import pytest from docarray import DocumentArray, Document -from docarray.array.mixins import GetItemMixin from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin from docarray.array.storage.memory import BackendMixin, SequenceLikeMixin @@ -26,14 +25,12 @@ def _set_doc_by_id(self, _id: str, value: 'Document'): self._data[old_idx] = value self._id2offset[value.id] = old_idx - def _get_doc_by_offset(self, offset: int) -> 'Document': return self._data[offset] def _get_doc_by_id(self, _id: str) -> 'Document': return self._data[self._id2offset[_id]] - def _set_doc_by_offset(self, offset: int, value: 'Document'): self._data[offset] = value self._id2offset[value.id] = offset @@ -69,4 +66,3 @@ def test_index_by_int_str(docs): for d in docs[1:5]: assert d.text.startswith('repl') assert len(docs) == 100 - From be1c7c36052eab25a4237104298706b2fdd767ae Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 17:16:03 +0100 Subject: [PATCH 50/55] fix(sqlite): fix _get_docs_by_offsets --- tests/unit/array/test_base_getsetdel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py index eda3ac070d1..b7911047315 100644 --- a/tests/unit/array/test_base_getsetdel.py +++ b/tests/unit/array/test_base_getsetdel.py @@ -3,7 +3,6 @@ import pytest from docarray import DocumentArray, Document -from docarray.array.mixins import GetItemMixin from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin from docarray.array.storage.memory import BackendMixin, SequenceLikeMixin From fc54273505c8529ba90c590055a4516160874d7e Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 25 Jan 2022 17:38:15 +0100 Subject: [PATCH 51/55] test: test base --- tests/unit/array/test_base_getsetdel.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py index b7911047315..064ffd2209e 100644 --- a/tests/unit/array/test_base_getsetdel.py +++ b/tests/unit/array/test_base_getsetdel.py @@ -1,5 +1,6 @@ from abc import ABC +import numpy as np import pytest from docarray import DocumentArray, Document @@ -47,7 +48,7 @@ def __new__(cls, *args, **kwargs): @pytest.fixture(scope='function') def docs(): - return DocumentArrayDummy([Document(text=j) for j in range(100)]) + return DocumentArrayDummy([Document(id=str(j), text=j) for j in range(100)]) def test_index_by_int_str(docs): @@ -66,3 +67,20 @@ def test_index_by_int_str(docs): for d in docs[1:5]: assert d.text.startswith('repl') assert len(docs) == 100 + + +def test_getter_int_str(docs): + # getter + assert docs[99].text == 99 + assert docs[-1].text == 99 + assert docs[0].text == 0 + + # string index + assert docs['0'].text == 0 + assert docs['99'].text == 99 + + with pytest.raises(IndexError): + docs[100] + + with pytest.raises(KeyError): + docs['adsad'] \ No newline at end of file From c83a755ecb96a675c5fce085bd5ef7f2d4e98a55 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 25 Jan 2022 17:40:43 +0100 Subject: [PATCH 52/55] fix: linting --- tests/unit/array/test_base_getsetdel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py index 064ffd2209e..a772c8e44e9 100644 --- a/tests/unit/array/test_base_getsetdel.py +++ b/tests/unit/array/test_base_getsetdel.py @@ -83,4 +83,4 @@ def test_getter_int_str(docs): docs[100] with pytest.raises(KeyError): - docs['adsad'] \ No newline at end of file + docs['adsad'] From 6a8db70316a80b18813afebfd00e3ba4f332a232 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 18:13:06 +0100 Subject: [PATCH 53/55] feat(sqlite): improve type hint --- docarray/array/document.py | 30 ++++++++++++++++++++++- docarray/array/storage/memory/backend.py | 22 ++++++++--------- docarray/array/storage/memory/seqlike.py | 5 ++-- docarray/array/storage/sqlite/__init__.py | 4 +-- docarray/array/storage/sqlite/backend.py | 18 ++++++++------ docarray/array/storage/sqlite/seqlike.py | 25 ++++++++++++++++++- docarray/helper.py | 10 ++++++++ docarray/types.py | 5 ++++ tests/unit/array/mixins/test_magic.py | 18 +++++++++++++- tests/unit/array/test_advance_indexing.py | 2 +- 10 files changed, 112 insertions(+), 27 deletions(-) diff --git a/docarray/array/document.py b/docarray/array/document.py index 817372110ae..f21bab69a33 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,9 +1,37 @@ +from typing import Optional, overload, TYPE_CHECKING, Dict, Union + from .base import BaseDocumentArray from .mixins import AllMixins +if TYPE_CHECKING: + from ..types import ( + DocumentArraySourceType, + DocumentArrayLike, + DocumentArraySqlite, + DocumentArrayInMemory, + ) + from .storage.sqlite import SqliteConfig + class DocumentArray(AllMixins, BaseDocumentArray): - def __new__(cls, *args, storage: str = 'memory', **kwargs): + @overload + def __new__( + cls, _docs: Optional['DocumentArraySourceType'] = None, copy: bool = False + ) -> 'DocumentArrayInMemory': + """Create an in-memory DocumentArray object.""" + ... + + @overload + def __new__( + cls, + _docs: Optional['DocumentArraySourceType'] = None, + storage: str = 'sqlite', + config: Optional[Union['SqliteConfig', Dict]] = None, + ) -> 'DocumentArraySqlite': + """Create a SQLite-powered DocumentArray object.""" + ... + + def __new__(cls, *args, storage: str = 'memory', **kwargs) -> 'DocumentArrayLike': if cls is DocumentArray: if storage == 'memory': from .memory import DocumentArrayInMemory diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py index 6146a3136fa..ea596702c85 100644 --- a/docarray/array/storage/memory/backend.py +++ b/docarray/array/storage/memory/backend.py @@ -41,28 +41,28 @@ def _rebuild_id2offset(self) -> None: } # type: Dict[str, int] def _init_storage( - self, docs: Optional['DocumentArraySourceType'] = None, copy: bool = False + self, _docs: Optional['DocumentArraySourceType'] = None, copy: bool = False ): from ... import DocumentArray self._data = [] - if docs is None: + if _docs is None: return elif isinstance( - docs, (DocumentArray, Sequence, Generator, Iterator, itertools.chain) + _docs, (DocumentArray, Sequence, Generator, Iterator, itertools.chain) ): if copy: - self._data = [Document(d, copy=True) for d in docs] + self._data = [Document(d, copy=True) for d in _docs] self._rebuild_id2offset() - elif isinstance(docs, DocumentArray): - self._data = docs._data - self._id_to_index = docs._id2offset + elif isinstance(_docs, DocumentArray): + self._data = _docs._data + self._id_to_index = _docs._id2offset else: - self._data = list(docs) + self._data = list(_docs) self._rebuild_id2offset() else: - if isinstance(docs, Document): + if isinstance(_docs, Document): if copy: - self.append(Document(docs, copy=True)) + self.append(Document(_docs, copy=True)) else: - self.append(docs) + self.append(_docs) diff --git a/docarray/array/storage/memory/seqlike.py b/docarray/array/storage/memory/seqlike.py index bc625d1b201..0121780fadf 100644 --- a/docarray/array/storage/memory/seqlike.py +++ b/docarray/array/storage/memory/seqlike.py @@ -48,11 +48,10 @@ def __bool__(self): return len(self) > 0 def __repr__(self): - return f'<{self.__class__.__name__} (length={len(self)}) at {id(self)}>' + return f'' def __add__(self, other: Union['Document', Sequence['Document']]): - v = type(self)() - v.extend(self) + v = type(self)(self) v.extend(other) return v diff --git a/docarray/array/storage/sqlite/__init__.py b/docarray/array/storage/sqlite/__init__.py index cf9d9e7f97d..4d9bf2b291b 100644 --- a/docarray/array/storage/sqlite/__init__.py +++ b/docarray/array/storage/sqlite/__init__.py @@ -1,10 +1,10 @@ from abc import ABC -from .backend import BackendMixin +from .backend import BackendMixin, SqliteConfig from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin -__all__ = ['StorageMixins'] +__all__ = ['StorageMixins', 'SqliteConfig'] class StorageMixins(BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index be55ee40fb5..d1605292143 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -11,7 +11,7 @@ from .helper import initialize_table from ..base.backend import BaseBackendMixin -from ....helper import random_identity +from ....helper import random_identity, dataclass_from_dict if TYPE_CHECKING: from ....types import ( @@ -31,6 +31,7 @@ class SqliteConfig: connection: Optional[Union[str, 'sqlite3.Connection']] = None table_name: Optional[str] = None serialize_config: Dict = field(default_factory=dict) + conn_config: Dict = field(default_factory=dict) class BackendMixin(BaseBackendMixin): @@ -50,12 +51,15 @@ def _cursor(self) -> 'sqlite3.Cursor': def _init_storage( self, - docs: Optional['DocumentArraySourceType'] = None, - config: Optional[SqliteConfig] = None, + _docs: Optional['DocumentArraySourceType'] = None, + config: Optional[Union[SqliteConfig, Dict]] = None, ): if not config: config = SqliteConfig() + if isinstance(config, dict): + config = dataclass_from_dict(SqliteConfig, config) + from docarray import Document sqlite3.register_adapter( @@ -66,6 +70,7 @@ def _init_storage( ) _conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES) + _conn_kwargs.update(config.conn_config) if config.connection is None: self._connection = sqlite3.connect( NamedTemporaryFile().name, **_conn_kwargs @@ -84,13 +89,12 @@ def _init_storage( if config.table_name is None else _sanitize_table_name(config.table_name) ) - self._persist = bool(config.table_name) initialize_table( self._table_name, self.__class__.__name__, self.schema_version, self._cursor ) self._connection.commit() - - if docs is not None: + self._config = config + if _docs is not None: self.clear() - self.extend(docs) + self.extend(_docs) diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py index 17c5b718fd7..86647531fc8 100644 --- a/docarray/array/storage/sqlite/seqlike.py +++ b/docarray/array/storage/sqlite/seqlike.py @@ -1,4 +1,4 @@ -from typing import Iterator, Union, Iterable, MutableSequence, Optional +from typing import Iterator, Union, Iterable, MutableSequence, Optional, Sequence from .... import Document @@ -79,3 +79,26 @@ def __iter__(self) -> Iterator['Document']: ) for res in r: yield res[0] + + def __repr__(self): + return f'' + + def __bool__(self): + """To simulate ```l = []; if l: ...``` + + :return: returns true if the length of the array is larger than 0 + """ + return len(self) > 0 + + def __eq__(self, other): + """In sqlite backend, data are considered as identical if configs point to the same database source""" + return ( + type(self) is type(other) + and type(self._config) is type(other._config) + and self._config == other._config + ) + + def __add__(self, other: Union['Document', Sequence['Document']]): + v = type(self)(self, storage='sqlite') + v.extend(other) + return v diff --git a/docarray/helper.py b/docarray/helper.py index 507c1d5f118..811241e2da2 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -311,3 +311,13 @@ def get_compress_ctx(algorithm: Optional[str] = None, mode: str = 'wb'): else: compress_ctx = None return compress_ctx + + +def dataclass_from_dict(klass, dikt): + try: + fieldtypes = klass.__annotations__ + return klass(**{f: dataclass_from_dict(fieldtypes[f], dikt[f]) for f in dikt}) + except AttributeError: + if isinstance(dikt, (tuple, list)): + return [dataclass_from_dict(klass.__args__[0], f) for f in dikt] + return dikt diff --git a/docarray/types.py b/docarray/types.py index 7a4253bf7be..18ee8361c08 100644 --- a/docarray/types.py +++ b/docarray/types.py @@ -61,3 +61,8 @@ DocumentArraySingleAttributeType, DocumentArrayMultipleAttributeType, ] + + from .array.sqlite import DocumentArraySqlite + from .array.memory import DocumentArrayInMemory + + DocumentArrayLike = Union[DocumentArrayInMemory, DocumentArraySqlite] diff --git a/tests/unit/array/mixins/test_magic.py b/tests/unit/array/mixins/test_magic.py index 76ade193d7c..5ceb186ffcd 100644 --- a/tests/unit/array/mixins/test_magic.py +++ b/tests/unit/array/mixins/test_magic.py @@ -1,6 +1,6 @@ import pytest -from docarray import DocumentArray +from docarray import DocumentArray, Document N = 100 @@ -10,6 +10,11 @@ def da_and_dam(): return (da,) +@pytest.fixture +def docs(): + yield (Document(text=str(j)) for j in range(100)) + + @pytest.mark.parametrize('da', da_and_dam()) def test_iter_len_bool(da): j = 0 @@ -27,6 +32,17 @@ def test_repr(da): assert f'length={N}' in repr(da) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_repr_str(docs, storage): + da = DocumentArray(docs, storage=storage) + print(da) + da.summary() + assert da + da.clear() + assert not da + print(da) + + @pytest.mark.parametrize('da', da_and_dam()) def test_iadd(da): oid = id(da) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 741d3da3ab7..5e489e6cbd1 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -6,7 +6,7 @@ @pytest.fixture def docs(): - yield (Document(text=j) for j in range(100)) + yield (Document(text=str(j)) for j in range(100)) @pytest.fixture From 62aa9da5ab7cc85454880470acd1b6a5fff8d7a8 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 18:23:07 +0100 Subject: [PATCH 54/55] feat(sqlite): improve type hint --- docarray/array/base.py | 6 ++++-- tests/unit/array/test_advance_indexing.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docarray/array/base.py b/docarray/array/base.py index a64fa37602e..dd1721c362b 100644 --- a/docarray/array/base.py +++ b/docarray/array/base.py @@ -1,7 +1,9 @@ -from abc import ABC +from typing import MutableSequence +from .. import Document -class BaseDocumentArray(ABC): + +class BaseDocumentArray(MutableSequence[Document]): def __init__(self, *args, storage: str = 'memory', **kwargs): super().__init__() self._init_storage(*args, **kwargs) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 5e489e6cbd1..741d3da3ab7 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -6,7 +6,7 @@ @pytest.fixture def docs(): - yield (Document(text=str(j)) for j in range(100)) + yield (Document(text=j) for j in range(100)) @pytest.fixture From 778aae4d6727d9f352437e84ac2d296a2154f87b Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 25 Jan 2022 18:25:00 +0100 Subject: [PATCH 55/55] chore: bump version to 0.5 --- docarray/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/__init__.py b/docarray/__init__.py index 8bad5be9b57..ba5f8fa1b7b 100644 --- a/docarray/__init__.py +++ b/docarray/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.3.4' +__version__ = '0.5.0' from .document import Document from .array import DocumentArray