Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docarray/array/mixins/io/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Union, BinaryIO, TYPE_CHECKING, Type, Optional, Generator

from ....helper import (
__windows__,
get_compress_ctx,
decompress_bytes,
protocol_and_compress_from_file_path,
Expand Down
37 changes: 34 additions & 3 deletions docarray/array/mixins/io/pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,21 @@ def read(self, n=-1):

@classmethod
def pull(
cls: Type['T'], token: str, show_progress: bool = False, *args, **kwargs
cls: Type['T'],
token: str,
show_progress: bool = False,
local_cache: bool = False,
*args,
**kwargs,
) -> 'T':
"""Pulling a :class:`DocumentArray` from Jina Cloud Service to local.

:param token: the upload token set during :meth:`.push`
:param show_progress: if to show a progress bar on pulling
:param local_cache: store the downloaded DocumentArray to local folder
:return: a :class:`DocumentArray` object
"""

import requests

url = f'{_get_cloud_api()}/v2/rpc/da.pull?token={token}'
Expand All @@ -127,9 +134,27 @@ def pull(
headers=get_request_header(),
) as r, progress:
r.raise_for_status()

_da_len = int(r.headers['Content-length'])

if local_cache and os.path.exists(f'.cache/{token}'):
_cache_len = os.path.getsize(f'.cache/{token}')
if _cache_len == _da_len:
if show_progress:
progress.stop()

return cls.load_binary(
f'.cache/{token}',
protocol='protobuf',
compress='gzip',
_show_progress=show_progress,
*args,
**kwargs,
)

if show_progress:
task_id = progress.add_task('download', start=False)
progress.update(task_id, total=int(r.headers['Content-length']))
progress.update(task_id, total=int(_da_len))
with io.BytesIO() as f:
chunk_size = 8192
if show_progress:
Expand All @@ -139,8 +164,14 @@ def pull(
if show_progress:
progress.update(task_id, advance=len(chunk))

if local_cache:
os.makedirs('.cache', exist_ok=True)
with open(f'.cache/{token}', 'wb') as fp:
fp.write(f.getbuffer())

if show_progress:
progress.stop()

return cls.from_bytes(
f.getvalue(),
protocol='protobuf',
Expand Down Expand Up @@ -180,7 +211,7 @@ def _get_progressbar(show_progress):
)

return Progress(
BarColumn(bar_width=None),
BarColumn(),
"[progress.percentage]{task.percentage:>3.1f}%",
"•",
DownloadColumn(),
Expand Down
34 changes: 32 additions & 2 deletions docarray/array/mixins/parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from math import ceil
from types import LambdaType
from typing import Callable, TYPE_CHECKING, Generator, Optional, overload, TypeVar

Expand All @@ -18,6 +19,7 @@ def apply(
func: Callable[['Document'], 'Document'],
backend: str = 'thread',
num_worker: Optional[int] = None,
show_progress: bool = False,
) -> 'T':
"""Apply each element in itself with ``func``, return itself after modified.

Expand All @@ -33,6 +35,7 @@ def apply(
and the original object do **not** share the same memory.

:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:param show_progress: show a progress bar

"""
...
Expand All @@ -56,6 +59,7 @@ def map(
func: Callable[['Document'], 'T'],
backend: str = 'thread',
num_worker: Optional[int] = None,
show_progress: bool = False,
) -> Generator['T', None, None]:
"""Return an iterator that applies function to every **element** of iterable in parallel, yielding the results.

Expand All @@ -76,12 +80,22 @@ def map(
and the original object do **not** share the same memory.

:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:param show_progress: show a progress bar

:yield: anything return from ``func``
"""
if _is_lambda_or_local_function(func) and backend == 'process':
func = _globalize_lambda_function(func)

if show_progress:
from rich.progress import track as _track

track = lambda x: _track(x, total=len(self))
else:
track = lambda x: x

with _get_pool(backend, num_worker) as p:
for x in p.imap(func, self):
for x in track(p.imap(func, self)):
yield x

@overload
Expand All @@ -92,6 +106,7 @@ def apply_batch(
backend: str = 'thread',
num_worker: Optional[int] = None,
shuffle: bool = False,
show_progress: bool = False,
) -> 'T':
"""Apply each element in itself with ``func``, return itself after modified.

Expand All @@ -109,6 +124,8 @@ def apply_batch(
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32)
:param shuffle: If set, shuffle the Documents before dividing into minibatches.
:param show_progress: show a progress bar

"""
...

Expand All @@ -135,6 +152,7 @@ def map_batch(
backend: str = 'thread',
num_worker: Optional[int] = None,
shuffle: bool = False,
show_progress: bool = False,
) -> Generator['T', None, None]:
"""Return an iterator that applies function to every **minibatch** of iterable in parallel, yielding the results.
Each element in the returned iterator is :class:`DocumentArray`.
Expand All @@ -158,13 +176,25 @@ def map_batch(
and the original object do **not** share the same memory.

:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:param show_progress: show a progress bar

:yield: anything return from ``func``
"""

if _is_lambda_or_local_function(func) and backend == 'process':
func = _globalize_lambda_function(func)

if show_progress:
from rich.progress import track as _track

track = lambda x: _track(x, total=ceil(len(self) / batch_size))
else:
track = lambda x: x

with _get_pool(backend, num_worker) as p:
for x in p.imap(func, self.batch(batch_size=batch_size, shuffle=shuffle)):
for x in track(
p.imap(func, self.batch(batch_size=batch_size, shuffle=shuffle))
):
yield x


Expand Down
11 changes: 11 additions & 0 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def plot_image_sprites(
min_size: int = 16,
channel_axis: int = -1,
image_source: str = 'tensor',
skip_empty: bool = False,
) -> None:
"""Generate a sprite image for all image tensors in this DocumentArray-like object.

Expand All @@ -314,6 +315,7 @@ def plot_image_sprites(
:param min_size: the minimum size of the image
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri
:param skip_empty: skip Document who has no .uri or .tensor.
"""
if not self:
raise ValueError(f'{self!r} is empty')
Expand All @@ -335,6 +337,15 @@ def plot_image_sprites(
img_id = 0
try:
for d in self:

if not d.uri and d.tensor is None:
if skip_empty:
continue
else:
raise ValueError(
f'Document has neither `uri` nor `tensor`, can not be plotted'
)

_d = copy.deepcopy(d)

if image_source == 'uri' or (
Expand Down
5 changes: 3 additions & 2 deletions docs/fundamentals/documentarray/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ Considering you are working on a GPU machine via Google Colab/Jupyter. After pre
from docarray import DocumentArray

da = DocumentArray(...) # heavylifting, processing, GPU task, ...
da.push(token='myda123')
da.push(token='myda123', show_progress=True)
```

```{figure} images/da-push.png
Expand All @@ -377,11 +377,12 @@ Then on your local laptop, simply pull it:
```python
from docarray import DocumentArray

da = DocumentArray.pull(token='myda123')
da = DocumentArray.pull(token='myda123', show_progress=True)
```

Now you can continue the work at local, analyzing `da` or visualizing it. Your friends & colleagues who know the token `myda123` can also pull that DocumentArray. It's useful when you want to quickly share the results with your colleagues & friends.

The maximum size of an upload is 4GB under the `protocol='protobuf'` and `compress='gzip'` setting. The lifetime of an upload is one week after its creation.

To avoid unnecessary download when upstream DocumentArray is unchanged, you can add `DocumentArray.pull(..., local_cache=True)`.

31 changes: 26 additions & 5 deletions tests/unit/array/mixins/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def foo_batch(da: DocumentArray):
)
@pytest.mark.parametrize('backend', ['process', 'thread'])
@pytest.mark.parametrize('num_worker', [1, 2, None])
def test_parallel_map(pytestconfig, da_cls, config, backend, num_worker, start_storage):
@pytest.mark.parametrize('show_progress', [True, False])
def test_parallel_map(
pytestconfig, da_cls, config, backend, num_worker, start_storage, show_progress
):
if __name__ == '__main__':

if config:
Expand All @@ -47,7 +50,9 @@ def test_parallel_map(pytestconfig, da_cls, config, backend, num_worker, start_s
da = da_cls.from_files(f'{pytestconfig.rootdir}/**/*.jpeg')[:10]

# use a generator
for d in da.map(foo, backend, num_worker=num_worker):
for d in da.map(
foo, backend, num_worker=num_worker, show_progress=show_progress
):
assert d.tensor.shape == (3, 222, 222)

if config:
Expand Down Expand Up @@ -87,8 +92,16 @@ def test_parallel_map(pytestconfig, da_cls, config, backend, num_worker, start_s
@pytest.mark.parametrize('backend', ['thread'])
@pytest.mark.parametrize('num_worker', [1, 2, None])
@pytest.mark.parametrize('b_size', [1, 2, 256])
@pytest.mark.parametrize('show_progress', [True, False])
def test_parallel_map_batch(
pytestconfig, da_cls, config, backend, num_worker, b_size, start_storage
pytestconfig,
da_cls,
config,
backend,
num_worker,
b_size,
start_storage,
show_progress,
):
if __name__ == '__main__':

Expand All @@ -101,7 +114,11 @@ def test_parallel_map_batch(

# use a generator
for _da in da.map_batch(
foo_batch, batch_size=b_size, backend=backend, num_worker=num_worker
foo_batch,
batch_size=b_size,
backend=backend,
num_worker=num_worker,
show_progress=True,
):
for d in _da:
assert d.tensor.shape == (3, 222, 222)
Expand All @@ -116,7 +133,11 @@ def test_parallel_map_batch(
# use as list, here the caveat is when using process backend you can not modify thing in-place
list(
da.map_batch(
foo_batch, batch_size=b_size, backend=backend, num_worker=num_worker
foo_batch,
batch_size=b_size,
backend=backend,
num_worker=num_worker,
show_progress=True,
)
)
if backend == 'thread':
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/array/mixins/test_pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class PushMockResponse:
def __init__(self, status_code: int = 200):
self.status_code = status_code
self.headers = {'Content-length': 1}

def json(self):
return {'code': self.status_code}
Expand All @@ -16,6 +17,7 @@ def json(self):
class PullMockResponse:
def __init__(self, status_code: int = 200):
self.status_code = status_code
self.headers = {'Content-length': 1}

def json(self):
return {
Expand All @@ -27,6 +29,7 @@ def json(self):
class DownloadMockResponse:
def __init__(self, status_code: int = 200):
self.status_code = status_code
self.headers = {'Content-length': 1}

def raise_for_status(self):
pass
Expand Down