From 6f2f5b6c6e02e1803cae90fe467de9c5550f84c6 Mon Sep 17 00:00:00 2001 From: Michael Guenther Date: Thu, 15 Dec 2022 16:22:02 +0100 Subject: [PATCH 1/4] fix: calculate relevant docs on index instead of queries Signed-off-by: Michael Guenther --- docarray/array/mixins/evaluation.py | 11 ++++---- .../array/mixins/oldproto/test_eval_class.py | 25 +++++++++++++++++-- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 2de9433fb33..8055859d704 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -458,13 +458,12 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): new_matches.append(m) query_data[doc.id, 'matches'] = new_matches - if ground_truth and label_tag in ground_truth[0].tags: - num_relevant_documents_per_label = dict( - Counter([d.tags[label_tag] for d in ground_truth]) - ) - elif not ground_truth and label_tag in query_data[0].tags: + if ( + not (ground_truth and len(ground_truth) > 0 and ground_truth[0].matches) + and label_tag in index_data[0].tags + ): num_relevant_documents_per_label = dict( - Counter([d.tags[label_tag] for d in query_data]) + Counter([d.tags[label_tag] for d in index_data]) ) else: num_relevant_documents_per_label = None diff --git a/tests/unit/array/mixins/oldproto/test_eval_class.py b/tests/unit/array/mixins/oldproto/test_eval_class.py index 78addbefb82..2adee9787e1 100644 --- a/tests/unit/array/mixins/oldproto/test_eval_class.py +++ b/tests/unit/array/mixins/oldproto/test_eval_class.py @@ -567,6 +567,21 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage) assert all([v == 1.0 for v in res.values()]) +def test_embed_and_evaluate_two_different_das(): + queries_da = DocumentArray([Document(text=str(i), label=i % 10) for i in range(10)]) + index_da = DocumentArray( + [Document(text=str(i), label=i % 10) for i in range(10, 110)] + ) + res = queries_da.embed_and_evaluate( + index_data=index_da, + metrics=['precision_at_k', 'reciprocal_rank', 'recall_at_k', 'f1_score_at_k'], + embed_funcs=dummy_embed_function, + match_batch_size=1, + limit=10, + ) + print(res) + + @pytest.mark.parametrize( 'use_index, expected, label_tag', [ @@ -639,7 +654,7 @@ def emb_func(da): ], ) def test_embed_and_evaluate_on_real_data(two_embed_funcs, kwargs): - metric_names = ['precision_at_k', 'reciprocal_rank'] + metric_names = ['precision_at_k', 'reciprocal_rank', 'recall_at_k'] labels = ['18828_alt.atheism', '18828_comp.graphics'] news = [load_dataset('newsgroup', label) for label in labels] @@ -673,10 +688,16 @@ def emb_func(da): ) # re-calculate manually + num_relevant_documents_per_label = dict( + Counter([d.tags['label'] for d in index_docs]) + ) emb_func(query_docs) emb_func(index_docs) query_docs.match(index_docs) - res2 = query_docs.evaluate(metrics=metric_names) + res2 = query_docs.evaluate( + metrics=metric_names, + num_relevant_documents_per_label=num_relevant_documents_per_label, + ) for key in res: assert key in res2 From 17c3f9ef9efb7aac7547dada6b5f135a64755cd1 Mon Sep 17 00:00:00 2001 From: Michael Guenther Date: Thu, 15 Dec 2022 16:45:03 +0100 Subject: [PATCH 2/4] refactor: reduce num labels if necessary Signed-off-by: Michael Guenther --- docarray/array/mixins/evaluation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 8055859d704..3da234234b3 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -465,6 +465,9 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): num_relevant_documents_per_label = dict( Counter([d.tags[label_tag] for d in index_data]) ) + if only_one_dataset and exclude_self: + for k, v in num_relevant_documents_per_label: + num_relevant_documents_per_label[k] -= 1 else: num_relevant_documents_per_label = None From 79000149362ffe7e5c32e6a8b1495bd662cce15b Mon Sep 17 00:00:00 2001 From: Michael Guenther Date: Thu, 15 Dec 2022 17:23:31 +0100 Subject: [PATCH 3/4] test: add test for exclude_self Signed-off-by: Michael Guenther --- docarray/array/mixins/evaluation.py | 2 +- .../array/mixins/oldproto/test_eval_class.py | 66 ++++++++++++++++++- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 3da234234b3..6d0cf107aba 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -466,7 +466,7 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): Counter([d.tags[label_tag] for d in index_data]) ) if only_one_dataset and exclude_self: - for k, v in num_relevant_documents_per_label: + for k, v in num_relevant_documents_per_label.items(): num_relevant_documents_per_label[k] -= 1 else: num_relevant_documents_per_label = None diff --git a/tests/unit/array/mixins/oldproto/test_eval_class.py b/tests/unit/array/mixins/oldproto/test_eval_class.py index 2adee9787e1..d76c01c6f02 100644 --- a/tests/unit/array/mixins/oldproto/test_eval_class.py +++ b/tests/unit/array/mixins/oldproto/test_eval_class.py @@ -528,6 +528,54 @@ def test_embed_and_evaluate_single_da(storage, config, start_storage): assert all([v == 1.0 for v in res.values()]) +@pytest.mark.parametrize( + 'exclude_self, expected_results', + [ + ( + True, + { + 'r_precision': 1, + 'precision_at_k': 1.0, + 'hit_at_k': 1.0, + 'average_precision': 1.0, + 'reciprocal_rank': 1.0, + 'recall_at_k': 5.0 / 9.0, + 'f1_score_at_k': (10.0 / 9.0) / (5.0 / 9.0 + 1), + 'ndcg_at_k': 1.0, + }, + ), + ( + False, + { + 'r_precision': 1.0, + 'precision_at_k': 1.0, + 'hit_at_k': 1.0, + 'average_precision': 1.0, + 'reciprocal_rank': 1.0, + 'recall_at_k': 0.5, + 'f1_score_at_k': 1.0 / 1.5, + 'ndcg_at_k': 1.0, + }, + ), + ], +) +def test_embed_and_evaluate_with_and_without_exclude_self( + exclude_self, expected_results +): + queries_da = DocumentArray( + [Document(text=str(i % 10), label=i % 10) for i in range(100)] + ) + res = queries_da.embed_and_evaluate( + metrics=list(expected_results.keys()), + exclude_self=exclude_self, + embed_funcs=dummy_embed_function, + match_batch_size=1, + limit=5, + ) + for key in expected_results: + assert abs(res[key] - expected_results[key]) < 1e-5 + + @pytest.mark.parametrize( 'sample_size', [None, 10], @@ -570,7 +618,7 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage) def test_embed_and_evaluate_two_different_das(): queries_da = DocumentArray([Document(text=str(i), label=i % 10) for i in range(10)]) index_da = DocumentArray( - [Document(text=str(i), label=i % 10) for i in range(10, 110)] + [Document(text=str(i % 10), label=i % 10) for i in range(10, 210)] ) res = queries_da.embed_and_evaluate( index_data=index_da, @@ -579,7 +627,10 @@ def test_embed_and_evaluate_two_different_das(): match_batch_size=1, limit=10, ) - print(res) + assert res['precision_at_k'] == 1.0 + assert res['reciprocal_rank'] == 1.0 + assert res['recall_at_k'] == 0.5 + assert abs(res['f1_score_at_k'] - 1.0 / 1.5) < 1e-5 @pytest.mark.parametrize( @@ -654,7 +705,16 @@ def emb_func(da): ], ) def test_embed_and_evaluate_on_real_data(two_embed_funcs, kwargs): - metric_names = ['precision_at_k', 'reciprocal_rank', 'recall_at_k'] + metric_names = [ + 'r_precision', + 'precision_at_k', + 'hit_at_k', + 'average_precision', + 'reciprocal_rank', + 'recall_at_k', + 'f1_score_at_k', + 'ndcg_at_k', + ] labels = ['18828_alt.atheism', '18828_comp.graphics'] news = [load_dataset('newsgroup', label) for label in labels] From e1a74b928033a4c3cc964a520a7d7bd0fd799bfb Mon Sep 17 00:00:00 2001 From: Michael Guenther Date: Thu, 15 Dec 2022 17:36:50 +0100 Subject: [PATCH 4/4] refactor: change metrics in tests Signed-off-by: Michael Guenther --- .../array/mixins/oldproto/test_eval_class.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/unit/array/mixins/oldproto/test_eval_class.py b/tests/unit/array/mixins/oldproto/test_eval_class.py index d76c01c6f02..2645a4cab07 100644 --- a/tests/unit/array/mixins/oldproto/test_eval_class.py +++ b/tests/unit/array/mixins/oldproto/test_eval_class.py @@ -541,7 +541,6 @@ def test_embed_and_evaluate_single_da(storage, config, start_storage): 'reciprocal_rank': 1.0, 'recall_at_k': 5.0 / 9.0, 'f1_score_at_k': (10.0 / 9.0) / (5.0 / 9.0 + 1), - 'ndcg_at_k': 1.0, }, ), ( @@ -554,7 +553,6 @@ def test_embed_and_evaluate_single_da(storage, config, start_storage): 'reciprocal_rank': 1.0, 'recall_at_k': 0.5, 'f1_score_at_k': 1.0 / 1.5, - 'ndcg_at_k': 1.0, }, ), ], @@ -622,12 +620,23 @@ def test_embed_and_evaluate_two_different_das(): ) res = queries_da.embed_and_evaluate( index_data=index_da, - metrics=['precision_at_k', 'reciprocal_rank', 'recall_at_k', 'f1_score_at_k'], + metrics=[ + 'r_precision', + 'precision_at_k', + 'hit_at_k', + 'average_precision', + 'reciprocal_rank', + 'recall_at_k', + 'f1_score_at_k', + ], embed_funcs=dummy_embed_function, match_batch_size=1, limit=10, ) + assert res['r_precision'] == 1.0 assert res['precision_at_k'] == 1.0 + assert res['hit_at_k'] == 1.0 + assert res['average_precision'] == 1.0 assert res['reciprocal_rank'] == 1.0 assert res['recall_at_k'] == 0.5 assert abs(res['f1_score_at_k'] - 1.0 / 1.5) < 1e-5 @@ -713,7 +722,7 @@ def test_embed_and_evaluate_on_real_data(two_embed_funcs, kwargs): 'reciprocal_rank', 'recall_at_k', 'f1_score_at_k', - 'ndcg_at_k', + 'dcg_at_k', ] labels = ['18828_alt.atheism', '18828_comp.graphics']