diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py index 13eb689375..13c08c2d73 100644 --- a/docarray/index/backends/weaviate.py +++ b/docarray/index/backends/weaviate.py @@ -18,6 +18,8 @@ TypeVar, Union, cast, + get_origin, + get_args, ) import numpy as np @@ -43,7 +45,6 @@ TSchema = TypeVar('TSchema', bound=BaseDoc) T = TypeVar('T', bound='WeaviateDocumentIndex') - DEFAULT_BATCH_CONFIG = { "batch_size": 20, "dynamic": False, @@ -210,6 +211,8 @@ def _create_schema(self) -> None: self.bytes_columns.append(column_name) if column_info.db_type == 'number[]': self.nonembedding_array_columns.append(column_name) + if column_info.db_type == 'text[]': + self.nonembedding_array_columns.append(column_name) prop = { "name": column_name if column_name != 'id' @@ -253,6 +256,8 @@ class DBConfig(BaseDocIndex.DBConfig): 'number': {}, 'boolean': {}, 'number[]': {}, + 'int[]': {}, + 'text[]': {}, 'blob': {}, } ) @@ -717,6 +722,23 @@ def python_type_to_db_type(self, python_type: Type) -> Any: bytes: 'blob', } + if get_origin(python_type) == list: + py_weaviate_list_type_map = { + int: 'int[]', + float: 'number[]', + str: 'text[]', + } + + container_type = None + args = get_args(python_type) + if args: + container_type = args[0] + if ( + container_type is not None + and container_type in py_weaviate_list_type_map + ): + return py_weaviate_list_type_map[container_type] + for py_type, weaviate_type in py_weaviate_type_map.items(): if safe_issubclass(python_type, py_type): return weaviate_type diff --git a/tests/index/weaviate/docker-compose.yml b/tests/index/weaviate/docker-compose.yml index 5cca1e722e..518e14bd9c 100644 --- a/tests/index/weaviate/docker-compose.yml +++ b/tests/index/weaviate/docker-compose.yml @@ -1,4 +1,4 @@ -version: '3.8' +version: '3.3' services: @@ -10,7 +10,7 @@ services: - '8080' - --scheme - http - image: semitechnologies/weaviate:1.18.3 + image: semitechnologies/weaviate:1.21.1 ports: - "8080:8080" restart: on-failure:0 @@ -24,4 +24,4 @@ services: LOG_LEVEL: debug # verbose LOG_FORMAT: text # LOG_LEVEL: trace # very verbose - GODEBUG: gctrace=1 # make go garbage collector verbose \ No newline at end of file + GODEBUG: gctrace=1 # make go garbage collector verbose diff --git a/tests/index/weaviate/test_index_get_del_weaviate.py b/tests/index/weaviate/test_index_get_del_weaviate.py index 10ac0acd82..956ecf51dd 100644 --- a/tests/index/weaviate/test_index_get_del_weaviate.py +++ b/tests/index/weaviate/test_index_get_del_weaviate.py @@ -6,6 +6,7 @@ import numpy as np import pytest from pydantic import Field +from typing import List from docarray import BaseDoc from docarray.documents import ImageDoc, TextDoc @@ -31,6 +32,9 @@ class SimpleDoc(BaseDoc): class Document(BaseDoc): embedding: NdArray[2] = Field(dim=2, is_embedding=True) text: str = Field() + texts: List[str] = Field(default=[]) + # integers: List[int] = Field(default=[]) + floats: List[float] = Field(default=[]) class NestedDocument(BaseDoc): @@ -50,7 +54,14 @@ def documents(): # create the docs by enumerating from 1 and use that as the id docs = [ - Document(id=str(i), embedding=embedding, text=text) + Document( + id=str(i), + embedding=embedding, + text=text, + texts=[f'text{i}_0', f'text{i}_1'], + integers=[i, i], + floats=[1.5 * i, 2.5 * i], + ) for i, (embedding, text) in enumerate(zip(embeddings, texts)) ] @@ -170,6 +181,8 @@ class Document(BaseDoc): ({"path": ["text"], "operator": "Equal", "valueText": "lorem ipsum"}, 1), ({"path": ["text"], "operator": "Equal", "valueText": "foo"}, 0), ({"path": ["id"], "operator": "Equal", "valueString": "1"}, 1), + ({"path": ["texts"], "operator": "ContainsAny", "valueText": ["text"]}, 3), + ({"path": ["texts"], "operator": "ContainsAny", "valueText": ["text1_"]}, 1), ], ) def test_filter(test_index, filter_query, expected_num_docs):