diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index d1605292143..f69200efdea 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -32,6 +32,8 @@ class SqliteConfig: table_name: Optional[str] = None serialize_config: Dict = field(default_factory=dict) conn_config: Dict = field(default_factory=dict) + journal_mode: str = 'DELETE' + synchronous: str = 'OFF' class BackendMixin(BaseBackendMixin): @@ -69,7 +71,9 @@ def _init_storage( 'Document', lambda x: Document.from_bytes(x, **config.serialize_config) ) - _conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES) + _conn_kwargs = dict( + detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False + ) _conn_kwargs.update(config.conn_config) if config.connection is None: self._connection = sqlite3.connect( @@ -83,6 +87,8 @@ def _init_storage( raise TypeError( f'connection argument must be None or a string or a sqlite3.Connection, not `{type(config.connection)}`' ) + self._connection.execute(f'PRAGMA synchronous={config.synchronous}') + self._connection.execute(f'PRAGMA journal_mode={config.journal_mode}') self._table_name = ( _sanitize_table_name(self.__class__.__name__ + random_identity()) diff --git a/tests/unit/array/mixins/test_parallel.py b/tests/unit/array/mixins/test_parallel.py index 9f332ae7264..396a472b81d 100644 --- a/tests/unit/array/mixins/test_parallel.py +++ b/tests/unit/array/mixins/test_parallel.py @@ -110,3 +110,22 @@ def test_map_lambda(pytestconfig, da_cls): for d in da.map(lambda x: x.load_uri_to_image_tensor()): assert d.tensor is not None + + +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +@pytest.mark.parametrize('backend', ['thread', 'process']) +def test_apply_diff_backend_storage(storage, backend): + da = DocumentArray( + (Document(text='hello world she smiled too much') for _ in range(1000)), + storage=storage, + ) + da.apply(lambda d: d.embed_feature_hashing(), backend=backend) + + q = ( + Document(text='she smiled too much') + .embed_feature_hashing() + .match(da, metric='jaccard', use_scipy=True) + ) + + assert len(q.matches[:5, ('text', 'scores__jaccard__value')]) == 2 + assert len(q.matches[:5, ('text', 'scores__jaccard__value')][0]) == 5