diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index db9c01c6dad..2de9433fb33 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -1,10 +1,10 @@ import warnings -from typing import Optional, Union, TYPE_CHECKING, Callable, List, Dict, Tuple +from typing import Optional, Union, TYPE_CHECKING, Callable, List, Dict, Tuple, Any from functools import wraps import numpy as np -from collections import defaultdict +from collections import defaultdict, Counter from docarray.score import NamedScore @@ -79,6 +79,7 @@ def evaluate( metric_names: Optional[List[str]] = None, strict: bool = True, label_tag: str = 'label', + num_relevant_documents_per_label: Optional[Dict[Any, int]] = None, **kwargs, ) -> Dict[str, float]: """ @@ -109,6 +110,10 @@ def evaluate( aligned: on the length, and on the semantic of length. These are preventing you to evaluate on irrelevant matches accidentally. :param label_tag: Specifies the tag which contains the labels. + :param num_relevant_documents_per_label: Some metrics, e.g., recall@k, require + the number of relevant documents. To apply those to a labeled dataset, one + can provide a dictionary which maps labels to the total number of documents + with this label. :param kwargs: Additional keyword arguments to be passed to the metric functions. :return: A dictionary which stores for each metric name the average evaluation @@ -161,7 +166,22 @@ def evaluate( results = defaultdict(list) caller_max_rel = kwargs.pop('max_rel', None) for d, gd in zip(self, ground_truth): - max_rel = caller_max_rel or len(gd.matches) + if caller_max_rel: + max_rel = caller_max_rel + elif ground_truth_type == 'labels': + if num_relevant_documents_per_label: + max_rel = num_relevant_documents_per_label.get( + d.tags[label_tag], None + ) + if max_rel is None: + raise ValueError( + '`num_relevant_documents_per_label` misses the label ' + + str(d.tags[label_tag]) + ) + else: + max_rel = None + else: + max_rel = len(gd.matches) if strict and hash_fn(d) != hash_fn(gd): raise ValueError( f'Document {d} from the left-hand side and ' @@ -174,7 +194,7 @@ def evaluate( f'Document {d!r} or {gd!r} has no matches, please check your Document' ) - targets = gd.matches[:max_rel] + targets = gd.matches if ground_truth_type == 'matches': desired = {hash_fn(m) for m in targets} @@ -438,12 +458,24 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): new_matches.append(m) query_data[doc.id, 'matches'] = new_matches + if ground_truth and label_tag in ground_truth[0].tags: + num_relevant_documents_per_label = dict( + Counter([d.tags[label_tag] for d in ground_truth]) + ) + elif not ground_truth and label_tag in query_data[0].tags: + num_relevant_documents_per_label = dict( + Counter([d.tags[label_tag] for d in query_data]) + ) + else: + num_relevant_documents_per_label = None + metrics_resp = query_data.evaluate( ground_truth=ground_truth, metrics=metrics, metric_names=metric_names, strict=strict, label_tag=label_tag, + num_relevant_documents_per_label=num_relevant_documents_per_label, **kwargs, ) diff --git a/docarray/math/evaluation.py b/docarray/math/evaluation.py index d05a0642a43..eb2e0ef16fc 100644 --- a/docarray/math/evaluation.py +++ b/docarray/math/evaluation.py @@ -98,6 +98,8 @@ def recall_at_k( """ _check_k(k) binary_relevance = np.array(binary_relevance[:k]) != 0 + if max_rel is None: + raise ValueError('The metric recall_at_k requires a max_rel parameter') if np.sum(binary_relevance) > max_rel: raise ValueError(f'Number of relevant Documents retrieved > {max_rel}') return np.sum(binary_relevance) / max_rel diff --git a/docs/fundamentals/documentarray/evaluation.md b/docs/fundamentals/documentarray/evaluation.md index 7ad289ca2ff..a0a45fbf14b 100644 --- a/docs/fundamentals/documentarray/evaluation.md +++ b/docs/fundamentals/documentarray/evaluation.md @@ -69,10 +69,10 @@ da_prediction['@m'].summary() To evaluate the matches against a ground truth array, you simply provide a DocumentArray to the evaluate function like `da_groundtruth` in the call below: ```python -da_predict.evaluate(ground_truth=da_groundtruth, metrics=['...'], **kwargs) +da_prediction.evaluate(ground_truth=da_groundtruth, metrics=['...'], **kwargs) ``` -Thereby, `da_groundtruth` should contain the same documents as in `da_prediction` where each `matches` attribute contains exactly those documents which are relevant to the respective root document. +Thereby, `da_groundtruth` should contain the same Documents as in `da_prediction` where each `matches` attribute contains exactly those Documents which are relevant to the respective root Document. The `metrics` argument determines the metric you want to use for your evaluation, e.g., `precision_at_k`. In the code cell below, we evaluate the array `da_prediction` with the noisy matches against the original one `da_original`: @@ -111,7 +111,8 @@ for d in da_prediction: Note that the evaluation against a ground truth DocumentArray only works if both DocumentArrays have the same length and their nested structure is the same. It makes no sense to evaluate with a completely different DocumentArray. -While evaluating, Document pairs are recognized as correct if they share the same identifier. By default, it simply uses {attr}`~docarray.Document.id`. One can customize this behavior by specifying `hash_fn`. +While evaluating, Document pairs are recognized as correct if they share the same identifier. By default, it simply uses {attr}`~docarray.Document.id`. +You can customize this behavior by specifying `hash_fn`. Let's see an example by creating two DocumentArrays with some matches with identical texts. @@ -157,8 +158,8 @@ It is correct as we define the evaluation as checking if the first two character ## Evaluation via labels -Alternatively, you can add labels to your documents to evaluate them. -In this case, a match is considered relevant to its root document if it has the same label: +Alternatively, you can add labels to your Documents to evaluate them. +In this case, a match is considered relevant to its root Document if it has the same label: ```python import numpy as np @@ -198,7 +199,7 @@ Some of those metrics accept additional arguments as `kwargs` which you can simp ```{danger} These metric scores might change if the `limit` argument of the match function is set differently. -**Note:** Not all of these metrics can be applied to a Top-K result, i.e., `ndcg_at_k` and `r_precision` are calculated correctly only if the limit is set equal or higher than the number of documents in the `DocumentArray` provided to the match function. +**Note:** Not all of these metrics can be applied to a Top-K result, i.e., `ndcg_at_k` and `r_precision` are calculated correctly only if the limit is set equal or higher than the number of Documents in the `DocumentArray` provided to the match function. ``` You can evaluate multiple metric functions at once, as you can see below: @@ -215,13 +216,57 @@ da_prediction.evaluate( In this case, the keyword argument `k` is passed to all metric functions, even though it does not fulfill any specific function for the calculation of the reciprocal rank. +### The max_rel parameter + +Some metric functions shown in the table above require a `max_rel` parameter. +This parameter should be set to the number of relevant Documents in the Document collection. +Without the knowledge of this number, metrics like `recall_at_k` and `f1_score_at_k` cannot be calculated. + +In the `evaluate` function, you can provide a keyword argument `max_rel`, which is then used for all queries. +In the example below, we can use the datasets `da_prediction` and `da_original` from the beginning, where each query has nine relevant Documents. +Therefore, we set `max_rel=9`. + +```python +da_prediction.evaluate(ground_truth=da_original, metrics=['recall_at_k'], max_rel=9) +``` + +```text +{'recall_at_k': 1.0} +``` + +Since all relevant Documents are in the matches, the recall is one. +However, this only makes sense if the number of relevant Documents is equal for each query. +If you provide a `ground_truth` parameter to the `evaluate` function, `max_rel` is set to the number of matches of the query Document. + +```python +da_prediction.evaluate(ground_truth=da_original, metrics=['recall_at_k']) +``` +```text +{'recall_at_k': 1.0} +``` + +For labeled datasets, this is not possible. +Here, you can set the `num_relevant_documents_per_label` parameter of `evaluate`. +It accepts a dictionary that contains the number of relevant Documents for each label. +In this way, the function can set `max_rel` to the correct value for each query Document. + +```python +example_da.evaluate( + metrics=['recall_at_k'], num_relevant_documents_per_label={0: 5, 1: 5} +) +``` + +```text +{'recall_at_k': 1.0} +``` + ### Custom metrics If the pre-defined metrics do not fit your use-case, you can define a custom metric function. It should take as input a list of binary relevance judgements of a query (`1` and `0` values). The evaluate function already calculates this binary list from the `matches` attribute so that each number represents the relevancy of a match. -Let's write a custom metric function, which counts the number of relevant documents per query: +Let's write a custom metric function, which counts the number of relevant Documents per query: ```python def count_relevant(binary_relevance): @@ -282,20 +327,22 @@ print(result) {'reciprocal_rank': 0.7583333333333333} ``` +For metric functions which require a `max_rel` parameter, the `embed_and_evaluate` function (described later in this section) automatically constructs the dictionary for `num_relevant_documents_per_label` based on the `index_data` argument. + ### Batch-wise matching -The ``embed_and_evaluate`` function is especially useful, when you need to evaluate the queries on a very large document collection (`example_index` in the code snippet above), which is too large to store the embeddings of all documents in main-memory. -In this case, ``embed_and_evaluate`` matches the queries to batches of the document collection. +The ``embed_and_evaluate`` function is especially useful, when you need to evaluate the queries on a very large Document collection (`example_index` in the code snippet above), which is too large to store the embeddings of all Documents in main-memory. +In this case, ``embed_and_evaluate`` matches the queries to batches of the Document collection. After the batch is processed all embeddings are deleted. By default, the batch size for the matching (`match_batch_size`) is set to `100_000`. If you want to reduce the memory footprint, you can set it to a lower value. ### Sampling Queries -If you want to evaluate a large dataset, it might be useful to sample query documents. +If you want to evaluate a large dataset, it might be useful to sample query Documents. Since the metric values returned by the `embed_and_evaluate` are mean values, sampling should not change the result significantly if the sample is large enough. -By default, sampling is applied for `DocumentArray` objects with more than 1,000 documents. -However, it is only applied on the `DocumentArray` itself and not on the document provided in `index_data`. +By default, sampling is applied for `DocumentArray` objects with more than 1,000 Documents. +However, it is only applied on the `DocumentArray` itself and not on the Documents provided in `index_data`. If you want to change the number of samples, you can ajust the `query_sample_size` argument. In the following code block an evaluation is done with 100 samples: @@ -323,7 +370,7 @@ da.embed_and_evaluate( {'precision_at_k': 0.13649999999999998} ``` -Please note that in this way only documents which are actually evaluated obtain an `.evaluations` attribute. +Please note that in this way only Documents which are actually evaluated obtain an `.evaluations` attribute. To test how close it is to the exact result, we execute the function again with `query_sample_size` set to 1,000: diff --git a/tests/unit/array/mixins/oldproto/test_eval_class.py b/tests/unit/array/mixins/oldproto/test_eval_class.py index 32b6bc2ceef..78addbefb82 100644 --- a/tests/unit/array/mixins/oldproto/test_eval_class.py +++ b/tests/unit/array/mixins/oldproto/test_eval_class.py @@ -192,10 +192,43 @@ def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score, label_tag): da = DocumentArray([Document(text=str(i), tags={label_tag: i}) for i in range(3)]) for d in da: d.matches = da - r = da.evaluate([metric_fn], label_tag=label_tag)[metric_fn] + r = da.evaluate([metric_fn], label_tag=label_tag, max_rel=3)[metric_fn] assert abs(r - metric_score) < 0.001 +@pytest.mark.parametrize('label_tag', ['label', 'custom_tag']) +@pytest.mark.parametrize( + 'metric_fn, metric_score', + [ + ('recall_at_k', 1.0), + ('f1_score_at_k', 0.5), + ], +) +def test_num_relevant_documents_per_label(metric_fn, metric_score, label_tag): + da = DocumentArray([Document(text=str(i), tags={label_tag: i}) for i in range(3)]) + num_relevant_documents_per_label = {i: 1 for i in range(3)} + for d in da: + d.matches = da + r = da.evaluate( + [metric_fn], + label_tag=label_tag, + num_relevant_documents_per_label=num_relevant_documents_per_label, + )[metric_fn] + assert abs(r - metric_score) < 0.001 + + +def test_missing_max_rel_should_raise(): + da = DocumentArray([Document(text=str(i), tags={'label': i}) for i in range(3)]) + num_relevant_documents_per_label = {i: 1 for i in range(2)} + for d in da: + d.matches = da + with pytest.raises(ValueError): + da.evaluate( + ['recall_at_k'], + num_relevant_documents_per_label=num_relevant_documents_per_label, + ) + + @pytest.mark.parametrize( 'storage, config', [ @@ -540,7 +573,11 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage) (False, {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 1.0}, 'label'), ( True, - {'precision_at_k': 1.0 / 3, 'reciprocal_rank': 11.0 / 18.0}, + { + 'precision_at_k': 1.0 / 3, + 'reciprocal_rank': 11.0 / 18.0, + 'recall_at_k': 1.0, + }, 'custom_tag', ), ],