diff --git a/docarray/array/mixins/plot.py b/docarray/array/mixins/plot.py index 94a0b21dafa..f966e7f2dc5 100644 --- a/docarray/array/mixins/plot.py +++ b/docarray/array/mixins/plot.py @@ -33,6 +33,9 @@ def summary(self): from rich.console import Console from rich import box + tables = [] + console = Console() + all_attrs = self._get_attributes('non_empty_fields') attr_counter = Counter(all_attrs) @@ -69,38 +72,43 @@ def summary(self): _text = f'{_doc_text} attributes' table.add_row(_text, str(_a)) - console = Console() + tables.append(table) + all_attrs_names = tuple(sorted(all_attrs_names)) - if not all_attrs_names: - console.print(table) - return - - attr_table = Table(box=box.SIMPLE, title='Attributes Summary') - attr_table.add_column('Attribute') - attr_table.add_column('Data type') - attr_table.add_column('#Unique values') - attr_table.add_column('Has empty value') - - all_attrs_values = self._get_attributes(*all_attrs_names) - if len(all_attrs_names) == 1: - all_attrs_values = [all_attrs_values] - for _a, _a_name in zip(all_attrs_values, all_attrs_names): - try: - _a = set(_a) - except: - pass # intentional ignore as some fields are not hashable - _set_type_a = set(type(_aa).__name__ for _aa in _a) - attr_table.add_row( - _a_name, - str(tuple(_set_type_a)), - str(len(_a)), - str(any(_aa is None for _aa in _a)), - ) + if all_attrs_names: + + attr_table = Table(box=box.SIMPLE, title='Attributes Summary') + attr_table.add_column('Attribute') + attr_table.add_column('Data type') + attr_table.add_column('#Unique values') + attr_table.add_column('Has empty value') + + all_attrs_values = self._get_attributes(*all_attrs_names) + if len(all_attrs_names) == 1: + all_attrs_values = [all_attrs_values] + for _a, _a_name in zip(all_attrs_values, all_attrs_names): + try: + _a = set(_a) + except: + pass # intentional ignore as some fields are not hashable + _set_type_a = set(type(_aa).__name__ for _aa in _a) + attr_table.add_row( + _a_name, + str(tuple(_set_type_a)), + str(len(_a)), + str(any(_aa is None for _aa in _a)), + ) + tables.append(attr_table) storage_table = Table(box=box.SIMPLE, title='Storage Summary') - self._fill_storage_table(storage_table) + storage_table.show_header = False + storage_infos = self._get_storage_infos() + for k, v in storage_infos.items(): + storage_table.add_row(k, v) + + tables.append(storage_table) - console.print(table, attr_table, storage_table) + console.print(*tables) def plot_embeddings( self, diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 4ff70d67fb4..29f3d20caff 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -1,8 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from rich.table import Table +from typing import Dict class BaseBackendMixin(ABC): @@ -10,6 +7,5 @@ class BaseBackendMixin(ABC): def _init_storage(self, *args, **kwargs): ... - def _fill_storage_table(self, table: 'Table'): - table.show_header = False - table.add_row('Class', self.__class__.__name__) + def _get_storage_infos(self) -> Dict: + return {'Class': self.__class__.__name__} diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py index 42c2cc7005f..6e921ae43a0 100644 --- a/docarray/array/storage/memory/backend.py +++ b/docarray/array/storage/memory/backend.py @@ -17,7 +17,6 @@ from ....types import ( DocumentArraySourceType, ) - from rich.table import Table def needs_id2offset_rebuild(func) -> Callable: @@ -88,6 +87,7 @@ def _init_storage( else: self.append(_docs) - def _fill_storage_table(self, table: 'Table'): - super()._fill_storage_table(table) - table.add_row('Backend', 'In Memory') + def _get_storage_infos(self) -> Dict: + storage_infos = super()._get_storage_infos() + storage_infos['Backend'] = 'In Memory' + return storage_infos diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index b97b8719bee..4332467c964 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -21,7 +21,6 @@ from ....types import ( DocumentArraySourceType, ) - from rich.table import Table def _sanitize_table_name(table_name: str) -> str: @@ -138,11 +137,12 @@ def __setstate__(self, state): **_conn_kwargs, ) - def _fill_storage_table(self, table: 'Table'): - super()._fill_storage_table(table) - table.add_row('Backend', 'SQLite (https://www.sqlite.org)') - table.add_row('Connection', self._config.connection) - table.add_row('Table Name', self._table_name) - table.add_row( - 'Serialization Protocol', self._config.serialize_config.get('protocol') - ) + def _get_storage_infos(self) -> Dict: + storage_infos = super()._get_storage_infos() + return { + 'Backend': 'SQLite (https://www.sqlite.org)', + 'Connection': self._config.connection, + 'Table Name': self._table_name, + 'Serialization Protocol': self._config.serialize_config.get('protocol'), + **storage_infos, + } diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index 4195f062941..3618bbc2cf1 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -23,7 +23,6 @@ from ....types import ( DocumentArraySourceType, ) - from rich.table import Table @dataclass @@ -306,11 +305,12 @@ def wmap(self, doc_id: str): # daw2[0, 'text'] == 'hi' # this will be False if we don't append class name return str(uuid.uuid5(uuid.NAMESPACE_URL, doc_id + self._class_name)) - def _fill_storage_table(self, table: 'Table'): - super()._fill_storage_table(table) - table.add_row('Backend', 'Weaviate (www.semi.technology/developers/weaviate)') - table.add_row('Hostname', self._config.client) - table.add_row('Schema Name', self._config.name) - table.add_row( - 'Serialization Protocol', self._config.serialize_config.get('protocol') - ) + def _get_storage_infos(self) -> Dict: + storage_infos = super()._get_storage_infos() + return { + 'Backend': 'Weaviate (www.semi.technology/developers/weaviate)', + 'Hostname': self._config.client, + 'Schema Name': self._config.name, + 'Serialization Protocol': self._config.serialize_config.get('protocol'), + **storage_infos, + }