diff --git a/docarray/array/document.py b/docarray/array/document.py index 688ea90228b..7c535b25330 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack from typing import Optional, overload, TYPE_CHECKING, Dict, Union from docarray.array.base import BaseDocumentArray @@ -140,13 +141,18 @@ def __new__( ... def __enter__(self): + self._exit_stack = ExitStack() + # Ensure that we sync the data to the storage backend when exiting the context manager + self._exit_stack.callback(self.sync) + # Enter (and then exit) context of all subindices + if getattr(self, '_subindices', None): + for selector, da in self._subindices.items(): + self._exit_stack.enter_context(da) return self def __exit__(self, *args, **kwargs): - """ - Ensures that we sync the data to the storage backend when exiting the context manager - """ - self.sync() + # Trigger all __exit__()s and callbacks added in self.__enter__() + self._exit_stack.close() def __new__(cls, *args, storage: str = 'memory', **kwargs): if cls is DocumentArray: diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index ed9db6558a6..2b29e81d7ab 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -326,7 +326,3 @@ def _save_offset2ids(self): def sync(self): if hasattr(self, '_offset2ids'): self._save_offset2ids() - - if getattr(self, '_subindices', None): - for selector, da in self._subindices.items(): - da.sync()