-
Notifications
You must be signed in to change notification settings - Fork 238
feat: Implementation of embed_and_evaluate #702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9c2b492
429a585
bca2f07
901c5cb
84a1331
8acc865
c5fb5df
c0d17ee
19a3e84
562d1cb
a64b2f8
4ba9f25
ebbd9dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| import warnings | ||
| from typing import Optional, Union, TYPE_CHECKING, Callable, List, Dict | ||
| from typing import Optional, Union, TYPE_CHECKING, Callable, List, Dict, Tuple | ||
|
|
||
| from functools import wraps | ||
|
|
||
|
|
@@ -10,6 +10,8 @@ | |
|
|
||
| if TYPE_CHECKING: # pragma: no cover | ||
| from docarray import Document, DocumentArray | ||
| from docarray.array.mixins.embed import CollateFnType | ||
| from docarray.typing import ArrayType, AnyDNN | ||
|
|
||
|
|
||
| def _evaluate_deprecation(f): | ||
|
|
@@ -94,7 +96,7 @@ def evaluate( | |
| This method will fill the `evaluations` field of Documents inside this | ||
| `DocumentArray` and will return the average of the computations | ||
|
|
||
| :param metrics: list of metric names or metric functions to be computed | ||
| :param metrics: List of metric names or metric functions to be computed | ||
| :param ground_truth: The ground_truth `DocumentArray` that the `DocumentArray` | ||
| compares to. | ||
| :param hash_fn: For the evaluation against a `ground_truth` DocumentArray, | ||
|
|
@@ -205,3 +207,237 @@ def evaluate( | |
| metric_name: float(np.mean(values)) | ||
| for metric_name, values in results.items() | ||
| } | ||
|
|
||
| def embed_and_evaluate( | ||
| self, | ||
| metrics: List[Union[str, Callable[..., float]]], | ||
| index_data: Optional['DocumentArray'] = None, | ||
| ground_truth: Optional['DocumentArray'] = None, | ||
| metric_names: Optional[str] = None, | ||
| strict: bool = True, | ||
| label_tag: str = 'label', | ||
| embed_models: Optional[Union['AnyDNN', Tuple['AnyDNN', 'AnyDNN']]] = None, | ||
| embed_funcs: Optional[Union[Callable, Tuple[Callable, Callable]]] = None, | ||
| device: str = 'cpu', | ||
| batch_size: Union[int, Tuple[int, int]] = 256, | ||
| collate_fns: Union[ | ||
| Optional['CollateFnType'], | ||
| Tuple[Optional['CollateFnType'], Optional['CollateFnType']], | ||
| ] = None, | ||
| distance: Union[ | ||
| str, Callable[['ArrayType', 'ArrayType'], 'np.ndarray'] | ||
| ] = 'cosine', | ||
| limit: Optional[Union[int, float]] = 20, | ||
| normalization: Optional[Tuple[float, float]] = None, | ||
| exclude_self: bool = False, | ||
| use_scipy: bool = False, | ||
| match_batch_size: int = 100_000, | ||
| query_sample_size: int = 1_000, | ||
| **kwargs, | ||
| ) -> Optional[Union[float, List[float]]]: # average for each metric | ||
| """ | ||
| Computes ranking evaluation metrics for a given `DocumentArray`. This | ||
| function does embedding and matching in the same turn. Thus, you don't need to | ||
| call ``embed`` and ``match`` before it. Instead, it embeds the documents in | ||
| `self` (and `index_data` when provided`) and compute the nearest neighbour | ||
| itself. This might be done in batches for the `index_data` object to reduce | ||
| the memory consumption of the evlauation process. The evaluation itself can be | ||
| done against a `ground_truth` DocumentArray or on the basis of labels like it | ||
| is possible with the :func:``evaluate`` function. | ||
|
|
||
| :param metrics: List of metric names or metric functions to be computed | ||
| :param index_data: The other DocumentArray to match against, if not given, | ||
| `self` will be matched against itself. This means that every document in | ||
| will be compared to all other documents in `self` to determine the nearest | ||
| neighbors. | ||
| :param ground_truth: The ground_truth `DocumentArray` that the `DocumentArray` | ||
| compares to. | ||
| :param metric_names: If provided, the results of the metrics computation will be | ||
| stored in the `evaluations` field of each Document with this names. If not | ||
| provided, the names will be derived from the metric function names. | ||
| :param strict: If set, then left and right sides are required to be fully | ||
| aligned: on the length, and on the semantic of length. These are preventing | ||
| you to evaluate on irrelevant matches accidentally. | ||
| :param label_tag: Specifies the tag which contains the labels. | ||
| :param embed_models: One or two embedding model written in Keras / Pytorch / | ||
| Paddle for embedding `self` and `index_data`. | ||
| :param embed_funcs: As an alternative to embedding models, custom embedding | ||
| functions can be provided. | ||
| :param device: the computational device for `embed_models`, can be either | ||
| `cpu` or `cuda`. | ||
| :param batch_size: Number of documents in a batch for embedding. | ||
| :param collate_fns: For each embedding function the respective collate | ||
| function creates a mini-batch of input(s) from the given `DocumentArray`. | ||
| If not provided a default built-in collate_fn uses the `tensors` of the | ||
| documents to create input batches. | ||
| :param distance: The distance metric. | ||
| :param limit: The maximum number of matches, when not given defaults to 20. | ||
| :param normalization: A tuple [a, b] to be used with min-max normalization, | ||
| the min distance will be rescaled to `a`, the max distance will be | ||
| rescaled to `b` all values will be rescaled into range `[a, b]`. | ||
| :param exclude_self: If set, Documents in ``index_data`` with same ``id`` | ||
| as the left-hand values will not be considered as matches. | ||
| :param use_scipy: if set, use ``scipy`` as the computation backend. Note, | ||
| ``scipy`` does not support distance on sparse matrix. | ||
| :parma match_batch_size: The number of documents which are embedded and | ||
| matched at once. Set this value to a lower value, if you experience high | ||
| memory consumption. | ||
| :param kwargs: Additional keyword arguments to be passed to the metric | ||
| functions. | ||
| :param query_sample_size: For a large number of documents in `self` the | ||
| evaluation becomes infeasible, especially, if `index_data` is large. | ||
| Therefore, queries are sampled if the number of documents in `self` exceeds | ||
| `query_sample_size`. Usually, this has only small impact on the mean metric | ||
| values returned by this function. To prevent sampling, you can set | ||
| `query_sample_size` to None. | ||
| :return: A dictionary which stores for each metric name the average evaluation | ||
| score. | ||
| """ | ||
|
|
||
| from docarray import Document, DocumentArray | ||
|
|
||
| if not query_sample_size: | ||
| query_sample_size = len(self) | ||
|
|
||
| query_data = self | ||
| only_one_dataset = not index_data | ||
| apply_sampling = len(self) > query_sample_size | ||
|
|
||
| if only_one_dataset: | ||
| # if the user does not provide a separate set of documents for indexing, | ||
| # the matching is done on the documents itself | ||
| copy_flag = ( | ||
| apply_sampling | ||
| or (type(embed_funcs) is tuple) | ||
| or ((embed_funcs is None) and (type(embed_models) is tuple)) | ||
| ) | ||
| index_data = DocumentArray(self, copy=True) if copy_flag else self | ||
|
|
||
| if apply_sampling: | ||
| rng = np.random.default_rng() | ||
| query_data = DocumentArray( | ||
| rng.choice(self, size=query_sample_size, replace=False) | ||
| ) | ||
|
|
||
| if ground_truth and apply_sampling: | ||
| ground_truth = DocumentArray( | ||
| [ground_truth[d.id] for d in query_data if d.id in ground_truth] | ||
| ) | ||
| if len(ground_truth) != len(query_data): | ||
| raise ValueError( | ||
| 'The DocumentArray provided in the ground_truth attribute does ' | ||
| 'not contain all the documents in self.' | ||
| ) | ||
|
|
||
| index_data_labels = None | ||
| if not ground_truth: | ||
| if not label_tag in query_data[0].tags: | ||
| raise ValueError( | ||
| 'Either a ground_truth `DocumentArray` or labels are ' | ||
| 'required for the evaluation.' | ||
| ) | ||
| if not label_tag in index_data[0].tags: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. technically we need to do this for all of the document in index_data ...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but I don't think it is necessary to do a full data validation. If a user provides only partly labels, I think it is ok if it crashes with a key error.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JohannesMessner @alaeddine-13 what is your opinion here ? For me it is either we do it for every doc on we don't do validation but I might be wrong |
||
| raise ValueError( | ||
| 'The `DocumentArray` provided in `index_data` misses ' 'labels.' | ||
| ) | ||
| index_data_labels = dict() | ||
| for id_value, tags in zip(index_data[:, 'id'], index_data[:, 'tags']): | ||
| index_data_labels[id_value] = tags[label_tag] | ||
|
|
||
| if embed_funcs is None: | ||
| # derive embed function from embed model | ||
| if embed_models is None: | ||
| raise RuntimeError( | ||
| 'For embedding the documents you need to provide either embedding ' | ||
| 'model(s) or embedding function(s)' | ||
| ) | ||
| else: | ||
| if type(embed_models) is not tuple: | ||
| embed_models = (embed_models, embed_models) | ||
| embed_args = [ | ||
| { | ||
| 'embed_model': model, | ||
| 'device': device, | ||
| 'batch_size': batch_size, | ||
| 'collate_fn': collate_fns[i] | ||
| if type(collate_fns) is tuple | ||
| else collate_fns, | ||
| } | ||
| for i, (model, docs) in enumerate( | ||
| zip(embed_models, (query_data, index_data)) | ||
| ) | ||
| ] | ||
| else: | ||
| if type(embed_funcs) is not tuple: | ||
| embed_funcs = ( | ||
| embed_funcs, | ||
| embed_funcs, | ||
| ) # use the same embedding function for queries and index | ||
|
|
||
| # embed queries: | ||
| if embed_funcs: | ||
| embed_funcs[0](query_data) | ||
| else: | ||
| query_data.embed(**embed_args[0]) | ||
|
|
||
| for doc in query_data: | ||
| doc.matches.clear() | ||
|
|
||
| local_queries = DocumentArray( | ||
| [Document(id=doc.id, embedding=doc.embedding) for doc in query_data] | ||
| ) | ||
|
|
||
| def fuse_matches(global_matches: DocumentArray, local_matches: DocumentArray): | ||
| global_matches.extend(local_matches) | ||
| global_matches = sorted( | ||
| global_matches, | ||
| key=lambda x: x.scores[distance].value, | ||
| )[:limit] | ||
| return DocumentArray(global_matches) | ||
|
|
||
| for batch in index_data.batch(match_batch_size): | ||
| if ( | ||
| apply_sampling | ||
| or (batch.embeddings is None) | ||
| or (batch[0].embedding[0] == 0) | ||
| ): | ||
| if embed_funcs: | ||
| embed_funcs[1](batch) | ||
| else: | ||
| batch.embed(**embed_args[1]) | ||
|
|
||
| local_queries.match( | ||
| batch, | ||
| limit=limit, | ||
| metric=distance, | ||
| normalization=normalization, | ||
| exclude_self=exclude_self, | ||
| use_scipy=use_scipy, | ||
| only_id=True, | ||
| ) | ||
|
|
||
| for doc in local_queries: | ||
| query_data[doc.id, 'matches'] = fuse_matches( | ||
| query_data[doc.id].matches, | ||
| doc.matches, | ||
| ) | ||
|
|
||
| batch.embeddings = None | ||
| # set labels if necessary | ||
| if not ground_truth: | ||
| for i, doc in enumerate(query_data): | ||
| new_matches = DocumentArray() | ||
| for m in doc.matches: | ||
| m.tags = {label_tag: index_data_labels[m.id]} | ||
| new_matches.append(m) | ||
| query_data[doc.id, 'matches'] = new_matches | ||
|
|
||
| metrics_resp = query_data.evaluate( | ||
| ground_truth=ground_truth, | ||
| metrics=metrics, | ||
| metric_names=metric_names, | ||
| strict=strict, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| return metrics_resp | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we describe what it means match against itself ? Will all of the document inside the da will be match over all the other one ? Or does is split the da in query, index randomly and perform the search from query to index ?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, yes I can explain this. Atm everything it will match everything against everything. I might do another small PR afterwards, to enable sampling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match everything against everything is not tracktable tbh, why not just do the sampling right now ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: We discussed that it might make sense, but it is better to implement it in this PR to avoid a breaking change.