From 94d17c916dabd35d219e0bcffd2f6f435b9b642d Mon Sep 17 00:00:00 2001 From: guenthermi Date: Wed, 12 Oct 2022 17:13:52 +0200 Subject: [PATCH 01/22] feat: add support for labeled dataset to evaluate function --- docarray/array/mixins/evaluation.py | 80 ++++++++--- docs/fundamentals/documentarray/evaluation.md | 8 +- tests/unit/array/mixins/test_eval_class.py | 127 ++++++++++++++++-- 3 files changed, 184 insertions(+), 31 deletions(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index c7107832cfb..be730afee00 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -14,31 +14,64 @@ class EvaluationMixin: def evaluate( self, - other: 'DocumentArray', metric: Union[str, Callable[..., float]], + ground_truth: Optional['DocumentArray'] = None, hash_fn: Optional[Callable[['Document'], str]] = None, metric_name: Optional[str] = None, strict: bool = True, + label_tag='label', **kwargs, ) -> Optional[float]: - """Compute ranking evaluation metrics for a given `DocumentArray` when compared with a groundtruth. + """ + Compute ranking evaluation metrics for a given `DocumentArray` when compared + with a groundtruth. - This implementation expects to provide a `groundtruth` DocumentArray that is structurally identical to `self`. It is based - on comparing the `matches` of `documents` inside the `DocumentArray. + This implementation expects the documents and their matches to have labels + annotated inside the tag with the key specified in the `label_tag` attribute. + Alternatively, one can provide a `ground_truth` DocumentArray that is + structurally identical to `self`. In this case, this function compares the + `matches` of `documents` inside the `DocumentArray`. - This method will fill the `evaluations` field of Documents inside this `DocumentArray` and will return the average of the computations + This method will fill the `evaluations` field of Documents inside this + `DocumentArray` and will return the average of the computations - :param other: The groundtruth DocumentArray` that the `DocumentArray` compares to. :param metric: The name of the metric, or multiple metrics to be computed - :param hash_fn: The function used for identifying the uniqueness of Documents. If not given, then ``Document.id`` is used. - :param metric_name: If provided, the results of the metrics computation will be stored in the `evaluations` field of each Document. If not provided, the name will be computed based on the metrics name. - :param strict: If set, then left and right sides are required to be fully aligned: on the length, and on the semantic of length. These are preventing + :param ground_truth: The ground_truth `DocumentArray` that the `DocumentArray` + compares to. + :param hash_fn: The function used for identifying the uniqueness of Documents. + If not given, then ``Document.id`` is used. + :param metric_name: If provided, the results of the metrics computation will be + stored in the `evaluations` field of each Document. If not provided, the + name will be computed based on the metrics name. + :param strict: If set, then left and right sides are required to be fully + aligned: on the length, and on the semantic of length. These are preventing you to evaluate on irrelevant matches accidentally. :param kwargs: Additional keyword arguments to be passed to `metric_fn` - :return: The average evaluation computed or a list of them if multiple metrics are required + :return: The average evaluation computed or a list of them if multiple metrics + are required """ + if len(self) == 0: + raise ValueError('It is not possible to evaluate an empty DocumentArray') + if ground_truth and len(ground_truth) > 0 and ground_truth[0].matches: + ground_truth_type = 'matches' + elif label_tag in self[0].tags: + if ground_truth: + warnings.warn( + 'An ground_truth attribute is provided but does not ' + 'contain matches. The labels are used instead and ' + 'ground_truth is ignored.' + ) + ground_truth = self + ground_truth_type = 'labels' + + else: + raise RuntimeError( + 'Could not find proper ground truth data. Either labels or the ' + 'ground_truth attribute with matches is required' + ) + if strict: - self._check_length(len(other)) + self._check_length(len(ground_truth)) if hash_fn is None: hash_fn = lambda d: d.id @@ -53,7 +86,7 @@ def evaluate( metric_name = metric_name or metric_fn.__name__ results = [] caller_max_rel = kwargs.pop('max_rel', None) - for d, gd in zip(self, other): + for d, gd in zip(self, ground_truth): max_rel = caller_max_rel or len(gd.matches) if strict and hash_fn(d) != hash_fn(gd): raise ValueError( @@ -68,14 +101,23 @@ def evaluate( ) targets = gd.matches[:max_rel] - desired = {hash_fn(m) for m in targets} - if len(desired) != len(targets): - warnings.warn( - f'{hash_fn!r} may not be valid, as it maps multiple Documents into the same hash. ' - f'Evaluation results may be affected' - ) - binary_relevance = [1 if hash_fn(m) in desired else 0 for m in d.matches] + if ground_truth_type == 'matches': + desired = {hash_fn(m) for m in targets} + if len(desired) != len(targets): + warnings.warn( + f'{hash_fn!r} may not be valid, as it maps multiple Documents into the same hash. ' + f'Evaluation results may be affected' + ) + binary_relevance = [ + 1 if hash_fn(m) in desired else 0 for m in d.matches + ] + elif ground_truth_type == 'labels': + binary_relevance = [ + 1 if m.tags[label_tag] == d.tags[label_tag] else 0 for m in targets + ] + else: + raise RuntimeError(f'Unsupported groundtruth type {ground_truth_type}') r = metric_fn(binary_relevance, max_rel=max_rel, **kwargs) d.evaluations[metric_name] = NamedScore( diff --git a/docs/fundamentals/documentarray/evaluation.md b/docs/fundamentals/documentarray/evaluation.md index 65a9be846fa..82e89334bc6 100644 --- a/docs/fundamentals/documentarray/evaluation.md +++ b/docs/fundamentals/documentarray/evaluation.md @@ -3,7 +3,7 @@ After you get `.matches` from the last chapter, you can easily evaluate matches against the groundtruth via {meth}`~docarray.array.mixins.evaluation.EvaluationMixin.evaluate`. ```python -da_predict.evaluate(da_groundtruth, metric='...', **kwargs) +da_predict.evaluate(ground_truth=da_groundtruth, metric='...', **kwargs) ``` The results are stored in `.evaluations` field of each Document. @@ -91,7 +91,7 @@ da2['@m'].summary() Now `da2` is our prediction, and `da` is our groundtruth. If we evaluate the average Precision@10, we should get something close to 0.5 (we have 10 real matches, we mixed in 10 fake matches and shuffle it, so top-10 would have approximate 10/20 real matches): ```python -da2.evaluate(da, metric='precision_at_k', k=5) +da2.evaluate(ground_truth=da, metric='precision_at_k', k=5) ``` ```text @@ -142,7 +142,7 @@ for d in g_da: Now when you do evaluate, you will receive an error: ```python -p_da.evaluate(g_da, 'average_precision') +p_da.evaluate('average_precision', groundtruth=g_da) ``` ```text @@ -154,7 +154,7 @@ This basically saying that based on `.id` (default identifier), the given two Do If we override the hash function as following the evaluation can be conducted: ```python -p_da.evaluate(g_da, 'average_precision', hash_fn=lambda d: d.text[:2]) +p_da.evaluate('average_precision', ground_truth=g_da, hash_fn=lambda d: d.text[:2]) ``` ```text diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index c8431dcf6db..4c48bf95352 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -36,13 +36,120 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs, storage, config, start_stor da1.embeddings = np.random.random([10, 256]) da1_index = DocumentArray(da1, storage=storage, config=config) da1.match(da1_index, exclude_self=True) - r = da1.evaluate(da1, metric=metric_fn, strict=False, **kwargs) + r = da1.evaluate(ground_truth=da1, metric=metric_fn, strict=False, **kwargs) assert isinstance(r, float) assert r == 1.0 for d in da1: assert d.evaluations[metric_fn].value == 1.0 +@pytest.mark.parametrize( + 'storage, config', + [ + ('memory', {}), + ('weaviate', {}), + ('sqlite', {}), + ('annlite', {'n_dim': 256}), + ('qdrant', {'n_dim': 256}), + ('elasticsearch', {'n_dim': 256}), + ('redis', {'n_dim': 256}), + ], +) +@pytest.mark.parametrize( + 'metric_fn, kwargs', + [ + ('r_precision', {}), + ('precision_at_k', {}), + ('hit_at_k', {}), + ('average_precision', {}), + ('reciprocal_rank', {}), + ('recall_at_k', {'max_rel': 9}), + ('f1_score_at_k', {'max_rel': 9}), + ('ndcg_at_k', {}), + ], +) +def test_eval_mixin_perfect_match_labeled( + metric_fn, kwargs, storage, config, start_storage +): + da1 = DocumentArray.empty(10) + for d in da1: + d.tags = {'label': 'A'} + da1.embeddings = np.random.random([10, 256]) + da1_index = DocumentArray(da1, storage=storage, config=config) + for d in da1_index: + d.tags = {'label': 'A'} + da1.match(da1_index, exclude_self=True) + r = da1.evaluate(metric=metric_fn, **kwargs) + assert isinstance(r, float) + assert r == 1.0 + for d in da1: + assert d.evaluations[metric_fn].value == 1.0 + + +@pytest.mark.parametrize( + 'storage, config', + [ + ('memory', {}), + ('weaviate', {}), + ('sqlite', {}), + ('annlite', {'n_dim': 256}), + ('qdrant', {'n_dim': 256}), + ('elasticsearch', {'n_dim': 256}), + ('redis', {'n_dim': 256}), + ], +) +@pytest.mark.parametrize( + 'metric_fn, kwargs', + [ + ('r_precision', {}), + ('precision_at_k', {}), + ('hit_at_k', {}), + ('average_precision', {}), + ('reciprocal_rank', {}), + ('recall_at_k', {'max_rel': 9}), + ('f1_score_at_k', {'max_rel': 9}), + ('ndcg_at_k', {}), + ], +) +def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwargs): + da1 = DocumentArray.empty(10) + for d in da1: + d.tags = {'label': 'A'} + da1.embeddings = np.random.random([10, 256]) + da2 = copy.deepcopy(da1) + for d in da2: + d.tags = {'label': 'B'} + da1_index = DocumentArray(da2, storage=storage, config=config) + da1.match(da1_index, exclude_self=True) + r = da1.evaluate(metric_fn, **kwargs) + assert isinstance(r, float) + assert r == 0.0 + for d in da1: + d: Document + assert d.evaluations[metric_fn].value == 0.0 + + +@pytest.mark.parametrize( + 'metric_fn, metric_score', + [ + ('r_precision', 0.6111111), # might be changed + ('precision_at_k', 1.0 / 3), + ('hit_at_k', 1.0), + ('average_precision', (1.0 + 0.5 + (1.0 / 3)) / 3), + ('reciprocal_rank', (1.0 + 0.5 + (1.0 / 3)) / 3), + ('recall_at_k', 1.0 / 3), + ('f1_score_at_k', 1.0 / 3), + ('dcg_at_k', (1.0 + 1.0 + 0.6309) / 3), + ], +) +def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score): + da = DocumentArray([Document(text=str(i), tags={'label': i}) for i in range(3)]) + for d in da: + d.matches = da + r = da.evaluate(metric_fn) + assert abs(r - metric_score) < 0.001 + + @pytest.mark.parametrize( 'storage, config', [ @@ -79,7 +186,7 @@ def test_eval_mixin_zero_match(storage, config, metric_fn, start_storage, kwargs da2_index = DocumentArray(da2, storage=storage, config=config) da2.match(da2_index, exclude_self=True) - r = da1.evaluate(da2, metric=metric_fn, **kwargs) + r = da1.evaluate(ground_truth=da2, metric=metric_fn, **kwargs) assert isinstance(r, float) assert r == 1.0 for d in da1: @@ -102,8 +209,10 @@ def test_eval_mixin_zero_match(storage, config, metric_fn, start_storage, kwargs def test_diff_len_should_raise(storage, config, start_storage): da1 = DocumentArray.empty(10) da2 = DocumentArray.empty(5, storage=storage, config=config) + for d in da2: + d.matches.append(da2[0]) with pytest.raises(ValueError): - da1.evaluate(da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metric='precision_at_k') @pytest.mark.parametrize( @@ -121,8 +230,10 @@ def test_diff_len_should_raise(storage, config, start_storage): def test_diff_hash_fun_should_raise(storage, config, start_storage): da1 = DocumentArray.empty(10) da2 = DocumentArray.empty(10, storage=storage, config=config) + for d in da2: + d.matches.append(da2[0]) with pytest.raises(ValueError): - da1.evaluate(da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metric='precision_at_k') @pytest.mark.parametrize( @@ -147,11 +258,11 @@ def test_same_hash_same_len_fun_should_work(storage, config, start_storage): da2_index = DocumentArray(da1, storage=storage, config=config) da2.match(da2_index) with pytest.raises(ValueError): - da1.evaluate(da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metric='precision_at_k') for d1, d2 in zip(da1, da2): d1.id = d2.id - da1.evaluate(da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metric='precision_at_k') @pytest.mark.parametrize( @@ -179,7 +290,7 @@ def test_adding_noise(storage, config, start_storage): d.matches.extend(DocumentArray.empty(10)) d.matches = d.matches.shuffle() - assert da2.evaluate(da, metric='precision_at_k', k=10) < 1.0 + assert da2.evaluate(ground_truth=da, metric='precision_at_k', k=10) < 1.0 for d in da2: assert 0.0 < d.evaluations['precision_at_k'].value < 1.0 @@ -217,7 +328,7 @@ def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs) # pop some matches from first document da2[0].matches.pop(8) - r = da1.evaluate(da2, metric=metric_fn, **kwargs) + r = da1.evaluate(ground_truth=da2, metric=metric_fn, **kwargs) assert isinstance(r, float) np.testing.assert_allclose(r, 1.0, rtol=1e-2) # for d in da1: From 99e3e66f15894fff6e9371bad4d1bc571331e2d1 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Wed, 12 Oct 2022 18:24:38 +0200 Subject: [PATCH 02/22] refactor: add matches before transfering to backend, add example to docs --- docs/fundamentals/documentarray/evaluation.md | 15 +++++++++++++++ tests/unit/array/mixins/test_eval_class.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/docs/fundamentals/documentarray/evaluation.md b/docs/fundamentals/documentarray/evaluation.md index 82e89334bc6..cd731751646 100644 --- a/docs/fundamentals/documentarray/evaluation.md +++ b/docs/fundamentals/documentarray/evaluation.md @@ -6,6 +6,21 @@ After you get `.matches` from the last chapter, you can easily evaluate matches da_predict.evaluate(ground_truth=da_groundtruth, metric='...', **kwargs) ``` +Alternatively, you can add labels to your documents to evaluate them. +In this case, a match is considered as relevant to its root document, if it has the same label. + +```python +import numpy as np +from docarray import Document, DocumentArray + +example_da = DocumentArray([Document(tags={'label': (i % 2)}) for i in range(10)]) +example_da.embeddings = np.random.random([10, 3]) + +example_da.match(example_da) + +example_da.evaluate(metric='precision_at_k') +``` + The results are stored in `.evaluations` field of each Document. DocArray provides some common metrics used in the information retrieval community that allows one to evaluate the nearest-neighbour matches. Different metric accepts different arguments as `kwargs`: diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 4c48bf95352..10c2d187ec3 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -208,9 +208,10 @@ def test_eval_mixin_zero_match(storage, config, metric_fn, start_storage, kwargs ) def test_diff_len_should_raise(storage, config, start_storage): da1 = DocumentArray.empty(10) - da2 = DocumentArray.empty(5, storage=storage, config=config) + da2 = DocumentArray.empty(5) for d in da2: d.matches.append(da2[0]) + da2 = DocumentArray(da2, storage=storage, config=config) with pytest.raises(ValueError): da1.evaluate(ground_truth=da2, metric='precision_at_k') @@ -229,9 +230,9 @@ def test_diff_len_should_raise(storage, config, start_storage): ) def test_diff_hash_fun_should_raise(storage, config, start_storage): da1 = DocumentArray.empty(10) - da2 = DocumentArray.empty(10, storage=storage, config=config) for d in da2: d.matches.append(da2[0]) + da2 = DocumentArray(da2, storage=storage, config=config) with pytest.raises(ValueError): da1.evaluate(ground_truth=da2, metric='precision_at_k') From bedfc7a43c8cf48a8544d9f38af114127aeba55c Mon Sep 17 00:00:00 2001 From: guenthermi Date: Thu, 13 Oct 2022 09:05:20 +0200 Subject: [PATCH 03/22] fix: missing initialization of d2 --- tests/unit/array/mixins/test_eval_class.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 10c2d187ec3..40a46ee5d1f 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -230,6 +230,7 @@ def test_diff_len_should_raise(storage, config, start_storage): ) def test_diff_hash_fun_should_raise(storage, config, start_storage): da1 = DocumentArray.empty(10) + da2 = DocumentArray.empty(5) for d in da2: d.matches.append(da2[0]) da2 = DocumentArray(da2, storage=storage, config=config) From 23fafd34ea51f56337ab797e2c1337e34cba6e37 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Thu, 13 Oct 2022 10:29:09 +0200 Subject: [PATCH 04/22] test: add tests to check if exceptions and warnings are raised --- tests/unit/array/mixins/test_eval_class.py | 62 ++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 40a46ee5d1f..6c217909608 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -337,3 +337,65 @@ def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs) d: Document # f1_score does not yield 1 for the first document as one of the match is missing assert d.evaluations[metric_fn].value > 0.9 + + +@pytest.mark.parametrize( + 'storage, config', + [ + ('memory', {}), + ('weaviate', {}), + ('sqlite', {}), + ('annlite', {'n_dim': 256}), + ('qdrant', {'n_dim': 256}), + ('elasticsearch', {'n_dim': 256}), + ('redis', {'n_dim': 256}), + ], +) +def test_empty_da_should_raise(storage, config, start_storage): + da = DocumentArray([], storage=storage, config=config) + with pytest.raises(ValueError): + da.evaluate(metric='precision_at_k') + + +@pytest.mark.parametrize( + 'storage, config', + [ + ('memory', {}), + ('weaviate', {}), + ('sqlite', {}), + ('annlite', {'n_dim': 256}), + ('qdrant', {'n_dim': 256}), + ('elasticsearch', {'n_dim': 256}), + ('redis', {'n_dim': 256}), + ], +) +def test_missing_groundtruth_should_raise(storage, config, start_storage): + da = DocumentArray(DocumentArray.empty(10), storage=storage, config=config) + with pytest.raises(RuntimeError): + da.evaluate(metric='precision_at_k') + + +@pytest.mark.parametrize( + 'storage, config', + [ + ('memory', {}), + ('weaviate', {}), + ('sqlite', {}), + ('annlite', {'n_dim': 256}), + ('qdrant', {'n_dim': 256}), + ('elasticsearch', {'n_dim': 256}), + ('redis', {'n_dim': 256}), + ], +) +def test_missing_groundtruth_should_raise(storage, config, start_storage): + da1 = DocumentArray.empty(10) + for d in da1: + d.tags = {'label': 'A'} + da1.embeddings = np.random.random([10, 256]) + da1_index = DocumentArray(da1, storage=storage, config=config) + for d in da1_index: + d.tags = {'label': 'A'} + da1.match(da1_index, exclude_self=True) + da2 = DocumentArray.empty(10) + with pytest.warns(UserWarning): + da1.evaluate(ground_truth=da2, metric='precision_at_k') From dfa4aa31d8c6ff02c5cabc198fbd47f91e7521cd Mon Sep 17 00:00:00 2001 From: guenthermi Date: Thu, 13 Oct 2022 11:52:14 +0200 Subject: [PATCH 05/22] fix: duplicate test name --- tests/unit/array/mixins/test_eval_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 6c217909608..10c48238562 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -387,7 +387,7 @@ def test_missing_groundtruth_should_raise(storage, config, start_storage): ('redis', {'n_dim': 256}), ], ) -def test_missing_groundtruth_should_raise(storage, config, start_storage): +def test_useless_groundtruth_warning_should_raise(storage, config, start_storage): da1 = DocumentArray.empty(10) for d in da1: d.tags = {'label': 'A'} From cfb6b2d1e41c69797b37290757f0971ee3c2bb8b Mon Sep 17 00:00:00 2001 From: guenthermi Date: Mon, 17 Oct 2022 14:20:15 +0200 Subject: [PATCH 06/22] refactor: implement review notes --- docarray/array/mixins/evaluation.py | 51 ++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index be730afee00..11922f832d7 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -9,9 +9,38 @@ from docarray import Document, DocumentArray +def _evaluate_deprecation(f): + """Raises a deprecation warning if the user executes the evaluate function with + the old interface and adjust the input to fit the new interface.""" + + def func(*args, **kwargs): + if len(args) > 1: + if not (isinstance(args[1], Callable) or isinstance(args[1], str)): + kwargs['ground_truth'] = args[1] + args = [args[0]] + list(args[2:]) + warnings.warn( + 'The `other` attribute in `evaluate()` is transfered from a ' + 'positional attribute into the keyword attribute `ground_truth`.' + 'Using it as a positional attribute is deprecated and will be removed ' + 'in the next version.', + DeprecationWarning, + ) + if 'other' in kwargs: + kwargs['ground_truth'] = kwargs['other'] + warnings.warn( + '`other` is renamed to `groundtruth` in `evaluate()`, the usage of `other` is ' + 'deprecated and will be removed in the next version.', + DeprecationWarning, + ) + return f(*args, **kwargs) + + return func + + class EvaluationMixin: """A mixin that provides ranking evaluation functionality to DocumentArrayLike objects""" + @_evaluate_deprecation def evaluate( self, metric: Union[str, Callable[..., float]], @@ -19,18 +48,20 @@ def evaluate( hash_fn: Optional[Callable[['Document'], str]] = None, metric_name: Optional[str] = None, strict: bool = True, - label_tag='label', + label_tag: Optional[str] = 'label', **kwargs, ) -> Optional[float]: """ Compute ranking evaluation metrics for a given `DocumentArray` when compared with a groundtruth. - This implementation expects the documents and their matches to have labels - annotated inside the tag with the key specified in the `label_tag` attribute. - Alternatively, one can provide a `ground_truth` DocumentArray that is - structurally identical to `self`. In this case, this function compares the - `matches` of `documents` inside the `DocumentArray`. + If one provides a `ground_truth` DocumentArray that is structurally identical + to `self`, this function compares the `matches` of `documents` inside the + `DocumentArray` to this `ground_truth`. + Alternatively, one can directly annotate the documents by adding labels in the + form of tags with the key specified in the `label_tag` attribute. + Those tags need to be added to `self` as well as to the documents in the + matches properties. This method will fill the `evaluations` field of Documents inside this `DocumentArray` and will return the average of the computations @@ -46,6 +77,7 @@ def evaluate( :param strict: If set, then left and right sides are required to be fully aligned: on the length, and on the semantic of length. These are preventing you to evaluate on irrelevant matches accidentally. + :param label_tag: Specifies the tag which contains the labels. :param kwargs: Additional keyword arguments to be passed to `metric_fn` :return: The average evaluation computed or a list of them if multiple metrics are required @@ -57,7 +89,7 @@ def evaluate( elif label_tag in self[0].tags: if ground_truth: warnings.warn( - 'An ground_truth attribute is provided but does not ' + 'A ground_truth attribute is provided but does not ' 'contain matches. The labels are used instead and ' 'ground_truth is ignored.' ) @@ -117,7 +149,10 @@ def evaluate( 1 if m.tags[label_tag] == d.tags[label_tag] else 0 for m in targets ] else: - raise RuntimeError(f'Unsupported groundtruth type {ground_truth_type}') + raise RuntimeError( + 'Could not identify which kind of ground truth' + 'information is provided to evaluate the matches.' + ) r = metric_fn(binary_relevance, max_rel=max_rel, **kwargs) d.evaluations[metric_name] = NamedScore( From c499634cc57e4f5fb279589aa8aa17c82376aa3c Mon Sep 17 00:00:00 2001 From: guenthermi Date: Mon, 17 Oct 2022 14:28:29 +0200 Subject: [PATCH 07/22] fix: update r_precision test --- tests/unit/array/mixins/test_eval_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 10c48238562..a29c9254aab 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -132,7 +132,7 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar @pytest.mark.parametrize( 'metric_fn, metric_score', [ - ('r_precision', 0.6111111), # might be changed + ('r_precision', 1.0 / 3), # might be changed ('precision_at_k', 1.0 / 3), ('hit_at_k', 1.0), ('average_precision', (1.0 + 0.5 + (1.0 / 3)) / 3), From 15ea32db3ddfc388e02d958bfb4364c1c226d691 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Mon, 17 Oct 2022 14:40:43 +0200 Subject: [PATCH 08/22] refactor: remove comment --- tests/unit/array/mixins/test_eval_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index a29c9254aab..2c0439f6902 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -132,7 +132,7 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar @pytest.mark.parametrize( 'metric_fn, metric_score', [ - ('r_precision', 1.0 / 3), # might be changed + ('r_precision', 1.0 / 3), ('precision_at_k', 1.0 / 3), ('hit_at_k', 1.0), ('average_precision', (1.0 + 0.5 + (1.0 / 3)) / 3), From 798d94faf3b7aab94c1caa476bdd0a8a90705f0e Mon Sep 17 00:00:00 2001 From: guenthermi Date: Mon, 17 Oct 2022 16:14:50 +0200 Subject: [PATCH 09/22] feat: support multiple metrics in evaluation function --- docarray/array/mixins/evaluation.py | 79 ++++++++++++------- docs/fundamentals/documentarray/evaluation.md | 4 +- tests/unit/array/mixins/test_eval_class.py | 14 ++-- 3 files changed, 62 insertions(+), 35 deletions(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 11922f832d7..ff472f49fd9 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -1,7 +1,8 @@ import warnings -from typing import Optional, Union, TYPE_CHECKING, Callable +from typing import Optional, Union, TYPE_CHECKING, Callable, List, Dict import numpy as np +from collections import defaultdict from docarray.score import NamedScore @@ -15,7 +16,11 @@ def _evaluate_deprecation(f): def func(*args, **kwargs): if len(args) > 1: - if not (isinstance(args[1], Callable) or isinstance(args[1], str)): + if not ( + isinstance(args[1], Callable) + or isinstance(args[1], str) + or isinstance(args[1], list) + ): kwargs['ground_truth'] = args[1] args = [args[0]] + list(args[2:]) warnings.warn( @@ -43,14 +48,16 @@ class EvaluationMixin: @_evaluate_deprecation def evaluate( self, - metric: Union[str, Callable[..., float]], + metrics: Union[ + Union[str, Callable[..., float]], List[Union[str, Callable[..., float]]] + ], ground_truth: Optional['DocumentArray'] = None, hash_fn: Optional[Callable[['Document'], str]] = None, - metric_name: Optional[str] = None, + metric_names: Optional[Union[str, List[str]]] = None, strict: bool = True, label_tag: Optional[str] = 'label', **kwargs, - ) -> Optional[float]: + ) -> Dict[str, float]: """ Compute ranking evaluation metrics for a given `DocumentArray` when compared with a groundtruth. @@ -66,21 +73,22 @@ def evaluate( This method will fill the `evaluations` field of Documents inside this `DocumentArray` and will return the average of the computations - :param metric: The name of the metric, or multiple metrics to be computed + :param metrics: One or multiple name of the metrics or metric functions to be computed :param ground_truth: The ground_truth `DocumentArray` that the `DocumentArray` compares to. :param hash_fn: The function used for identifying the uniqueness of Documents. If not given, then ``Document.id`` is used. - :param metric_name: If provided, the results of the metrics computation will be - stored in the `evaluations` field of each Document. If not provided, the - name will be computed based on the metrics name. + :param metric_names: If provided, the results of the metrics computation will be + stored in the `evaluations` field of each Document with this name. If not + provided, the names will be derived from the metric function names. :param strict: If set, then left and right sides are required to be fully aligned: on the length, and on the semantic of length. These are preventing you to evaluate on irrelevant matches accidentally. :param label_tag: Specifies the tag which contains the labels. - :param kwargs: Additional keyword arguments to be passed to `metric_fn` - :return: The average evaluation computed or a list of them if multiple metrics - are required + :param kwargs: Additional keyword arguments to be passed to the metric + functions. + :return: A dictionary which stores for each metric name the average evaluation + score. """ if len(self) == 0: raise ValueError('It is not possible to evaluate an empty DocumentArray') @@ -108,15 +116,30 @@ def evaluate( if hash_fn is None: hash_fn = lambda d: d.id - if callable(metric): - metric_fn = metric - elif isinstance(metric, str): - from docarray.math import evaluation + if type(metrics) is not list: + metrics = [metrics] + if type(metric_names) is str: + metric_names = [metric_names] - metric_fn = getattr(evaluation, metric) + metric_fns = [] + for metric in metrics: + if callable(metric): + metric_fns.append(metric) + elif isinstance(metric, str): + from docarray.math import evaluation - metric_name = metric_name or metric_fn.__name__ - results = [] + metric_fns.append(getattr(evaluation, metric)) + + if not metric_names: + metric_names = [metric_fn.__name__ for metric_fn in metric_fns] + + if len(metric_names) != len(metrics): + raise ValueError( + 'Could not match metric names to the metrics since the number of ' + 'metric names does not match the number of metrics' + ) + + results = defaultdict(list) caller_max_rel = kwargs.pop('max_rel', None) for d, gd in zip(self, ground_truth): max_rel = caller_max_rel or len(gd.matches) @@ -153,11 +176,13 @@ def evaluate( 'Could not identify which kind of ground truth' 'information is provided to evaluate the matches.' ) - - r = metric_fn(binary_relevance, max_rel=max_rel, **kwargs) - d.evaluations[metric_name] = NamedScore( - value=r, op_name=str(metric_fn), ref_id=d.id - ) - results.append(r) - if results: - return float(np.mean(results)) + for metric_name, metric_fn in zip(metric_names, metric_fns): + r = metric_fn(binary_relevance, max_rel=max_rel, **kwargs) + d.evaluations[metric_name] = NamedScore( + value=r, op_name=str(metric_fn), ref_id=d.id + ) + results[metric_name].append(r) + return { + metric_name: float(np.mean(values)) + for metric_name, values in results.items() + } diff --git a/docs/fundamentals/documentarray/evaluation.md b/docs/fundamentals/documentarray/evaluation.md index aaeec9e3e44..994bc0d82ad 100644 --- a/docs/fundamentals/documentarray/evaluation.md +++ b/docs/fundamentals/documentarray/evaluation.md @@ -116,7 +116,7 @@ da2.evaluate(ground_truth=da, metric='precision_at_k', k=5) ``` ```text -0.48 +{'precision_at_k': 0.48} ``` Note that this value is an average number over all Documents of `da2`. If you want to look at the individual evaluation, you can check {attr}`~docarray.Document.evaluations` attribute, e.g. @@ -179,7 +179,7 @@ p_da.evaluate('average_precision', ground_truth=g_da, hash_fn=lambda d: d.text[: ``` ```text -1.0 +{'average_precision': 1.0} ``` It is correct as we define the evaluation as checking if the first two characters in `.text` are the same. diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 2c0439f6902..7db5b2aed9e 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -36,7 +36,9 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs, storage, config, start_stor da1.embeddings = np.random.random([10, 256]) da1_index = DocumentArray(da1, storage=storage, config=config) da1.match(da1_index, exclude_self=True) - r = da1.evaluate(ground_truth=da1, metric=metric_fn, strict=False, **kwargs) + r = da1.evaluate(ground_truth=da1, metric=metric_fn, strict=False, **kwargs)[ + metric_fn + ] assert isinstance(r, float) assert r == 1.0 for d in da1: @@ -79,7 +81,7 @@ def test_eval_mixin_perfect_match_labeled( for d in da1_index: d.tags = {'label': 'A'} da1.match(da1_index, exclude_self=True) - r = da1.evaluate(metric=metric_fn, **kwargs) + r = da1.evaluate(metric=metric_fn, **kwargs)[metric_fn] assert isinstance(r, float) assert r == 1.0 for d in da1: @@ -121,7 +123,7 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar d.tags = {'label': 'B'} da1_index = DocumentArray(da2, storage=storage, config=config) da1.match(da1_index, exclude_self=True) - r = da1.evaluate(metric_fn, **kwargs) + r = da1.evaluate(metric_fn, **kwargs)[metric_fn] assert isinstance(r, float) assert r == 0.0 for d in da1: @@ -146,7 +148,7 @@ def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score): da = DocumentArray([Document(text=str(i), tags={'label': i}) for i in range(3)]) for d in da: d.matches = da - r = da.evaluate(metric_fn) + r = da.evaluate(metric_fn)[metric_fn] assert abs(r - metric_score) < 0.001 @@ -186,7 +188,7 @@ def test_eval_mixin_zero_match(storage, config, metric_fn, start_storage, kwargs da2_index = DocumentArray(da2, storage=storage, config=config) da2.match(da2_index, exclude_self=True) - r = da1.evaluate(ground_truth=da2, metric=metric_fn, **kwargs) + r = da1.evaluate(ground_truth=da2, metric=metric_fn, **kwargs)[metric_fn] assert isinstance(r, float) assert r == 1.0 for d in da1: @@ -330,7 +332,7 @@ def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs) # pop some matches from first document da2[0].matches.pop(8) - r = da1.evaluate(ground_truth=da2, metric=metric_fn, **kwargs) + r = da1.evaluate(ground_truth=da2, metric=metric_fn, **kwargs)[metric_fn] assert isinstance(r, float) np.testing.assert_allclose(r, 1.0, rtol=1e-2) # for d in da1: From 087d7d67ce997bfa950c86847534755202c8478d Mon Sep 17 00:00:00 2001 From: guenthermi Date: Mon, 17 Oct 2022 17:17:39 +0200 Subject: [PATCH 10/22] docs: update evaluate functions in docs --- docs/fundamentals/documentarray/evaluation.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/fundamentals/documentarray/evaluation.md b/docs/fundamentals/documentarray/evaluation.md index 994bc0d82ad..b13d7281555 100644 --- a/docs/fundamentals/documentarray/evaluation.md +++ b/docs/fundamentals/documentarray/evaluation.md @@ -3,7 +3,7 @@ After you get `.matches` from the last chapter, you can easily evaluate matches against the groundtruth via {meth}`~docarray.array.mixins.evaluation.EvaluationMixin.evaluate`. ```python -da_predict.evaluate(ground_truth=da_groundtruth, metric='...', **kwargs) +da_predict.evaluate(ground_truth=da_groundtruth, metrics='...', **kwargs) ``` Alternatively, you can add labels to your documents to evaluate them. @@ -18,7 +18,7 @@ example_da.embeddings = np.random.random([10, 3]) example_da.match(example_da) -example_da.evaluate(metric='precision_at_k') +example_da.evaluate(metrics='precision_at_k') ``` The results are stored in `.evaluations` field of each Document. @@ -109,10 +109,10 @@ da2['@m'].summary() -Now `da2` is our prediction, and `da` is our groundtruth. If we evaluate the average Precision@10, we should get something close to 0.5 (we have 10 real matches, we mixed in 10 fake matches and shuffle it, so top-10 would have approximate 10/20 real matches): +Now `da2` is our prediction, and `da` is our groundtruth. If we evaluate the average Precision@10, we should get something close to 0.47 (we have 9 real matches, we mixed in 10 fake matches and shuffle it, so top-10 would have approximate 9/19 real matches): ```python -da2.evaluate(ground_truth=da, metric='precision_at_k', k=5) +da2.evaluate(ground_truth=da, metrics='precision_at_k', k=10) ``` ```text From a5a697b28c44ac6b9df1e417f71c471abf601a5f Mon Sep 17 00:00:00 2001 From: guenthermi Date: Tue, 18 Oct 2022 09:29:50 +0200 Subject: [PATCH 11/22] fix: change metric to metrics in tests --- tests/unit/array/mixins/test_eval_class.py | 24 +++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 7db5b2aed9e..7bc0c284a04 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -36,7 +36,7 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs, storage, config, start_stor da1.embeddings = np.random.random([10, 256]) da1_index = DocumentArray(da1, storage=storage, config=config) da1.match(da1_index, exclude_self=True) - r = da1.evaluate(ground_truth=da1, metric=metric_fn, strict=False, **kwargs)[ + r = da1.evaluate(ground_truth=da1, metrics=metric_fn, strict=False, **kwargs)[ metric_fn ] assert isinstance(r, float) @@ -81,7 +81,7 @@ def test_eval_mixin_perfect_match_labeled( for d in da1_index: d.tags = {'label': 'A'} da1.match(da1_index, exclude_self=True) - r = da1.evaluate(metric=metric_fn, **kwargs)[metric_fn] + r = da1.evaluate(metrics=metric_fn, **kwargs)[metric_fn] assert isinstance(r, float) assert r == 1.0 for d in da1: @@ -188,7 +188,7 @@ def test_eval_mixin_zero_match(storage, config, metric_fn, start_storage, kwargs da2_index = DocumentArray(da2, storage=storage, config=config) da2.match(da2_index, exclude_self=True) - r = da1.evaluate(ground_truth=da2, metric=metric_fn, **kwargs)[metric_fn] + r = da1.evaluate(ground_truth=da2, metrics=metric_fn, **kwargs)[metric_fn] assert isinstance(r, float) assert r == 1.0 for d in da1: @@ -215,7 +215,7 @@ def test_diff_len_should_raise(storage, config, start_storage): d.matches.append(da2[0]) da2 = DocumentArray(da2, storage=storage, config=config) with pytest.raises(ValueError): - da1.evaluate(ground_truth=da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metrics='precision_at_k') @pytest.mark.parametrize( @@ -237,7 +237,7 @@ def test_diff_hash_fun_should_raise(storage, config, start_storage): d.matches.append(da2[0]) da2 = DocumentArray(da2, storage=storage, config=config) with pytest.raises(ValueError): - da1.evaluate(ground_truth=da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metrics='precision_at_k') @pytest.mark.parametrize( @@ -262,11 +262,11 @@ def test_same_hash_same_len_fun_should_work(storage, config, start_storage): da2_index = DocumentArray(da1, storage=storage, config=config) da2.match(da2_index) with pytest.raises(ValueError): - da1.evaluate(ground_truth=da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metrics='precision_at_k') for d1, d2 in zip(da1, da2): d1.id = d2.id - da1.evaluate(ground_truth=da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metrics='precision_at_k') @pytest.mark.parametrize( @@ -294,7 +294,7 @@ def test_adding_noise(storage, config, start_storage): d.matches.extend(DocumentArray.empty(10)) d.matches = d.matches.shuffle() - assert da2.evaluate(ground_truth=da, metric='precision_at_k', k=10) < 1.0 + assert da2.evaluate(ground_truth=da, metrics='precision_at_k', k=10) < 1.0 for d in da2: assert 0.0 < d.evaluations['precision_at_k'].value < 1.0 @@ -332,7 +332,7 @@ def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs) # pop some matches from first document da2[0].matches.pop(8) - r = da1.evaluate(ground_truth=da2, metric=metric_fn, **kwargs)[metric_fn] + r = da1.evaluate(ground_truth=da2, metrics=metric_fn, **kwargs)[metric_fn] assert isinstance(r, float) np.testing.assert_allclose(r, 1.0, rtol=1e-2) # for d in da1: @@ -356,7 +356,7 @@ def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs) def test_empty_da_should_raise(storage, config, start_storage): da = DocumentArray([], storage=storage, config=config) with pytest.raises(ValueError): - da.evaluate(metric='precision_at_k') + da.evaluate(metrics='precision_at_k') @pytest.mark.parametrize( @@ -374,7 +374,7 @@ def test_empty_da_should_raise(storage, config, start_storage): def test_missing_groundtruth_should_raise(storage, config, start_storage): da = DocumentArray(DocumentArray.empty(10), storage=storage, config=config) with pytest.raises(RuntimeError): - da.evaluate(metric='precision_at_k') + da.evaluate(metrics='precision_at_k') @pytest.mark.parametrize( @@ -400,4 +400,4 @@ def test_useless_groundtruth_warning_should_raise(storage, config, start_storage da1.match(da1_index, exclude_self=True) da2 = DocumentArray.empty(10) with pytest.warns(UserWarning): - da1.evaluate(ground_truth=da2, metric='precision_at_k') + da1.evaluate(ground_truth=da2, metrics='precision_at_k') From 49426766b9e03990c18c69dd32394c57a15b5bc3 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Tue, 18 Oct 2022 17:16:38 +0200 Subject: [PATCH 12/22] test: add test for multiple metrics --- tests/unit/array/mixins/test_eval_class.py | 42 +++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 8e036734669..0f7bdb7b414 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -45,6 +45,43 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs, storage, config, start_stor assert d.evaluations[metric_fn].value == 1.0 +@pytest.mark.parametrize( + 'storage, config', + [ + ('memory', {}), + ('weaviate', {}), + ('sqlite', {}), + ('annlite', {'n_dim': 256}), + ('qdrant', {'n_dim': 256}), + ('elasticsearch', {'n_dim': 256}), + ('redis', {'n_dim': 256}), + ], +) +def test_eval_mixin_perfect_match_multiple_metrics(storage, config, start_storage): + metric_fns = [ + 'r_precision', + 'precision_at_k', + 'hit_at_k', + 'average_precision', + 'reciprocal_rank', + 'recall_at_k', + 'f1_score_at_k', + 'ndcg_at_k', + ] + kwargs = {'max_rel': 9} + da1 = DocumentArray.empty(10) + da1.embeddings = np.random.random([10, 256]) + da1_index = DocumentArray(da1, storage=storage, config=config) + da1.match(da1_index, exclude_self=True) + r = da1.evaluate(ground_truth=da1, metrics=metric_fns, strict=False, **kwargs) + for metric_fn in metric_fns: + assert metric_fn in r + assert isinstance(r[metric_fn], float) + assert r[metric_fn] == 1.0 + for d in da1: + assert d.evaluations[metric_fn].value == 1.0 + + @pytest.mark.parametrize( 'storage, config', [ @@ -291,7 +328,10 @@ def test_adding_noise(storage, config, start_storage): d.matches.extend(DocumentArray.empty(10)) d.matches = d.matches.shuffle() - assert da2.evaluate(ground_truth=da, metrics='precision_at_k', k=10) < 1.0 + assert ( + da2.evaluate(ground_truth=da, metrics='precision_at_k', k=10)['precision_at_k'] + < 1.0 + ) for d in da2: assert 0.0 < d.evaluations['precision_at_k'].value < 1.0 From c607fcfea192ca7a4f60314e0cdc570459a52047 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Tue, 18 Oct 2022 17:37:53 +0200 Subject: [PATCH 13/22] refactor: handle metric and metric_name in deprecation decorator --- docarray/array/mixins/evaluation.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 792e193cb0b..caf5a11e8f0 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -30,13 +30,14 @@ def func(*args, **kwargs): 'soon.', DeprecationWarning, ) - if 'other' in kwargs: - kwargs['ground_truth'] = kwargs['other'] - warnings.warn( - '`other` is renamed to `groundtruth` in `evaluate()`, the usage of `other` is ' - 'deprecated and will be removed soon.', - DeprecationWarning, - ) + for old_key, new_key in zip(['other', 'metric', 'metric_name'], ['ground_truth', 'metrics', 'metric_names']): + if old_key in kwargs: + kwargs[new_key] = kwargs[old_key] + warnings.warn( + '`other` is renamed to `groundtruth` in `evaluate()`, the usage of `other` is ' + 'deprecated and will be removed soon.', + DeprecationWarning, + ) return f(*args, **kwargs) return func @@ -55,7 +56,7 @@ def evaluate( hash_fn: Optional[Callable[['Document'], str]] = None, metric_names: Optional[Union[str, List[str]]] = None, strict: bool = True, - label_tag: Optional[str] = 'label', + label_tag: str = 'label', **kwargs, ) -> Dict[str, float]: """ From 241c2944bb66b936381348529f1383c8a8630979 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Tue, 18 Oct 2022 17:45:34 +0200 Subject: [PATCH 14/22] fix: black --- docarray/array/mixins/evaluation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index caf5a11e8f0..5d8e38681f9 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -30,7 +30,10 @@ def func(*args, **kwargs): 'soon.', DeprecationWarning, ) - for old_key, new_key in zip(['other', 'metric', 'metric_name'], ['ground_truth', 'metrics', 'metric_names']): + for old_key, new_key in zip( + ['other', 'metric', 'metric_name'], + ['ground_truth', 'metrics', 'metric_names'], + ): if old_key in kwargs: kwargs[new_key] = kwargs[old_key] warnings.warn( From 70b44883fa6fcda9148ba9917b30e39100dca317 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Tue, 18 Oct 2022 17:57:29 +0200 Subject: [PATCH 15/22] fix: add max_rel only to function which have it --- docarray/array/mixins/evaluation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 5d8e38681f9..c21e04c2f54 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -181,7 +181,9 @@ def evaluate( 'information is provided to evaluate the matches.' ) for metric_name, metric_fn in zip(metric_names, metric_fns): - r = metric_fn(binary_relevance, max_rel=max_rel, **kwargs) + if 'max_rel' in metric_fn.__code__.co_varnames: + kwargs['max_rel'] = max_rel + r = metric_fn(binary_relevance, **kwargs) d.evaluations[metric_name] = NamedScore( value=r, op_name=str(metric_fn), ref_id=d.id ) From 6dac178309156e9424a2a083b9c96fcb05958070 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Wed, 19 Oct 2022 10:00:03 +0200 Subject: [PATCH 16/22] fix: set old_key and new_key in deprecation warning --- docarray/array/mixins/evaluation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index c21e04c2f54..1aea06a227d 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -37,8 +37,8 @@ def func(*args, **kwargs): if old_key in kwargs: kwargs[new_key] = kwargs[old_key] warnings.warn( - '`other` is renamed to `groundtruth` in `evaluate()`, the usage of `other` is ' - 'deprecated and will be removed soon.', + f'`{old_key}` is renamed to `{new_key}` in `evaluate()`, the ' + f'usage of `{old_key}` is deprecated and will be removed soon.', DeprecationWarning, ) return f(*args, **kwargs) From 30da1eae906a25f0d905497233bd73810b271497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20G=C3=BCnther?= Date: Wed, 19 Oct 2022 10:39:02 +0200 Subject: [PATCH 17/22] docs: improve formulation in evaluation.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Signed-off-by: Michael Günther --- docs/fundamentals/documentarray/evaluation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/fundamentals/documentarray/evaluation.md b/docs/fundamentals/documentarray/evaluation.md index 82b033602da..74a885c383f 100644 --- a/docs/fundamentals/documentarray/evaluation.md +++ b/docs/fundamentals/documentarray/evaluation.md @@ -1,6 +1,6 @@ # Evaluate Matches -After you get `.matches` from the last chapter, you can easily evaluate matches against the groundtruth via {meth}`~docarray.array.mixins.evaluation.EvaluationMixin.evaluate`. +After you get `.matches`, you can evaluate matches against the groundtruth via {meth}`~docarray.array.mixins.evaluation.EvaluationMixin.evaluate`. ```python da_predict.evaluate(ground_truth=da_groundtruth, metrics='...', **kwargs) From 8ac740230c9ca64a63f76192a975fa17bebf3be5 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Wed, 19 Oct 2022 10:53:56 +0200 Subject: [PATCH 18/22] docs: add multi metric example to docs --- docs/fundamentals/documentarray/evaluation.md | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/fundamentals/documentarray/evaluation.md b/docs/fundamentals/documentarray/evaluation.md index 74a885c383f..26f27ee35bf 100644 --- a/docs/fundamentals/documentarray/evaluation.md +++ b/docs/fundamentals/documentarray/evaluation.md @@ -127,18 +127,30 @@ for d in da2: ``` ```text +0.5 0.4 -0.4 -0.6 +0.3 0.6 -0.2 -0.4 -0.8 -0.8 -0.2 +0.5 +0.3 0.4 +0.6 +0.5 +0.7 +``` + +If you want to evaluate your data with multiple metric functions, you can pass a list of metrics: + +```python +da2.evaluate(ground_truth=da, metrics=['precision_at_k', 'reciprocal_rank'], k=10) +``` + +```text +{'precision_at_k': 0.48, 'reciprocal_rank': 0.6333333333333333} ``` +In this case, the keyword attribute `k` is passed to all metric functions, even though it does not fulfill any specific function for the calculation of the reciprocal rank. + ## Document identifier Note that `.evaluate()` works only when two DocumentArray have the same length and their nested structure are same. It makes no sense to evaluate on two completely irrelevant DocumentArrays. From 3f6f2dea66fa6731c793e6822a3ddea913d8fcf7 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Wed, 19 Oct 2022 12:15:27 +0200 Subject: [PATCH 19/22] refactor: only support lists --- docarray/array/mixins/evaluation.py | 30 ++++++++++++------- docs/fundamentals/documentarray/evaluation.md | 6 ++-- tests/unit/array/mixins/test_eval_class.py | 26 ++++++++-------- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 1aea06a227d..b3cba473547 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -41,6 +41,21 @@ def func(*args, **kwargs): f'usage of `{old_key}` is deprecated and will be removed soon.', DeprecationWarning, ) + + # transfer metrics and metric_names into lists + list_warning_msg = ( + 'The attribute `%s` now accepts a list instead of a ' + 'single element. Passing a single element is deprecated and will soon not ' + 'be supported anymore.' + ) + if len(args) > 1: + if type(args[1]) is str: + args[1] = [args[1]] + warnings.warn(list_warning_msg % 'metrics', DeprecationWarning) + for key in ['metrics', 'metric_names']: + if key in kwargs and type(kwargs[key]) is str: + kwargs[key] = [kwargs[key]] + warnings.warn(list_warning_msg % key, DeprecationWarning) return f(*args, **kwargs) return func @@ -52,12 +67,10 @@ class EvaluationMixin: @_evaluate_deprecation def evaluate( self, - metrics: Union[ - Union[str, Callable[..., float]], List[Union[str, Callable[..., float]]] - ], + metrics: List[Union[str, Callable[..., float]]], ground_truth: Optional['DocumentArray'] = None, hash_fn: Optional[Callable[['Document'], str]] = None, - metric_names: Optional[Union[str, List[str]]] = None, + metric_names: Optional[List[str]] = None, strict: bool = True, label_tag: str = 'label', **kwargs, @@ -77,13 +90,13 @@ def evaluate( This method will fill the `evaluations` field of Documents inside this `DocumentArray` and will return the average of the computations - :param metrics: One or multiple name of the metrics or metric functions to be computed + :param metrics: list of metric names or metric functions to be computed :param ground_truth: The ground_truth `DocumentArray` that the `DocumentArray` compares to. :param hash_fn: The function used for identifying the uniqueness of Documents. If not given, then ``Document.id`` is used. :param metric_names: If provided, the results of the metrics computation will be - stored in the `evaluations` field of each Document with this name. If not + stored in the `evaluations` field of each Document with this names. If not provided, the names will be derived from the metric function names. :param strict: If set, then left and right sides are required to be fully aligned: on the length, and on the semantic of length. These are preventing @@ -120,11 +133,6 @@ def evaluate( if hash_fn is None: hash_fn = lambda d: d.id - if type(metrics) is not list: - metrics = [metrics] - if type(metric_names) is str: - metric_names = [metric_names] - metric_fns = [] for metric in metrics: if callable(metric): diff --git a/docs/fundamentals/documentarray/evaluation.md b/docs/fundamentals/documentarray/evaluation.md index 26f27ee35bf..33454c09c19 100644 --- a/docs/fundamentals/documentarray/evaluation.md +++ b/docs/fundamentals/documentarray/evaluation.md @@ -3,7 +3,7 @@ After you get `.matches`, you can evaluate matches against the groundtruth via {meth}`~docarray.array.mixins.evaluation.EvaluationMixin.evaluate`. ```python -da_predict.evaluate(ground_truth=da_groundtruth, metrics='...', **kwargs) +da_predict.evaluate(ground_truth=da_groundtruth, metrics=['...'], **kwargs) ``` Alternatively, you can add labels to your documents to evaluate them. @@ -18,7 +18,7 @@ example_da.embeddings = np.random.random([10, 3]) example_da.match(example_da) -example_da.evaluate(metrics='precision_at_k') +example_da.evaluate(metrics=['precision_at_k']) ``` The results are stored in `.evaluations` field of each Document. @@ -112,7 +112,7 @@ da2['@m'].summary() Now `da2` is our prediction, and `da` is our groundtruth. If we evaluate the average Precision@10, we should get something close to 0.47 (we have 9 real matches, we mixed in 10 fake matches and shuffle it, so top-10 would have approximate 9/19 real matches): ```python -da2.evaluate(ground_truth=da, metrics='precision_at_k', k=10) +da2.evaluate(ground_truth=da, metrics=['precision_at_k'], k=10) ``` ```text diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 0f7bdb7b414..3c802f7eefe 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -36,7 +36,7 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs, storage, config, start_stor da1.embeddings = np.random.random([10, 256]) da1_index = DocumentArray(da1, storage=storage, config=config) da1.match(da1_index, exclude_self=True) - r = da1.evaluate(ground_truth=da1, metrics=metric_fn, strict=False, **kwargs)[ + r = da1.evaluate(ground_truth=da1, metrics=[metric_fn], strict=False, **kwargs)[ metric_fn ] assert isinstance(r, float) @@ -116,7 +116,7 @@ def test_eval_mixin_perfect_match_labeled( da1.embeddings = np.random.random([10, 256]) da1_index = DocumentArray(da1, storage=storage, config=config) da1.match(da1_index, exclude_self=True) - r = da1.evaluate(metrics=metric_fn, **kwargs)[metric_fn] + r = da1.evaluate(metrics=[metric_fn], **kwargs)[metric_fn] assert isinstance(r, float) assert r == 1.0 for d in da1: @@ -222,7 +222,7 @@ def test_eval_mixin_zero_match(storage, config, metric_fn, start_storage, kwargs da2_index = DocumentArray(da2, storage=storage, config=config) da2.match(da2_index, exclude_self=True) - r = da1.evaluate(ground_truth=da2, metrics=metric_fn, **kwargs)[metric_fn] + r = da1.evaluate(ground_truth=da2, metrics=[metric_fn], **kwargs)[metric_fn] assert isinstance(r, float) assert r == 1.0 for d in da1: @@ -249,7 +249,7 @@ def test_diff_len_should_raise(storage, config, start_storage): d.matches.append(da2[0]) da2 = DocumentArray(da2, storage=storage, config=config) with pytest.raises(ValueError): - da1.evaluate(ground_truth=da2, metrics='precision_at_k') + da1.evaluate(ground_truth=da2, metrics=['precision_at_k']) @pytest.mark.parametrize( @@ -271,7 +271,7 @@ def test_diff_hash_fun_should_raise(storage, config, start_storage): d.matches.append(da2[0]) da2 = DocumentArray(da2, storage=storage, config=config) with pytest.raises(ValueError): - da1.evaluate(ground_truth=da2, metrics='precision_at_k') + da1.evaluate(ground_truth=da2, metrics=['precision_at_k']) @pytest.mark.parametrize( @@ -296,11 +296,11 @@ def test_same_hash_same_len_fun_should_work(storage, config, start_storage): da2_index = DocumentArray(da1, storage=storage, config=config) da2.match(da2_index) with pytest.raises(ValueError): - da1.evaluate(ground_truth=da2, metrics='precision_at_k') + da1.evaluate(ground_truth=da2, metrics=['precision_at_k']) for d1, d2 in zip(da1, da2): d1.id = d2.id - da1.evaluate(ground_truth=da2, metrics='precision_at_k') + da1.evaluate(ground_truth=da2, metrics=['precision_at_k']) @pytest.mark.parametrize( @@ -329,7 +329,9 @@ def test_adding_noise(storage, config, start_storage): d.matches = d.matches.shuffle() assert ( - da2.evaluate(ground_truth=da, metrics='precision_at_k', k=10)['precision_at_k'] + da2.evaluate(ground_truth=da, metrics=['precision_at_k'], k=10)[ + 'precision_at_k' + ] < 1.0 ) @@ -369,7 +371,7 @@ def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs) # pop some matches from first document da2[0].matches.pop(8) - r = da1.evaluate(ground_truth=da2, metrics=metric_fn, **kwargs)[metric_fn] + r = da1.evaluate(ground_truth=da2, metrics=[metric_fn], **kwargs)[metric_fn] assert isinstance(r, float) np.testing.assert_allclose(r, 1.0, rtol=1e-2) # for d in da1: @@ -393,7 +395,7 @@ def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs) def test_empty_da_should_raise(storage, config, start_storage): da = DocumentArray([], storage=storage, config=config) with pytest.raises(ValueError): - da.evaluate(metrics='precision_at_k') + da.evaluate(metrics=['precision_at_k']) @pytest.mark.parametrize( @@ -411,7 +413,7 @@ def test_empty_da_should_raise(storage, config, start_storage): def test_missing_groundtruth_should_raise(storage, config, start_storage): da = DocumentArray(DocumentArray.empty(10), storage=storage, config=config) with pytest.raises(RuntimeError): - da.evaluate(metrics='precision_at_k') + da.evaluate(metrics=['precision_at_k']) @pytest.mark.parametrize( @@ -435,4 +437,4 @@ def test_useless_groundtruth_warning_should_raise(storage, config, start_storage da1.match(da1_index, exclude_self=True) da2 = DocumentArray.empty(10) with pytest.warns(UserWarning): - da1.evaluate(ground_truth=da2, metrics='precision_at_k') + da1.evaluate(ground_truth=da2, metrics=['precision_at_k']) From b50a6420fbc35d0d9d36341f9d13aac783894164 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Wed, 19 Oct 2022 13:12:43 +0200 Subject: [PATCH 20/22] fix: tuple error in deprecation decorator --- docarray/array/mixins/evaluation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index b3cba473547..deeffa9add5 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -50,6 +50,7 @@ def func(*args, **kwargs): ) if len(args) > 1: if type(args[1]) is str: + args = list(args) args[1] = [args[1]] warnings.warn(list_warning_msg % 'metrics', DeprecationWarning) for key in ['metrics', 'metric_names']: From 3fefd2dbd7bc4c8197203c93d2edc5824fc66743 Mon Sep 17 00:00:00 2001 From: guenthermi Date: Wed, 19 Oct 2022 13:20:07 +0200 Subject: [PATCH 21/22] refactor: change some tests to pass metrics as list --- tests/unit/array/mixins/test_eval_class.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/test_eval_class.py index 3c802f7eefe..5961dc4abfb 100644 --- a/tests/unit/array/mixins/test_eval_class.py +++ b/tests/unit/array/mixins/test_eval_class.py @@ -158,7 +158,7 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar d.tags = {'label': 'B'} da1_index = DocumentArray(da2, storage=storage, config=config) da1.match(da1_index, exclude_self=True) - r = da1.evaluate(metric_fn, **kwargs)[metric_fn] + r = da1.evaluate([metric_fn], **kwargs)[metric_fn] assert isinstance(r, float) assert r == 0.0 for d in da1: @@ -182,7 +182,7 @@ def test_eval_mixin_one_of_n_labeled(metric_fn, metric_score): da = DocumentArray([Document(text=str(i), tags={'label': i}) for i in range(3)]) for d in da: d.matches = da - r = da.evaluate(metric_fn)[metric_fn] + r = da.evaluate([metric_fn])[metric_fn] assert abs(r - metric_score) < 0.001 From b9415db1eaf790f2fff7175929b3c1df3564656b Mon Sep 17 00:00:00 2001 From: guenthermi Date: Wed, 19 Oct 2022 14:09:24 +0200 Subject: [PATCH 22/22] fix: add wraps decorator to retain docstring --- docarray/array/mixins/evaluation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index deeffa9add5..893913a8d5a 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -1,6 +1,8 @@ import warnings from typing import Optional, Union, TYPE_CHECKING, Callable, List, Dict +from functools import wraps + import numpy as np from collections import defaultdict @@ -14,6 +16,7 @@ def _evaluate_deprecation(f): """Raises a deprecation warning if the user executes the evaluate function with the old interface and adjust the input to fit the new interface.""" + @wraps(f) def func(*args, **kwargs): if len(args) > 1: if not (