Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
94d17c9
feat: add support for labeled dataset to evaluate function
guenthermi Oct 12, 2022
99e3e66
refactor: add matches before transfering to backend, add example to docs
guenthermi Oct 12, 2022
bedfc7a
fix: missing initialization of d2
guenthermi Oct 13, 2022
23fafd3
test: add tests to check if exceptions and warnings are raised
guenthermi Oct 13, 2022
dfa4aa3
fix: duplicate test name
guenthermi Oct 13, 2022
cfb6b2d
refactor: implement review notes
guenthermi Oct 17, 2022
4df372a
Merge branch 'main' into feat-support-labels-in-evaluate
guenthermi Oct 17, 2022
c499634
fix: update r_precision test
guenthermi Oct 17, 2022
15ea32d
refactor: remove comment
guenthermi Oct 17, 2022
798d94f
feat: support multiple metrics in evaluation function
guenthermi Oct 17, 2022
087d7d6
docs: update evaluate functions in docs
guenthermi Oct 17, 2022
a5a697b
fix: change metric to metrics in tests
guenthermi Oct 18, 2022
9461ed5
refactor: solve merge conflict
guenthermi Oct 18, 2022
4942676
test: add test for multiple metrics
guenthermi Oct 18, 2022
c607fcf
refactor: handle metric and metric_name in deprecation decorator
guenthermi Oct 18, 2022
241c294
fix: black
guenthermi Oct 18, 2022
011cdf1
Merge branch 'main' into feat-multiple-metrics-in-evaluate
guenthermi Oct 18, 2022
70b4488
fix: add max_rel only to function which have it
guenthermi Oct 18, 2022
6dac178
fix: set old_key and new_key in deprecation warning
guenthermi Oct 19, 2022
30da1ea
docs: improve formulation in evaluation.md
guenthermi Oct 19, 2022
8ac7402
docs: add multi metric example to docs
guenthermi Oct 19, 2022
3f6f2de
refactor: only support lists
guenthermi Oct 19, 2022
b50a642
fix: tuple error in deprecation decorator
guenthermi Oct 19, 2022
3fefd2d
refactor: change some tests to pass metrics as list
guenthermi Oct 19, 2022
b9415db
fix: add wraps decorator to retain docstring
guenthermi Oct 19, 2022
096dbee
Merge branch 'main' into feat-multiple-metrics-in-evaluate
Oct 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 77 additions & 34 deletions docarray/array/mixins/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import warnings
from typing import Optional, Union, TYPE_CHECKING, Callable
from typing import Optional, Union, TYPE_CHECKING, Callable, List, Dict

from functools import wraps

import numpy as np
from collections import defaultdict

from docarray.score import NamedScore

Expand All @@ -13,9 +16,14 @@ def _evaluate_deprecation(f):
"""Raises a deprecation warning if the user executes the evaluate function with
the old interface and adjust the input to fit the new interface."""

@wraps(f)
def func(*args, **kwargs):
if len(args) > 1:
if not (isinstance(args[1], Callable) or isinstance(args[1], str)):
if not (
isinstance(args[1], Callable)
or isinstance(args[1], str)
or isinstance(args[1], list)
):
kwargs['ground_truth'] = args[1]
args = [args[0]] + list(args[2:])
warnings.warn(
Expand All @@ -25,13 +33,33 @@ def func(*args, **kwargs):
'soon.',
DeprecationWarning,
)
if 'other' in kwargs:
kwargs['ground_truth'] = kwargs['other']
warnings.warn(
'`other` is renamed to `groundtruth` in `evaluate()`, the usage of `other` is '
'deprecated and will be removed soon.',
DeprecationWarning,
)
for old_key, new_key in zip(
['other', 'metric', 'metric_name'],
['ground_truth', 'metrics', 'metric_names'],
):
if old_key in kwargs:
kwargs[new_key] = kwargs[old_key]
warnings.warn(
f'`{old_key}` is renamed to `{new_key}` in `evaluate()`, the '
f'usage of `{old_key}` is deprecated and will be removed soon.',
DeprecationWarning,
)

# transfer metrics and metric_names into lists
list_warning_msg = (
'The attribute `%s` now accepts a list instead of a '
'single element. Passing a single element is deprecated and will soon not '
'be supported anymore.'
)
if len(args) > 1:
if type(args[1]) is str:
args = list(args)
args[1] = [args[1]]
warnings.warn(list_warning_msg % 'metrics', DeprecationWarning)
for key in ['metrics', 'metric_names']:
if key in kwargs and type(kwargs[key]) is str:
kwargs[key] = [kwargs[key]]
warnings.warn(list_warning_msg % key, DeprecationWarning)
return f(*args, **kwargs)

return func
Expand All @@ -43,14 +71,14 @@ class EvaluationMixin:
@_evaluate_deprecation
def evaluate(
self,
metric: Union[str, Callable[..., float]],
metrics: List[Union[str, Callable[..., float]]],
ground_truth: Optional['DocumentArray'] = None,
hash_fn: Optional[Callable[['Document'], str]] = None,
metric_name: Optional[str] = None,
metric_names: Optional[List[str]] = None,
strict: bool = True,
label_tag: str = 'label',
**kwargs,
) -> Optional[float]:
) -> Dict[str, float]:
"""
Compute ranking evaluation metrics for a given `DocumentArray` when compared
with a groundtruth.
Expand All @@ -66,21 +94,22 @@ def evaluate(
This method will fill the `evaluations` field of Documents inside this
`DocumentArray` and will return the average of the computations

:param metric: The name of the metric, or multiple metrics to be computed
:param metrics: list of metric names or metric functions to be computed
:param ground_truth: The ground_truth `DocumentArray` that the `DocumentArray`
compares to.
:param hash_fn: The function used for identifying the uniqueness of Documents.
If not given, then ``Document.id`` is used.
:param metric_name: If provided, the results of the metrics computation will be
stored in the `evaluations` field of each Document. If not provided, the
name will be computed based on the metrics name.
:param metric_names: If provided, the results of the metrics computation will be
stored in the `evaluations` field of each Document with this names. If not
provided, the names will be derived from the metric function names.
:param strict: If set, then left and right sides are required to be fully
aligned: on the length, and on the semantic of length. These are preventing
you to evaluate on irrelevant matches accidentally.
:param label_tag: Specifies the tag which contains the labels.
:param kwargs: Additional keyword arguments to be passed to `metric_fn`
:return: The average evaluation computed or a list of them if multiple metrics
are required
:param kwargs: Additional keyword arguments to be passed to the metric
functions.
:return: A dictionary which stores for each metric name the average evaluation
score.
"""
if len(self) == 0:
raise ValueError('It is not possible to evaluate an empty DocumentArray')
Expand Down Expand Up @@ -108,15 +137,25 @@ def evaluate(
if hash_fn is None:
hash_fn = lambda d: d.id

if callable(metric):
metric_fn = metric
elif isinstance(metric, str):
from docarray.math import evaluation
metric_fns = []
for metric in metrics:
if callable(metric):
metric_fns.append(metric)
elif isinstance(metric, str):
from docarray.math import evaluation

metric_fn = getattr(evaluation, metric)
metric_fns.append(getattr(evaluation, metric))

metric_name = metric_name or metric_fn.__name__
results = []
if not metric_names:
metric_names = [metric_fn.__name__ for metric_fn in metric_fns]

if len(metric_names) != len(metrics):
raise ValueError(
'Could not match metric names to the metrics since the number of '
'metric names does not match the number of metrics'
)

results = defaultdict(list)
caller_max_rel = kwargs.pop('max_rel', None)
for d, gd in zip(self, ground_truth):
max_rel = caller_max_rel or len(gd.matches)
Expand Down Expand Up @@ -153,11 +192,15 @@ def evaluate(
'Could not identify which kind of ground truth'
'information is provided to evaluate the matches.'
)

r = metric_fn(binary_relevance, max_rel=max_rel, **kwargs)
d.evaluations[metric_name] = NamedScore(
value=r, op_name=str(metric_fn), ref_id=d.id
)
results.append(r)
if results:
return float(np.mean(results))
for metric_name, metric_fn in zip(metric_names, metric_fns):
if 'max_rel' in metric_fn.__code__.co_varnames:
kwargs['max_rel'] = max_rel
r = metric_fn(binary_relevance, **kwargs)
d.evaluations[metric_name] = NamedScore(
value=r, op_name=str(metric_fn), ref_id=d.id
)
results[metric_name].append(r)
return {
metric_name: float(np.mean(values))
for metric_name, values in results.items()
}
40 changes: 26 additions & 14 deletions docs/fundamentals/documentarray/evaluation.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Evaluate Matches

After you get `.matches` from the last chapter, you can easily evaluate matches against the groundtruth via {meth}`~docarray.array.mixins.evaluation.EvaluationMixin.evaluate`.
After you get `.matches`, you can evaluate matches against the groundtruth via {meth}`~docarray.array.mixins.evaluation.EvaluationMixin.evaluate`.

```python
da_predict.evaluate(ground_truth=da_groundtruth, metric='...', **kwargs)
da_predict.evaluate(ground_truth=da_groundtruth, metrics=['...'], **kwargs)
```

Alternatively, you can add labels to your documents to evaluate them.
Expand All @@ -18,7 +18,7 @@ example_da.embeddings = np.random.random([10, 3])

example_da.match(example_da)

example_da.evaluate(metric='precision_at_k')
example_da.evaluate(metrics=['precision_at_k'])
```

The results are stored in `.evaluations` field of each Document.
Expand Down Expand Up @@ -109,14 +109,14 @@ da2['@m'].summary()



Now `da2` is our prediction, and `da` is our groundtruth. If we evaluate the average Precision@10, we should get something close to 0.5 (we have 10 real matches, we mixed in 10 fake matches and shuffle it, so top-10 would have approximate 10/20 real matches):
Now `da2` is our prediction, and `da` is our groundtruth. If we evaluate the average Precision@10, we should get something close to 0.47 (we have 9 real matches, we mixed in 10 fake matches and shuffle it, so top-10 would have approximate 9/19 real matches):

```python
da2.evaluate(ground_truth=da, metric='precision_at_k', k=5)
da2.evaluate(ground_truth=da, metrics=['precision_at_k'], k=10)
```

```text
0.48
{'precision_at_k': 0.48}
```

Note that this value is an average number over all Documents of `da2`. If you want to look at the individual evaluation, you can check {attr}`~docarray.Document.evaluations` attribute, e.g.
Expand All @@ -127,18 +127,30 @@ for d in da2:
```

```text
0.5
0.4
0.4
0.6
0.3
0.6
0.2
0.4
0.8
0.8
0.2
0.5
0.3
0.4
0.6
0.5
0.7
```

If you want to evaluate your data with multiple metric functions, you can pass a list of metrics:

```python
da2.evaluate(ground_truth=da, metrics=['precision_at_k', 'reciprocal_rank'], k=10)
```

```text
{'precision_at_k': 0.48, 'reciprocal_rank': 0.6333333333333333}
```

In this case, the keyword attribute `k` is passed to all metric functions, even though it does not fulfill any specific function for the calculation of the reciprocal rank.

## Document identifier

Note that `.evaluate()` works only when two DocumentArray have the same length and their nested structure are same. It makes no sense to evaluate on two completely irrelevant DocumentArrays.
Expand Down Expand Up @@ -179,7 +191,7 @@ p_da.evaluate('average_precision', ground_truth=g_da, hash_fn=lambda d: d.text[:
```

```text
1.0
{'average_precision': 1.0}
```

It is correct as we define the evaluation as checking if the first two characters in `.text` are the same.
Expand Down
Loading