Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 102 additions & 4 deletions docarray/array/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Sequence,
Iterable,
overload,
Any,
List,
)

import numpy as np
Expand All @@ -24,6 +26,8 @@
DocumentArrayIndexType,
DocumentArraySingletonIndexType,
DocumentArrayMultipleIndexType,
DocumentArrayMultipleAttributeType,
DocumentArraySingleAttributeType,
)


Expand Down Expand Up @@ -109,7 +113,17 @@ def __getitem__(self, index: 'DocumentArraySingletonIndexType') -> 'Document':
...

@overload
def __getitem__(self, index: 'DocumentArrayMultipleIndexType') -> 'Document':
def __getitem__(self, index: 'DocumentArrayMultipleIndexType') -> 'DocumentArray':
...

@overload
def __getitem__(self, index: 'DocumentArraySingleAttributeType') -> List[Any]:
...

@overload
def __getitem__(
self, index: 'DocumentArrayMultipleAttributeType'
) -> List[List[Any]]:
...

def __getitem__(
Expand All @@ -127,7 +141,18 @@ def __getitem__(
elif index is Ellipsis:
return self.flatten()
elif isinstance(index, Sequence):
if isinstance(index[0], bool):
if (
isinstance(index, tuple)
and len(index) == 2
and isinstance(index[0], (slice, Sequence))
):
_docs = self[index[0]]
_attrs = index[1]
if isinstance(_attrs, str):
_attrs = (index[1],)

return _docs.get_attributes(*_attrs)
elif isinstance(index[0], bool):
return DocumentArray(itertools.compress(self._data, index))
elif isinstance(index[0], int):
return DocumentArray(self._data[t] for t in index)
Expand All @@ -143,6 +168,38 @@ def __getitem__(
)
raise IndexError(f'Unsupported index type {typename(index)}: {index}')

@overload
def __setitem__(
self,
index: 'DocumentArrayMultipleAttributeType',
value: List[List['Any']],
):
...

@overload
def __setitem__(
self,
index: 'DocumentArraySingleAttributeType',
value: List['Any'],
):
...

@overload
def __setitem__(
self,
index: 'DocumentArraySingletonIndexType',
value: 'Document',
):
...

@overload
def __setitem__(
self,
index: 'DocumentArrayMultipleIndexType',
value: Sequence['Document'],
):
...

def __setitem__(
self,
index: 'DocumentArrayIndexType',
Expand All @@ -169,7 +226,37 @@ def __setitem__(
_d._data = _v._data
self._rebuild_id2offset()
elif isinstance(index, Sequence):
if isinstance(index[0], bool):
if (
isinstance(index, tuple)
and len(index) == 2
and isinstance(index[0], (slice, Sequence))
):
_docs = self[index[0]]
_attrs = index[1]

if isinstance(_attrs, str):
# a -> [a]
# [a, a] -> [a, a]
_attrs = (index[1],)
if isinstance(value, (list, tuple)) and not any(
isinstance(el, (tuple, list)) for el in value
):
# [x] -> [[x]]
# [[x], [y]] -> [[x], [y]]
value = (value,)
if not isinstance(value, (list, tuple)):
# x -> [x]
value = (value,)

for _a, _v in zip(_attrs, value):
if _a == 'blob':
_docs.blobs = _v
elif _a == 'embedding':
_docs.embeddings = _v
else:
for _d, _vv in zip(_docs, _v):
setattr(_d, _a, _vv)
elif isinstance(index[0], bool):
if len(index) != len(self._data):
raise IndexError(
f'Boolean mask index is required to have the same length as {len(self._data)}, '
Expand Down Expand Up @@ -221,7 +308,18 @@ def __delitem__(self, index: 'DocumentArrayIndexType'):
self._data.clear()
self._id2offset.clear()
elif isinstance(index, Sequence):
if isinstance(index[0], bool):
if (
isinstance(index, tuple)
and len(index) == 2
and isinstance(index[0], (slice, Sequence))
):
_docs = self[index[0]]
_attrs = index[1]
if isinstance(_attrs, str):
_attrs = (index[1],)
for _d in _docs:
_d.pop(*_attrs)
elif isinstance(index[0], bool):
self._data = list(
itertools.compress(self._data, (not _i for _i in index))
)
Expand Down
73 changes: 38 additions & 35 deletions docarray/array/mixins/getattr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from typing import Union, List, Tuple, TYPE_CHECKING

if TYPE_CHECKING:
from ... import DocumentArray
from typing import List


class GetAttributeMixin:
Expand All @@ -14,34 +11,40 @@ def get_attributes(self, *fields: str) -> List:
:return: Returns a list of the values for these fields.
When `fields` has multiple values, then it returns a list of list.
"""
contents = [doc.get_attributes(*fields) for doc in self]

if len(fields) > 1:
contents = list(map(list, zip(*contents)))

return contents

def get_attributes_with_docs(
self,
*fields: str,
) -> Tuple[List, 'DocumentArray']:
"""Return all nonempty values of the fields together with their nonempty docs

:param fields: Variable length argument with the name of the fields to extract
:return: Returns a tuple. The first element is a list of the values for these fields.
When `fields` has multiple values, then it returns a list of list. The second element is the non-empty docs.
"""

contents = []
docs_pts = []

for doc in self:
contents.append(doc.get_attributes(*fields))
docs_pts.append(doc)

if len(fields) > 1:
contents = list(map(list, zip(*contents)))

from ... import DocumentArray

return contents, DocumentArray(docs_pts)
e_index, b_index = None, None
fields = list(fields)
if 'embedding' in fields:
e_index = fields.index('embedding')
if 'blob' in fields:
b_index = fields.index('blob')
fields.remove('blob')

if 'embedding' in fields:
fields.remove('embedding')
if 'blob' in fields:
fields.remove('blob')

if fields:
contents = [doc.get_attributes(*fields) for doc in self]
if len(fields) > 1:
contents = list(map(list, zip(*contents)))
if b_index is None and e_index is None:
return contents

contents = [contents]
if b_index is not None:
contents.insert(b_index, self.blobs)
if e_index is not None:
contents.insert(e_index, self.embeddings)
return contents

if b_index is not None and e_index is None:
return self.blobs
if b_index is None and e_index is not None:
return self.embeddings
if b_index is not None and e_index is not None:
return (
[self.embeddings, self.blobs]
if b_index > e_index
else [self.blobs, self.embeddings]
)
28 changes: 22 additions & 6 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ def summary(self):
is_homo = len(attr_counter) == 1
table.add_row('Homogenous Documents', str(is_homo))

all_attrs_names = set(v for k in all_attrs for v in k)
_nested_in = []
if 'chunks' in all_attrs_names:
_nested_in.append('chunks')

if 'matches' in all_attrs_names:
_nested_in.append('matches')

if _nested_in:
table.add_row('Has nested Documents in', str(tuple(_nested_in)))

if is_homo:
table.add_row('Common Attributes', str(list(attr_counter.items())[0][0]))
else:
Expand All @@ -44,25 +55,30 @@ def summary(self):
_text = f'{_doc_text} attributes'
table.add_row(_text, str(_a))

console = Console()
all_attrs_names = tuple(sorted(all_attrs_names))
if not all_attrs_names:
console.print(table)
return

attr_table = Table(box=box.SIMPLE, title='Attributes Summary')
attr_table.add_column('Attribute')
attr_table.add_column('Data type')
attr_table.add_column('#Unique values')
attr_table.add_column('Has empty value')

all_attrs_names = tuple(sorted(set(v for k in all_attrs for v in k)))
all_attrs_values = self.get_attributes(*all_attrs_names)
if len(all_attrs_names) == 1:
all_attrs_values = [all_attrs_values]
for _a, _a_name in zip(all_attrs_values, all_attrs_names):
_counter_a = Counter(_a)
_set_a = set(_a)
try:
_a = set(_a)
except:
pass
_set_type_a = set(type(_aa).__name__ for _aa in _a)
attr_table.add_row(
_a_name, str(tuple(_set_type_a)), str(len(_set_a)), str(None in _set_a)
_a_name, str(tuple(_set_type_a)), str(len(_a)), str(None in _a)
)

console = Console()
console.print(table, attr_table)

def plot_embeddings(
Expand Down
47 changes: 37 additions & 10 deletions docarray/array/mixins/traverse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import itertools
import re
from typing import (
Iterable,
TYPE_CHECKING,
Optional,
Callable,
Union,
Tuple,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,24 +51,26 @@ def _traverse(
path: str,
filter_fn: Optional[Callable[['Document'], bool]] = None,
):
path = path.strip()
path = re.sub(r'\s+', '', path)
if path:
loc = path[0]
if loc == 'r':
yield from TraverseMixin._traverse(docs, path[1:], filter_fn=filter_fn)
elif loc == 'm':
cur_loc, cur_slice, _left = _parse_path_string(path)
if cur_loc == 'r':
yield from TraverseMixin._traverse(
docs[cur_slice], _left, filter_fn=filter_fn
)
elif cur_loc == 'm':
for d in docs:
yield from TraverseMixin._traverse(
d.matches, path[1:], filter_fn=filter_fn
d.matches[cur_slice], _left, filter_fn=filter_fn
)
elif loc == 'c':
elif cur_loc == 'c':
for d in docs:
yield from TraverseMixin._traverse(
d.chunks, path[1:], filter_fn=filter_fn
d.chunks[cur_slice], _left, filter_fn=filter_fn
)
else:
raise ValueError(
f'`path`:{loc} is invalid, must be one of `c`, `r`, `m`'
f'`path`:{path} is invalid, please refer to https://docarray.jina.ai/fundamentals/documentarray/access-elements/#index-by-nested-structure'
)
elif filter_fn is None:
yield docs
Expand Down Expand Up @@ -148,3 +151,27 @@ def _flatten(sequence) -> 'DocumentArray':
from ... import DocumentArray

return DocumentArray(list(itertools.chain.from_iterable(sequence)))


def _parse_path_string(p: str) -> Tuple[str, slice, str]:
g = re.match(r'^([rcm])([-\d:]+)?([rcm].*)?$', p)
_this = g.group(1)
slice_str = g.group(2)
_next = g.group(3)
return _this, _parse_slice(slice_str or ':'), _next or ''


def _parse_slice(value):
"""
Parses a `slice()` from string, like `start:stop:step`.
"""
if value:
parts = value.split(':')
if len(parts) == 1:
# slice(stop)
parts = [None, parts[0]]
# else: slice(start, stop[, step])
else:
# slice()
parts = []
return slice(*[int(p) if p else None for p in parts])
Loading