Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 70 additions & 9 deletions docarray/array/mixins/io/pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,83 @@
import warnings
from collections import Counter
from pathlib import Path
from typing import Dict, Type, TYPE_CHECKING, Any, List
from typing import Dict, Type, TYPE_CHECKING, List, Optional

import hubble
from hubble import Client
from hubble.client.endpoints import EndpointsV2


from docarray.helper import get_request_header, __cache_path__

if TYPE_CHECKING:
from docarray.typing import T


def _get_length_from_summary(summary: List[Dict]) -> Optional[int]:
"""Get the length from summary."""
for item in summary:
if 'Length' == item['name']:
return item['value']


class PushPullMixin:
"""Transmitting :class:`DocumentArray` via Jina Cloud Service"""

_max_bytes = 4 * 1024 * 1024 * 1024

@classmethod
@hubble.login_required
def cloud_list(cls, show_table: bool = False) -> List[str]:
"""List all available arrays in the cloud.

:param show_table: if true, show the table of the arrays.
:returns: List of available DocumentArray's names.
"""

result = []
from rich.table import Table
from rich import box

table = Table(
title='Your DocumentArray on the cloud', box=box.SIMPLE, highlight=True
)
table.add_column('Name')
table.add_column('Length')
table.add_column('Visibility')
table.add_column('Create at', justify='center')
table.add_column('Updated at', justify='center')

for da in Client(jsonify=True).list_artifacts(
filter={'type': 'documentArray'}, sort={'createdAt': 1}
)['data']:
if da['type'] == 'documentArray':
result.append(da['name'])

table.add_row(
da['name'],
str(_get_length_from_summary(da['metaData'].get('summary', []))),
da['visibility'],
da['createdAt'],
da['updatedAt'],
)

if show_table:
from rich import print

print(table)
return result

@classmethod
@hubble.login_required
def cloud_delete(cls, name: str) -> None:
"""
Delete a DocumentArray from the cloud.
:param name: the name of the DocumentArray to delete.
"""
Client(jsonify=True).delete_artifact(name)


def _get_raw_summary(self) -> List[Dict[str, Any]]:
all_attrs = self._get_attributes('non_empty_fields')
# remove underscore attribute
Expand Down Expand Up @@ -106,6 +170,7 @@ def _get_raw_summary(self) -> List[Dict[str, Any]]:

return items

@hubble.login_required
def push(
self,
name: str,
Expand Down Expand Up @@ -149,7 +214,6 @@ def push(
)

headers = {'Content-Type': ctype, **get_request_header()}
import hubble

auth_token = hubble.get_token()
if auth_token:
Expand Down Expand Up @@ -196,8 +260,6 @@ def _get_chunk(_batch):
yield _tail

with pbar:
from hubble import Client
from hubble.client.endpoints import EndpointsV2

response = requests.post(
Client()._base_url + EndpointsV2.upload_artifact,
Expand All @@ -211,6 +273,7 @@ def _get_chunk(_batch):
response.raise_for_status()

@classmethod
@hubble.login_required
def pull(
cls: Type['T'],
name: str,
Expand All @@ -231,16 +294,11 @@ def pull(

headers = {}

import hubble

auth_token = hubble.get_token()

if auth_token:
headers['Authorization'] = f'token {auth_token}'

from hubble import Client
from hubble.client.endpoints import EndpointsV2

url = Client()._base_url + EndpointsV2.download_artifact + f'?name={name}'
response = requests.get(url, headers=headers)

Expand Down Expand Up @@ -281,3 +339,6 @@ def pull(
fp.write(_source.content)

return r

cloud_push = push
cloud_pull = pull
10 changes: 10 additions & 0 deletions tests/unit/array/mixins/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ def test_push_pull_io(da_cls, config, show_progress, start_storage):
assert len(da1) == len(da2) == 10
assert da1.texts == da2.texts == random_texts

all_names = DocumentArray.cloud_list()

assert name in all_names

DocumentArray.cloud_delete(name)

all_names = DocumentArray.cloud_list()

assert name not in all_names


@pytest.mark.parametrize(
'protocol', ['protobuf', 'pickle', 'protobuf-array', 'pickle-array']
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/array/test_from_to_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_from_to_safe_list(target, protocol, to_fn):

@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
@pytest.mark.parametrize('show_progress', [True, False])
def test_push_pull_show_progress(show_progress, protocol):
def test_to_bytes_show_progress(show_progress, protocol):
da = DocumentArray.empty(1000)
r = da.to_bytes(_show_progress=show_progress, protocol=protocol)
da_r = DocumentArray.from_bytes(r, _show_progress=show_progress, protocol=protocol)
Expand Down