From f1bd2e8ccec24aa74b397fce1b740dd6cb204802 Mon Sep 17 00:00:00 2001 From: Jonathan Rowley Date: Tue, 27 Sep 2022 07:02:36 -0500 Subject: [PATCH] test(weaviate): add test cases for weaviate --- tests/unit/array/storage/weaviate/__init__.py | 0 .../weaviate/test_additional_properties.py | 27 +++++++++++++++++ .../storage/weaviate/test_query_params.py | 30 +++++++++++++++++++ 3 files changed, 57 insertions(+) create mode 100644 tests/unit/array/storage/weaviate/__init__.py create mode 100644 tests/unit/array/storage/weaviate/test_additional_properties.py create mode 100644 tests/unit/array/storage/weaviate/test_query_params.py diff --git a/tests/unit/array/storage/weaviate/__init__.py b/tests/unit/array/storage/weaviate/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/array/storage/weaviate/test_additional_properties.py b/tests/unit/array/storage/weaviate/test_additional_properties.py new file mode 100644 index 00000000000..baded17e6ec --- /dev/null +++ b/tests/unit/array/storage/weaviate/test_additional_properties.py @@ -0,0 +1,27 @@ +from docarray import Document, DocumentArray + + +def test_get_additional(start_storage): + da = DocumentArray(storage="weaviate", config={"n_dim": 3}) + + with da: + da.extend( + [ + Document(embedding=[0, 0, 0]), + Document(embedding=[2, 2, 2]), + Document(embedding=[4, 4, 4]), + Document(embedding=[2, 2, 2]), + Document(embedding=[4, 4, 4]), + ] + ) + + additional = ["creationTimeUnix", "lastUpdateTimeUnix"] + results = da.find( + DocumentArray([Document(embedding=[2, 2, 2])]), + limit=1, + additional=additional, + ) + + for res in results: + assert res[:, "tags__creationTimeUnix"][0] is not None + assert res[:, "tags__lastUpdateTimeUnix"][0] is not None diff --git a/tests/unit/array/storage/weaviate/test_query_params.py b/tests/unit/array/storage/weaviate/test_query_params.py new file mode 100644 index 00000000000..2b38d77b7c5 --- /dev/null +++ b/tests/unit/array/storage/weaviate/test_query_params.py @@ -0,0 +1,30 @@ +from docarray import Document, DocumentArray +import numpy as np + + +def find_random(da, target_certainty): + return da.find( + DocumentArray([Document(embedding=np.random.randint(10, size=10))]), + query_params={"certainty": target_certainty}, + )[0] + + +def test_certainty_filter(start_storage): + nrof_docs = 100 + target_certainty = 0.98 + da = DocumentArray(storage="weaviate", config={"n_dim": 10}) + + with da: + da.extend( + [ + Document(embedding=np.random.randint(10, size=10)) + for i in range(1, nrof_docs) + ], + ) + + results = [] + while len(results) == 0: + results = find_random(da, target_certainty) + + for res in results: + assert res.scores["weaviate_certainty"].value >= target_certainty