Skip to content
11 changes: 9 additions & 2 deletions docarray/array/mixins/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def embed_and_evaluate(
normalization: Optional[Tuple[float, float]] = None,
exclude_self: bool = False,
use_scipy: bool = False,
num_worker: int = 1,
match_batch_size: int = 100_000,
query_sample_size: int = 1_000,
**kwargs,
Expand Down Expand Up @@ -263,8 +264,8 @@ def embed_and_evaluate(
Paddle for embedding `self` and `index_data`.
:param embed_funcs: As an alternative to embedding models, custom embedding
functions can be provided.
:param device: the computational device for `embed_models`, can be either
`cpu` or `cuda`.
:param device: the computational device for `embed_models`, and the matching
can be either `cpu` or `cuda`.
:param batch_size: Number of documents in a batch for embedding.
:param collate_fns: For each embedding function the respective collate
function creates a mini-batch of input(s) from the given `DocumentArray`.
Expand All @@ -279,6 +280,8 @@ def embed_and_evaluate(
as the left-hand values will not be considered as matches.
:param use_scipy: if set, use ``scipy`` as the computation backend. Note,
``scipy`` does not support distance on sparse matrix.
:param num_worker: Specifies the number of workers for the execution of the
match function.
:parma match_batch_size: The number of documents which are embedded and
matched at once. Set this value to a lower value, if you experience high
memory consumption.
Expand Down Expand Up @@ -413,6 +416,9 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray):
normalization=normalization,
exclude_self=exclude_self,
use_scipy=use_scipy,
num_worker=num_worker,
device=device,
batch_size=int(len(batch) / num_worker) if num_worker > 1 else None,
only_id=True,
)

Expand All @@ -437,6 +443,7 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray):
metrics=metrics,
metric_names=metric_names,
strict=strict,
label_tag=label_tag,
**kwargs,
)

Expand Down
23 changes: 15 additions & 8 deletions tests/unit/array/mixins/oldproto/test_eval_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar
assert d.evaluations[metric_fn].value == 0.0


@pytest.mark.parametrize('label_tag', ['label', 'custom_tag'])
@pytest.mark.parametrize(
'metric_fn, metric_score',
[
Expand All @@ -184,11 +185,11 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar
('dcg_at_k', (1.0 + 1.0 + 0.6309) / 3),
],
)
def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score):
da = DocumentArray([Document(text=str(i), tags={'label': i}) for i in range(3)])
def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score, label_tag):
da = DocumentArray([Document(text=str(i), tags={label_tag: i}) for i in range(3)])
for d in da:
d.matches = da
r = da.evaluate([metric_fn])[metric_fn]
r = da.evaluate([metric_fn], label_tag=label_tag)[metric_fn]
assert abs(r - metric_score) < 0.001


Expand Down Expand Up @@ -522,10 +523,14 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage)


@pytest.mark.parametrize(
'use_index, expected',
'use_index, expected, label_tag',
[
(False, {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 1.0}),
(True, {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 11.0 / 18.0}),
(False, {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 1.0}, 'label'),
(
True,
{'precision_at_k': 1.0 / 3, 'reciprocal_rank': 11.0 / 18.0},
'custom_tag',
),
],
)
@pytest.mark.parametrize(
Expand All @@ -541,15 +546,15 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage)
],
)
def test_embed_and_evaluate_labeled_dataset(
storage, config, start_storage, use_index, expected
storage, config, start_storage, use_index, expected, label_tag
):
metric_fns = list(expected.keys())

def emb_func(da):
np.random.seed(0) # makes sure that embeddings are always equal
da[:, 'embedding'] = np.random.random((len(da), 5))

da1 = DocumentArray([Document(text=str(i), tags={'label': i}) for i in range(3)])
da1 = DocumentArray([Document(text=str(i), tags={label_tag: i}) for i in range(3)])
da2 = DocumentArray(da1, storage=storage, config=config, copy=True)

if (
Expand All @@ -561,13 +566,15 @@ def emb_func(da):
embed_funcs=emb_func,
match_batch_size=1,
limit=3,
label_tag=label_tag,
)
else: # query and index are the same (embeddings of both das are equal)
res = da2.embed_and_evaluate(
metrics=metric_fns,
embed_funcs=emb_func,
match_batch_size=1,
limit=3,
label_tag=label_tag,
)
for key in metric_fns:
assert key in res
Expand Down