Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 80 additions & 64 deletions tests/unit/array/mixins/oldproto/test_eval_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs, storage, config, start_stor
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, storage=storage, config=config)
da1.match(da1_index, exclude_self=True)
with da1_index:
da1.match(da1_index, exclude_self=True)
r = da1.evaluate(ground_truth=da1, metrics=[metric_fn], strict=False, **kwargs)[
metric_fn
]
Expand Down Expand Up @@ -80,7 +81,8 @@ def test_eval_mixin_perfect_match_multiple_metrics(storage, config, start_storag
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, storage=storage, config=config)
da1.match(da1_index, exclude_self=True)
with da1_index:
da1.match(da1_index, exclude_self=True)
r = da1.evaluate(ground_truth=da1, metrics=metric_fns, strict=False, **kwargs)
for metric_fn in metric_fns:
assert metric_fn in r
Expand Down Expand Up @@ -123,7 +125,8 @@ def test_eval_mixin_perfect_match_labeled(
d.tags = {'label': 'A'}
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, storage=storage, config=config)
da1.match(da1_index, exclude_self=True)
with da1_index:
da1.match(da1_index, exclude_self=True)
r = da1.evaluate(metrics=[metric_fn], **kwargs)[metric_fn]
assert isinstance(r, float)
assert r == 1.0
Expand Down Expand Up @@ -166,7 +169,8 @@ def test_eval_mixin_zero_labeled(storage, config, metric_fn, start_storage, kwar
for d in da2:
d.tags = {'label': 'B'}
da1_index = DocumentArray(da2, storage=storage, config=config)
da1.match(da1_index, exclude_self=True)
with da1_index:
da1.match(da1_index, exclude_self=True)
r = da1.evaluate([metric_fn], **kwargs)[metric_fn]
assert isinstance(r, float)
assert r == 0.0
Expand Down Expand Up @@ -264,9 +268,10 @@ def test_eval_mixin_zero_match(storage, config, metric_fn, start_storage, kwargs
da2 = copy.deepcopy(da1)
da2.embeddings = np.random.random([10, 256])
da2_index = DocumentArray(da2, storage=storage, config=config)
da2.match(da2_index, exclude_self=True)
with da2_index:
da2.match(da2_index, exclude_self=True)

r = da1.evaluate(ground_truth=da2, metrics=[metric_fn], **kwargs)[metric_fn]
r = da1.evaluate(ground_truth=da2, metrics=[metric_fn], **kwargs)[metric_fn]
assert isinstance(r, float)
assert r == 1.0
for d in da1:
Expand Down Expand Up @@ -337,17 +342,20 @@ def test_same_hash_same_len_fun_should_work(storage, config, start_storage):
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 3])
da1_index = DocumentArray(da1, storage=storage, config=config)
da1.match(da1_index)
with da1_index:
da1.match(da1_index)
da2 = DocumentArray.empty(10)
da2.embeddings = np.random.random([10, 3])
da2_index = DocumentArray(da1, storage=storage, config=config)
da2.match(da2_index)
with pytest.raises(ValueError):
da1.evaluate(ground_truth=da2, metrics=['precision_at_k'])
for d1, d2 in zip(da1, da2):
d1.id = d2.id
with da2_index:
da2.match(da2_index)
with da1_index, da2_index:
with pytest.raises(ValueError):
da1.evaluate(ground_truth=da2, metrics=['precision_at_k'])
for d1, d2 in zip(da1, da2):
d1.id = d2.id

da1.evaluate(ground_truth=da2, metrics=['precision_at_k'])
da1.evaluate(ground_truth=da2, metrics=['precision_at_k'])


@pytest.mark.parametrize(
Expand All @@ -368,7 +376,8 @@ def test_adding_noise(storage, config, start_storage):

da.embeddings = np.random.random([10, 3])
da_index = DocumentArray(da, storage=storage, config=config)
da.match(da_index, exclude_self=True)
with da_index:
da.match(da_index, exclude_self=True)

da2 = copy.deepcopy(da)

Expand Down Expand Up @@ -410,17 +419,18 @@ def test_adding_noise(storage, config, start_storage):
def test_diff_match_len_in_gd(storage, config, metric_fn, start_storage, kwargs):
da1 = DocumentArray.empty(10)
da1.embeddings = np.random.random([10, 128])
da1_index = DocumentArray(da1, storage=storage, config=config)
# da1_index = DocumentArray(da1, storage=storage, config=config)
da1.match(da1, exclude_self=True)

da2 = copy.deepcopy(da1)
da2.embeddings = np.random.random([10, 128])
da2_index = DocumentArray(da2, storage=storage, config=config)
da2.match(da2_index, exclude_self=True)
# pop some matches from first document
da2[0].matches.pop(8)
with da2_index:
da2.match(da2_index, exclude_self=True)
# pop some matches from first document
da2[0].matches.pop(8)

r = da1.evaluate(ground_truth=da2, metrics=[metric_fn], **kwargs)[metric_fn]
r = da1.evaluate(ground_truth=da2, metrics=[metric_fn], **kwargs)[metric_fn]
assert isinstance(r, float)
np.testing.assert_allclose(r, 1.0, rtol=1e-2) #
for d in da1:
Expand Down Expand Up @@ -486,7 +496,8 @@ def test_useless_groundtruth_warning_should_raise(storage, config, start_storage
d.tags = {'label': 'A'}
da1.embeddings = np.random.random([10, 256])
da1_index = DocumentArray(da1, storage=storage, config=config)
da1.match(da1_index, exclude_self=True)
with da1_index:
da1.match(da1_index, exclude_self=True)
da2 = DocumentArray.empty(10)
with pytest.warns(UserWarning):
da1.evaluate(ground_truth=da2, metrics=['precision_at_k'])
Expand Down Expand Up @@ -518,13 +529,14 @@ def test_embed_and_evaluate_single_da(storage, config, start_storage):
dummy_embed_function(gt)
gt.match(gt, limit=3)

res = queries_da.embed_and_evaluate(
ground_truth=gt,
metrics=['precision_at_k', 'reciprocal_rank'],
embed_funcs=dummy_embed_function,
match_batch_size=1,
limit=3,
)
with queries_da:
res = queries_da.embed_and_evaluate(
ground_truth=gt,
metrics=['precision_at_k', 'reciprocal_rank'],
embed_funcs=dummy_embed_function,
match_batch_size=1,
limit=3,
)
assert all([v == 1.0 for v in res.values()])


Expand Down Expand Up @@ -601,15 +613,16 @@ def test_embed_and_evaluate_two_das(storage, config, sample_size, start_storage)
dummy_embed_function(gt_index)
gt_queries.match(gt_index, limit=3)

res = queries_da.embed_and_evaluate(
ground_truth=gt_queries,
index_data=index_da,
metrics=['precision_at_k', 'reciprocal_rank'],
embed_funcs=dummy_embed_function,
match_batch_size=1,
limit=3,
query_sample_size=sample_size,
)
with index_da:
res = queries_da.embed_and_evaluate(
ground_truth=gt_queries,
index_data=index_da,
metrics=['precision_at_k', 'reciprocal_rank'],
embed_funcs=dummy_embed_function,
match_batch_size=1,
limit=3,
query_sample_size=sample_size,
)
assert all([v == 1.0 for v in res.values()])


Expand Down Expand Up @@ -681,25 +694,26 @@ def emb_func(da):
da1 = DocumentArray([Document(text=str(i), tags={label_tag: i}) for i in range(3)])
da2 = DocumentArray(da1, storage=storage, config=config, copy=True)

if (
use_index
): # query and index da are distinct # (different embeddings are generated)
res = da1.embed_and_evaluate(
index_data=da2,
metrics=metric_fns,
embed_funcs=emb_func,
match_batch_size=1,
limit=3,
label_tag=label_tag,
)
else: # query and index are the same (embeddings of both das are equal)
res = da2.embed_and_evaluate(
metrics=metric_fns,
embed_funcs=emb_func,
match_batch_size=1,
limit=3,
label_tag=label_tag,
)
with da2:
if (
use_index
): # query and index da are distinct # (different embeddings are generated)
res = da1.embed_and_evaluate(
index_data=da2,
metrics=metric_fns,
embed_funcs=emb_func,
match_batch_size=1,
limit=3,
label_tag=label_tag,
)
else: # query and index are the same (embeddings of both das are equal)
res = da2.embed_and_evaluate(
metrics=metric_fns,
embed_funcs=emb_func,
match_batch_size=1,
limit=3,
label_tag=label_tag,
)
for key in metric_fns:
assert key in res
assert abs(res[key] - expected[key]) < 1e-4
Expand Down Expand Up @@ -799,9 +813,10 @@ def test_embed_and_evaluate_with_embed_model(
[Document(text=f'some text {i}', tags={'label': str(i)}) for i in range(5)]
)
da = DocumentArray(da, storage=storage, config=config)
res = da.embed_and_evaluate(
metrics=['precision_at_k'], embed_models=model, collate_fns=collate_fn
)
with da:
res = da.embed_and_evaluate(
metrics=['precision_at_k'], embed_models=model, collate_fns=collate_fn
)
assert res
assert res['precision_at_k'] == 0.2

Expand Down Expand Up @@ -873,12 +888,13 @@ def emb_func(da):
)
da2 = DocumentArray(da1, storage=storage, config=config, copy=True)

res = da1.embed_and_evaluate(
index_data=da2,
metrics=metric_fns,
embed_funcs=emb_func,
query_sample_size=sample_size,
)
with da2:
res = da1.embed_and_evaluate(
index_data=da2,
metrics=metric_fns,
embed_funcs=emb_func,
query_sample_size=sample_size,
)
expected_size = (
sample_size if sample_size and (sample_size < len(da1)) else len(da1)
)
Expand Down