From e274a2311e1d222ee900c776850633fb29445d64 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Mon, 15 Apr 2024 23:59:15 -0700 Subject: [PATCH 1/6] update postgres online store to make it work with write and read APIs. --- Makefile | 19 ++++++++ sdk/python/feast/infra/key_encoding_utils.py | 4 +- .../contrib/pgvector_repo_configuration.py | 12 +++++ .../infra/online_stores/contrib/postgres.py | 48 +++++++++---------- .../contrib/postgres_repo_configuration.py | 6 --- .../universal/online_store/init.sql | 1 + .../universal/online_store/postgres.py | 13 +++-- 7 files changed, 68 insertions(+), 35 deletions(-) create mode 100644 sdk/python/feast/infra/online_stores/contrib/pgvector_repo_configuration.py create mode 100644 sdk/python/tests/integration/feature_repos/universal/online_store/init.sql diff --git a/Makefile b/Makefile index 6fcf95dc7da..bf2d876b7f1 100644 --- a/Makefile +++ b/Makefile @@ -216,6 +216,25 @@ test-python-universal-postgres-online: not test_snowflake" \ sdk/python/tests + test-python-universal-pgvector-online: + PYTHONPATH='.' \ + FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.pgvector_repo_configuration \ + PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \ + python -m pytest -n 8 --integration \ + -k "not test_universal_cli and \ + not test_go_feature_server and \ + not test_feature_logging and \ + not test_reorder_columns and \ + not test_logged_features_validation and \ + not test_lambda_materialization_consistency and \ + not test_offline_write and \ + not test_push_features_to_offline_store and \ + not gcs_registry and \ + not s3_registry and \ + not test_universal_types and \ + not test_snowflake" \ + sdk/python/tests + test-python-universal-mysql-online: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.mysql_repo_configuration \ diff --git a/sdk/python/feast/infra/key_encoding_utils.py b/sdk/python/feast/infra/key_encoding_utils.py index e50e438c3de..bdfeb83c4ee 100644 --- a/sdk/python/feast/infra/key_encoding_utils.py +++ b/sdk/python/feast/infra/key_encoding_utils.py @@ -74,8 +74,8 @@ def serialize_entity_key( return b"".join(output) -def get_val_str(val): - accept_value_types = ["float_list_val", "double_list_val", "int_list_val"] +def get_list_val_str(val: ValueProto): + accept_value_types = ["float_list_val", "double_list_val", "int32_list_val", "int64_list_val"] for accept_type in accept_value_types: if val.HasField(accept_type): return str(getattr(val, accept_type).val) diff --git a/sdk/python/feast/infra/online_stores/contrib/pgvector_repo_configuration.py b/sdk/python/feast/infra/online_stores/contrib/pgvector_repo_configuration.py new file mode 100644 index 00000000000..26b05613158 --- /dev/null +++ b/sdk/python/feast/infra/online_stores/contrib/pgvector_repo_configuration.py @@ -0,0 +1,12 @@ +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.online_store.postgres import ( + PGVectorOnlineStoreCreator, +) + +FULL_REPO_CONFIGS = [ + IntegrationTestRepoConfig( + online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator + ), +] diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 2dcb6187837..8b0dd58f918 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from datetime import datetime -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple import psycopg2 import pytz @@ -12,7 +12,7 @@ from feast import Entity from feast.feature_view import FeatureView -from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key +from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig @@ -74,19 +74,15 @@ def online_write_batch( created_ts = _to_naive_utc(created_ts) for feature_name, val in values.items(): - val_str: Union[str, bytes] - if ( - "pgvector_enabled" in config.online_config - and config.online_config["pgvector_enabled"] - ): - val_str = get_val_str(val) - else: - val_str = val.SerializeToString() + vector_val = None + if "pgvector_enabled" in config.online_store and config.online_store.pgvector_enabled: + vector_val = get_list_val_str(val) insert_values.append( ( entity_key_bin, feature_name, - val_str, + val.SerializeToString(), + vector_val, timestamp, created_ts, ) @@ -100,11 +96,12 @@ def online_write_batch( sql.SQL( """ INSERT INTO {} - (entity_key, feature_name, value, event_ts, created_ts) + (entity_key, feature_name, value, vector_value, event_ts, created_ts) VALUES %s ON CONFLICT (entity_key, feature_name) DO UPDATE SET value = EXCLUDED.value, + vector_value = EXCLUDED.vector_value, event_ts = EXCLUDED.event_ts, created_ts = EXCLUDED.created_ts; """, @@ -226,12 +223,11 @@ def update( for table in tables_to_keep: table_name = _table_id(project, table) - value_type = "BYTEA" - if ( - "pgvector_enabled" in config.online_config - and config.online_config["pgvector_enabled"] - ): - value_type = f'vector({config.online_config["vector_len"]})' + if "pgvector_enabled" in config.online_store and config.online_store.pgvector_enabled: + vector_value_type = f'vector({config.online_store.vector_len})' + else: + # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility + vector_value_type = 'BYTEA' cur.execute( sql.SQL( """ @@ -239,7 +235,8 @@ def update( ( entity_key BYTEA, feature_name TEXT, - value {}, + value BYTEA, + vector_value {} NULL, event_ts TIMESTAMPTZ, created_ts TIMESTAMPTZ, PRIMARY KEY(entity_key, feature_name) @@ -248,7 +245,7 @@ def update( """ ).format( sql.Identifier(table_name), - sql.SQL(value_type), + sql.SQL(vector_value_type), sql.Identifier(f"{table_name}_ek"), sql.Identifier(table_name), ) @@ -294,6 +291,9 @@ def retrieve_online_documents( """ project = config.project + if "pgvector_enabled" not in config.online_store or not config.online_store.pgvector_enabled: + raise ValueError("pgvector is not enabled in the online store configuration") + # Convert the embedding to a string to be used in postgres vector search query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" @@ -311,8 +311,8 @@ def retrieve_online_documents( SELECT entity_key, feature_name, - value, - value <-> %s as distance, + vector_value, + vector_value <-> %s as distance, event_ts FROM {table_name} WHERE feature_name = {feature_name} ORDER BY distance @@ -327,13 +327,13 @@ def retrieve_online_documents( ) rows = cur.fetchall() - for entity_key, feature_name, value, distance, event_ts in rows: + for entity_key, feature_name, vector_value, distance, event_ts in rows: # TODO Deserialize entity_key to return the entity in response # entity_key_proto = EntityKeyProto() # entity_key_proto_bin = bytes(entity_key) # TODO Convert to List[float] for value type proto - feature_value_proto = ValueProto(string_val=value) + feature_value_proto = ValueProto(string_val=vector_value) distance_value_proto = ValueProto(float_val=distance) result.append((event_ts, feature_value_proto, distance_value_proto)) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py b/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py index 6e4ca3f9501..ea975ec808f 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py @@ -2,7 +2,6 @@ IntegrationTestRepoConfig, ) from tests.integration.feature_repos.universal.online_store.postgres import ( - PGVectorOnlineStoreCreator, PostgresOnlineStoreCreator, ) @@ -10,9 +9,4 @@ IntegrationTestRepoConfig( online_store="postgres", online_store_creator=PostgresOnlineStoreCreator ), - IntegrationTestRepoConfig( - online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator - ), ] - -AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator} diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/init.sql b/sdk/python/tests/integration/feature_repos/universal/online_store/init.sql new file mode 100644 index 00000000000..64f04f61ad3 --- /dev/null +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/init.sql @@ -0,0 +1 @@ +CREATE EXTENSION IF NOT EXISTS vector; \ No newline at end of file diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py b/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py index 58e7af9c468..8f019032f95 100644 --- a/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py @@ -3,7 +3,7 @@ from testcontainers.core.container import DockerContainer from testcontainers.core.waiting_utils import wait_for_logs from testcontainers.postgres import PostgresContainer - +import os from tests.integration.feature_repos.universal.online_store_creator import ( OnlineStoreCreator, ) @@ -37,12 +37,17 @@ def teardown(self): class PGVectorOnlineStoreCreator(OnlineStoreCreator): def __init__(self, project_name: str, **kwargs): super().__init__(project_name) + script_directory = os.path.dirname(os.path.abspath(__file__)) self.container = ( DockerContainer("pgvector/pgvector:pg16") .with_env("POSTGRES_USER", "root") .with_env("POSTGRES_PASSWORD", "test") .with_env("POSTGRES_DB", "test") .with_exposed_ports(5432) + .with_volume_mapping( + os.path.join(script_directory, 'init.sql'), + "/docker-entrypoint-initdb.d/init.sql", + ) ) def create_online_store(self) -> Dict[str, str]: @@ -51,8 +56,10 @@ def create_online_store(self) -> Dict[str, str]: wait_for_logs( container=self.container, predicate=log_string_to_wait_for, timeout=10 ) - command = "psql -h localhost -p 5432 -U root -d test -c 'CREATE EXTENSION IF NOT EXISTS vector;'" - self.container.exec(command) + init_log_string_to_wait_for = "PostgreSQL init process complete" + wait_for_logs( + container=self.container, predicate=init_log_string_to_wait_for, timeout=10 + ) return { "host": "localhost", "type": "postgres", From 6abffbfa3fd5726d3de43cc3a50e89bf827b6eb2 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 16 Apr 2024 00:00:12 -0700 Subject: [PATCH 2/6] format --- sdk/python/feast/infra/key_encoding_utils.py | 9 ++++++-- .../infra/online_stores/contrib/postgres.py | 23 ++++++++++++++----- .../universal/online_store/postgres.py | 5 ++-- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/sdk/python/feast/infra/key_encoding_utils.py b/sdk/python/feast/infra/key_encoding_utils.py index bdfeb83c4ee..ca834f19176 100644 --- a/sdk/python/feast/infra/key_encoding_utils.py +++ b/sdk/python/feast/infra/key_encoding_utils.py @@ -74,8 +74,13 @@ def serialize_entity_key( return b"".join(output) -def get_list_val_str(val: ValueProto): - accept_value_types = ["float_list_val", "double_list_val", "int32_list_val", "int64_list_val"] +def get_list_val_str(val): + accept_value_types = [ + "float_list_val", + "double_list_val", + "int32_list_val", + "int64_list_val", + ] for accept_type in accept_value_types: if val.HasField(accept_type): return str(getattr(val, accept_type).val) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 8b0dd58f918..2890f60746b 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -75,7 +75,10 @@ def online_write_batch( for feature_name, val in values.items(): vector_val = None - if "pgvector_enabled" in config.online_store and config.online_store.pgvector_enabled: + if ( + "pgvector_enabled" in config.online_store + and config.online_store.pgvector_enabled + ): vector_val = get_list_val_str(val) insert_values.append( ( @@ -223,11 +226,14 @@ def update( for table in tables_to_keep: table_name = _table_id(project, table) - if "pgvector_enabled" in config.online_store and config.online_store.pgvector_enabled: - vector_value_type = f'vector({config.online_store.vector_len})' + if ( + "pgvector_enabled" in config.online_store + and config.online_store.pgvector_enabled + ): + vector_value_type = f"vector({config.online_store.vector_len})" else: # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility - vector_value_type = 'BYTEA' + vector_value_type = "BYTEA" cur.execute( sql.SQL( """ @@ -291,8 +297,13 @@ def retrieve_online_documents( """ project = config.project - if "pgvector_enabled" not in config.online_store or not config.online_store.pgvector_enabled: - raise ValueError("pgvector is not enabled in the online store configuration") + if ( + "pgvector_enabled" not in config.online_store + or not config.online_store.pgvector_enabled + ): + raise ValueError( + "pgvector is not enabled in the online store configuration" + ) # Convert the embedding to a string to be used in postgres vector search query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py b/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py index 8f019032f95..7b4156fffe0 100644 --- a/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py @@ -1,9 +1,10 @@ +import os from typing import Dict from testcontainers.core.container import DockerContainer from testcontainers.core.waiting_utils import wait_for_logs from testcontainers.postgres import PostgresContainer -import os + from tests.integration.feature_repos.universal.online_store_creator import ( OnlineStoreCreator, ) @@ -45,7 +46,7 @@ def __init__(self, project_name: str, **kwargs): .with_env("POSTGRES_DB", "test") .with_exposed_ports(5432) .with_volume_mapping( - os.path.join(script_directory, 'init.sql'), + os.path.join(script_directory, "init.sql"), "/docker-entrypoint-initdb.d/init.sql", ) ) From ce854d48865f4aeb1492ede12e5eb3c16f791c71 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 16 Apr 2024 00:13:46 -0700 Subject: [PATCH 3/6] update doc Signed-off-by: cmuhao --- docs/reference/online-stores/postgres.md | 28 ++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/reference/online-stores/postgres.md b/docs/reference/online-stores/postgres.md index 3885867dd26..57d6366f830 100644 --- a/docs/reference/online-stores/postgres.md +++ b/docs/reference/online-stores/postgres.md @@ -30,6 +30,8 @@ online_store: sslkey_path: /path/to/client-key.pem sslcert_path: /path/to/client-cert.pem sslrootcert_path: /path/to/server-ca.pem + pgvector_enabled: false + vector_len: 512 ``` {% endcode %} @@ -60,3 +62,29 @@ Below is a matrix indicating which functionality is supported by the Postgres on | collocated by entity key | no | To compare this set of functionality against other online stores, please see the full [functionality matrix](overview.md#functionality-matrix). + +## PGVector +The Postgres online store supports the use of [PGVector](https://pgvector.dev/) for storing feature values. +To enable PGVector, set `pgvector_enabled: true` in the online store configuration. +The `vector_len` parameter can be used to specify the length of the vector. The default value is 512. + +Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector. + +{% code title="python" %} +```python +from feast import FeatureStore +from feast.infra.online_stores.postgres import retrieve_online_documents + +feature_store = FeatureStore(repo_path=".") + +query_vector = [0.1, 0.2, 0.3, 0.4, 0.5] +top_k = 5 + +feature_values = retrieve_online_documents( + feature_store=feature_store, + feature_view_name="document_fv:embedding_float", + query_vector=query_vector, + top_k=top_k, +) +``` +{% endcode %} \ No newline at end of file From 1bad023e4ab425d7ce4b1b64bfbbfacfc9efd01f Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 16 Apr 2024 00:19:09 -0700 Subject: [PATCH 4/6] update doc Signed-off-by: cmuhao --- docs/reference/online-stores/postgres.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/online-stores/postgres.md b/docs/reference/online-stores/postgres.md index 57d6366f830..277494868c8 100644 --- a/docs/reference/online-stores/postgres.md +++ b/docs/reference/online-stores/postgres.md @@ -87,4 +87,4 @@ feature_values = retrieve_online_documents( top_k=top_k, ) ``` -{% endcode %} \ No newline at end of file +{% endcode %} From 1c62a34f43b54f8abd3ffefd47217a574a63d5a9 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 16 Apr 2024 22:22:17 -0700 Subject: [PATCH 5/6] minor change on return value Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 17 ++++++----- .../infra/online_stores/contrib/postgres.py | 29 +++++++------------ .../feast/infra/online_stores/online_store.py | 2 +- sdk/python/feast/infra/provider.py | 2 +- sdk/python/tests/foo_provider.py | 2 +- 5 files changed, 24 insertions(+), 28 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 15598e1d609..15ac0f20e41 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1740,12 +1740,14 @@ def _retrieve_online_documents( query, top_k, ) - document_feature_vals = [feature[2] for feature in document_features] - document_feature_distance_vals = [feature[3] for feature in document_features] - online_features_response = GetOnlineFeaturesResponse(results=[]) # TODO Refactor to better way of populating result # TODO populate entity in the response after returning entity in document_features is supported + # TODO currently not return the vector value since it is same as feature value, if embedding is supported, + # the feature value can be raw text before embedded + document_feature_vals = [feature[2] for feature in document_features] + document_feature_distance_vals = [feature[4] for feature in document_features] + online_features_response = GetOnlineFeaturesResponse(results=[]) self._populate_result_rows_from_columnar( online_features_response=online_features_response, data={requested_feature: document_feature_vals}, @@ -1979,7 +1981,7 @@ def _retrieve_from_online_store( requested_feature: str, query: List[float], top_k: int, - ) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]: + ) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]: """ Search and return document features from the online document store. """ @@ -1994,19 +1996,20 @@ def _retrieve_from_online_store( read_row_protos = [] row_ts_proto = Timestamp() - for row_ts, feature_val, distance_val in documents: + for row_ts, feature_val, vector_value, distance_val in documents: # Reset timestamp to default or update if row_ts is not None if row_ts is not None: row_ts_proto.FromDatetime(row_ts) - if feature_val is None or distance_val is None: + if feature_val is None or vector_value is None or distance_val is None: feature_val = Value() + vector_value = Value() distance_val = Value() status = FieldStatus.NOT_FOUND else: status = FieldStatus.PRESENT - read_row_protos.append((row_ts_proto, status, feature_val, distance_val)) + read_row_protos.append((row_ts_proto, status, feature_val, vector_value, distance_val)) return read_row_protos @staticmethod diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 2890f60746b..25b80ce001a 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -75,10 +75,7 @@ def online_write_batch( for feature_name, val in values.items(): vector_val = None - if ( - "pgvector_enabled" in config.online_store - and config.online_store.pgvector_enabled - ): + if config.online_store.pgvector_enabled: vector_val = get_list_val_str(val) insert_values.append( ( @@ -226,10 +223,7 @@ def update( for table in tables_to_keep: table_name = _table_id(project, table) - if ( - "pgvector_enabled" in config.online_store - and config.online_store.pgvector_enabled - ): + if config.online_store.pgvector_enabled: vector_value_type = f"vector({config.online_store.vector_len})" else: # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility @@ -282,7 +276,7 @@ def retrieve_online_documents( requested_feature: str, embedding: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: """ Args: @@ -297,10 +291,7 @@ def retrieve_online_documents( """ project = config.project - if ( - "pgvector_enabled" not in config.online_store - or not config.online_store.pgvector_enabled - ): + if not config.online_store.pgvector_enabled: raise ValueError( "pgvector is not enabled in the online store configuration" ) @@ -309,7 +300,7 @@ def retrieve_online_documents( query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" result: List[ - Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]] + Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]] ] = [] with self._get_conn(config) as conn, conn.cursor() as cur: table_name = _table_id(project, table) @@ -322,6 +313,7 @@ def retrieve_online_documents( SELECT entity_key, feature_name, + value, vector_value, vector_value <-> %s as distance, event_ts FROM {table_name} @@ -338,16 +330,17 @@ def retrieve_online_documents( ) rows = cur.fetchall() - for entity_key, feature_name, vector_value, distance, event_ts in rows: + for entity_key, feature_name, value, vector_value, distance, event_ts in rows: # TODO Deserialize entity_key to return the entity in response # entity_key_proto = EntityKeyProto() # entity_key_proto_bin = bytes(entity_key) - # TODO Convert to List[float] for value type proto - feature_value_proto = ValueProto(string_val=vector_value) + feature_value_proto = ValueProto() + feature_value_proto.ParseFromString(bytes(value)) + vector_value_proto = ValueProto(string_val=vector_value) distance_value_proto = ValueProto(float_val=distance) - result.append((event_ts, feature_value_proto, distance_value_proto)) + result.append((event_ts, feature_value_proto, vector_value_proto, distance_value_proto)) return result diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index fc1b3d4ad30..be31c38f1ef 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -142,7 +142,7 @@ def retrieve_online_documents( requested_feature: str, embedding: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: """ Retrieves online feature values for the specified embeddings. diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index e71e87488d7..18aa7c411c5 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -303,7 +303,7 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: """ Searches for the top-k nearest neighbors of the given document in the online document store. diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 7ba4adb114b..7464ab864c5 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -111,5 +111,5 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: return [] From d5d93e7aacc68f5d31771ab9df30af306e812c2c Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 16 Apr 2024 22:22:45 -0700 Subject: [PATCH 6/6] minor change on return value Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 4 ++- .../infra/online_stores/contrib/postgres.py | 34 ++++++++++++++++--- .../feast/infra/online_stores/online_store.py | 9 ++++- sdk/python/feast/infra/provider.py | 9 ++++- sdk/python/tests/foo_provider.py | 9 ++++- 5 files changed, 57 insertions(+), 8 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 15ac0f20e41..f42cced11cf 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -2009,7 +2009,9 @@ def _retrieve_from_online_store( else: status = FieldStatus.PRESENT - read_row_protos.append((row_ts_proto, status, feature_val, vector_value, distance_val)) + read_row_protos.append( + (row_ts_proto, status, feature_val, vector_value, distance_val) + ) return read_row_protos @staticmethod diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 25b80ce001a..6ed0885d138 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -276,7 +276,14 @@ def retrieve_online_documents( requested_feature: str, embedding: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: """ Args: @@ -300,7 +307,12 @@ def retrieve_online_documents( query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" result: List[ - Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]] + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] ] = [] with self._get_conn(config) as conn, conn.cursor() as cur: table_name = _table_id(project, table) @@ -330,7 +342,14 @@ def retrieve_online_documents( ) rows = cur.fetchall() - for entity_key, feature_name, value, vector_value, distance, event_ts in rows: + for ( + entity_key, + feature_name, + value, + vector_value, + distance, + event_ts, + ) in rows: # TODO Deserialize entity_key to return the entity in response # entity_key_proto = EntityKeyProto() # entity_key_proto_bin = bytes(entity_key) @@ -340,7 +359,14 @@ def retrieve_online_documents( vector_value_proto = ValueProto(string_val=vector_value) distance_value_proto = ValueProto(float_val=distance) - result.append((event_ts, feature_value_proto, vector_value_proto, distance_value_proto)) + result.append( + ( + event_ts, + feature_value_proto, + vector_value_proto, + distance_value_proto, + ) + ) return result diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index be31c38f1ef..67c5a931dda 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -142,7 +142,14 @@ def retrieve_online_documents( requested_feature: str, embedding: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: """ Retrieves online feature values for the specified embeddings. diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 18aa7c411c5..a45051a1b6b 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -303,7 +303,14 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: """ Searches for the top-k nearest neighbors of the given document in the online document store. diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 7464ab864c5..2a830d424cc 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -111,5 +111,12 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: return []