diff --git a/docarray/math/distance/torch.py b/docarray/math/distance/torch.py index 84c1add5a94..cdd27b9e9d2 100644 --- a/docarray/math/distance/torch.py +++ b/docarray/math/distance/torch.py @@ -18,9 +18,8 @@ def cosine( :param device: the computational device for `embed_model`, can be either `cpu` or `cuda`. :return: np.ndarray with ndim=2 """ - if device == 'cuda': - x_mat = x_mat.cuda() - y_mat = y_mat.cuda() + x_mat = x_mat.to(device) + y_mat = y_mat.to(device) a_n, b_n = x_mat.norm(dim=1)[:, None], y_mat.norm(dim=1)[:, None] a_norm = x_mat / torch.clamp(a_n, min=eps) @@ -37,9 +36,8 @@ def euclidean(x_mat: 'tensor', y_mat: 'tensor', device: str = 'cpu') -> 'numpy.n :param device: the computational device for `embed_model`, can be either `cpu` or `cuda`. :return: np.ndarray with ndim=2 """ - if device == 'cuda': - x_mat = x_mat.cuda() - y_mat = y_mat.cuda() + x_mat = x_mat.to(device) + y_mat = y_mat.to(device) return torch.cdist(x_mat, y_mat).cpu().detach().numpy() @@ -54,8 +52,7 @@ def sqeuclidean( :param device: the computational device for `embed_model`, can be either `cpu` or `cuda`. :return: np.ndarray with ndim=2 """ - if device == 'cuda': - x_mat = x_mat.cuda() - y_mat = y_mat.cuda() + x_mat = x_mat.to(device) + y_mat = y_mat.to(device) return (torch.cdist(x_mat, y_mat) ** 2).cpu().detach().numpy() diff --git a/docs/advanced/document-store/index.md b/docs/advanced/document-store/index.md index 74a02bd5665..19ec7d84f6e 100644 --- a/docs/advanced/document-store/index.md +++ b/docs/advanced/document-store/index.md @@ -157,9 +157,9 @@ DocArray supports multiple storage backends with different search features. The | [`Sqlite`](./sqlite.md) | `DocumentArray(storage='sqlite')` | ❌ | ❌ | ✅ | | [`Weaviate`](./weaviate.md) | `DocumentArray(storage='weaviate')` | ✅ | ✅ | ✅ | | [`Qdrant`](./qdrant.md) | `DocumentArray(storage='qdrant')` | ✅ | ✅ | ❌ | -| [`Annlite`](./annlite.md) | `DocumentArray(storage='annlite')` | ✅ | ✅ | ✅ | +| [`Annlite`](./annlite.md) | `DocumentArray(storage='annlite')` | ✅ | ✅ | ✅ | | [`ElasticSearch`](./elasticsearch.md) | `DocumentArray(storage='elasticsearch')` | ✅ | ✅ | ✅ | -| [`Redis`](./redis.md) | `DocumentArray(storage='elasticsearch')` | ✅ | ✅ | ✅ | +| [`Redis`](./redis.md) | `DocumentArray(storage='redis')` | ✅ | ✅ | ✅ | Here we understand by