diff --git a/docarray/array/mixins/evaluation.py b/docarray/array/mixins/evaluation.py index 2de9433fb33..6d0cf107aba 100644 --- a/docarray/array/mixins/evaluation.py +++ b/docarray/array/mixins/evaluation.py @@ -458,14 +458,16 @@ def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): new_matches.append(m) query_data[doc.id, 'matches'] = new_matches - if ground_truth and label_tag in ground_truth[0].tags: - num_relevant_documents_per_label = dict( - Counter([d.tags[label_tag] for d in ground_truth]) - ) - elif not ground_truth and label_tag in query_data[0].tags: + if ( + not (ground_truth and len(ground_truth) > 0 and ground_truth[0].matches) + and label_tag in index_data[0].tags + ): num_relevant_documents_per_label = dict( - Counter([d.tags[label_tag] for d in query_data]) + Counter([d.tags[label_tag] for d in index_data]) ) + if only_one_dataset and exclude_self: + for k, v in num_relevant_documents_per_label.items(): + num_relevant_documents_per_label[k] -= 1 else: num_relevant_documents_per_label = None diff --git a/tests/unit/array/mixins/oldproto/test_eval_class.py b/tests/unit/array/mixins/oldproto/test_eval_class.py index 78addbefb82..2645a4cab07 100644 --- a/tests/unit/array/mixins/oldproto/test_eval_class.py +++ b/tests/unit/array/mixins/oldproto/test_eval_class.py @@ -528,6 +528,52 @@ def test_embed_and_evaluate_single_da(storage, config, start_storage): assert all([v == 1.0 for v in res.values()]) +@pytest.mark.parametrize( + 'exclude_self, expected_results', + [ + ( + True, + { + 'r_precision': 1, + 'precision_at_k': 1.0, + 'hit_at_k': 1.0, + 'average_precision': 1.0, + 'reciprocal_rank': 1.0, + 'recall_at_k': 5.0 / 9.0, + 'f1_score_at_k': (10.0 / 9.0) / (5.0 / 9.0 + 1), + }, + ), + ( + False, + { + 'r_precision': 1.0, + 'precision_at_k': 1.0, + 'hit_at_k': 1.0, + 'average_precision': 1.0, + 'reciprocal_rank': 1.0, + 'recall_at_k': 0.5, + 'f1_score_at_k': 1.0 / 1.5, + }, + ), + ], +) +def test_embed_and_evaluate_with_and_without_exclude_self( + exclude_self, expected_results +): + queries_da = DocumentArray( + [Document(text=str(i % 10), label=i % 10) for i in range(100)] + ) + res = queries_da.embed_and_evaluate( + metrics=list(expected_results.keys()), + exclude_self=exclude_self, + embed_funcs=dummy_embed_function, + match_batch_size=1, + limit=5, + ) + for key in expected_results: + assert abs(res[key] - expected_results[key]) < 1e-5 + + @pytest.mark.parametrize( 'sample_size', [None, 10], @@ -567,6 +613,35 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage) assert all([v == 1.0 for v in res.values()]) +def test_embed_and_evaluate_two_different_das(): + queries_da = DocumentArray([Document(text=str(i), label=i % 10) for i in range(10)]) + index_da = DocumentArray( + [Document(text=str(i % 10), label=i % 10) for i in range(10, 210)] + ) + res = queries_da.embed_and_evaluate( + index_data=index_da, + metrics=[ + 'r_precision', + 'precision_at_k', + 'hit_at_k', + 'average_precision', + 'reciprocal_rank', + 'recall_at_k', + 'f1_score_at_k', + ], + embed_funcs=dummy_embed_function, + match_batch_size=1, + limit=10, + ) + assert res['r_precision'] == 1.0 + assert res['precision_at_k'] == 1.0 + assert res['hit_at_k'] == 1.0 + assert res['average_precision'] == 1.0 + assert res['reciprocal_rank'] == 1.0 + assert res['recall_at_k'] == 0.5 + assert abs(res['f1_score_at_k'] - 1.0 / 1.5) < 1e-5 + + @pytest.mark.parametrize( 'use_index, expected, label_tag', [ @@ -639,7 +714,16 @@ def emb_func(da): ], ) def test_embed_and_evaluate_on_real_data(two_embed_funcs, kwargs): - metric_names = ['precision_at_k', 'reciprocal_rank'] + metric_names = [ + 'r_precision', + 'precision_at_k', + 'hit_at_k', + 'average_precision', + 'reciprocal_rank', + 'recall_at_k', + 'f1_score_at_k', + 'dcg_at_k', + ] labels = ['18828_alt.atheism', '18828_comp.graphics'] news = [load_dataset('newsgroup', label) for label in labels] @@ -673,10 +757,16 @@ def emb_func(da): ) # re-calculate manually + num_relevant_documents_per_label = dict( + Counter([d.tags['label'] for d in index_docs]) + ) emb_func(query_docs) emb_func(index_docs) query_docs.match(index_docs) - res2 = query_docs.evaluate(metrics=metric_names) + res2 = query_docs.evaluate( + metrics=metric_names, + num_relevant_documents_per_label=num_relevant_documents_per_label, + ) for key in res: assert key in res2