diff --git a/docarray/array/document.py b/docarray/array/document.py index 58b1abf2de4..be70a655bb9 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -151,7 +151,7 @@ def __getitem__( if isinstance(_attrs, str): _attrs = (index[1],) - return _docs.get_attributes(*_attrs) + return _docs._get_attributes(*_attrs) elif isinstance(index[0], bool): return DocumentArray(itertools.compress(self._data, index)) elif isinstance(index[0], int): diff --git a/docarray/array/mixins/getattr.py b/docarray/array/mixins/getattr.py index 5a23549bf6f..598e20ab6a0 100644 --- a/docarray/array/mixins/getattr.py +++ b/docarray/array/mixins/getattr.py @@ -4,7 +4,7 @@ class GetAttributeMixin: """Helpers that provide attributes getter in bulk """ - def get_attributes(self, *fields: str) -> List: + def _get_attributes(self, *fields: str) -> List: """Return all nonempty values of the fields from all docs this array contains :param fields: Variable length argument with the name of the fields to extract @@ -25,7 +25,7 @@ def get_attributes(self, *fields: str) -> List: fields.remove('blob') if fields: - contents = [doc.get_attributes(*fields) for doc in self] + contents = [doc._get_attributes(*fields) for doc in self] if len(fields) > 1: contents = list(map(list, zip(*contents))) if b_index is None and e_index is None: diff --git a/docarray/array/mixins/plot.py b/docarray/array/mixins/plot.py index bde1bde1459..93425b96167 100644 --- a/docarray/array/mixins/plot.py +++ b/docarray/array/mixins/plot.py @@ -26,7 +26,7 @@ def summary(self): from rich.console import Console from rich import box - all_attrs = self.get_attributes('non_empty_fields') + all_attrs = self._get_attributes('non_empty_fields') attr_counter = Counter(all_attrs) table = Table(box=box.SIMPLE, title='Documents Summary') @@ -74,7 +74,7 @@ def summary(self): attr_table.add_column('#Unique values') attr_table.add_column('Has empty value') - all_attrs_values = self.get_attributes(*all_attrs_names) + all_attrs_values = self._get_attributes(*all_attrs_names) if len(all_attrs_names) == 1: all_attrs_values = [all_attrs_values] for _a, _a_name in zip(all_attrs_values, all_attrs_names): diff --git a/docarray/document/mixins/attribute.py b/docarray/document/mixins/attribute.py index ce265c797c2..ec9a703c79d 100644 --- a/docarray/document/mixins/attribute.py +++ b/docarray/document/mixins/attribute.py @@ -7,7 +7,7 @@ class GetAttributesMixin: """Provide helper functions for :class:`Document` to allow advanced set and get attributes """ - def get_attributes(self, *fields: str) -> Union[Any, List[Any]]: + def _get_attributes(self, *fields: str) -> Union[Any, List[Any]]: """Bulk fetch Document fields and return a list of the values of these fields :param fields: the variable length values to extract from the document diff --git a/tests/unit/array/mixins/test_getset.py b/tests/unit/array/mixins/test_getset.py index ab48a661b35..1a747f041c5 100644 --- a/tests/unit/array/mixins/test_getset.py +++ b/tests/unit/array/mixins/test_getset.py @@ -34,7 +34,8 @@ def test_set_embeddings_multi_kind(array): @pytest.mark.parametrize('da', da_and_dam()) def test_da_get_embeddings(da): - np.testing.assert_almost_equal(da.get_attributes('embedding'), da.embeddings) + np.testing.assert_almost_equal(da._get_attributes('embedding'), da.embeddings) + np.testing.assert_almost_equal(da[:, 'embedding'], da.embeddings) @pytest.mark.parametrize('da', da_and_dam()) @@ -65,7 +66,6 @@ def test_blobs_getter_da(da): blobs = np.random.random((100, 10, 10)) da.blobs = blobs assert len(da) == 100 - np.testing.assert_almost_equal(da.get_attributes('blob'), da.blobs) np.testing.assert_almost_equal(da.blobs, blobs) da.blobs = None @@ -77,7 +77,7 @@ def test_blobs_getter_da(da): @pytest.mark.parametrize('da', da_and_dam()) def test_texts_getter_da(da): assert len(da.texts) == 100 - assert da.texts == da.get_attributes('text') + assert da.texts == da[:, 'text'] texts = ['text' for _ in range(100)] da.texts = texts assert da.texts == texts diff --git a/tests/unit/array/mixins/test_group.py b/tests/unit/array/mixins/test_group.py index 77e1650ef47..90316f3dfe4 100644 --- a/tests/unit/array/mixins/test_group.py +++ b/tests/unit/array/mixins/test_group.py @@ -82,9 +82,9 @@ def test_batching(da, batch_size, shuffle): all_ids = [] for v in da.batch(batch_size=batch_size, shuffle=shuffle): assert len(v) <= batch_size - all_ids.extend(v.get_attributes('id')) + all_ids.extend(v[:, 'id']) if shuffle: - assert all_ids != da.get_attributes('id') + assert all_ids != da[:, 'id'] else: - assert all_ids == da.get_attributes('id') + assert all_ids == da[:, 'id'] diff --git a/tests/unit/array/mixins/test_match.py b/tests/unit/array/mixins/test_match.py index 7a708954e20..43592259593 100644 --- a/tests/unit/array/mixins/test_match.py +++ b/tests/unit/array/mixins/test_match.py @@ -85,7 +85,7 @@ def test_matching_retrieves_correct_number( D1.match( D2, metric='sqeuclidean', limit=limit, batch_size=batch_size, only_id=only_id ) - for m in D1.get_attributes('matches'): + for m in D1[:, 'matches']: if limit is None: assert len(m) == len(D2) else: @@ -106,14 +106,14 @@ def test_matching_same_results_with_sparse( # use match with numpy arrays D1.match(D2, metric=metric, only_id=only_id) distances = [] - for m in D1.get_attributes('matches'): + for m in D1[:, 'matches']: for d in m: distances.extend([d.scores[metric].value]) # use match with sparse arrays D1_sp.match(D2_sp, metric=metric, is_sparse=True) distances_sparse = [] - for m in D1.get_attributes('matches'): + for m in D1[:, 'matches']: for d in m: distances_sparse.extend([d.scores[metric].value]) @@ -132,7 +132,7 @@ def test_matching_same_results_with_batch( # use match without batches D1.match(D2, metric=metric, only_id=only_id) distances = [] - for m in D1.get_attributes('matches'): + for m in D1[:, 'matches']: for d in m: distances.extend([d.scores[metric].value]) @@ -140,7 +140,7 @@ def test_matching_same_results_with_batch( D1_batch.match(D2_batch, metric=metric, batch_size=10) distances_batch = [] - for m in D1.get_attributes('matches'): + for m in D1[:, 'matches']: for d in m: distances_batch.extend([d.scores[metric].value]) @@ -161,14 +161,14 @@ def scipy_cdist_metric(X, Y, *args): # match with our custom metric D1.match(D2, metric=metric) distances = [] - for m in D1.get_attributes('matches'): + for m in D1[:, 'matches']: for d in m: distances.extend([d.scores[metric].value]) # match with callable cdist function from scipy D1_scipy.match(D2, metric=scipy_cdist_metric, only_id=only_id) distances_scipy = [] - for m in D1.get_attributes('matches'): + for m in D1[:, 'matches']: for d in m: distances_scipy.extend([d.scores[metric].value]) diff --git a/tests/unit/array/mixins/test_plot.py b/tests/unit/array/mixins/test_plot.py index bd99c6c0289..1ec701244b8 100644 --- a/tests/unit/array/mixins/test_plot.py +++ b/tests/unit/array/mixins/test_plot.py @@ -59,7 +59,7 @@ def test_plot_embeddings_same_path(tmpdir): def test_summary_homo_hetero(): da = DocumentArray.empty(100) - da.get_attributes() + da._get_attributes() da.summary() da[0].pop('id') @@ -69,4 +69,4 @@ def test_summary_homo_hetero(): def test_empty_get_attributes(): da = DocumentArray.empty(10) da[0].pop('id') - print(da.get_attributes('id')) + print(da[:, 'id']) diff --git a/tests/unit/array/mixins/test_traverse.py b/tests/unit/array/mixins/test_traverse.py index 9094e77d3ac..627b7de1db0 100644 --- a/tests/unit/array/mixins/test_traverse.py +++ b/tests/unit/array/mixins/test_traverse.py @@ -88,7 +88,7 @@ def test_traverse_root_match_chunk(doc_req, filter_fn): @pytest.mark.parametrize('filter_fn', [(lambda d: True), None]) def test_traverse_flatten_embedding(doc_req, filter_fn): flattened_results = doc_req.traverse_flat('r,c', filter_fn=filter_fn) - ds = np.stack(flattened_results.get_attributes('embedding')) + ds = flattened_results.embeddings assert ds.shape == (num_docs + num_chunks_per_doc * num_docs, 10) @@ -137,10 +137,10 @@ def test_traverse_flatten_root_match_chunk(doc_req, filter_fn): @pytest.mark.parametrize('filter_fn', [(lambda d: True), None]) def test_traverse_flattened_per_path_embedding(doc_req, filter_fn): flattened_results = list(doc_req.traverse_flat_per_path('r,c', filter_fn=filter_fn)) - ds = np.stack(flattened_results[0].get_attributes('embedding')) + ds = flattened_results[0].embeddings assert ds.shape == (num_docs, 10) - ds = np.stack(flattened_results[1].get_attributes('embedding')) + ds = flattened_results[1].embeddings assert ds.shape == (num_docs * num_chunks_per_doc, 10) diff --git a/tests/unit/document/test_docdata.py b/tests/unit/document/test_docdata.py index 8750869f71c..ff39b24bb68 100644 --- a/tests/unit/document/test_docdata.py +++ b/tests/unit/document/test_docdata.py @@ -164,7 +164,7 @@ def test_get_attr_values(): 'tags__id', 'tags__e__2__f', ] - res = d.get_attributes(*required_keys) + res = d._get_attributes(*required_keys) assert len(res) == len(required_keys) assert res[required_keys.index('id')] == '123' assert res[required_keys.index('tags__feature1')] == 121 @@ -176,18 +176,18 @@ def test_get_attr_values(): assert res[required_keys.index('tags__e__2__f')] == 'g' required_keys_2 = ['tags', 'text'] - res2 = d.get_attributes(*required_keys_2) + res2 = d._get_attributes(*required_keys_2) assert len(res2) == 2 assert res2[required_keys_2.index('text')] == 'document' assert res2[required_keys_2.index('tags')] == d.tags d = Document({'id': '123', 'tags': {'outterkey': {'innerkey': 'real_value'}}}) required_keys_3 = ['tags__outterkey__innerkey'] - res3 = d.get_attributes(*required_keys_3) + res3 = d._get_attributes(*required_keys_3) assert res3 == 'real_value' d = Document(content=np.array([1, 2, 3])) - res4 = np.stack(d.get_attributes(*['blob'])) + res4 = np.stack(d._get_attributes(*['blob'])) np.testing.assert_equal(res4, np.array([1, 2, 3]))