From 9e00a008ec1cd26aac14b685446cc440ca00a2f2 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 18 Aug 2022 15:09:30 +0200 Subject: [PATCH] feat: push meta data along with docarray --- docarray/array/mixins/io/pushpull.py | 77 +++++++++++++++++++++++--- tests/unit/array/mixins/test_io.py | 10 ++++ tests/unit/array/test_from_to_bytes.py | 2 +- 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/docarray/array/mixins/io/pushpull.py b/docarray/array/mixins/io/pushpull.py index 5e9cbbad1bc..37808d2b07f 100644 --- a/docarray/array/mixins/io/pushpull.py +++ b/docarray/array/mixins/io/pushpull.py @@ -1,7 +1,11 @@ import os import warnings from pathlib import Path -from typing import Dict, Type, TYPE_CHECKING, Optional +from typing import Dict, Type, TYPE_CHECKING, List, Optional + +import hubble +from hubble import Client +from hubble.client.endpoints import EndpointsV2 from docarray.helper import get_request_header, __cache_path__ @@ -9,11 +13,70 @@ from docarray.typing import T +def _get_length_from_summary(summary: List[Dict]) -> Optional[int]: + """Get the length from summary.""" + for item in summary: + if 'Length' == item['name']: + return item['value'] + + class PushPullMixin: """Transmitting :class:`DocumentArray` via Jina Cloud Service""" _max_bytes = 4 * 1024 * 1024 * 1024 + @classmethod + @hubble.login_required + def cloud_list(cls, show_table: bool = False) -> List[str]: + """List all available arrays in the cloud. + + :param show_table: if true, show the table of the arrays. + :returns: List of available DocumentArray's names. + """ + + result = [] + from rich.table import Table + from rich import box + + table = Table( + title='Your DocumentArray on the cloud', box=box.SIMPLE, highlight=True + ) + table.add_column('Name') + table.add_column('Length') + table.add_column('Visibility') + table.add_column('Create at', justify='center') + table.add_column('Updated at', justify='center') + + for da in Client(jsonify=True).list_artifacts( + filter={'type': 'documentArray'}, sort={'createdAt': 1} + )['data']: + if da['type'] == 'documentArray': + result.append(da['name']) + + table.add_row( + da['name'], + str(_get_length_from_summary(da['metaData'].get('summary', []))), + da['visibility'], + da['createdAt'], + da['updatedAt'], + ) + + if show_table: + from rich import print + + print(table) + return result + + @classmethod + @hubble.login_required + def cloud_delete(cls, name: str) -> None: + """ + Delete a DocumentArray from the cloud. + :param name: the name of the DocumentArray to delete. + """ + Client(jsonify=True).delete_artifact(name) + + @hubble.login_required def push( self, name: str, @@ -51,7 +114,6 @@ def push( ) headers = {'Content-Type': ctype, **get_request_header()} - import hubble auth_token = hubble.get_token() if auth_token: @@ -98,8 +160,6 @@ def _get_chunk(_batch): yield _tail with pbar: - from hubble import Client - from hubble.client.endpoints import EndpointsV2 response = requests.post( Client()._base_url + EndpointsV2.upload_artifact, @@ -113,6 +173,7 @@ def _get_chunk(_batch): response.raise_for_status() @classmethod + @hubble.login_required def pull( cls: Type['T'], name: str, @@ -133,16 +194,11 @@ def pull( headers = {} - import hubble - auth_token = hubble.get_token() if auth_token: headers['Authorization'] = f'token {auth_token}' - from hubble import Client - from hubble.client.endpoints import EndpointsV2 - url = Client()._base_url + EndpointsV2.download_artifact + f'?name={name}' response = requests.get(url, headers=headers) @@ -183,3 +239,6 @@ def pull( fp.write(_source.content) return r + + cloud_push = push + cloud_pull = pull diff --git a/tests/unit/array/mixins/test_io.py b/tests/unit/array/mixins/test_io.py index 383020fb3d2..026befb2fe3 100644 --- a/tests/unit/array/mixins/test_io.py +++ b/tests/unit/array/mixins/test_io.py @@ -237,6 +237,16 @@ def test_push_pull_io(da_cls, config, show_progress, start_storage): assert len(da1) == len(da2) == 10 assert da1.texts == da2.texts == random_texts + all_names = DocumentArray.cloud_list() + + assert name in all_names + + DocumentArray.cloud_delete(name) + + all_names = DocumentArray.cloud_list() + + assert name not in all_names + @pytest.mark.parametrize( 'protocol', ['protobuf', 'pickle', 'protobuf-array', 'pickle-array'] diff --git a/tests/unit/array/test_from_to_bytes.py b/tests/unit/array/test_from_to_bytes.py index a8b77d90327..fd500e8ce67 100644 --- a/tests/unit/array/test_from_to_bytes.py +++ b/tests/unit/array/test_from_to_bytes.py @@ -110,7 +110,7 @@ def test_from_to_safe_list(target, protocol, to_fn): @pytest.mark.parametrize('protocol', ['protobuf', 'pickle']) @pytest.mark.parametrize('show_progress', [True, False]) -def test_push_pull_show_progress(show_progress, protocol): +def test_to_bytes_show_progress(show_progress, protocol): da = DocumentArray.empty(1000) r = da.to_bytes(_show_progress=show_progress, protocol=protocol) da_r = DocumentArray.from_bytes(r, _show_progress=show_progress, protocol=protocol)