diff --git a/docarray/array/mixins/io/pushpull.py b/docarray/array/mixins/io/pushpull.py index 90c2d09f7be..4cec5f4a14d 100644 --- a/docarray/array/mixins/io/pushpull.py +++ b/docarray/array/mixins/io/pushpull.py @@ -4,7 +4,12 @@ import warnings from collections import Counter from pathlib import Path -from typing import Dict, Type, TYPE_CHECKING, Any, List +from typing import Dict, Type, TYPE_CHECKING, List, Optional + +import hubble +from hubble import Client +from hubble.client.endpoints import EndpointsV2 + from docarray.helper import get_request_header, __cache_path__ @@ -12,11 +17,70 @@ from docarray.typing import T +def _get_length_from_summary(summary: List[Dict]) -> Optional[int]: + """Get the length from summary.""" + for item in summary: + if 'Length' == item['name']: + return item['value'] + + class PushPullMixin: """Transmitting :class:`DocumentArray` via Jina Cloud Service""" _max_bytes = 4 * 1024 * 1024 * 1024 + @classmethod + @hubble.login_required + def cloud_list(cls, show_table: bool = False) -> List[str]: + """List all available arrays in the cloud. + + :param show_table: if true, show the table of the arrays. + :returns: List of available DocumentArray's names. + """ + + result = [] + from rich.table import Table + from rich import box + + table = Table( + title='Your DocumentArray on the cloud', box=box.SIMPLE, highlight=True + ) + table.add_column('Name') + table.add_column('Length') + table.add_column('Visibility') + table.add_column('Create at', justify='center') + table.add_column('Updated at', justify='center') + + for da in Client(jsonify=True).list_artifacts( + filter={'type': 'documentArray'}, sort={'createdAt': 1} + )['data']: + if da['type'] == 'documentArray': + result.append(da['name']) + + table.add_row( + da['name'], + str(_get_length_from_summary(da['metaData'].get('summary', []))), + da['visibility'], + da['createdAt'], + da['updatedAt'], + ) + + if show_table: + from rich import print + + print(table) + return result + + @classmethod + @hubble.login_required + def cloud_delete(cls, name: str) -> None: + """ + Delete a DocumentArray from the cloud. + :param name: the name of the DocumentArray to delete. + """ + Client(jsonify=True).delete_artifact(name) + + def _get_raw_summary(self) -> List[Dict[str, Any]]: all_attrs = self._get_attributes('non_empty_fields') # remove underscore attribute @@ -106,6 +170,7 @@ def _get_raw_summary(self) -> List[Dict[str, Any]]: return items + @hubble.login_required def push( self, name: str, @@ -149,7 +214,6 @@ def push( ) headers = {'Content-Type': ctype, **get_request_header()} - import hubble auth_token = hubble.get_token() if auth_token: @@ -196,8 +260,6 @@ def _get_chunk(_batch): yield _tail with pbar: - from hubble import Client - from hubble.client.endpoints import EndpointsV2 response = requests.post( Client()._base_url + EndpointsV2.upload_artifact, @@ -211,6 +273,7 @@ def _get_chunk(_batch): response.raise_for_status() @classmethod + @hubble.login_required def pull( cls: Type['T'], name: str, @@ -231,16 +294,11 @@ def pull( headers = {} - import hubble - auth_token = hubble.get_token() if auth_token: headers['Authorization'] = f'token {auth_token}' - from hubble import Client - from hubble.client.endpoints import EndpointsV2 - url = Client()._base_url + EndpointsV2.download_artifact + f'?name={name}' response = requests.get(url, headers=headers) @@ -281,3 +339,6 @@ def pull( fp.write(_source.content) return r + + cloud_push = push + cloud_pull = pull diff --git a/tests/unit/array/mixins/test_io.py b/tests/unit/array/mixins/test_io.py index 383020fb3d2..026befb2fe3 100644 --- a/tests/unit/array/mixins/test_io.py +++ b/tests/unit/array/mixins/test_io.py @@ -237,6 +237,16 @@ def test_push_pull_io(da_cls, config, show_progress, start_storage): assert len(da1) == len(da2) == 10 assert da1.texts == da2.texts == random_texts + all_names = DocumentArray.cloud_list() + + assert name in all_names + + DocumentArray.cloud_delete(name) + + all_names = DocumentArray.cloud_list() + + assert name not in all_names + @pytest.mark.parametrize( 'protocol', ['protobuf', 'pickle', 'protobuf-array', 'pickle-array'] diff --git a/tests/unit/array/test_from_to_bytes.py b/tests/unit/array/test_from_to_bytes.py index a8b77d90327..fd500e8ce67 100644 --- a/tests/unit/array/test_from_to_bytes.py +++ b/tests/unit/array/test_from_to_bytes.py @@ -110,7 +110,7 @@ def test_from_to_safe_list(target, protocol, to_fn): @pytest.mark.parametrize('protocol', ['protobuf', 'pickle']) @pytest.mark.parametrize('show_progress', [True, False]) -def test_push_pull_show_progress(show_progress, protocol): +def test_to_bytes_show_progress(show_progress, protocol): da = DocumentArray.empty(1000) r = da.to_bytes(_show_progress=show_progress, protocol=protocol) da_r = DocumentArray.from_bytes(r, _show_progress=show_progress, protocol=protocol)