From 765ee2346510d26f1deee9d4bc2dc95042a85303 Mon Sep 17 00:00:00 2001 From: Delgermurun Purevkhuu Date: Thu, 17 Nov 2022 12:00:35 +0100 Subject: [PATCH] fix: pull docarray with username/da-name format Signed-off-by: Delgermurun Purevkhuu --- docarray/array/mixins/io/pushpull.py | 10 ++++++---- tests/unit/array/mixins/test_pushpull.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/docarray/array/mixins/io/pushpull.py b/docarray/array/mixins/io/pushpull.py index e266d168984..7bb196704c6 100644 --- a/docarray/array/mixins/io/pushpull.py +++ b/docarray/array/mixins/io/pushpull.py @@ -303,10 +303,12 @@ def pull( from docarray.array.mixins.io.binary import LazyRequestReader _source = LazyRequestReader(r) - if local_cache and os.path.exists(f'{__cache_path__}/{name}'): - _cache_len = os.path.getsize(f'{__cache_path__}/{name}') + + cache_file = f'{__cache_path__}/{name.replace("/", "_")}.da' + if local_cache and os.path.exists(cache_file): + _cache_len = os.path.getsize(cache_file) if _cache_len == _da_len: - _source = f'{__cache_path__}/{name}' + _source = cache_file r = cls.load_binary( _source, @@ -319,7 +321,7 @@ def pull( if isinstance(_source, LazyRequestReader) and local_cache: Path(__cache_path__).mkdir(parents=True, exist_ok=True) - with open(f'{__cache_path__}/{name}', 'wb') as fp: + with open(cache_file, 'wb') as fp: fp.write(_source.content) return r diff --git a/tests/unit/array/mixins/test_pushpull.py b/tests/unit/array/mixins/test_pushpull.py index 823e463b500..0ea5f7693e9 100644 --- a/tests/unit/array/mixins/test_pushpull.py +++ b/tests/unit/array/mixins/test_pushpull.py @@ -115,11 +115,12 @@ def test_push_with_public(mocker, monkeypatch, public): assert form_data['public'] == [str(public)] -def test_pull(mocker, monkeypatch): +@pytest.mark.parametrize('da_name', ['test_name', 'username/test_name']) +def test_pull(mocker, monkeypatch, da_name): mock = mocker.Mock() _mock_get(mock, monkeypatch) - DocumentArray.pull(name='test_name') + DocumentArray.pull(name=da_name) assert mock.call_count == 2 # 1 for pull, 1 for download