From ae41e6dbab1461ceac61bfea9e269ff48cc7526e Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 26 Jan 2022 14:47:08 +0100 Subject: [PATCH 01/41] test(sqlite): add more test to cover sqlite backend --- docarray/array/storage/sqlite/backend.py | 18 +++++++++-- tests/unit/array/mixins/test_content.py | 12 +++++--- tests/unit/array/mixins/test_embed.py | 3 +- tests/unit/array/mixins/test_empty.py | 6 +++- tests/unit/array/mixins/test_getset.py | 8 ++--- tests/unit/array/mixins/test_io.py | 33 +++++++++++++++----- tests/unit/array/mixins/test_magic.py | 4 ++- tests/unit/array/mixins/test_parallel.py | 13 +++----- tests/unit/array/mixins/test_plot.py | 37 ++++++++++++++-------- tests/unit/array/mixins/test_sample.py | 4 ++- tests/unit/array/mixins/test_text.py | 39 +++++++++++++----------- tests/unit/array/test_construct.py | 36 ++++++++++++++-------- tests/unit/array/test_ravel_unravel.py | 11 +++++-- tests/unit/array/test_sequence.py | 5 +-- tests/unit/test_pydantic.py | 6 ++-- 15 files changed, 155 insertions(+), 80 deletions(-) diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index f69200efdea..690277a450d 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -1,12 +1,16 @@ +import itertools import sqlite3 import warnings from dataclasses import dataclass, field from tempfile import NamedTemporaryFile from typing import ( + Generator, + Iterator, + Dict, + Sequence, Optional, TYPE_CHECKING, Union, - Dict, ) from .helper import initialize_table @@ -55,6 +59,7 @@ def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, config: Optional[Union[SqliteConfig, Dict]] = None, + **kwargs, ): if not config: config = SqliteConfig() @@ -101,6 +106,15 @@ def _init_storage( ) self._connection.commit() self._config = config - if _docs is not None: + from ... import DocumentArray + + if _docs is None: + return + elif isinstance( + _docs, (DocumentArray, Sequence, Generator, Iterator, itertools.chain) + ): self.clear() self.extend(_docs) + else: + if isinstance(_docs, Document): + self.append(_docs) diff --git a/tests/unit/array/mixins/test_content.py b/tests/unit/array/mixins/test_content.py index 4059508b2b0..45686809b2e 100644 --- a/tests/unit/array/mixins/test_content.py +++ b/tests/unit/array/mixins/test_content.py @@ -2,9 +2,10 @@ import pytest from docarray import DocumentArray +from docarray.array.sqlite import DocumentArraySqlite -@pytest.mark.parametrize('cls', [DocumentArray]) +@pytest.mark.parametrize('cls', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize( 'content_attr', ['texts', 'embeddings', 'tensors', 'blobs', 'contents'] ) @@ -13,7 +14,7 @@ def test_content_empty_getter_return_none(cls, content_attr): assert getattr(da, content_attr) is None -@pytest.mark.parametrize('cls', [DocumentArray]) +@pytest.mark.parametrize('cls', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize( 'content_attr', [ @@ -30,7 +31,7 @@ def test_content_empty_setter(cls, content_attr): assert getattr(da, content_attr[0]) is None -@pytest.mark.parametrize('cls', [DocumentArray]) +@pytest.mark.parametrize('cls', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize( 'content_attr', [ @@ -51,8 +52,9 @@ def test_content_getter_setter(cls, content_attr): @pytest.mark.parametrize('da_len', [0, 1, 2]) -def test_content_empty(da_len): - da = DocumentArray.empty(da_len) +@pytest.mark.parametrize('cls', [DocumentArray, DocumentArraySqlite]) +def test_content_empty(da_len, cls): + da = cls.empty(da_len) assert not da.texts assert not da.contents assert not da.tensors diff --git a/tests/unit/array/mixins/test_embed.py b/tests/unit/array/mixins/test_embed.py index c3146a3675f..f13976be4e3 100644 --- a/tests/unit/array/mixins/test_embed.py +++ b/tests/unit/array/mixins/test_embed.py @@ -8,6 +8,7 @@ import torch from docarray import DocumentArray +from docarray.array.sqlite import DocumentArraySqlite random_embed_models = { 'keras': lambda: tf.keras.Sequential( @@ -40,7 +41,7 @@ @pytest.mark.parametrize('framework', ['onnx', 'keras', 'pytorch', 'paddle']) -@pytest.mark.parametrize('da', [DocumentArray]) +@pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize('N', [2, 1000]) @pytest.mark.parametrize('batch_size', [1, 256]) @pytest.mark.parametrize('to_numpy', [True, False]) diff --git a/tests/unit/array/mixins/test_empty.py b/tests/unit/array/mixins/test_empty.py index 480eec09f65..12da4462916 100644 --- a/tests/unit/array/mixins/test_empty.py +++ b/tests/unit/array/mixins/test_empty.py @@ -1,7 +1,11 @@ +import pytest + from docarray import DocumentArray +from docarray.array.sqlite import DocumentArraySqlite -def test_empty_non_zero(): +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) +def test_empty_non_zero(da_cls): da = DocumentArray.empty(10) assert len(da) == 10 da = DocumentArray.empty() diff --git a/tests/unit/array/mixins/test_getset.py b/tests/unit/array/mixins/test_getset.py index cb8e4d0450a..d455bb352d2 100644 --- a/tests/unit/array/mixins/test_getset.py +++ b/tests/unit/array/mixins/test_getset.py @@ -6,6 +6,7 @@ from scipy.sparse import csr_matrix from docarray import DocumentArray, Document +from docarray.array.sqlite import DocumentArraySqlite from tests import random_docs rand_array = np.random.random([10, 3]) @@ -15,7 +16,8 @@ def da_and_dam(): rand_docs = random_docs(100) da = DocumentArray() da.extend(rand_docs) - return (da,) + das = DocumentArraySqlite(rand_docs) + return (da, das) @pytest.mark.parametrize( @@ -69,9 +71,7 @@ def test_tensors_getter_da(da): np.testing.assert_almost_equal(da.tensors, tensors) da.tensors = None - if hasattr(da, 'flush'): - da.flush() - assert not da.tensors + assert da.tensors is None @pytest.mark.parametrize('da', da_and_dam()) diff --git a/tests/unit/array/mixins/test_io.py b/tests/unit/array/mixins/test_io.py index 934c5585ce4..003d4754601 100644 --- a/tests/unit/array/mixins/test_io.py +++ b/tests/unit/array/mixins/test_io.py @@ -5,17 +5,19 @@ import pytest from docarray import DocumentArray +from docarray.array.sqlite import DocumentArraySqlite from tests import random_docs def da_and_dam(): da = DocumentArray(random_docs(100)) - return (da,) + das = DocumentArraySqlite(random_docs(100)) + return (da, das) @pytest.mark.slow @pytest.mark.parametrize('method', ['json', 'binary']) -@pytest.mark.parametrize('da', da_and_dam()) +@pytest.mark.parametrize('da', (da_and_dam()[0],)) def test_document_save_load(method, tmp_path, da): tmp_file = os.path.join(tmp_path, 'test') da.save(tmp_file, file_format=method) @@ -38,13 +40,13 @@ def test_da_csv_write(flatten_tags, tmp_path, da): assert len([v for v in fp]) == len(da) + 1 -@pytest.mark.parametrize('da', [DocumentArray]) +@pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) def test_from_ndarray(da): _da = da.from_ndarray(np.random.random([10, 256])) assert len(_da) == 10 -@pytest.mark.parametrize('da', [DocumentArray]) +@pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) def test_from_files(da): assert len(da.from_files(patterns='*.*', to_dataturi=True, size=1)) == 1 @@ -52,14 +54,19 @@ def test_from_files(da): cur_dir = os.path.dirname(os.path.abspath(__file__)) -@pytest.mark.parametrize('da', [DocumentArray]) +@pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) def test_from_ndjson(da): with open(os.path.join(cur_dir, 'docs.jsonlines')) as fp: _da = da.from_ndjson(fp) assert len(_da) == 2 -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize( + 'da_cls', + [ + DocumentArray, + ], +) def test_from_to_pd_dataframe(da_cls): # simple assert len(da_cls.from_dataframe(da_cls.empty(2).to_dataframe())) == 2 @@ -74,7 +81,12 @@ def test_from_to_pd_dataframe(da_cls): assert da2[1].tags == {} -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize( + 'da_cls', + [ + DocumentArray, + ], +) def test_from_to_bytes(da_cls): # simple assert len(da_cls.load_binary(bytes(da_cls.empty(2)))) == 2 @@ -91,7 +103,12 @@ def test_from_to_bytes(da_cls): assert da2[1].tags == {} -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize( + 'da_cls', + [ + DocumentArray, + ], +) @pytest.mark.parametrize('show_progress', [True, False]) def test_push_pull_io(da_cls, show_progress): da1 = da_cls.empty(10) diff --git a/tests/unit/array/mixins/test_magic.py b/tests/unit/array/mixins/test_magic.py index 5ceb186ffcd..93f821010fb 100644 --- a/tests/unit/array/mixins/test_magic.py +++ b/tests/unit/array/mixins/test_magic.py @@ -1,13 +1,15 @@ import pytest from docarray import DocumentArray, Document +from docarray.array.sqlite import DocumentArraySqlite N = 100 def da_and_dam(): da = DocumentArray.empty(N) - return (da,) + dasq = DocumentArraySqlite.empty(N) + return (da, dasq) @pytest.fixture diff --git a/tests/unit/array/mixins/test_parallel.py b/tests/unit/array/mixins/test_parallel.py index 396a472b81d..ea54de3e548 100644 --- a/tests/unit/array/mixins/test_parallel.py +++ b/tests/unit/array/mixins/test_parallel.py @@ -2,6 +2,7 @@ import pytest from docarray import DocumentArray, Document +from docarray.array.sqlite import DocumentArraySqlite def foo(d: Document): @@ -25,9 +26,7 @@ def foo_batch(da: DocumentArray): ) @pytest.mark.parametrize( 'da_cls', - [ - DocumentArray, - ], + [DocumentArray, DocumentArraySqlite], ) @pytest.mark.parametrize('backend', ['process', 'thread']) @pytest.mark.parametrize('num_worker', [1, 2, None]) @@ -58,9 +57,7 @@ def test_parallel_map(pytestconfig, da_cls, backend, num_worker): ) @pytest.mark.parametrize( 'da_cls', - [ - DocumentArray, - ], + [DocumentArray, DocumentArraySqlite], ) @pytest.mark.parametrize('backend', ['thread']) @pytest.mark.parametrize('num_worker', [1, 2, None]) @@ -98,9 +95,7 @@ def test_parallel_map_batch(pytestconfig, da_cls, backend, num_worker, b_size): ) @pytest.mark.parametrize( 'da_cls', - [ - DocumentArray, - ], + [DocumentArray, DocumentArraySqlite], ) def test_map_lambda(pytestconfig, da_cls): da = da_cls.from_files(f'{pytestconfig.rootdir}/**/*.jpeg')[:10] diff --git a/tests/unit/array/mixins/test_plot.py b/tests/unit/array/mixins/test_plot.py index cdc3853ba4e..b49594396ca 100644 --- a/tests/unit/array/mixins/test_plot.py +++ b/tests/unit/array/mixins/test_plot.py @@ -6,10 +6,12 @@ import pytest from docarray import DocumentArray, Document +from docarray.array.sqlite import DocumentArraySqlite -def test_sprite_fail_tensor_success_uri(pytestconfig, tmpdir): - da = DocumentArray.from_files( +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) +def test_sprite_fail_tensor_success_uri(pytestconfig, tmpdir, da_cls): + da = da_cls.from_files( [ f'{pytestconfig.rootdir}/**/*.png', f'{pytestconfig.rootdir}/**/*.jpg', @@ -26,8 +28,9 @@ def test_sprite_fail_tensor_success_uri(pytestconfig, tmpdir): @pytest.mark.parametrize('image_source', ['tensor', 'uri']) -def test_sprite_image_generator(pytestconfig, tmpdir, image_source): - da = DocumentArray.from_files( +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) +def test_sprite_image_generator(pytestconfig, tmpdir, image_source, da_cls): + da = da_cls.from_files( [ f'{pytestconfig.rootdir}/**/*.png', f'{pytestconfig.rootdir}/**/*.jpg', @@ -48,7 +51,14 @@ def da_and_dam(): ] ) - return (doc_array,) + doc_arraysq = DocumentArraySqlite( + [ + Document(embedding=x, tags={'label': random.randint(0, 5)}) + for x in embeddings + ] + ) + + return (doc_array, doc_arraysq) @pytest.mark.parametrize('da', da_and_dam()) @@ -62,11 +72,12 @@ def test_plot_embeddings(da): assert config['embeddings'][0]['tensorShape'] == list(da.embeddings.shape) -def test_plot_embeddings_same_path(tmpdir): - da1 = DocumentArray.empty(100) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) +def test_plot_embeddings_same_path(tmpdir, da_cls): + da1 = da_cls.empty(100) da1.embeddings = np.random.random([100, 5]) p1 = da1.plot_embeddings(start_server=False, path=tmpdir) - da2 = DocumentArray.empty(768) + da2 = da_cls.empty(768) da2.embeddings = np.random.random([768, 5]) p2 = da2.plot_embeddings(start_server=False, path=tmpdir) assert p1 == p2 @@ -76,8 +87,9 @@ def test_plot_embeddings_same_path(tmpdir): assert len(config['embeddings']) == 2 -def test_summary_homo_hetero(): - da = DocumentArray.empty(100) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) +def test_summary_homo_hetero(da_cls): + da = da_cls.empty(100) da._get_attributes() da.summary() @@ -85,7 +97,8 @@ def test_summary_homo_hetero(): da.summary() -def test_empty_get_attributes(): - da = DocumentArray.empty(10) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) +def test_empty_get_attributes(da_cls): + da = da_cls.empty(10) da[0].pop('id') print(da[:, 'id']) diff --git a/tests/unit/array/mixins/test_sample.py b/tests/unit/array/mixins/test_sample.py index ddb15dc4fe1..dabbe79155f 100644 --- a/tests/unit/array/mixins/test_sample.py +++ b/tests/unit/array/mixins/test_sample.py @@ -1,11 +1,13 @@ import pytest from docarray import DocumentArray +from docarray.array.sqlite import DocumentArraySqlite def da_and_dam(N): da = DocumentArray.empty(N) - return (da,) + dam = DocumentArraySqlite.empty(N) + return (da, dam) @pytest.mark.parametrize('da', da_and_dam(100)) diff --git a/tests/unit/array/mixins/test_text.py b/tests/unit/array/mixins/test_text.py index 3be8342b7c6..f9470f7e27f 100644 --- a/tests/unit/array/mixins/test_text.py +++ b/tests/unit/array/mixins/test_text.py @@ -2,6 +2,7 @@ import pytest from docarray import DocumentArray, Document +from docarray.array.sqlite import DocumentArraySqlite def da_and_dam(): @@ -13,7 +14,15 @@ def da_and_dam(): ] ) - return (da,) + das = DocumentArraySqlite( + [ + Document(text='hello'), + Document(text='hello world'), + Document(text='goodbye world!'), + ] + ) + + return (da, das) @pytest.mark.parametrize('min_freq', [1, 2, 3]) @@ -34,13 +43,11 @@ def test_da_vocabulary(da, min_freq): @pytest.mark.parametrize('test_docs', da_and_dam()) def test_da_text_to_tensor_non_max_len(test_docs): vocab = test_docs.get_vocabulary() - for d in test_docs: - d.convert_text_to_tensor(vocab) + test_docs.apply(lambda d: d.convert_text_to_tensor(vocab)) np.testing.assert_array_equal(test_docs[0].tensor, [2]) np.testing.assert_array_equal(test_docs[1].tensor, [2, 3]) np.testing.assert_array_equal(test_docs[2].tensor, [4, 3]) - for d in test_docs: - d.convert_tensor_to_text(vocab) + test_docs.apply(lambda d: d.convert_tensor_to_text(vocab)) assert test_docs[0].text == 'hello' assert test_docs[1].text == 'hello world' @@ -50,13 +57,13 @@ def test_da_text_to_tensor_non_max_len(test_docs): @pytest.mark.parametrize('test_docs', da_and_dam()) def test_da_text_to_tensor_max_len_3(test_docs): vocab = test_docs.get_vocabulary() - for d in test_docs: - d.convert_text_to_tensor(vocab, max_length=3) + test_docs.apply(lambda d: d.convert_text_to_tensor(vocab, max_length=3)) + np.testing.assert_array_equal(test_docs[0].tensor, [0, 0, 2]) np.testing.assert_array_equal(test_docs[1].tensor, [0, 2, 3]) np.testing.assert_array_equal(test_docs[2].tensor, [0, 4, 3]) - for d in test_docs: - d.convert_tensor_to_text(vocab) + + test_docs.apply(lambda d: d.convert_tensor_to_text(vocab)) assert test_docs[0].text == 'hello' assert test_docs[1].text == 'hello world' @@ -66,13 +73,13 @@ def test_da_text_to_tensor_max_len_3(test_docs): @pytest.mark.parametrize('test_docs', da_and_dam()) def test_da_text_to_tensor_max_len_1(test_docs): vocab = test_docs.get_vocabulary() - for d in test_docs: - d.convert_text_to_tensor(vocab, max_length=1) + test_docs.apply(lambda d: d.convert_text_to_tensor(vocab, max_length=1)) + np.testing.assert_array_equal(test_docs[0].tensor, [2]) np.testing.assert_array_equal(test_docs[1].tensor, [3]) np.testing.assert_array_equal(test_docs[2].tensor, [3]) - for d in test_docs: - d.convert_tensor_to_text(vocab) + + test_docs.apply(lambda d: d.convert_tensor_to_text(vocab)) assert test_docs[0].text == 'hello' assert test_docs[1].text == 'world' @@ -87,12 +94,10 @@ def test_convert_text_tensor_random_text(da): vocab = da.get_vocabulary() # encoding - for d in da: - d.convert_text_to_tensor(vocab, max_length=10) + da.apply(lambda d: d.convert_text_to_tensor(vocab, max_length=10)) # decoding - for d in da: - d.convert_tensor_to_text(vocab) + da.apply(lambda d: d.convert_tensor_to_text(vocab)) assert texts assert da.texts == texts diff --git a/tests/unit/array/test_construct.py b/tests/unit/array/test_construct.py index 302a6359c2c..058c474ef3d 100644 --- a/tests/unit/array/test_construct.py +++ b/tests/unit/array/test_construct.py @@ -1,9 +1,10 @@ import pytest from docarray import Document, DocumentArray +from docarray.array.sqlite import DocumentArraySqlite -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) def test_construct_docarray(da_cls): da = da_cls() assert len(da) == 0 @@ -24,39 +25,48 @@ def test_construct_docarray(da_cls): assert len(da1) == 10 -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize('is_copy', [True, False]) def test_docarray_copy_singleton(da_cls, is_copy): d = Document() da = da_cls(d, copy=is_copy) d.id = 'hello' - if is_copy: - assert da[0].id != 'hello' + if da_cls == DocumentArray: + if is_copy: + assert da[0].id != 'hello' + else: + assert da[0].id == 'hello' else: - assert da[0].id == 'hello' + assert da[0].id != 'hello' -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize('is_copy', [True, False]) def test_docarray_copy_da(da_cls, is_copy): d1 = Document() d2 = Document() da = da_cls([d1, d2], copy=is_copy) d1.id = 'hello' - if is_copy: - assert da[0].id != 'hello' + if da_cls == DocumentArray: + if is_copy: + assert da[0].id != 'hello' + else: + assert da[0].id == 'hello' else: - assert da[0].id == 'hello' + assert da[0] != 'hello' -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize('is_copy', [True, False]) def test_docarray_copy_list(da_cls, is_copy): d1 = Document() d2 = Document() da = da_cls([d1, d2], copy=is_copy) d1.id = 'hello' - if is_copy: - assert da[0].id != 'hello' + if da_cls == DocumentArray: + if is_copy: + assert da[0].id != 'hello' + else: + assert da[0].id == 'hello' else: - assert da[0].id == 'hello' + assert da[0] != 'hello' diff --git a/tests/unit/array/test_ravel_unravel.py b/tests/unit/array/test_ravel_unravel.py index 3563f7caf2c..f6b70613bf8 100644 --- a/tests/unit/array/test_ravel_unravel.py +++ b/tests/unit/array/test_ravel_unravel.py @@ -8,6 +8,7 @@ from scipy.sparse import csr_matrix, coo_matrix, bsr_matrix, csc_matrix from docarray import DocumentArray, Document +from docarray.array.sqlite import DocumentArraySqlite def get_ndarrays_for_ravel(): @@ -29,8 +30,14 @@ def get_ndarrays_for_ravel(): @pytest.mark.parametrize('ndarray_val, is_sparse', get_ndarrays_for_ravel()) @pytest.mark.parametrize('attr', ['embeddings', 'tensors']) -def test_ravel_embeddings_tensors(ndarray_val, attr, is_sparse): - da = DocumentArray.empty(10) +@pytest.mark.parametrize( + 'da_cls', + [ + DocumentArray, + ], +) +def test_ravel_embeddings_tensors(ndarray_val, attr, is_sparse, da_cls): + da = da_cls.empty(10) setattr(da, attr, ndarray_val) ndav = getattr(da, attr) diff --git a/tests/unit/array/test_sequence.py b/tests/unit/array/test_sequence.py index 6ba6679f936..4055eb38319 100644 --- a/tests/unit/array/test_sequence.py +++ b/tests/unit/array/test_sequence.py @@ -1,9 +1,10 @@ import pytest from docarray import Document, DocumentArray +from docarray.array.sqlite import DocumentArraySqlite -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) def test_insert(da_cls): da = da_cls() assert not len(da) @@ -14,7 +15,7 @@ def test_insert(da_cls): assert da[1].text == 'hello' -@pytest.mark.parametrize('da_cls', [DocumentArray]) +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) def test_append_extend(da_cls): da = da_cls() da.append(Document()) diff --git a/tests/unit/test_pydantic.py b/tests/unit/test_pydantic.py index e1eb9ae3ce6..f04c1f2485b 100644 --- a/tests/unit/test_pydantic.py +++ b/tests/unit/test_pydantic.py @@ -11,10 +11,12 @@ from docarray import DocumentArray, Document from docarray.document.pydantic_model import PydanticDocument, PydanticDocumentArray from docarray.score import NamedScore +from docarray.array.sqlite import DocumentArraySqlite -def test_pydantic_doc_da(pytestconfig): - da = DocumentArray.from_files( +@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArraySqlite]) +def test_pydantic_doc_da(pytestconfig, da_cls): + da = da_cls.from_files( [ f'{pytestconfig.rootdir}/**/*.png', f'{pytestconfig.rootdir}/**/*.jpg', From b208bc4dcf29346431c82cbfddf15ed5897392a5 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 15:25:38 +0100 Subject: [PATCH 02/41] fix: adapt embedding setters for storage backends --- docarray/array/mixins/content.py | 2 +- docarray/math/ndarray.py | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/docarray/array/mixins/content.py b/docarray/array/mixins/content.py index 73be9cf53c8..595215fbf8a 100644 --- a/docarray/array/mixins/content.py +++ b/docarray/array/mixins/content.py @@ -39,7 +39,7 @@ def embeddings(self, value: 'ArrayType'): if value is None: for d in self: - d.embedding = None + self[d.id, 'embedding'] = None else: emb_shape0 = _get_len(value) self._check_length(emb_shape0) diff --git a/docarray/math/ndarray.py b/docarray/math/ndarray.py index a4ef24a9f03..af4b7887272 100644 --- a/docarray/math/ndarray.py +++ b/docarray/math/ndarray.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: from ..types import ArrayType - from .. import Document + from .. import Document, DocumentArray def unravel(docs: Sequence['Document'], field: str) -> Optional['ArrayType']: @@ -48,15 +48,13 @@ def unravel(docs: Sequence['Document'], field: str) -> Optional['ArrayType']: return cls_type(scipy.sparse.vstack(all_fields)) -def ravel(value: 'ArrayType', docs: Sequence['Document'], field: str) -> None: +def ravel(value: 'ArrayType', docs: 'DocumentArray', field: str) -> None: """Ravel :attr:`value` into ``doc.field`` of each documents :param docs: the docs to set :param field: the field of the doc to set :param value: the value to be set on ``doc.field`` """ - from .. import DocumentArray - use_get_row = False if hasattr(value, 'getformat'): # for scipy only @@ -70,19 +68,17 @@ def ravel(value: 'ArrayType', docs: Sequence['Document'], field: str) -> None: if use_get_row: emb_shape0 = value.shape[0] - for d, j in zip(docs, range(emb_shape0)): + for i, (d, j) in enumerate(zip(docs, range(emb_shape0))): row = getattr(value.getrow(j), f'to{sp_format}')() - setattr(d, field, row) + docs[d.id, field] = row elif isinstance(value, (list, tuple)): for d, j in zip(docs, value): - setattr(d, field, j) + docs[d.id, field] = j else: emb_shape0 = value.shape[0] for i, (d, j) in enumerate(zip(docs, range(emb_shape0))): - setattr(d, field, value[j, ...]) - if isinstance(docs, DocumentArray): - docs._set_doc_by_id(d.id, d) + docs[d.id, field] = value[j, ...] def get_array_type(array: 'ArrayType') -> Tuple[str, bool]: From 863d864f6138365a557fc2b09faad73233a24f0b Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 15:30:08 +0100 Subject: [PATCH 03/41] test: cover embeddings setter --- tests/unit/array/mixins/test_content.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/unit/array/mixins/test_content.py b/tests/unit/array/mixins/test_content.py index 45686809b2e..21f4825e05f 100644 --- a/tests/unit/array/mixins/test_content.py +++ b/tests/unit/array/mixins/test_content.py @@ -68,3 +68,12 @@ def test_content_empty(da_len, cls): assert da.texts == ['hello'] * da_len assert not da.tensors assert not da.blobs + + +@pytest.mark.parametrize('da_len', [0, 1, 2]) +@pytest.mark.parametrize('cls', [DocumentArray, DocumentArraySqlite]) +def test_embeddings_setter(da_len, cls): + da = cls.empty(da_len) + da.embeddings = np.random.rand(da_len, 5) + for doc in da: + assert doc.embedding.shape == (5,) From c8c0c9964d96d7a85c2a3be6a1af295a37e5b464 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 15:33:34 +0100 Subject: [PATCH 04/41] fix: texts setter --- docarray/array/mixins/content.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/mixins/content.py b/docarray/array/mixins/content.py index 595215fbf8a..9ad91f6b161 100644 --- a/docarray/array/mixins/content.py +++ b/docarray/array/mixins/content.py @@ -97,12 +97,12 @@ def texts(self, value: Sequence[str]): """ if value is None: for d in self: - d.text = None + self[d.id, 'text'] = None else: self._check_length(len(value)) for doc, text in zip(self, value): - doc.text = text + self[doc.id, 'text'] = text @property def blobs(self) -> Optional[List[bytes]]: From 26b5edc6852a3e23cee5f61ce9d3a886a290bd6e Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 15:36:09 +0100 Subject: [PATCH 05/41] fix: tensors and blob setters --- docarray/array/mixins/content.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docarray/array/mixins/content.py b/docarray/array/mixins/content.py index 9ad91f6b161..9a3b3fd99d4 100644 --- a/docarray/array/mixins/content.py +++ b/docarray/array/mixins/content.py @@ -71,7 +71,7 @@ def tensors(self, value: 'ArrayType'): if value is None: for d in self: - d.tensor = None + self[d.id, 'tensor'] = None else: tensors_shape0 = _get_len(value) self._check_length(tensors_shape0) @@ -124,12 +124,12 @@ def blobs(self, value: List[bytes]): if value is None: for d in self: - d.blob = None + self[d.id, 'blob'] = None else: self._check_length(len(value)) for doc, blob in zip(self, value): - doc.blob = blob + self[doc.id, 'blob'] = blob @property def contents(self) -> Optional[Union[Sequence['DocumentContentType'], 'ArrayType']]: From 86839012a351b050e8f4a5b782819c29154a1749 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 15:38:34 +0100 Subject: [PATCH 06/41] fix: linting --- docarray/array/mixins/content.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/mixins/content.py b/docarray/array/mixins/content.py index 9a3b3fd99d4..f072b3759ba 100644 --- a/docarray/array/mixins/content.py +++ b/docarray/array/mixins/content.py @@ -39,7 +39,7 @@ def embeddings(self, value: 'ArrayType'): if value is None: for d in self: - self[d.id, 'embedding'] = None + self[d.id, 'embedding'] = None else: emb_shape0 = _get_len(value) self._check_length(emb_shape0) From a3ca29f5bf91d203449e223f03fcf6283f1a48e9 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 17:00:49 +0100 Subject: [PATCH 07/41] fix: embed for sqlite backend --- docarray/array/mixins/embed.py | 8 ++++++-- tests/unit/array/mixins/test_embed.py | 24 ++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/docarray/array/mixins/embed.py b/docarray/array/mixins/embed.py index 1eb4594bc92..f74dc60d734 100644 --- a/docarray/array/mixins/embed.py +++ b/docarray/array/mixins/embed.py @@ -58,11 +58,15 @@ def _set_embeddings_torch( embed_model = embed_model.to(device) is_training_before = embed_model.training embed_model.eval() + length = len(self) with torch.inference_mode(): - for b in self.batch(batch_size): + for i, b in enumerate(self.batch(batch_size)): batch_inputs = torch.tensor(b.tensors, device=device) r = embed_model(batch_inputs).cpu().detach() - b.embeddings = r.numpy() if to_numpy else r + if to_numpy: + r = r.numpy() + self[i*batch_size: min((i + 1)*batch_size, length), 'embedding'] = r + if is_training_before: embed_model.train() diff --git a/tests/unit/array/mixins/test_embed.py b/tests/unit/array/mixins/test_embed.py index f13976be4e3..b54bbe9988f 100644 --- a/tests/unit/array/mixins/test_embed.py +++ b/tests/unit/array/mixins/test_embed.py @@ -1,13 +1,11 @@ import os import numpy as np -import onnxruntime -import paddle import pytest -import tensorflow as tf import torch from docarray import DocumentArray +from docarray.array.memory import DocumentArrayInMemory from docarray.array.sqlite import DocumentArraySqlite random_embed_models = { @@ -59,14 +57,16 @@ def test_embedding_on_random_network(framework, da, N, batch_size, to_numpy): # reset docs.embeddings = np.random.random([N, 128]).astype(np.float32) - # try it again, it should yield the same result - docs.embed(embed_model, batch_size=batch_size, to_numpy=to_numpy) - np.testing.assert_array_almost_equal(docs.embeddings, embed1) + # docs[a: b].embed is only supported for DocumentArrayInMemory + if isinstance(da, DocumentArrayInMemory): + # try it again, it should yield the same result + docs.embed(embed_model, batch_size=batch_size, to_numpy=to_numpy) + np.testing.assert_array_almost_equal(docs.embeddings, embed1) - # reset - docs.embeddings = np.random.random([N, 128]).astype(np.float32) + # reset + docs.embeddings = np.random.random([N, 128]).astype(np.float32) - # now do this one by one - docs[: int(N / 2)].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy) - docs[-int(N / 2) :].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy) - np.testing.assert_array_almost_equal(docs.embeddings, embed1) + # now do this one by one + docs[: int(N / 2)].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy) + docs[-int(N / 2) :].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy) + np.testing.assert_array_almost_equal(docs.embeddings, embed1) From aef897985e957dadd450a9afaed71ef707ec9f82 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 17:26:20 +0100 Subject: [PATCH 08/41] refactor: delegate to __setitem__ in content setters --- docarray/array/mixins/content.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/docarray/array/mixins/content.py b/docarray/array/mixins/content.py index f072b3759ba..952798e011a 100644 --- a/docarray/array/mixins/content.py +++ b/docarray/array/mixins/content.py @@ -38,8 +38,7 @@ def embeddings(self, value: 'ArrayType'): """ if value is None: - for d in self: - self[d.id, 'embedding'] = None + self[:, 'embedding'] = [None] * len(self) else: emb_shape0 = _get_len(value) self._check_length(emb_shape0) @@ -70,8 +69,7 @@ def tensors(self, value: 'ArrayType'): """ if value is None: - for d in self: - self[d.id, 'tensor'] = None + self[:, 'tensor'] = [None] * len(self) else: tensors_shape0 = _get_len(value) self._check_length(tensors_shape0) @@ -96,13 +94,11 @@ def texts(self, value: Sequence[str]): number of Documents """ if value is None: - for d in self: - self[d.id, 'text'] = None + self[:, 'text'] = [None] * len(self) else: self._check_length(len(value)) - for doc, text in zip(self, value): - self[doc.id, 'text'] = text + self[:, 'text'] = value @property def blobs(self) -> Optional[List[bytes]]: @@ -123,8 +119,7 @@ def blobs(self, value: List[bytes]): """ if value is None: - for d in self: - self[d.id, 'blob'] = None + self[:, 'blob'] = [None] * len(self) else: self._check_length(len(value)) From 8bae8d5fbaa600ab726cda541576fd24a86e83b5 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 17:26:36 +0100 Subject: [PATCH 09/41] fix: linting --- docarray/array/mixins/embed.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docarray/array/mixins/embed.py b/docarray/array/mixins/embed.py index f74dc60d734..9ece337c628 100644 --- a/docarray/array/mixins/embed.py +++ b/docarray/array/mixins/embed.py @@ -65,7 +65,9 @@ def _set_embeddings_torch( r = embed_model(batch_inputs).cpu().detach() if to_numpy: r = r.numpy() - self[i*batch_size: min((i + 1)*batch_size, length), 'embedding'] = r + self[ + i * batch_size : min((i + 1) * batch_size, length), 'embedding' + ] = r if is_training_before: embed_model.train() From be13b62204ddcd2328394d07542e0d55f681aaea Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 17:36:05 +0100 Subject: [PATCH 10/41] test: cover set attributes with size 1 --- tests/unit/array/test_advance_indexing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 741d3da3ab7..241cefca559 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -193,14 +193,15 @@ def test_path_syntax_indexing(storage): assert len(da['@r:1cc,m']) == 1 * 5 * 3 + 3 * 7 +@pytest.mark.parametrize('size', [1, 5]) @pytest.mark.parametrize('storage', ['memory', 'sqlite']) -def test_attribute_indexing(storage): +def test_attribute_indexing(storage, size): da = DocumentArray(storage=storage) - da.extend(DocumentArray.empty(10)) + da.extend(DocumentArray.empty(size)) for v in da[:, 'id']: assert v - da[:, 'mime_type'] = [f'type {j}' for j in range(10)] + da[:, 'mime_type'] = [f'type {j}' for j in range(size)] for v in da[:, 'mime_type']: assert v del da[:, 'mime_type'] @@ -208,8 +209,8 @@ def test_attribute_indexing(storage): assert not v da[:, ['text', 'mime_type']] = [ - [f'hello {j}' for j in range(10)], - [f'type {j}' for j in range(10)], + [f'hello {j}' for j in range(size)], + [f'type {j}' for j in range(size)], ] da.summary() From bb97b2ab98c177cbf19e3575a8592ab33477c297 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Wed, 26 Jan 2022 17:36:22 +0100 Subject: [PATCH 11/41] fix: fix set attributes with size 1 --- docarray/array/mixins/setitem.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 245631a1ca3..03e0fedcf71 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -121,11 +121,8 @@ def __setitem__( for _d in _docs: self._set_doc_by_id(_d.id, _d) else: - if len(_docs) == 1: - self._set_doc_attr_by_id(_docs[0].id, _a, _v) - else: - for _d, _vv in zip(_docs, _v): - self._set_doc_attr_by_id(_d.id, _a, _vv) + for _d, _vv in zip(_docs, _v): + self._set_doc_attr_by_id(_d.id, _a, _vv) elif isinstance(index[0], bool): if len(index) != len(self): raise IndexError( From caa2b9d1a541ddaf3e9b27e04a0f0c8bd7db2de5 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Thu, 27 Jan 2022 09:17:21 +0100 Subject: [PATCH 12/41] feat: add batching by id --- docarray/array/mixins/embed.py | 30 ++++++-------- docarray/array/mixins/group.py | 57 +++++++++++++++++++++++++++ tests/unit/array/mixins/test_embed.py | 7 +++- 3 files changed, 76 insertions(+), 18 deletions(-) diff --git a/docarray/array/mixins/embed.py b/docarray/array/mixins/embed.py index 9ece337c628..a27b9d99667 100644 --- a/docarray/array/mixins/embed.py +++ b/docarray/array/mixins/embed.py @@ -42,9 +42,9 @@ def _set_embeddings_keras( device = tf.device('/GPU:0') if device == 'cuda' else tf.device('/CPU:0') with device: - for b in self.batch(batch_size): - r = embed_model(b.tensors, training=False) - b.embeddings = r.numpy() if to_numpy else r + for b_ids in self.batch_ids(batch_size): + r = embed_model(self[b_ids,'tensor'], training=False) + self[b_ids,'embedding'] = r.numpy() if to_numpy else r def _set_embeddings_torch( self: 'T', @@ -60,14 +60,10 @@ def _set_embeddings_torch( embed_model.eval() length = len(self) with torch.inference_mode(): - for i, b in enumerate(self.batch(batch_size)): - batch_inputs = torch.tensor(b.tensors, device=device) + for b_ids in self.batch_ids(batch_size): + batch_inputs = torch.tensor(self[b_ids,'tensor'], device=device) r = embed_model(batch_inputs).cpu().detach() - if to_numpy: - r = r.numpy() - self[ - i * batch_size : min((i + 1) * batch_size, length), 'embedding' - ] = r + self[b_ids, 'embedding'] = r.numpy() if to_numpy else r if is_training_before: embed_model.train() @@ -84,10 +80,11 @@ def _set_embeddings_paddle( is_training_before = embed_model.training embed_model.to(device=device) embed_model.eval() - for b in self.batch(batch_size): - batch_inputs = paddle.to_tensor(b.tensors, place=device) + for b_ids in self.batch_ids(batch_size): + batch_inputs = paddle.to_tensor(self[b_ids,'tensor'], place=device) r = embed_model(batch_inputs) - b.embeddings = r.numpy() if to_numpy else r + self[b_ids, 'embedding'] = r.numpy() if to_numpy else r + if is_training_before: embed_model.train() @@ -109,12 +106,11 @@ def _set_embeddings_onnx( f'Your installed `onnxruntime` supports `{support_device}`, but you give {device}' ) - for b in self.batch(batch_size): - b.embeddings = embed_model.run( - None, {embed_model.get_inputs()[0].name: b.tensors} + for b_ids in self.batch_ids(batch_size): + self[b_ids, 'embedding'] = embed_model.run( + None, {embed_model.get_inputs()[0].name: self[b_ids,'tensor']} )[0] - def get_framework(dnn_model) -> str: """Return the framework that powers a DNN model. diff --git a/docarray/array/mixins/group.py b/docarray/array/mixins/group.py index 17cb179ab33..74bdaa3b7ef 100644 --- a/docarray/array/mixins/group.py +++ b/docarray/array/mixins/group.py @@ -64,3 +64,60 @@ def batch( for i in range(n_batches): yield self[ix[i * batch_size : (i + 1) * batch_size]] + + + def batch_indices( + self, + batch_size: int, + shuffle: bool = False, + ) -> Generator[list, None, None]: + """ + Creates a `Generator` that yields `list` of size `batch_size` until `docs` is fully traversed along + the `traversal_path`. Note, that the last batch might be smaller than `batch_size`. + + :param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32) + :param shuffle: If set, shuffle the Documents before dividing into minibatches. + :yield: a Generator of `np.ndarray`, each in the length of `batch_size` + """ + + if not (isinstance(batch_size, int) and batch_size > 0): + raise ValueError('`batch_size` should be a positive integer') + + N = len(self) + ix = list(range(N)) + n_batches = int(np.ceil(N / batch_size)) + + if shuffle: + random.shuffle(ix) + + for i in range(n_batches): + yield ix[i * batch_size : (i + 1) * batch_size] + + + def batch_ids( + self, + batch_size: int, + shuffle: bool = False, + ) -> Generator[list, None, None]: + """ + Creates a `Generator` that yields `lists of ids` of size `batch_size` until `docs` is fully traversed along + the `traversal_path`. Note, that the last batch might be smaller than `batch_size`. + + :param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32) + :param shuffle: If set, shuffle the Documents before dividing into minibatches. + :yield: a Generator of `np.ndarray`, each in the length of `batch_size` + """ + + if not (isinstance(batch_size, int) and batch_size > 0): + raise ValueError('`batch_size` should be a positive integer') + + N = len(self) + ix = self[:,'id'] + n_batches = int(np.ceil(N / batch_size)) + + if shuffle: + random.shuffle(ix) + + for i in range(n_batches): + yield ix[i * batch_size : (i + 1) * batch_size] + diff --git a/tests/unit/array/mixins/test_embed.py b/tests/unit/array/mixins/test_embed.py index b54bbe9988f..f4109360e07 100644 --- a/tests/unit/array/mixins/test_embed.py +++ b/tests/unit/array/mixins/test_embed.py @@ -1,8 +1,12 @@ import os import numpy as np +import tensorflow as tf import pytest import torch +import paddle +import onnx +import onnxruntime from docarray import DocumentArray from docarray.array.memory import DocumentArrayInMemory @@ -42,7 +46,7 @@ @pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize('N', [2, 1000]) @pytest.mark.parametrize('batch_size', [1, 256]) -@pytest.mark.parametrize('to_numpy', [True, False]) +@pytest.mark.parametrize('to_numpy', [True]) def test_embedding_on_random_network(framework, da, N, batch_size, to_numpy): docs = da.empty(N) docs.tensors = np.random.random([N, 128]).astype(np.float32) @@ -52,6 +56,7 @@ def test_embedding_on_random_network(framework, da, N, batch_size, to_numpy): r = docs.embeddings if hasattr(r, 'numpy'): r = r.numpy() + embed1 = r.copy() # reset From bf0e769219a20ddd2e42871a03b781ea9cede6c0 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Thu, 27 Jan 2022 09:18:02 +0100 Subject: [PATCH 13/41] feat: add batching by id --- docarray/array/mixins/embed.py | 11 ++++++----- docarray/array/mixins/group.py | 5 +---- tests/unit/array/mixins/test_embed.py | 3 ++- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/docarray/array/mixins/embed.py b/docarray/array/mixins/embed.py index a27b9d99667..c5f353f9db1 100644 --- a/docarray/array/mixins/embed.py +++ b/docarray/array/mixins/embed.py @@ -43,8 +43,8 @@ def _set_embeddings_keras( device = tf.device('/GPU:0') if device == 'cuda' else tf.device('/CPU:0') with device: for b_ids in self.batch_ids(batch_size): - r = embed_model(self[b_ids,'tensor'], training=False) - self[b_ids,'embedding'] = r.numpy() if to_numpy else r + r = embed_model(self[b_ids, 'tensor'], training=False) + self[b_ids, 'embedding'] = r.numpy() if to_numpy else r def _set_embeddings_torch( self: 'T', @@ -61,7 +61,7 @@ def _set_embeddings_torch( length = len(self) with torch.inference_mode(): for b_ids in self.batch_ids(batch_size): - batch_inputs = torch.tensor(self[b_ids,'tensor'], device=device) + batch_inputs = torch.tensor(self[b_ids, 'tensor'], device=device) r = embed_model(batch_inputs).cpu().detach() self[b_ids, 'embedding'] = r.numpy() if to_numpy else r @@ -81,7 +81,7 @@ def _set_embeddings_paddle( embed_model.to(device=device) embed_model.eval() for b_ids in self.batch_ids(batch_size): - batch_inputs = paddle.to_tensor(self[b_ids,'tensor'], place=device) + batch_inputs = paddle.to_tensor(self[b_ids, 'tensor'], place=device) r = embed_model(batch_inputs) self[b_ids, 'embedding'] = r.numpy() if to_numpy else r @@ -108,9 +108,10 @@ def _set_embeddings_onnx( for b_ids in self.batch_ids(batch_size): self[b_ids, 'embedding'] = embed_model.run( - None, {embed_model.get_inputs()[0].name: self[b_ids,'tensor']} + None, {embed_model.get_inputs()[0].name: self[b_ids, 'tensor']} )[0] + def get_framework(dnn_model) -> str: """Return the framework that powers a DNN model. diff --git a/docarray/array/mixins/group.py b/docarray/array/mixins/group.py index 74bdaa3b7ef..fd98231c107 100644 --- a/docarray/array/mixins/group.py +++ b/docarray/array/mixins/group.py @@ -65,7 +65,6 @@ def batch( for i in range(n_batches): yield self[ix[i * batch_size : (i + 1) * batch_size]] - def batch_indices( self, batch_size: int, @@ -93,7 +92,6 @@ def batch_indices( for i in range(n_batches): yield ix[i * batch_size : (i + 1) * batch_size] - def batch_ids( self, batch_size: int, @@ -112,7 +110,7 @@ def batch_ids( raise ValueError('`batch_size` should be a positive integer') N = len(self) - ix = self[:,'id'] + ix = self[:, 'id'] n_batches = int(np.ceil(N / batch_size)) if shuffle: @@ -120,4 +118,3 @@ def batch_ids( for i in range(n_batches): yield ix[i * batch_size : (i + 1) * batch_size] - diff --git a/tests/unit/array/mixins/test_embed.py b/tests/unit/array/mixins/test_embed.py index f4109360e07..5374b79ca46 100644 --- a/tests/unit/array/mixins/test_embed.py +++ b/tests/unit/array/mixins/test_embed.py @@ -43,7 +43,8 @@ @pytest.mark.parametrize('framework', ['onnx', 'keras', 'pytorch', 'paddle']) -@pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) +# @pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) +@pytest.mark.parametrize('da', [DocumentArraySqlite]) @pytest.mark.parametrize('N', [2, 1000]) @pytest.mark.parametrize('batch_size', [1, 256]) @pytest.mark.parametrize('to_numpy', [True]) From bccd9aba2a249be11af8a4b5e9e75b32369687a4 Mon Sep 17 00:00:00 2001 From: David Buchaca Prats Date: Thu, 27 Jan 2022 11:41:59 +0100 Subject: [PATCH 14/41] fix: change protocol to protobuf in sqlite --- docarray/array/mixins/embed.py | 1 - docarray/array/storage/sqlite/backend.py | 2 +- tests/unit/array/mixins/test_embed.py | 5 ++--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docarray/array/mixins/embed.py b/docarray/array/mixins/embed.py index c5f353f9db1..e91e2c2f7bd 100644 --- a/docarray/array/mixins/embed.py +++ b/docarray/array/mixins/embed.py @@ -58,7 +58,6 @@ def _set_embeddings_torch( embed_model = embed_model.to(device) is_training_before = embed_model.training embed_model.eval() - length = len(self) with torch.inference_mode(): for b_ids in self.batch_ids(batch_size): batch_inputs = torch.tensor(self[b_ids, 'tensor'], device=device) diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 690277a450d..2de14f16d4b 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -34,7 +34,7 @@ def _sanitize_table_name(table_name: str) -> str: class SqliteConfig: connection: Optional[Union[str, 'sqlite3.Connection']] = None table_name: Optional[str] = None - serialize_config: Dict = field(default_factory=dict) + serialize_config: Dict = field(default_factory=lambda: {'protocol': 'protobuf'}) conn_config: Dict = field(default_factory=dict) journal_mode: str = 'DELETE' synchronous: str = 'OFF' diff --git a/tests/unit/array/mixins/test_embed.py b/tests/unit/array/mixins/test_embed.py index 5374b79ca46..442cdb1ea8c 100644 --- a/tests/unit/array/mixins/test_embed.py +++ b/tests/unit/array/mixins/test_embed.py @@ -43,11 +43,10 @@ @pytest.mark.parametrize('framework', ['onnx', 'keras', 'pytorch', 'paddle']) -# @pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) -@pytest.mark.parametrize('da', [DocumentArraySqlite]) +@pytest.mark.parametrize('da', [DocumentArray, DocumentArraySqlite]) @pytest.mark.parametrize('N', [2, 1000]) @pytest.mark.parametrize('batch_size', [1, 256]) -@pytest.mark.parametrize('to_numpy', [True]) +@pytest.mark.parametrize('to_numpy', [True, False]) def test_embedding_on_random_network(framework, da, N, batch_size, to_numpy): docs = da.empty(N) docs.tensors = np.random.random([N, 128]).astype(np.float32) From 99da09f2962aea552b7a91e3315b466f79082c63 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 13:12:30 +0100 Subject: [PATCH 15/41] fix: fix setter by sequences --- docarray/array/mixins/setitem.py | 8 ++++++-- docarray/array/storage/base/getsetdel.py | 2 +- docarray/array/storage/memory/getsetdel.py | 7 ++++++- tests/unit/array/mixins/test_getset.py | 22 ++++++++++++++-------- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 03e0fedcf71..cf50648104e 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -121,8 +121,12 @@ def __setitem__( for _d in _docs: self._set_doc_by_id(_d.id, _d) else: - for _d, _vv in zip(_docs, _v): - self._set_doc_attr_by_id(_d.id, _a, _vv) + if not isinstance(_v, (list, tuple)): + for _d in _docs: + self._set_doc_attr_by_id(_d.id, _a, _v) + else: + for _d, _vv in zip(_docs, _v): + self._set_doc_attr_by_id(_d.id, _a, _vv) elif isinstance(index[0], bool): if len(index) != len(self): raise IndexError( diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 72ee83e200c..e32f026d314 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -144,7 +144,7 @@ def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): d = self._get_doc_by_id(_id) if hasattr(d, attr): setattr(d, attr, value) - self._set_doc_by_id(d.id, d) + self._set_doc_by_id(_id, d) def _find_root_doc(self, d: Document): """Find `d`'s root Document in an exhaustive manner""" diff --git a/docarray/array/storage/memory/getsetdel.py b/docarray/array/storage/memory/getsetdel.py index 8ea62471cce..3f6aed54813 100644 --- a/docarray/array/storage/memory/getsetdel.py +++ b/docarray/array/storage/memory/getsetdel.py @@ -56,7 +56,12 @@ def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): setattr(self._data[offset], attr, value) def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): - setattr(self._data[self._id2offset[_id]], attr, value) + if attr == 'id': + old_idx = self._id2offset.pop(_id) + setattr(self._data[old_idx], attr, value) + self._id2offset[value] = old_idx + else: + setattr(self._data[self._id2offset[_id]], attr, value) def _get_doc_by_offset(self, offset: int) -> 'Document': return self._data[offset] diff --git a/tests/unit/array/mixins/test_getset.py b/tests/unit/array/mixins/test_getset.py index d455bb352d2..dba6a39c2fa 100644 --- a/tests/unit/array/mixins/test_getset.py +++ b/tests/unit/array/mixins/test_getset.py @@ -96,18 +96,24 @@ def test_texts_getter_da(da): @pytest.mark.parametrize('da', da_and_dam()) def test_setter_by_sequences_in_selected_docs_da(da): + da[[0, 1, 2], 'text'] = 'test' + assert da[[0, 1, 2], 'text'] == ['test', 'test', 'test'] - da[[0], 'text'] = 'jina' - assert ['jina'] == da[[0], 'text'] + da[[3, 4], 'text'] = ['test', 'test'] + assert da[[3, 4], 'text'] == ['test', 'test'] - da[[0, 1], 'text'] = ['jina', 'jana'] - assert ['jina', 'jana'] == da[[0, 1], 'text'] + da[[5], 'text'] = 'test' + assert da[[5], 'text'] == ['test'] - da[[0], 'id'] = '12' - assert ['12'] == da[[0], 'id'] + da[[6], 'text'] = ['test'] + assert da[[6], 'text'] == ['test'] - da[[0, 1], 'id'] = ['12', '34'] - assert ['12', '34'] == da[[0, 1], 'id'] + # test that ID not present in da works + da[[0], 'id'] = '999' + assert ['999'] == da[[0], 'id'] + + da[[0, 1], 'id'] = ['101', '102'] + assert ['101', '102'] == da[[0, 1], 'id'] @pytest.mark.parametrize('da', da_and_dam()) From be82921916995bf0d03d87e9ace5e426d86556fd Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 13:52:38 +0100 Subject: [PATCH 16/41] test: text type should be string not integer --- tests/unit/array/test_advance_indexing.py | 18 +++++++++--------- tests/unit/array/test_base_getsetdel.py | 12 ++++++------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 241cefca559..e245be34ea7 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -6,7 +6,7 @@ @pytest.fixture def docs(): - yield (Document(text=j) for j in range(100)) + yield (Document(text=str(j)) for j in range(100)) @pytest.fixture @@ -18,14 +18,14 @@ def indices(): def test_getter_int_str(docs, storage): docs = DocumentArray(docs, storage=storage) # getter - assert docs[99].text == 99 - assert docs[np.int(99)].text == 99 - assert docs[-1].text == 99 - assert docs[0].text == 0 + assert docs[99].text == '99' + assert docs[np.int(99)].text == '99' + assert docs[-1].text == '99' + assert docs[0].text == '0' # string index - assert docs[docs[0].id].text == 0 - assert docs[docs[99].id].text == 99 - assert docs[docs[-1].id].text == 99 + assert docs[docs[0].id].text == '0' + assert docs[docs[99].id].text == '99' + assert docs[docs[-1].id].text == '99' with pytest.raises(IndexError): docs[100] @@ -114,7 +114,7 @@ def test_sequence_bool_index(docs, storage): # got replaced assert d.text.startswith('repl') else: - assert isinstance(d.text, int) + assert d.text == str(idx) # del del docs[mask] diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py index a772c8e44e9..16d93ce973a 100644 --- a/tests/unit/array/test_base_getsetdel.py +++ b/tests/unit/array/test_base_getsetdel.py @@ -48,7 +48,7 @@ def __new__(cls, *args, **kwargs): @pytest.fixture(scope='function') def docs(): - return DocumentArrayDummy([Document(id=str(j), text=j) for j in range(100)]) + return DocumentArrayDummy([Document(id=str(j), text=str(j)) for j in range(100)]) def test_index_by_int_str(docs): @@ -71,13 +71,13 @@ def test_index_by_int_str(docs): def test_getter_int_str(docs): # getter - assert docs[99].text == 99 - assert docs[-1].text == 99 - assert docs[0].text == 0 + assert docs[99].text == '99' + assert docs[-1].text == '99' + assert docs[0].text == '0' # string index - assert docs['0'].text == 0 - assert docs['99'].text == 99 + assert docs['0'].text == '0' + assert docs['99'].text == '99' with pytest.raises(IndexError): docs[100] From ceebf5626141049d498693d482e3147061cc0bf9 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 16:07:08 +0100 Subject: [PATCH 17/41] test: cover ellipsis getter --- tests/unit/array/mixins/test_getset.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/array/mixins/test_getset.py b/tests/unit/array/mixins/test_getset.py index dba6a39c2fa..c4430f2eb3e 100644 --- a/tests/unit/array/mixins/test_getset.py +++ b/tests/unit/array/mixins/test_getset.py @@ -20,6 +20,17 @@ def da_and_dam(): return (da, das) +def nested_da_and_dam(): + docs = [ + Document(id='r1', chunks=[Document(id='c1'), Document(id='c2')]), + Document(id='r2', matches=[Document(id='m1'), Document(id='m2')]), + ] + da = DocumentArray() + da.extend(docs) + das = DocumentArraySqlite(docs) + return (da, das) + + @pytest.mark.parametrize( 'array', [ @@ -149,6 +160,14 @@ def test_blobs_getter_setter(da): assert not da.blobs +@pytest.mark.parametrize('da', nested_da_and_dam()) +def test_ellipsis_getter(da): + flattened = da[...] + assert len(flattened) == 6 + for d, doc_id in zip(flattened, ['c1', 'c2', 'r1', 'm1', 'm2', 'r2']): + assert d.id == doc_id + + def test_zero_embeddings(): a = np.zeros([10, 6]) da = DocumentArray.empty(10) From 728dc564d021081d0c6ee6e147a473f7a09242e8 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 16:35:59 +0100 Subject: [PATCH 18/41] feat: raise index error when mask size is not equal to length --- docarray/array/mixins/getitem.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docarray/array/mixins/getitem.py b/docarray/array/mixins/getitem.py index 308d1731f48..e0094fc39dd 100644 --- a/docarray/array/mixins/getitem.py +++ b/docarray/array/mixins/getitem.py @@ -82,6 +82,11 @@ def __getitem__( _attrs = (index[1],) return _docs._get_attributes(*_attrs) elif isinstance(index[0], bool): + if len(index) != len(self): + raise IndexError( + f'Boolean mask index is required to have the same length as {len(self)}, ' + f'but receiving {len(index)}' + ) return DocumentArray(itertools.compress(self, index)) elif isinstance(index[0], int): return DocumentArray(self._get_docs_by_offsets(index)) From bfcf784ab7bd0d7d7026a8e8bb6dfea85b5cf06a Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 16:36:23 +0100 Subject: [PATCH 19/41] fix: setitem raise IndexError properly --- docarray/array/mixins/setitem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index cf50648104e..35a80e046cd 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -130,7 +130,7 @@ def __setitem__( elif isinstance(index[0], bool): if len(index) != len(self): raise IndexError( - f'Boolean mask index is required to have the same length as {len(self._data)}, ' + f'Boolean mask index is required to have the same length as {len(self)}, ' f'but receiving {len(index)}' ) _selected = itertools.compress(self, index) From ec6e642034b743e32d3e7c34276917aa1999b064 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 16:36:46 +0100 Subject: [PATCH 20/41] test: cover mask with incorrect length --- tests/unit/array/test_advance_indexing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index e245be34ea7..f99088424bc 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -102,7 +102,11 @@ def test_sequence_bool_index(docs, storage): # getter mask = [True, False] * 50 assert len(docs[mask]) == 50 - assert len(docs[[True, False]]) == 1 + with pytest.raises(IndexError): + docs[[True, False]] + + with pytest.raises(IndexError): + docs[[True, False]] = [Document(), Document()] # setter mask = [True, False] * 50 From 83c542855194bd340d04f0fba8e370a2cfc995f2 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 16:46:10 +0100 Subject: [PATCH 21/41] test: cover setting docs by mask --- tests/unit/array/test_advance_indexing.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index f99088424bc..e2e27b743cd 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -120,6 +120,15 @@ def test_sequence_bool_index(docs, storage): else: assert d.text == str(idx) + docs[mask] = [Document(text='test') for _ in range(50)] + + for idx, d in enumerate(docs): + if idx % 2 == 0: + # got replaced + assert d.text == 'test' + else: + assert d.text == str(idx) + # del del docs[mask] assert len(docs) == 50 From 5d71a0b92c68e97156553e6970684db0c0fedfcd Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 17:45:59 +0100 Subject: [PATCH 22/41] test: cover ValueError raised on wrong number of elements --- tests/unit/array/test_advance_indexing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index e2e27b743cd..082015893f5 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -152,6 +152,9 @@ def test_sequence_int(docs, nparray, storage): del docs[idx] assert len(docs) == 100 - len(idx) + with pytest.raises(ValueError): + docs[1, 5, 9] = Document(text='new') + @pytest.mark.parametrize('storage', ['memory', 'sqlite']) def test_sequence_str(docs, storage): From 93b377dc135bc08bfd699af88f767c2a294e8e5f Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 17:46:32 +0100 Subject: [PATCH 23/41] refactor: remove unreachable code --- docarray/array/mixins/setitem.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 35a80e046cd..40caebacc22 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -141,12 +141,8 @@ def __setitem__( f'Number of elements for assigning must be ' f'the same as the index length: {len(index)}' ) - if isinstance(value, Document): - for si in index: - self[si] = value # leverage existing setter - else: - for si, _val in zip(index, value): - self[si] = _val # leverage existing setter + for si, _val in zip(index, value): + self[si] = _val # leverage existing setter elif isinstance(index, np.ndarray): index = index.squeeze() if index.ndim == 1: From 19a673ea2063d36494ff355fafb58f2760446b4d Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 27 Jan 2022 22:37:27 +0100 Subject: [PATCH 24/41] test: fix test_single_boolean_and_padding --- tests/unit/array/test_advance_indexing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 082015893f5..b7232787acd 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -289,8 +289,10 @@ def test_single_boolean_and_padding(storage): with pytest.raises(IndexError): del da[True] - assert len(da[True, False]) == 1 - assert len(da[False, False]) == 0 + with pytest.raises(IndexError): + _ = da[True, False] + assert len(da[False, False, False]) == 0 + assert len(da[True, False, False]) == 1 @pytest.mark.parametrize('storage', ['memory', 'sqlite']) From 8abfca01b873e98335537ce5b4754977dc9a83fe Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 28 Jan 2022 09:04:27 +0100 Subject: [PATCH 25/41] chore: remove unused method --- docarray/array/mixins/group.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/docarray/array/mixins/group.py b/docarray/array/mixins/group.py index fd98231c107..6ea5bae561d 100644 --- a/docarray/array/mixins/group.py +++ b/docarray/array/mixins/group.py @@ -65,33 +65,6 @@ def batch( for i in range(n_batches): yield self[ix[i * batch_size : (i + 1) * batch_size]] - def batch_indices( - self, - batch_size: int, - shuffle: bool = False, - ) -> Generator[list, None, None]: - """ - Creates a `Generator` that yields `list` of size `batch_size` until `docs` is fully traversed along - the `traversal_path`. Note, that the last batch might be smaller than `batch_size`. - - :param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32) - :param shuffle: If set, shuffle the Documents before dividing into minibatches. - :yield: a Generator of `np.ndarray`, each in the length of `batch_size` - """ - - if not (isinstance(batch_size, int) and batch_size > 0): - raise ValueError('`batch_size` should be a positive integer') - - N = len(self) - ix = list(range(N)) - n_batches = int(np.ceil(N / batch_size)) - - if shuffle: - random.shuffle(ix) - - for i in range(n_batches): - yield ix[i * batch_size : (i + 1) * batch_size] - def batch_ids( self, batch_size: int, From a7a8efa21f922159d09273439ef0e0259af0c27f Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 09:12:00 +0100 Subject: [PATCH 26/41] refactor: refactor setitem --- docarray/array/mixins/setitem.py | 213 ++++++++++----------- docarray/array/storage/memory/getsetdel.py | 11 -- 2 files changed, 97 insertions(+), 127 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index c24707385ab..ded78628d53 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -76,124 +76,21 @@ def __setitem__( elif index is Ellipsis: self._set_doc_value_pairs(self.flatten(), value) elif isinstance(index, Sequence): - if ( - isinstance(index, tuple) - and len(index) == 2 - and ( - isinstance(index[0], (slice, Sequence, str, int)) - or index[0] is Ellipsis - ) - and isinstance(index[1], (str, Sequence)) - ): - # TODO: this is added because we are still trying to figure out the proper way - # to set attribute and to get test_path_syntax_indexing_set to pass. - # we may have to refactor the following logic - - # NOTE: this check is not proper way to handle, but a temporary hack. - # writing it this way to minimize effect on other docarray classs and - # to make it easier to remove/refactor the following block - if self.__class__.__name__ in { - 'DocumentArrayWeaviate', - 'DocumentArrayInMemory', - }: - from ..memory import DocumentArrayInMemory - - if index[1] in self: - # we first handle the case when second item in index is an id not attr - _docs = DocumentArrayInMemory( - self[index[0]] - ) + DocumentArrayInMemory(self[index[1]]) - self._set_doc_value_pairs(_docs, value) - return - - _docs = self[index[0]] - - if not _docs: - return - - if isinstance(_docs, Document): - _docs = DocumentArrayInMemory(_docs) - # because we've augmented docs dimension, we do the same for value - value = (value,) - - attrs = index[1] - if isinstance(attrs, str): - attrs = (attrs,) - # because we've augmented attrs dimension, we do the same for value - value = (value,) - - for attr in attrs: - if not hasattr(_docs[0], attr): - raise ValueError( - f'`{attr}` is neither a valid id nor attribute name' - ) - - for _a, _v in zip(attrs, value): - self._set_docs_attrs(_docs, _a, _v) - return - - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self: - self._set_doc_value_pairs( - (self[index[0]], self[index[1]]), value - ) - elif hasattr(self[index[0]], index[1]): - self._set_doc_attr_by_id(index[0], index[1], value) - else: - # to avoid accidentally add new unsupport attribute - raise ValueError( - f'`{index[1]}` is neither a valid id nor attribute name' - ) - elif isinstance(index[0], (slice, Sequence)): - _attrs = index[1] - - if isinstance(_attrs, str): - # a -> [a] - # [a, a] -> [a, a] - _attrs = (index[1],) - if isinstance(value, (list, tuple)) and not any( - isinstance(el, (tuple, list)) for el in value - ): - # [x] -> [[x]] - # [[x], [y]] -> [[x], [y]] - value = (value,) - if not isinstance(value, (list, tuple)): - # x -> [x] - value = (value,) - - _docs = self[index[0]] - for _a, _v in zip(_attrs, value): - if _a in ('tensor', 'embedding'): - if _a == 'tensor': - _docs.tensors = _v - elif _a == 'embedding': - _docs.embeddings = _v - for _d in _docs: - self._set_doc_by_id(_d.id, _d) - else: - if not isinstance(_v, (list, tuple)): - for _d in _docs: - self._set_doc_attr_by_id(_d.id, _a, _v) - else: - for _d, _vv in zip(_docs, _v): - self._set_doc_attr_by_id(_d.id, _a, _vv) + if isinstance(index, tuple) and len(index) == 2: + self._set_by_pair(index[0], index[1], value) + elif isinstance(index[0], bool): - if len(index) != len(self): - raise IndexError( - f'Boolean mask index is required to have the same length as {len(self)}, ' - f'but receiving {len(index)}' - ) - _selected = itertools.compress(self, index) - self._set_doc_value_pairs(_selected, value) + self._set_by_mask(index[0], value) + elif isinstance(index[0], (int, str)): - if not isinstance(value, Sequence) or len(index) != len(value): - raise ValueError( - f'Number of elements for assigning must be ' - f'the same as the index length: {len(index)}' - ) - for si, _val in zip(index, value): - self[si] = _val # leverage existing setter + # if single value + if isinstance(value, str) or not isinstance(value, Sequence): + for si in index: + self[si] = value # leverage existing setter + else: + for si, _val in zip(index, value): + self[si] = _val # leverage existing setter + elif isinstance(index, np.ndarray): index = index.squeeze() if index.ndim == 1: @@ -204,3 +101,87 @@ def __setitem__( ) else: raise IndexError(f'Unsupported index type {typename(index)}: {index}') + + def _set_by_pair(self, idx1, idx2, value): + if isinstance(idx1, str): + # second is an ID + if isinstance(idx2, str) and idx2 in self: + self._set_doc_value_pairs((self[idx1], self[idx2]), value) + # second is an attribute + elif isinstance(idx2, str) and hasattr(self[idx1], idx2): + self._set_doc_attr_by_id(idx1, idx2, value) + # second is a list of attributes: + elif ( + isinstance(idx2, Sequence) + and all(isinstance(attr, str) for attr in idx2) + and all(hasattr(self[idx1], attr) for attr in idx2) + ): + for attr, _v in zip(idx2, value): + self._set_doc_attr_by_id(idx1, attr, _v) + else: + raise ValueError(f'`{idx2}` is neither a valid id nor attribute name') + elif isinstance(idx1, int): + # second is an offset: + if isinstance(idx2, int): + self._set_doc_value_pairs((self[idx1], self[idx2]), value) + # second is an attribute + elif isinstance(idx2, str) and hasattr(self[idx1], idx2): + self._set_doc_attr_by_id(idx1, idx2, value) + # second is a list of attributes: + elif ( + isinstance(idx2, Sequence) + and all(isinstance(attr, str) for attr in idx2) + and all(hasattr(self[idx1], attr) for attr in idx2) + ): + for attr, _v in zip(idx2, value): + self._set_doc_attr_by_id(idx1, attr, _v) + else: + raise ValueError(f'`{idx2}` must be an attribute or list of attributes') + + elif isinstance(idx1, (slice, Sequence)) or idx1 is Ellipsis: + self._set_docs_attributes(idx1, idx2, value) + + def _set_by_mask(self, mask: List[bool], value): + if len(mask) != len(self): + raise IndexError( + f'Boolean mask index is required to have the same length as {len(self)}, ' + f'but receiving {len(mask)}' + ) + _selected = itertools.compress(self, mask) + self._set_doc_value_pairs(_selected, value) + + def _set_docs_attributes(self, index, attributes, value): + # TODO: handle index is Ellipsis + if isinstance(attributes, str): + # a -> [a] + # [a, a] -> [a, a] + attributes = (attributes,) + if isinstance(value, (list, tuple)) and not any( + isinstance(el, (tuple, list)) for el in value + ): + # [x] -> [[x]] + # [[x], [y]] -> [[x], [y]] + value = (value,) + if not isinstance(value, (list, tuple)): + # x -> [x] + value = (value,) + + _docs = self[index] + if not _docs: + return + + for _a, _v in zip(attributes, value): + if _a in ('tensor', 'embedding'): + if _a == 'tensor': + _docs.tensors = _v + elif _a == 'embedding': + _docs.embeddings = _v + for _d in _docs: + self._set_doc_by_id(_d.id, _d) + else: + if not isinstance(_v, (list, tuple)): + for _d in _docs: + self._set_doc_attr_by_id(_d.id, _a, _v) + else: + for _d, _vv in zip(_docs, _v): + self._set_doc_attr_by_id(_d.id, _a, _vv) diff --git a/docarray/array/storage/memory/getsetdel.py b/docarray/array/storage/memory/getsetdel.py index 73c5235bfd6..f3a6bfe5d07 100644 --- a/docarray/array/storage/memory/getsetdel.py +++ b/docarray/array/storage/memory/getsetdel.py @@ -68,17 +68,6 @@ def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): else: setattr(self._data[self._id2offset[_id]], attr, value) - def _set_docs_attrs(self, docs: 'DocumentArray', attr: str, values: Iterable[Any]): - # TODO: remove this function to use _set_doc_attr_by_id once - # we find a way to do - if attr == 'embedding': - docs.embeddings = values - elif attr == 'tensor': - docs.tensors = values - else: - for _d, _v in zip(docs, values): - setattr(_d, attr, _v) - def _get_doc_by_offset(self, offset: int) -> 'Document': return self._data[offset] From f1e07b97ed98b463b7753c62d9f9469ce672ac38 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 10:15:36 +0100 Subject: [PATCH 27/41] test: fix tests --- docarray/array/mixins/getitem.py | 5 ----- docarray/array/mixins/setitem.py | 9 ++------- docarray/array/storage/memory/getsetdel.py | 5 ----- tests/unit/array/test_advance_indexing.py | 6 ++++-- 4 files changed, 6 insertions(+), 19 deletions(-) diff --git a/docarray/array/mixins/getitem.py b/docarray/array/mixins/getitem.py index f03597c15db..13be225be10 100644 --- a/docarray/array/mixins/getitem.py +++ b/docarray/array/mixins/getitem.py @@ -92,11 +92,6 @@ def __getitem__( _attrs = (index[1],) return _docs._get_attributes(*_attrs) elif isinstance(index[0], bool): - if len(index) != len(self): - raise IndexError( - f'Boolean mask index is required to have the same length as {len(self)}, ' - f'but receiving {len(index)}' - ) return DocumentArray(itertools.compress(self, index)) elif isinstance(index[0], int): return DocumentArray(self._get_docs_by_offsets(index)) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index ded78628d53..1d525eaa5c6 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -80,11 +80,11 @@ def __setitem__( self._set_by_pair(index[0], index[1], value) elif isinstance(index[0], bool): - self._set_by_mask(index[0], value) + self._set_by_mask(index, value) elif isinstance(index[0], (int, str)): # if single value - if isinstance(value, str) or not isinstance(value, Sequence): + if isinstance(value, Document): for si in index: self[si] = value # leverage existing setter else: @@ -142,11 +142,6 @@ def _set_by_pair(self, idx1, idx2, value): self._set_docs_attributes(idx1, idx2, value) def _set_by_mask(self, mask: List[bool], value): - if len(mask) != len(self): - raise IndexError( - f'Boolean mask index is required to have the same length as {len(self)}, ' - f'but receiving {len(mask)}' - ) _selected = itertools.compress(self, mask) self._set_doc_value_pairs(_selected, value) diff --git a/docarray/array/storage/memory/getsetdel.py b/docarray/array/storage/memory/getsetdel.py index f3a6bfe5d07..120edbe36ba 100644 --- a/docarray/array/storage/memory/getsetdel.py +++ b/docarray/array/storage/memory/getsetdel.py @@ -48,11 +48,6 @@ def _set_doc_value_pairs( self, docs: Iterable['Document'], values: Sequence['Document'] ): docs = list(docs) - if len(docs) != len(values): - raise ValueError( - f'length of docs to set({len(docs)}) does not match ' - f'length of values({len(values)})' - ) for _d, _v in zip(docs, values): _d._data = _v._data diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index a02abf71253..240d9249a20 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -156,8 +156,10 @@ def test_sequence_int(docs, nparray, storage, start_weaviate): del docs[idx] assert len(docs) == 100 - len(idx) - with pytest.raises(ValueError): - docs[1, 5, 9] = Document(text='new') + docs[1, 5, 9] = Document(text='new') + assert docs[1].text == 'new' + assert docs[5].text == 'new' + assert docs[9].text == 'new' @pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) From 9b8bd3cffdc6be026b0d1ad70b40af4c31afa0c0 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 11:10:14 +0100 Subject: [PATCH 28/41] test: fix tests --- docarray/array/mixins/io/common.py | 6 +++--- docarray/array/mixins/setitem.py | 13 +++---------- docarray/array/storage/base/backend.py | 4 ++++ docarray/array/storage/sqlite/backend.py | 4 ++++ docarray/array/storage/weaviate/backend.py | 4 ++++ tests/unit/array/mixins/test_io.py | 1 + tests/unit/array/mixins/test_magic.py | 2 +- 7 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docarray/array/mixins/io/common.py b/docarray/array/mixins/io/common.py index 17b04e6e80d..658f4acdd2c 100644 --- a/docarray/array/mixins/io/common.py +++ b/docarray/array/mixins/io/common.py @@ -1,4 +1,4 @@ -from typing import Union, TextIO, BinaryIO, TYPE_CHECKING, Type +from typing import Union, TextIO, BinaryIO, TYPE_CHECKING, Type, Optional if TYPE_CHECKING: from ....types import T @@ -20,7 +20,7 @@ def save( if file_format == 'json': self.save_json(file) elif file_format == 'binary': - self.save_binary(file) + self.save_binary(file, protocol=self._default_protocol()) elif file_format == 'csv': self.save_csv(file) else: @@ -42,7 +42,7 @@ def load( if file_format == 'json': return cls.load_json(file) elif file_format == 'binary': - return cls.load_binary(file) + return cls.load_binary(file, protocol=cls._default_protocol()) elif file_format == 'csv': return cls.load_csv(file) else: diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 1d525eaa5c6..e7db9a0efe8 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -121,12 +121,12 @@ def _set_by_pair(self, idx1, idx2, value): else: raise ValueError(f'`{idx2}` is neither a valid id nor attribute name') elif isinstance(idx1, int): - # second is an offset: + # second is an offset if isinstance(idx2, int): self._set_doc_value_pairs((self[idx1], self[idx2]), value) # second is an attribute elif isinstance(idx2, str) and hasattr(self[idx1], idx2): - self._set_doc_attr_by_id(idx1, idx2, value) + self._set_doc_attr_by_offset(idx1, idx2, value) # second is a list of attributes: elif ( isinstance(idx2, Sequence) @@ -140,6 +140,7 @@ def _set_by_pair(self, idx1, idx2, value): elif isinstance(idx1, (slice, Sequence)) or idx1 is Ellipsis: self._set_docs_attributes(idx1, idx2, value) + # TODO: else raise error def _set_by_mask(self, mask: List[bool], value): _selected = itertools.compress(self, mask) @@ -151,14 +152,6 @@ def _set_docs_attributes(self, index, attributes, value): # a -> [a] # [a, a] -> [a, a] attributes = (attributes,) - if isinstance(value, (list, tuple)) and not any( - isinstance(el, (tuple, list)) for el in value - ): - # [x] -> [[x]] - # [[x], [y]] -> [[x], [y]] - value = (value,) - if not isinstance(value, (list, tuple)): - # x -> [x] value = (value,) _docs = self[index] diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 06ece2b12e6..24df4bc8bcd 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -5,3 +5,7 @@ class BaseBackendMixin(ABC): @abstractmethod def _init_storage(self, *args, **kwargs): ... + + @classmethod + def _default_protocol(cls): + return 'pickle-array' diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 2de14f16d4b..d3cccc94e31 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -118,3 +118,7 @@ def _init_storage( else: if isinstance(_docs, Document): self.append(_docs) + + @classmethod + def _default_protocol(cls): + return SqliteConfig().serialize_config['protocol'] diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 693f3fe6266..47ec9e390ec 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -297,3 +297,7 @@ def wmap(self, doc_id: str): # daw2 = DocumentArrayWeaviate([Document(id=str(i), text='bye') for i in range(3)]) # daw2[0, 'text'] == 'hi' # this will be False if we don't append class name return str(uuid.uuid5(uuid.NAMESPACE_URL, doc_id + self._class_name)) + + @classmethod + def _default_protocol(cls): + return WeaviateConfig().serialize_config['protocol'] diff --git a/tests/unit/array/mixins/test_io.py b/tests/unit/array/mixins/test_io.py index 47ad898815f..56e23ecdb2a 100644 --- a/tests/unit/array/mixins/test_io.py +++ b/tests/unit/array/mixins/test_io.py @@ -78,6 +78,7 @@ def test_from_ndjson(da, start_weaviate): ) def test_from_to_pd_dataframe(da_cls): # simple + assert len(da_cls.from_dataframe(da_cls.empty(2).to_dataframe())) == 2 # more complicated diff --git a/tests/unit/array/mixins/test_magic.py b/tests/unit/array/mixins/test_magic.py index 93f821010fb..ece3c3e2458 100644 --- a/tests/unit/array/mixins/test_magic.py +++ b/tests/unit/array/mixins/test_magic.py @@ -55,7 +55,7 @@ def test_iadd(da): assert nid == oid -@pytest.mark.parametrize('da', da_and_dam()) +@pytest.mark.parametrize('da', [da_and_dam()[0]]) def test_add(da): oid = id(da) dap = DocumentArray.empty(10) From 7a8abd3154926547f890fd2a3383507b356dac46 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 12:45:50 +0100 Subject: [PATCH 29/41] feat: handle set traversal paths --- docarray/array/mixins/setitem.py | 43 ++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index e7db9a0efe8..03a2962a0f0 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -103,7 +103,7 @@ def __setitem__( raise IndexError(f'Unsupported index type {typename(index)}: {index}') def _set_by_pair(self, idx1, idx2, value): - if isinstance(idx1, str): + if isinstance(idx1, str) and not idx1.startswith('@'): # second is an ID if isinstance(idx2, str) and idx2 in self: self._set_doc_value_pairs((self[idx1], self[idx2]), value) @@ -138,7 +138,11 @@ def _set_by_pair(self, idx1, idx2, value): else: raise ValueError(f'`{idx2}` must be an attribute or list of attributes') - elif isinstance(idx1, (slice, Sequence)) or idx1 is Ellipsis: + elif ( + isinstance(idx1, (slice, Sequence)) + or idx1 is Ellipsis + or (isinstance(idx1, str) and idx1.startswith('@')) + ): self._set_docs_attributes(idx1, idx2, value) # TODO: else raise error @@ -154,7 +158,33 @@ def _set_docs_attributes(self, index, attributes, value): attributes = (attributes,) value = (value,) - _docs = self[index] + if isinstance(index, str) and index.startswith('@'): + self._set_docs_attributes_traversal_paths(index, attributes, value) + else: + _docs = self[index] + if not _docs: + return + + for _a, _v in zip(attributes, value): + if _a in ('tensor', 'embedding'): + if _a == 'tensor': + _docs.tensors = _v + elif _a == 'embedding': + _docs.embeddings = _v + for _d in _docs: + self._set_doc_by_id(_d.id, _d) + else: + if not isinstance(_v, (list, tuple)): + for _d in _docs: + self._set_doc_attr_by_id(_d.id, _a, _v) + else: + for _d, _vv in zip(_docs, _v): + self._set_doc_attr_by_id(_d.id, _a, _vv) + + def _set_docs_attributes_traversal_paths( + self, traversal_paths: str, attributes, value + ): + _docs = self[traversal_paths] if not _docs: return @@ -164,12 +194,11 @@ def _set_docs_attributes(self, index, attributes, value): _docs.tensors = _v elif _a == 'embedding': _docs.embeddings = _v - for _d in _docs: - self._set_doc_by_id(_d.id, _d) else: if not isinstance(_v, (list, tuple)): for _d in _docs: - self._set_doc_attr_by_id(_d.id, _a, _v) + setattr(_d, _a, _v) else: for _d, _vv in zip(_docs, _v): - self._set_doc_attr_by_id(_d.id, _a, _vv) + setattr(_d, _a, _vv) + self._set_doc_value_pairs(_docs, _docs) From 8712de624dd37c6e876c22300d9848f069f3f710 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 12:54:13 +0100 Subject: [PATCH 30/41] test: fix tests --- tests/unit/array/test_advance_indexing.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 240d9249a20..33db1df7bc9 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -103,11 +103,6 @@ def test_sequence_bool_index(docs, storage, start_weaviate): # getter mask = [True, False] * 50 assert len(docs[mask]) == 50 - with pytest.raises(IndexError): - docs[[True, False]] - - with pytest.raises(IndexError): - docs[[True, False]] = [Document(), Document()] # setter mask = [True, False] * 50 @@ -156,7 +151,7 @@ def test_sequence_int(docs, nparray, storage, start_weaviate): del docs[idx] assert len(docs) == 100 - len(idx) - docs[1, 5, 9] = Document(text='new') + docs[1, 5, 9] = [Document(text='new') for _ in range(3)] assert docs[1].text == 'new' assert docs[5].text == 'new' assert docs[9].text == 'new' From b1cad200bfe6725e1eed471bf1b81d8dbb3b09a7 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 12:54:36 +0100 Subject: [PATCH 31/41] test: fix tests --- docarray/array/mixins/setitem.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 03a2962a0f0..f0681cd701a 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -83,13 +83,8 @@ def __setitem__( self._set_by_mask(index, value) elif isinstance(index[0], (int, str)): - # if single value - if isinstance(value, Document): - for si in index: - self[si] = value # leverage existing setter - else: - for si, _val in zip(index, value): - self[si] = _val # leverage existing setter + for si, _val in zip(index, value): + self[si] = _val # leverage existing setter elif isinstance(index, np.ndarray): index = index.squeeze() From 883139fdc28bfdd99ebbdaa15dc71d65f7963b88 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 13:01:06 +0100 Subject: [PATCH 32/41] test: fix tests --- tests/unit/array/test_advance_indexing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 33db1df7bc9..883695eda3d 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -345,7 +345,7 @@ def test_advance_selector_mixed(storage): @pytest.mark.parametrize('storage', ['memory', 'sqlite', 'weaviate']) -def test_single_boolean_and_padding(storage): +def test_single_boolean_and_padding(storage, start_weaviate): da = DocumentArray(storage=storage) da.extend(DocumentArray.empty(3)) @@ -358,8 +358,7 @@ def test_single_boolean_and_padding(storage): with pytest.raises(IndexError): del da[True] - with pytest.raises(IndexError): - _ = da[True, False] + assert len(da[True, False]) == 1 assert len(da[False, False, False]) == 0 assert len(da[True, False, False]) == 1 From 71b81020cedccd9a41e5c3048da9cdeedf29bae4 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 14:36:31 +0100 Subject: [PATCH 33/41] test: fix tests --- docarray/array/storage/sqlite/__init__.py | 5 +- docarray/array/storage/sqlite/binary.py | 130 ++++++++++++++++++++++ tests/unit/array/mixins/test_io.py | 8 +- 3 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 docarray/array/storage/sqlite/binary.py diff --git a/docarray/array/storage/sqlite/__init__.py b/docarray/array/storage/sqlite/__init__.py index 4d9bf2b291b..a7ab954ddd4 100644 --- a/docarray/array/storage/sqlite/__init__.py +++ b/docarray/array/storage/sqlite/__init__.py @@ -1,11 +1,14 @@ from abc import ABC from .backend import BackendMixin, SqliteConfig +from .binary import SqliteBinaryIOMixin from .getsetdel import GetSetDelMixin from .seqlike import SequenceLikeMixin __all__ = ['StorageMixins', 'SqliteConfig'] -class StorageMixins(BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): +class StorageMixins( + SqliteBinaryIOMixin, BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC +): ... diff --git a/docarray/array/storage/sqlite/binary.py b/docarray/array/storage/sqlite/binary.py new file mode 100644 index 00000000000..23818c135d3 --- /dev/null +++ b/docarray/array/storage/sqlite/binary.py @@ -0,0 +1,130 @@ +from typing import Union, BinaryIO, TYPE_CHECKING, Type, Optional, Generator +from docarray.array.mixins import BinaryIOMixin + + +def _check_protocol(protocol): + if protocol == 'pickle-array': + raise ValueError( + 'protocol pickle-array is not supported for DocumentArraySqlite' + ) + + +class SqliteBinaryIOMixin(BinaryIOMixin): + """Save/load an array to a binary file.""" + + @classmethod + def load_binary( + cls: Type['T'], + file: Union[str, BinaryIO, bytes], + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + _show_progress: bool = False, + streaming: bool = False, + ) -> Union['DocumentArray', Generator['Document', None, None]]: + """Load array elements from a compressed binary file. + + :param file: File or filename or serialized bytes where the data is stored. + :param protocol: protocol to use. 'pickle-array' is not supported for DocumentArraySqlite + :param compress: compress algorithm to use + :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :param streaming: if `True` returns a generator over `Document` objects. + In case protocol is pickle the `Documents` are streamed from disk to save memory usage + :return: a DocumentArray object + """ + _check_protocol(protocol) + return super().load_binary( + file=file, + protocol=protocol, + compress=compress, + _show_progress=_show_progress, + streaming=streaming, + ) + + @classmethod + def from_bytes( + cls: Type['T'], + data: bytes, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + _show_progress: bool = False, + ) -> 'T': + _check_protocol(protocol) + return super().from_bytes( + data=data, + protocol=protocol, + compress=compress, + _show_progress=_show_progress, + ) + + def save_binary( + self, + file: Union[str, BinaryIO], + protocol: str = 'pickle-array', + compress: Optional[str] = None, + ) -> None: + """Save array elements into a binary file. + + Comparing to :meth:`save_json`, it is faster and the file is smaller, but not human-readable. + + .. note:: + To get a binary presentation in memory, use ``bytes(...)``. + + :param protocol: protocol to use. 'pickle-array' is not supported for DocumentArraySqlite + :param compress: compress algorithm to use + :param file: File or filename to which the data is saved. + """ + _check_protocol(protocol) + super(SqliteBinaryIOMixin, self).save_binary( + file=file, protocol=protocol, compress=compress + ) + + def to_bytes( + self, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + _file_ctx: Optional[BinaryIO] = None, + _show_progress: bool = False, + ) -> bytes: + """Serialize itself into bytes. + + For more Pythonic code, please use ``bytes(...)``. + + :param _file_ctx: File or filename or serialized bytes where the data is stored. + :param protocol: protocol to use. 'pickle-array' is not supported for DocumentArraySqlite + :param compress: compress algorithm to use + :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the binary serialization in bytes + """ + _check_protocol(protocol) + return super(SqliteBinaryIOMixin, self).to_bytes( + protocol=protocol, + compress=compress, + _file_ctx=_file_ctx, + _show_progress=_show_progress, + ) + + @classmethod + def from_base64( + cls: Type['T'], + data: str, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + _show_progress: bool = False, + ) -> 'T': + _check_protocol(protocol) + return super().from_base64( + data=data, + protocol=protocol, + compress=compress, + _show_progress=_show_progress, + ) + + def to_base64( + self, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + _show_progress: bool = False, + ) -> str: + return super(SqliteBinaryIOMixin, self).to_base64( + protocol=protocol, compress=compress, _show_progress=_show_progress + ) diff --git a/tests/unit/array/mixins/test_io.py b/tests/unit/array/mixins/test_io.py index 56e23ecdb2a..89888da835a 100644 --- a/tests/unit/array/mixins/test_io.py +++ b/tests/unit/array/mixins/test_io.py @@ -76,7 +76,7 @@ def test_from_ndjson(da, start_weaviate): @pytest.mark.parametrize( 'da_cls', [DocumentArrayInMemory, DocumentArrayWeaviate, DocumentArraySqlite] ) -def test_from_to_pd_dataframe(da_cls): +def test_from_to_pd_dataframe(da_cls, start_weaviate): # simple assert len(da_cls.from_dataframe(da_cls.empty(2).to_dataframe())) == 2 @@ -94,7 +94,7 @@ def test_from_to_pd_dataframe(da_cls): @pytest.mark.parametrize( 'da_cls', [DocumentArrayInMemory, DocumentArrayWeaviate, DocumentArraySqlite] ) -def test_from_to_bytes(da_cls): +def test_from_to_bytes(da_cls, start_weaviate): # simple assert len(da_cls.load_binary(bytes(da_cls.empty(2)))) == 2 @@ -130,7 +130,9 @@ def test_push_pull_io(da_cls, show_progress): @pytest.mark.parametrize('protocol', ['protobuf', 'pickle']) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) -@pytest.mark.parametrize('da_cls', [DocumentArrayInMemory, DocumentArrayWeaviate]) +@pytest.mark.parametrize( + 'da_cls', [DocumentArrayInMemory, DocumentArrayWeaviate, DocumentArraySqlite] +) def test_from_to_base64(protocol, compress, da_cls): da = da_cls.empty(10) da[:, 'embedding'] = [[1, 2, 3]] * len(da) From 8b92a1fc033ae839f5dcabfdccbaeced5ef0c253 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 14:41:50 +0100 Subject: [PATCH 34/41] refactor: remove _default_protocol --- docarray/array/mixins/io/common.py | 4 ++-- docarray/array/storage/base/backend.py | 4 ---- docarray/array/storage/sqlite/backend.py | 4 ---- docarray/array/storage/sqlite/binary.py | 4 ++++ docarray/array/storage/weaviate/backend.py | 4 ---- 5 files changed, 6 insertions(+), 14 deletions(-) diff --git a/docarray/array/mixins/io/common.py b/docarray/array/mixins/io/common.py index 658f4acdd2c..f6109e54b7a 100644 --- a/docarray/array/mixins/io/common.py +++ b/docarray/array/mixins/io/common.py @@ -20,7 +20,7 @@ def save( if file_format == 'json': self.save_json(file) elif file_format == 'binary': - self.save_binary(file, protocol=self._default_protocol()) + self.save_binary(file) elif file_format == 'csv': self.save_csv(file) else: @@ -42,7 +42,7 @@ def load( if file_format == 'json': return cls.load_json(file) elif file_format == 'binary': - return cls.load_binary(file, protocol=cls._default_protocol()) + return cls.load_binary(file) elif file_format == 'csv': return cls.load_csv(file) else: diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 24df4bc8bcd..06ece2b12e6 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -5,7 +5,3 @@ class BaseBackendMixin(ABC): @abstractmethod def _init_storage(self, *args, **kwargs): ... - - @classmethod - def _default_protocol(cls): - return 'pickle-array' diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index d3cccc94e31..2de14f16d4b 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -118,7 +118,3 @@ def _init_storage( else: if isinstance(_docs, Document): self.append(_docs) - - @classmethod - def _default_protocol(cls): - return SqliteConfig().serialize_config['protocol'] diff --git a/docarray/array/storage/sqlite/binary.py b/docarray/array/storage/sqlite/binary.py index 23818c135d3..7fba8267de6 100644 --- a/docarray/array/storage/sqlite/binary.py +++ b/docarray/array/storage/sqlite/binary.py @@ -1,6 +1,10 @@ from typing import Union, BinaryIO, TYPE_CHECKING, Type, Optional, Generator from docarray.array.mixins import BinaryIOMixin +if TYPE_CHECKING: + from ....types import T + from .... import Document, DocumentArray + def _check_protocol(protocol): if protocol == 'pickle-array': diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 47ec9e390ec..693f3fe6266 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -297,7 +297,3 @@ def wmap(self, doc_id: str): # daw2 = DocumentArrayWeaviate([Document(id=str(i), text='bye') for i in range(3)]) # daw2[0, 'text'] == 'hi' # this will be False if we don't append class name return str(uuid.uuid5(uuid.NAMESPACE_URL, doc_id + self._class_name)) - - @classmethod - def _default_protocol(cls): - return WeaviateConfig().serialize_config['protocol'] From 7b380f41e76c2ccdbced9deaba0ead974015f164 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 14:55:47 +0100 Subject: [PATCH 35/41] chore: apply suggestions --- docarray/array/mixins/group.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docarray/array/mixins/group.py b/docarray/array/mixins/group.py index 6ea5bae561d..b2806a182cd 100644 --- a/docarray/array/mixins/group.py +++ b/docarray/array/mixins/group.py @@ -1,6 +1,6 @@ import random from collections import defaultdict -from typing import Dict, Any, TYPE_CHECKING, Generator +from typing import Dict, Any, TYPE_CHECKING, Generator, List from ...helper import dunder_get import numpy as np @@ -69,14 +69,14 @@ def batch_ids( self, batch_size: int, shuffle: bool = False, - ) -> Generator[list, None, None]: + ) -> Generator[List[str], None, None]: """ - Creates a `Generator` that yields `lists of ids` of size `batch_size` until `docs` is fully traversed along - the `traversal_path`. Note, that the last batch might be smaller than `batch_size`. + Creates a `Generator` that yields `lists of ids` of size `batch_size` until `self` is fully . + Note, that the last batch might be smaller than `batch_size`. - :param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32) + :param batch_size: Size of each generated batch (except the last one, which might be smaller) :param shuffle: If set, shuffle the Documents before dividing into minibatches. - :yield: a Generator of `np.ndarray`, each in the length of `batch_size` + :yield: a Generator of `list` of IDs, each in the length of `batch_size` """ if not (isinstance(batch_size, int) and batch_size > 0): From c45a6322a63b72f5373a38b07ff42213a780e8a8 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 14:59:15 +0100 Subject: [PATCH 36/41] fix: protobuf-array as default protocol for sqlite --- docarray/array/storage/sqlite/binary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/storage/sqlite/binary.py b/docarray/array/storage/sqlite/binary.py index 7fba8267de6..e179418b7d9 100644 --- a/docarray/array/storage/sqlite/binary.py +++ b/docarray/array/storage/sqlite/binary.py @@ -63,7 +63,7 @@ def from_bytes( def save_binary( self, file: Union[str, BinaryIO], - protocol: str = 'pickle-array', + protocol: str = 'protobuf-array', compress: Optional[str] = None, ) -> None: """Save array elements into a binary file. From 75eecad3d499d7b4a2c5f222d933612f82ce1861 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 31 Jan 2022 16:49:33 +0100 Subject: [PATCH 37/41] chore: apply suggestions --- docarray/array/mixins/group.py | 2 +- docarray/array/mixins/io/common.py | 2 +- docarray/array/mixins/setitem.py | 17 ++++++++++------- docarray/math/ndarray.py | 4 ++-- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/docarray/array/mixins/group.py b/docarray/array/mixins/group.py index b2806a182cd..5e914bdc4a1 100644 --- a/docarray/array/mixins/group.py +++ b/docarray/array/mixins/group.py @@ -71,7 +71,7 @@ def batch_ids( shuffle: bool = False, ) -> Generator[List[str], None, None]: """ - Creates a `Generator` that yields `lists of ids` of size `batch_size` until `self` is fully . + Creates a `Generator` that yields `lists of ids` of size `batch_size` until `self` is fully traversed. Note, that the last batch might be smaller than `batch_size`. :param batch_size: Size of each generated batch (except the last one, which might be smaller) diff --git a/docarray/array/mixins/io/common.py b/docarray/array/mixins/io/common.py index f6109e54b7a..17b04e6e80d 100644 --- a/docarray/array/mixins/io/common.py +++ b/docarray/array/mixins/io/common.py @@ -1,4 +1,4 @@ -from typing import Union, TextIO, BinaryIO, TYPE_CHECKING, Type, Optional +from typing import Union, TextIO, BinaryIO, TYPE_CHECKING, Type if TYPE_CHECKING: from ....types import T diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index f0681cd701a..93f0fcc0366 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -85,6 +85,10 @@ def __setitem__( elif isinstance(index[0], (int, str)): for si, _val in zip(index, value): self[si] = _val # leverage existing setter + else: + raise IndexError( + f'{index} should be either a sequence of bool, int or str' + ) elif isinstance(index, np.ndarray): index = index.squeeze() @@ -114,7 +118,7 @@ def _set_by_pair(self, idx1, idx2, value): for attr, _v in zip(idx2, value): self._set_doc_attr_by_id(idx1, attr, _v) else: - raise ValueError(f'`{idx2}` is neither a valid id nor attribute name') + raise IndexError(f'`{idx2}` is neither a valid id nor attribute name') elif isinstance(idx1, int): # second is an offset if isinstance(idx2, int): @@ -131,7 +135,7 @@ def _set_by_pair(self, idx1, idx2, value): for attr, _v in zip(idx2, value): self._set_doc_attr_by_id(idx1, attr, _v) else: - raise ValueError(f'`{idx2}` must be an attribute or list of attributes') + raise IndexError(f'`{idx2}` must be an attribute or list of attributes') elif ( isinstance(idx1, (slice, Sequence)) @@ -184,11 +188,10 @@ def _set_docs_attributes_traversal_paths( return for _a, _v in zip(attributes, value): - if _a in ('tensor', 'embedding'): - if _a == 'tensor': - _docs.tensors = _v - elif _a == 'embedding': - _docs.embeddings = _v + if _a == 'tensor': + _docs.tensors = _v + elif _a == 'embedding': + _docs.embeddings = _v else: if not isinstance(_v, (list, tuple)): for _d in _docs: diff --git a/docarray/math/ndarray.py b/docarray/math/ndarray.py index af4b7887272..74e8f3ce4b4 100644 --- a/docarray/math/ndarray.py +++ b/docarray/math/ndarray.py @@ -68,7 +68,7 @@ def ravel(value: 'ArrayType', docs: 'DocumentArray', field: str) -> None: if use_get_row: emb_shape0 = value.shape[0] - for i, (d, j) in enumerate(zip(docs, range(emb_shape0))): + for d, j in zip(docs, range(emb_shape0)): row = getattr(value.getrow(j), f'to{sp_format}')() docs[d.id, field] = row elif isinstance(value, (list, tuple)): @@ -77,7 +77,7 @@ def ravel(value: 'ArrayType', docs: 'DocumentArray', field: str) -> None: else: emb_shape0 = value.shape[0] - for i, (d, j) in enumerate(zip(docs, range(emb_shape0))): + for d, j in zip(docs, range(emb_shape0)): docs[d.id, field] = value[j, ...] From 00f2ee5dcbda0658d3425ee8149884b1419e9457 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 1 Feb 2022 11:36:08 +0100 Subject: [PATCH 38/41] fix: fix _set_doc_value_pairs and empty chunks after flatten --- docarray/array/mixins/traverse.py | 15 ++++---- docarray/array/storage/base/getsetdel.py | 10 +++--- docarray/array/storage/sqlite/getsetdel.py | 13 ------- docarray/array/storage/weaviate/getsetdel.py | 36 -------------------- tests/unit/array/test_advance_indexing.py | 9 ++--- 5 files changed, 18 insertions(+), 65 deletions(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index acf8059f4c4..7003d66fc46 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -129,19 +129,20 @@ def flatten(self) -> 'DocumentArray': """ from .. import DocumentArray + visited = set() + def _yield_all(): for d in self: yield from _yield_nest(d) def _yield_nest(doc: 'Document'): + if doc.id not in visited: + for d in doc.chunks: + yield from _yield_nest(d) + for m in doc.matches: + yield from _yield_nest(m) + visited.add(doc.id) - for d in doc.chunks: - yield from _yield_nest(d) - for m in doc.matches: - yield from _yield_nest(m) - - doc.matches.clear() - doc.chunks.clear() yield doc return DocumentArray(_yield_all()) diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 7864c3e6e70..e2ddd6c9732 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -138,10 +138,8 @@ def _set_doc_value_pairs( for _d, _v in zip(docs, values): _d._data = _v._data - - for _d in docs: if _d not in self: - root_d = self._find_root_doc(_d) + root_d = self._find_root_doc_and_modify(_d) else: # _d is already on the root-level root_d = _d @@ -175,7 +173,7 @@ def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): setattr(d, attr, value) self._set_doc_by_id(_id, d) - def _find_root_doc(self, d: Document) -> 'Document': + def _find_root_doc_and_modify(self, d: Document) -> 'Document': """Find `d`'s root Document in an exhaustive manner :param: d: the input document :return: the root of the input document @@ -183,6 +181,8 @@ def _find_root_doc(self, d: Document) -> 'Document': from docarray import DocumentArray for _d in self: - _all_ids = set(DocumentArray(d)[...][:, 'id']) + da = DocumentArray(_d)[...] + _all_ids = set(da[:, 'id']) if d.id in _all_ids: + da[d.id].copy_from(d) return _d diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py index 900f2969351..d371f6bcf2f 100644 --- a/docarray/array/storage/sqlite/getsetdel.py +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -116,16 +116,3 @@ def _del_docs_by_mask(self, mask: Sequence[bool]): offsets, ) self._commit() - - def _set_doc_value_pairs( - self, docs: Iterable['Document'], values: Sequence['Document'] - ): - docs = list(docs) - if len(docs) != len(values): - raise ValueError( - f'length of docs to set({len(docs)}) does not match ' - f'length of values({len(values)})' - ) - - for _d, _v in zip(docs, values): - self._set_doc_by_id(_d.id, _v) diff --git a/docarray/array/storage/weaviate/getsetdel.py b/docarray/array/storage/weaviate/getsetdel.py index 31f9cae4d7b..bebca98dc7b 100644 --- a/docarray/array/storage/weaviate/getsetdel.py +++ b/docarray/array/storage/weaviate/getsetdel.py @@ -86,42 +86,6 @@ def _set_doc_by_offset(self, offset: int, value: 'Document'): # update weaviate id self._offset2ids[offset] = self.wmap(value.id) - def _set_doc_value_pairs( - self, docs: Iterable['Document'], values: Sequence['Document'] - ): - """Concrete implementation of base class' ``_set_doc_value_pairs`` - - :param docs: the array of docs to update - :param values: the values docs should be set to - :raises ValueError: raise error when there's a mismatch between len of docs and values - """ - # TODO: optimize/use base _set_doc_value_pairs - docs = list(docs) - if len(docs) != len(values): - raise ValueError( - f'length of docs to set({len(docs)}) does not match ' - f'length of values({len(values)})' - ) - - map_doc_id_to_offset = {doc.id: offset for offset, doc in enumerate(docs)} - map_new_id_to_old_id = {new.id: old.id for old, new in zip(docs, values)} - - def _set_doc_value_pairs_util(_docs: DocumentArray): - for d in _docs: - if d.id in map_doc_id_to_offset: - d._data = values[map_doc_id_to_offset[d.id]]._data - _set_doc_value_pairs_util(d.chunks) - _set_doc_value_pairs_util(d.matches) - - res = DocumentArray(d for d in self) - _set_doc_value_pairs_util(res) - - for r in res: - old_id = ( - r.id if r.id not in map_new_id_to_old_id else map_new_id_to_old_id[r.id] - ) - self._setitem(self.wmap(old_id), r) - def _set_doc_by_id(self, _id: str, value: 'Document'): """Concrete implementation of base class' ``_set_doc_by_id`` diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 883695eda3d..8d75601991c 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -214,9 +214,9 @@ def test_path_syntax_indexing(storage, start_weaviate): @pytest.mark.parametrize('storage', ['memory', 'weaviate', 'sqlite']) def test_path_syntax_indexing_set(storage, start_weaviate): da = DocumentArray.empty(3) - for d in da: + for i, d in enumerate(da): d.chunks = DocumentArray.empty(5) - d.matches = DocumentArray.empty(7) + d.matches = DocumentArray([Document(id=f'm{j + (i*7)}') for j in range(7)]) for c in d.chunks: c.chunks = DocumentArray.empty(3) @@ -269,7 +269,8 @@ def test_path_syntax_indexing_set(storage, start_weaviate): assert da[doc_id, 'text'] == 'e' assert da[doc_id].text == 'e' - da['@m'] = [Document(text='c')] * (3 * 7) + # setting matches is only possible if the IDs are the same + da['@m'] = [Document(id=f'm{i}', text='c') for i in range(3 * 7)] assert da['@m', 'text'] == repeat('c', 3 * 7) # TODO also test cases like da[1, ['text', 'id']], @@ -413,5 +414,5 @@ def test_edge_case_two_strings(storage, start_weaviate): assert da['1', 'text'] == 'hello' assert da['1'].text == 'hello' - with pytest.raises(ValueError): + with pytest.raises(IndexError): da['1', 'hellohello'] = 'hello' From 7be25b75706a1c5b10c8bab2330fe79a0fb0a5db Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 1 Feb 2022 11:44:12 +0100 Subject: [PATCH 39/41] fix: use separate method _set_doc_value_pairs_nested --- docarray/array/mixins/setitem.py | 4 ++-- docarray/array/storage/base/getsetdel.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 93f0fcc0366..4c8d8440f50 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -68,7 +68,7 @@ def __setitem__( self._set_doc_by_offset(int(index), value) elif isinstance(index, str): if index.startswith('@'): - self._set_doc_value_pairs(self.traverse_flat(index[1:]), value) + self._set_doc_value_pairs_nested(self.traverse_flat(index[1:]), value) else: self._set_doc_by_id(index, value) elif isinstance(index, slice): @@ -199,4 +199,4 @@ def _set_docs_attributes_traversal_paths( else: for _d, _vv in zip(_docs, _v): setattr(_d, _a, _vv) - self._set_doc_value_pairs(_docs, _docs) + self._set_doc_value_pairs_nested(_docs, _docs) diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index e2ddd6c9732..85ed140fb56 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -122,6 +122,19 @@ def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): def _set_doc_value_pairs( self, docs: Iterable['Document'], values: Sequence['Document'] + ): + docs = list(docs) + if len(docs) != len(values): + raise ValueError( + f'length of docs to set({len(docs)}) does not match ' + f'length of values({len(values)})' + ) + + for _d, _v in zip(docs, values): + self._set_doc_by_id(_d.id, _v) + + def _set_doc_value_pairs_nested( + self, docs: Iterable['Document'], values: Sequence['Document'] ): """This function is derived and may not have the most efficient implementation. From 30ee9a84bb531b34a21eade00ee1c8bb34662d59 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 1 Feb 2022 12:06:43 +0100 Subject: [PATCH 40/41] fix: add flattened flag to DocumentArray after flattening --- docarray/array/mixins/traverse.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docarray/array/mixins/traverse.py b/docarray/array/mixins/traverse.py index 7003d66fc46..5c70ccb7068 100644 --- a/docarray/array/mixins/traverse.py +++ b/docarray/array/mixins/traverse.py @@ -129,6 +129,9 @@ def flatten(self) -> 'DocumentArray': """ from .. import DocumentArray + if hasattr(self, '_flattened') and getattr(self, '_flattened'): + return self + visited = set() def _yield_all(): @@ -145,7 +148,9 @@ def _yield_nest(doc: 'Document'): yield doc - return DocumentArray(_yield_all()) + da = DocumentArray(_yield_all()) + da._flattened = True + return da @staticmethod def _flatten(sequence) -> 'DocumentArray': From 459b6c84823556db9f8a6d7d276c0ea454942713 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Tue, 1 Feb 2022 12:17:43 +0100 Subject: [PATCH 41/41] feat: do not allow setting by traversal paths with diff ID --- docarray/array/storage/base/getsetdel.py | 4 ++++ tests/unit/array/test_advance_indexing.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 85ed140fb56..2d831a715bf 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -150,6 +150,10 @@ def _set_doc_value_pairs_nested( ) for _d, _v in zip(docs, values): + if _d.id != _v.id: + raise ValueError( + 'Setting Documents by traversal paths with different IDs is not supported' + ) _d._data = _v._data if _d not in self: root_d = self._find_root_doc_and_modify(_d) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 8d75601991c..e12c5080faf 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -273,6 +273,10 @@ def test_path_syntax_indexing_set(storage, start_weaviate): da['@m'] = [Document(id=f'm{i}', text='c') for i in range(3 * 7)] assert da['@m', 'text'] == repeat('c', 3 * 7) + # setting by traversal paths with different IDs is not supported + with pytest.raises(ValueError): + da['@m'] = [Document() for _ in range(3 * 7)] + # TODO also test cases like da[1, ['text', 'id']], # where first index is str/int and second is attr