Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions tests/unit/array/mixins/oldproto/test_eval_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
('ndcg_at_k', {}),
],
)
def test_eval_mixin_perfect_match(metric_fn, kwargs, config):
def test_eval_mixin_perfect_match(metric_fn, kwargs):
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, config=config)
da1_index = DocumentArray(da1)
with da1_index:
da1.match(da1_index, exclude_self=True)
r = da1.evaluate(ground_truth=da1, metrics=[metric_fn], strict=False, **kwargs)[
Expand All @@ -40,7 +40,7 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs, config):
assert d.evaluations[metric_fn].value == 1.0


def test_eval_mixin_perfect_match_multiple_metrics(config):
def test_eval_mixin_perfect_match_multiple_metrics():
metric_fns = [
'r_precision',
'precision_at_k',
Expand All @@ -54,7 +54,7 @@ def test_eval_mixin_perfect_match_multiple_metrics(config):
kwargs = {'max_rel': 9}
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, config=config)
da1_index = DocumentArray(da1)
with da1_index:
da1.match(da1_index, exclude_self=True)
r = da1.evaluate(ground_truth=da1, metrics=metric_fns, strict=False, **kwargs)
Expand All @@ -79,12 +79,12 @@ def test_eval_mixin_perfect_match_multiple_metrics(config):
('ndcg_at_k', {}),
],
)
def test_eval_mixin_perfect_match_labeled(metric_fn, kwargs, config):
def test_eval_mixin_perfect_match_labeled(metric_fn, kwargs):
da1 = DocumentArray.empty(10)
for d in da1:
d.tags = {'label': 'A'}
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, config=config)
da1_index = DocumentArray(da1)
with da1_index:
da1.match(da1_index, exclude_self=True)
r = da1.evaluate(metrics=[metric_fn], **kwargs)[metric_fn]
Expand All @@ -107,15 +107,16 @@ def test_eval_mixin_perfect_match_labeled(metric_fn, kwargs, config):
('ndcg_at_k', {}),
],
)
def test_eval_mixin_zero_labeled(config, metric_fn, kwargs):
def test_eval_mixin_zero_labeled(metric_fn, kwargs):

da1 = DocumentArray.empty(10)
for d in da1:
d.tags = {'label': 'A'}
da1.embeddings = np.random.random([10, 256])
da2 = copy.deepcopy(da1)
for d in da2:
d.tags = {'label': 'B'}
da1_index = DocumentArray(da2, config=config)
da1_index = DocumentArray(da2)
with da1_index:
da1.match(da1_index, exclude_self=True)
r = da1.evaluate([metric_fn], **kwargs)[metric_fn]
Expand Down Expand Up @@ -193,15 +194,15 @@ def test_missing_max_rel_should_raise():
('ndcg_at_k', {}),
],
)
def test_eval_mixin_zero_match(config, metric_fn, kwargs):
def test_eval_mixin_zero_match(metric_fn, kwargs):
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, config=config)
da1_index = DocumentArray(da1)
da1.match(da1_index, exclude_self=True)

da2 = copy.deepcopy(da1)
da2.embeddings = np.random.random([10, 256])
da2_index = DocumentArray(da2, config=config)
da2_index = DocumentArray(da2)
with da2_index:
da2.match(da2_index, exclude_self=True)

Expand All @@ -213,35 +214,35 @@ def test_eval_mixin_zero_match(config, metric_fn, kwargs):
assert d.evaluations[metric_fn].value == 1.0


def test_diff_len_should_raise(config):
def test_diff_len_should_raise():
da1 = DocumentArray.empty(10)
da2 = DocumentArray.empty(5)
for d in da2:
d.matches.append(da2[0])
da2 = DocumentArray(da2, config=config)
da2 = DocumentArray(da2)
with pytest.raises(ValueError):
da1.evaluate(ground_truth=da2, metrics=['precision_at_k'])


def test_diff_hash_fun_should_raise(config):
def test_diff_hash_fun_should_raise():
da1 = DocumentArray.empty(10)
da2 = DocumentArray.empty(5)
for d in da2:
d.matches.append(da2[0])
da2 = DocumentArray(da2, config=config)
da2 = DocumentArray(da2)
with pytest.raises(ValueError):
da1.evaluate(ground_truth=da2, metrics=['precision_at_k'])


def test_same_hash_same_len_fun_should_work(config):
def test_same_hash_same_len_fun_should_work():
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 3])
da1_index = DocumentArray(da1, config=config)
da1_index = DocumentArray(da1)
with da1_index:
da1.match(da1_index)
da2 = DocumentArray.empty(10)
da2.embeddings = np.random.random([10, 3])
da2_index = DocumentArray(da1, config=config)
da2_index = DocumentArray(da1)
with da2_index:
da2.match(da2_index)
with da1_index, da2_index:
Expand All @@ -253,11 +254,11 @@ def test_same_hash_same_len_fun_should_work(config):
da1.evaluate(ground_truth=da2, metrics=['precision_at_k'])


def test_adding_noise(config):
def test_adding_noise():
da = DocumentArray.empty(10)

da.embeddings = np.random.random([10, 3])
da_index = DocumentArray(da, config=config)
da_index = DocumentArray(da)
with da_index:
da.match(da_index, exclude_self=True)

Expand Down Expand Up @@ -285,15 +286,15 @@ def test_adding_noise(config):
('f1_score_at_k', {}),
],
)
def test_diff_match_len_in_gd(config, metric_fn, kwargs):
def test_diff_match_len_in_gd(metric_fn, kwargs):
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 128])
# da1_index = DocumentArray(da1, storage=storage, config=config)
da1.match(da1, exclude_self=True)

da2 = copy.deepcopy(da1)
da2.embeddings = np.random.random([10, 128])
da2_index = DocumentArray(da2, config=config)
da2_index = DocumentArray(da2)
with da2_index:
da2.match(da2_index, exclude_self=True)
# pop some matches from first document
Expand All @@ -308,24 +309,26 @@ def test_diff_match_len_in_gd(config, metric_fn, kwargs):
assert d.evaluations[metric_fn].value > 0.9


def test_empty_da_should_raise(config):
da = DocumentArray([], config=config)
def test_empty_da_should_raise():
da = DocumentArray(
[],
)
with pytest.raises(ValueError):
da.evaluate(metrics=['precision_at_k'])


def test_missing_groundtruth_should_raise(config):
da = DocumentArray(DocumentArray.empty(10), config=config)
def test_missing_groundtruth_should_raise():
da = DocumentArray(DocumentArray.empty(10))
with pytest.raises(RuntimeError):
da.evaluate(metrics=['precision_at_k'])


def test_useless_groundtruth_warning_should_raise(config):
def test_useless_groundtruth_warning_should_raise():
da1 = DocumentArray.empty(10)
for d in da1:
d.tags = {'label': 'A'}
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, config=config)
da1_index = DocumentArray(da1)
with da1_index:
da1.match(da1_index, exclude_self=True)
da2 = DocumentArray.empty(10)
Expand All @@ -339,11 +342,11 @@ def dummy_embed_function(da):
da[i, 'embedding'] = np.random.random(5)


def test_embed_and_evaluate_single_da(config):
def test_embed_and_evaluate_single_da():

gt = DocumentArray([Document(text=str(i)) for i in range(10)])
queries_da = DocumentArray(gt, copy=True)
queries_da = DocumentArray(queries_da, config=config)
queries_da = DocumentArray(queries_da)
dummy_embed_function(gt)
gt.match(gt, limit=3)

Expand Down Expand Up @@ -408,13 +411,15 @@ def test_embed_and_evaluate_with_and_without_exclude_self(
'sample_size',
[None, 10],
)
def test_embed_and_evaluate_two_das(config, sample_size):
def test_embed_and_evaluate_two_das(sample_size):

gt_queries = DocumentArray([Document(text=str(i)) for i in range(100)])
gt_index = DocumentArray([Document(text=str(i)) for i in range(100, 200)])
queries_da = DocumentArray(gt_queries, copy=True)
index_da = DocumentArray(gt_index, copy=True)
index_da = DocumentArray(index_da, config=config)
index_da = DocumentArray(
index_da,
)
dummy_embed_function(gt_queries)
dummy_embed_function(gt_index)
gt_queries.match(gt_index, limit=3)
Expand Down Expand Up @@ -476,15 +481,15 @@ def test_embed_and_evaluate_two_different_das():
),
],
)
def test_embed_and_evaluate_labeled_dataset(config, use_index, expected, label_tag):
def test_embed_and_evaluate_labeled_dataset(use_index, expected, label_tag):
metric_fns = list(expected.keys())

def emb_func(da):
np.random.seed(0) # makes sure that embeddings are always equal
da[:, 'embedding'] = np.random.random((len(da), 5))

da1 = DocumentArray([Document(text=str(i), tags={label_tag: i}) for i in range(3)])
da2 = DocumentArray(da1, config=config, copy=True)
da2 = DocumentArray(da1, copy=True)

with da2:
if (
Expand Down Expand Up @@ -584,13 +589,13 @@ def bert_tokenizer():
return BertTokenizer.from_pretrained('bert-base-uncased')


def test_embed_and_evaluate_with_embed_model(config, bert_tokenizer):
def test_embed_and_evaluate_with_embed_model(bert_tokenizer):
model = BertModel(BertConfig())
collate_fn = lambda da: bert_tokenizer(da.texts, return_tensors='pt')
da = DocumentArray(
[Document(text=f'some text {i}', tags={'label': str(i)}) for i in range(5)]
)
da = DocumentArray(da, config=config)
da = DocumentArray(da)
with da:
res = da.embed_and_evaluate(
metrics=['precision_at_k'], embed_models=model, collate_fns=collate_fn
Expand All @@ -616,19 +621,17 @@ def test_embed_and_evaluate_with_embed_model(config, bert_tokenizer):
),
],
)
def test_embed_and_evaluate_invalid_input_should_raise(
config, queries, kwargs, exception
):
def test_embed_and_evaluate_invalid_input_should_raise(queries, kwargs, exception):
kwargs.update({'metrics': ['precision_at_k']})
if 'index_data' in kwargs:
kwargs['index_data'] = DocumentArray(kwargs['index_data'], config=config)
kwargs['index_data'] = DocumentArray(kwargs['index_data'])

with pytest.raises(exception):
queries.embed_and_evaluate(**kwargs)


@pytest.mark.parametrize('sample_size', [100, 1_000, 10_000])
def test_embed_and_evaluate_sampling(config, sample_size):
def test_embed_and_evaluate_sampling(sample_size):
metric_fns = ['precision_at_k', 'reciprocal_rank']

def emb_func(da):
Expand All @@ -638,7 +641,7 @@ def emb_func(da):
da1 = DocumentArray(
[Document(text=str(i), tags={'label': i % 20}) for i in range(2_000)]
)
da2 = DocumentArray(da1, config=config, copy=True)
da2 = DocumentArray(da1, copy=True)

with da2:
res = da1.embed_and_evaluate(
Expand Down