diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 4c21b88109d..db9c01c6dad 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -231,6 +231,7 @@ def embed_and_evaluate( normalization: Optional[Tuple[float, float]] = None, exclude_self: bool = False, use_scipy: bool = False, + num_worker: int = 1, match_batch_size: int = 100_000, query_sample_size: int = 1_000, **kwargs, @@ -263,8 +264,8 @@ def embed_and_evaluate( Paddle for embedding `self` and `index_data`. :param embed_funcs: As an alternative to embedding models, custom embedding functions can be provided. - :param device: the computational device for `embed_models`, can be either - `cpu` or `cuda`. + :param device: the computational device for `embed_models`, and the matching + can be either `cpu` or `cuda`. :param batch_size: Number of documents in a batch for embedding. :param collate_fns: For each embedding function the respective collate function creates a mini-batch of input(s) from the given `DocumentArray`. @@ -279,6 +280,8 @@ def embed_and_evaluate( as the left-hand values will not be considered as matches. :param use_scipy: if set, use ``scipy`` as the computation backend. Note, ``scipy`` does not support distance on sparse matrix. + :param num_worker: Specifies the number of workers for the execution of the + match function. :parma match_batch_size: The number of documents which are embedded and matched at once. Set this value to a lower value, if you experience high memory consumption. @@ -413,6 +416,9 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): normalization=normalization, exclude_self=exclude_self, use_scipy=use_scipy, + num_worker=num_worker, + device=device, + batch_size=int(len(batch) / num_worker) if num_worker > 1 else None, only_id=True, ) @@ -437,6 +443,7 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): metrics=metrics, metric_names=metric_names, strict=strict, + label_tag=label_tag, **kwargs, ) diff --git a/tests/unit/array/mixins/oldproto/test_eval_class.py b/tests/unit/array/mixins/oldproto/test_eval_class.py index a533d1bb228..3eb0b79a3a2 100644 --- a/tests/unit/array/mixins/oldproto/test_eval_class.py +++ b/tests/unit/array/mixins/oldproto/test_eval_class.py @@ -171,6 +171,7 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar assert d.evaluations[metric_fn].value == 0.0 +@pytest.mark.parametrize('label_tag', ['label', 'custom_tag']) @pytest.mark.parametrize( 'metric_fn, metric_score', [ @@ -184,11 +185,11 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar ('dcg_at_k', (1.0 + 1.0 + 0.6309) / 3), ], ) -def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score): - da = DocumentArray([Document(text=str(i), tags={'label': i}) for i in range(3)]) +def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score, label_tag): + da = DocumentArray([Document(text=str(i), tags={label_tag: i}) for i in range(3)]) for d in da: d.matches = da - r = da.evaluate([metric_fn])[metric_fn] + r = da.evaluate([metric_fn], label_tag=label_tag)[metric_fn] assert abs(r - metric_score) < 0.001 @@ -522,10 +523,14 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage) @pytest.mark.parametrize( - 'use_index, expected', + 'use_index, expected, label_tag', [ - (False, {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 1.0}), - (True, {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 11.0 / 18.0}), + (False, {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 1.0}, 'label'), + ( + True, + {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 11.0 / 18.0}, + 'custom_tag', + ), ], ) @pytest.mark.parametrize( @@ -541,7 +546,7 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage) ], ) def test_embed_and_evaluate_labeled_dataset( - storage, config, start_storage, use_index, expected + storage, config, start_storage, use_index, expected, label_tag ): metric_fns = list(expected.keys()) @@ -549,7 +554,7 @@ def emb_func(da): np.random.seed(0) # makes sure that embeddings are always equal da[:, 'embedding'] = np.random.random((len(da), 5)) - da1 = DocumentArray([Document(text=str(i), tags={'label': i}) for i in range(3)]) + da1 = DocumentArray([Document(text=str(i), tags={label_tag: i}) for i in range(3)]) da2 = DocumentArray(da1, storage=storage, config=config, copy=True) if ( @@ -561,6 +566,7 @@ def emb_func(da): embed_funcs=emb_func, match_batch_size=1, limit=3, + label_tag=label_tag, ) else: # query and index are the same (embeddings of both das are equal) res = da2.embed_and_evaluate( @@ -568,6 +574,7 @@ def emb_func(da): embed_funcs=emb_func, match_batch_size=1, limit=3, + label_tag=label_tag, ) for key in metric_fns: assert key in res