From 69a6902916f41e4ad7ec6070bf905f47d44ebff4 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 30 Aug 2022 20:40:41 +0200 Subject: [PATCH] fix(find): make device more generic --- docarray/math/distance/torch.py | 15 ++++++--------- docs/advanced/document-store/index.md | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/docarray/math/distance/torch.py b/docarray/math/distance/torch.py index 84c1add5a94..cdd27b9e9d2 100644 --- a/docarray/math/distance/torch.py +++ b/docarray/math/distance/torch.py @@ -18,9 +18,8 @@ def cosine( :param device: the computational device for `embed_model`, can be either `cpu` or `cuda`. :return: np.ndarray with ndim=2 """ - if device == 'cuda': - x_mat = x_mat.cuda() - y_mat = y_mat.cuda() + x_mat = x_mat.to(device) + y_mat = y_mat.to(device) a_n, b_n = x_mat.norm(dim=1)[:, None], y_mat.norm(dim=1)[:, None] a_norm = x_mat / torch.clamp(a_n, min=eps) @@ -37,9 +36,8 @@ def euclidean(x_mat: 'tensor', y_mat: 'tensor', device: str = 'cpu') -> 'numpy.n :param device: the computational device for `embed_model`, can be either `cpu` or `cuda`. :return: np.ndarray with ndim=2 """ - if device == 'cuda': - x_mat = x_mat.cuda() - y_mat = y_mat.cuda() + x_mat = x_mat.to(device) + y_mat = y_mat.to(device) return torch.cdist(x_mat, y_mat).cpu().detach().numpy() @@ -54,8 +52,7 @@ def sqeuclidean( :param device: the computational device for `embed_model`, can be either `cpu` or `cuda`. :return: np.ndarray with ndim=2 """ - if device == 'cuda': - x_mat = x_mat.cuda() - y_mat = y_mat.cuda() + x_mat = x_mat.to(device) + y_mat = y_mat.to(device) return (torch.cdist(x_mat, y_mat) ** 2).cpu().detach().numpy() diff --git a/docs/advanced/document-store/index.md b/docs/advanced/document-store/index.md index 74a02bd5665..19ec7d84f6e 100644 --- a/docs/advanced/document-store/index.md +++ b/docs/advanced/document-store/index.md @@ -157,9 +157,9 @@ DocArray supports multiple storage backends with different search features. The | [`Sqlite`](./sqlite.md) | `DocumentArray(storage='sqlite')` | ❌ | ❌ | ✅ | | [`Weaviate`](./weaviate.md) | `DocumentArray(storage='weaviate')` | ✅ | ✅ | ✅ | | [`Qdrant`](./qdrant.md) | `DocumentArray(storage='qdrant')` | ✅ | ✅ | ❌ | -| [`Annlite`](./annlite.md) | `DocumentArray(storage='annlite')` | ✅ | ✅ | ✅ | +| [`Annlite`](./annlite.md) | `DocumentArray(storage='annlite')` | ✅ | ✅ | ✅ | | [`ElasticSearch`](./elasticsearch.md) | `DocumentArray(storage='elasticsearch')` | ✅ | ✅ | ✅ | -| [`Redis`](./redis.md) | `DocumentArray(storage='elasticsearch')` | ✅ | ✅ | ✅ | +| [`Redis`](./redis.md) | `DocumentArray(storage='redis')` | ✅ | ✅ | ✅ | Here we understand by