From d77023388ed64da3322a728707d34cf7a42ba261 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 21 Apr 2023 13:26:49 +0200 Subject: [PATCH] feat: add install instructions for hnswlib and elastic doc index Signed-off-by: anna-charlotte --- docarray/index/backends/elastic.py | 29 ++++++++++++++++++++--------- docarray/index/backends/hnswlib.py | 2 +- docarray/utils/_internal/misc.py | 1 + 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index 646e23e8cbd..74688a2715d 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -4,6 +4,7 @@ from collections import defaultdict from dataclasses import dataclass, field from typing import ( + TYPE_CHECKING, Any, Dict, Generator, @@ -21,9 +22,6 @@ ) import numpy as np -from elastic_transport import NodeConfig -from elasticsearch import Elasticsearch -from elasticsearch.helpers import parallel_bulk from pydantic import parse_obj_as import docarray.typing @@ -32,7 +30,7 @@ from docarray.typing import AnyTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.typing.tensor.ndarray import NdArray -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import import_library from docarray.utils.find import _FindResult, _FindResultBatched TSchema = TypeVar('TSchema', bound=BaseDoc) @@ -40,14 +38,27 @@ ELASTIC_PY_VEC_TYPES: List[Any] = [list, tuple, np.ndarray, AbstractTensor] -if is_torch_available(): - import torch - ELASTIC_PY_VEC_TYPES.append(torch.Tensor) +if TYPE_CHECKING: + from elastic_transport import NodeConfig + from elasticsearch import Elasticsearch + from elasticsearch.helpers import parallel_bulk +else: + elasticsearch = import_library('elasticsearch', raise_error=True) + from elasticsearch import Elasticsearch + from elasticsearch.helpers import parallel_bulk + + elastic_transport = import_library('elastic_transport', raise_error=True) + from elastic_transport import NodeConfig + + torch = import_library('torch', raise_error=False) + tf = import_library('tensorflow', raise_error=False) -if is_tf_available(): - import tensorflow as tf # type: ignore +if torch is not None: + ELASTIC_PY_VEC_TYPES.append(torch.Tensor) + +if tf is not None: from docarray.typing import TensorFlowTensor ELASTIC_PY_VEC_TYPES.append(tf.Tensor) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index d657053b59c..4c66dc52de8 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -41,7 +41,7 @@ from docarray.typing import TensorFlowTensor else: - hnswlib = import_library('hnswlib', raise_error=False) + hnswlib = import_library('hnswlib', raise_error=True) torch = import_library('torch', raise_error=False) tf = import_library('tensorflow', raise_error=False) if tf is not None: diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index 4ec86fbfd68..ea1b7399ffd 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -32,6 +32,7 @@ 'trimesh': '"docarray[mesh]"', 'hnswlib': '"docarray[hnswlib]"', 'elasticsearch': '"docarray[elasticsearch]"', + 'elastic_transport': '"docarray[elasticsearch]"', 'weaviate': '"docarray[weaviate]"', 'qdrant_client': '"docarray[qdrant]"', 'fastapi': '"docarray[web]"',