Skip to content
2 changes: 1 addition & 1 deletion sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class VersionedOnlineReadNotSupported(FeastError):
def __init__(self, store_name: str, version: int):
super().__init__(
f"Versioned feature reads (@v{version}) are not yet supported by {store_name}. "
f"Currently only SQLite supports version-qualified feature references. "
f"Currently only SQLite, PostgreSQL, and MySQL support version-qualified feature references. "
)


Expand Down
12 changes: 12 additions & 0 deletions sdk/python/feast/infra/online_stores/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,15 @@ def _to_naive_utc(ts: datetime) -> datetime:
return ts
else:
return ts.astimezone(tz=timezone.utc).replace(tzinfo=None)


def compute_table_id(project: str, table: Any, enable_versioning: bool = False) -> str:
"""Build the online-store table name, appending a version suffix when versioning is enabled."""
name = table.name
if enable_versioning:
version = getattr(table.projection, "version_tag", None)
if version is None:
version = getattr(table, "current_version_number", None)
if version is not None and version > 0:
name = f"{table.name}_v{version}"
return f"{project}_{name}"
72 changes: 54 additions & 18 deletions sdk/python/feast/infra/online_stores/mysql_online_store/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from feast import Entity, FeatureView, RepoConfig
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.helpers import compute_table_id
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
Expand Down Expand Up @@ -70,6 +71,7 @@ def online_write_batch(
cur = conn.cursor()

project = config.project
versioning = config.registry.enable_online_feature_view_versioning

batch_write = config.online_store.batch_write
if not batch_write:
Expand All @@ -92,6 +94,7 @@ def online_write_batch(
table,
timestamp,
val,
versioning,
)
conn.commit()
if progress:
Expand Down Expand Up @@ -124,7 +127,9 @@ def online_write_batch(

if len(insert_values) >= batch_size:
try:
self._execute_batch(cur, project, table, insert_values)
self._execute_batch(
cur, project, table, insert_values, versioning
)
conn.commit()
if progress:
progress(len(insert_values))
Expand All @@ -135,17 +140,20 @@ def online_write_batch(

if insert_values:
try:
self._execute_batch(cur, project, table, insert_values)
self._execute_batch(cur, project, table, insert_values, versioning)
conn.commit()
if progress:
progress(len(insert_values))
except Exception as e:
conn.rollback()
raise e

def _execute_batch(self, cur, project, table, insert_values):
sql = f"""
INSERT INTO {_table_id(project, table)}
def _execute_batch(
self, cur, project, table, insert_values, enable_versioning=False
):
table_name = _table_id(project, table, enable_versioning)
stmt = f"""
INSERT INTO {table_name}
(entity_key, feature_name, value, event_ts, created_ts)
values (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
Expand All @@ -154,22 +162,29 @@ def _execute_batch(self, cur, project, table, insert_values):
created_ts = VALUES(created_ts);
"""
try:
cur.executemany(sql, insert_values)
cur.executemany(stmt, insert_values)
except Exception as e:
# Log SQL info for debugging without leaking sensitive data
first_sample = insert_values[0] if insert_values else None
raise RuntimeError(
f"Failed to execute batch insert into table '{_table_id(project, table)}' "
f"Failed to execute batch insert into table '{table_name}' "
f"(rows={len(insert_values)}, sample={first_sample}): {e}"
) from e

@staticmethod
def write_to_table(
created_ts, cur, entity_key_bin, feature_name, project, table, timestamp, val
created_ts,
cur,
entity_key_bin,
feature_name,
project,
table,
timestamp,
val,
enable_versioning=False,
) -> None:
cur.execute(
f"""
INSERT INTO {_table_id(project, table)}
INSERT INTO {_table_id(project, table, enable_versioning)}
(entity_key, feature_name, value, event_ts, created_ts)
values (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
Expand Down Expand Up @@ -204,14 +219,15 @@ def online_read(
result: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = []

project = config.project
versioning = config.registry.enable_online_feature_view_versioning
for entity_key in entity_keys:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=3,
).hex()

cur.execute(
f"SELECT feature_name, value, event_ts FROM {_table_id(project, table)} WHERE entity_key = %s",
f"SELECT feature_name, value, event_ts FROM {_table_id(project, table, versioning)} WHERE entity_key = %s",
(entity_key_bin,),
)

Expand Down Expand Up @@ -243,10 +259,11 @@ def update(
conn = self._get_conn(config)
cur = conn.cursor()
project = config.project
versioning = config.registry.enable_online_feature_view_versioning

# We don't create any special state for the entities in this implementation.
for table in tables_to_keep:
table_name = _table_id(project, table)
table_name = _table_id(project, table, versioning)
index_name = f"{table_name}_ek"
cur.execute(
f"""CREATE TABLE IF NOT EXISTS {table_name} (entity_key VARCHAR(512),
Expand All @@ -269,7 +286,10 @@ def update(
)

for table in tables_to_delete:
_drop_table_and_index(cur, project, table)
if versioning:
_drop_all_version_tables(cur, project, table)
else:
_drop_table_and_index(cur, _table_id(project, table))

def teardown(
self,
Expand All @@ -280,16 +300,32 @@ def teardown(
conn = self._get_conn(config)
cur = conn.cursor()
project = config.project
versioning = config.registry.enable_online_feature_view_versioning

for table in tables:
_drop_table_and_index(cur, project, table)
if versioning:
_drop_all_version_tables(cur, project, table)
else:
_drop_table_and_index(cur, _table_id(project, table))


def _drop_table_and_index(cur: Cursor, project: str, table: FeatureView) -> None:
table_name = _table_id(project, table)
def _drop_table_and_index(cur: Cursor, table_name: str) -> None:
cur.execute(f"DROP INDEX {table_name}_ek ON {table_name};")
cur.execute(f"DROP TABLE IF EXISTS {table_name}")


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"
def _drop_all_version_tables(cur: Cursor, project: str, table: FeatureView) -> None:
"""Drop the base table and all versioned tables (e.g. _v1, _v2, ...)."""
base = f"{project}_{table.name}"
cur.execute(
"SELECT table_name FROM information_schema.tables "
"WHERE table_schema = DATABASE() AND (table_name = %s OR table_name REGEXP %s)",
(base, f"^{base}_v[0-9]+$"),
)
for (name,) in cur.fetchall():
cur.execute(f"DROP INDEX IF EXISTS {name}_ek ON {name};")
cur.execute(f"DROP TABLE IF EXISTS {name}")


def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str:
return compute_table_id(project, table, enable_versioning)
20 changes: 19 additions & 1 deletion sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,25 @@ def _check_versioned_read_support(self, grouped_refs):
"""Raise an error if versioned reads are attempted on unsupported stores."""
from feast.infra.online_stores.sqlite import SqliteOnlineStore

if isinstance(self, SqliteOnlineStore):
supported_types: list[type] = [SqliteOnlineStore]
try:
from feast.infra.online_stores.mysql_online_store.mysql import (
MySQLOnlineStore,
)

supported_types.append(MySQLOnlineStore)
except ImportError:
pass
try:
from feast.infra.online_stores.postgres_online_store.postgres import (
PostgreSQLOnlineStore,
)

supported_types.append(PostgreSQLOnlineStore)
except ImportError:
pass

if isinstance(self, tuple(supported_types)):
return
for table, _ in grouped_refs:
version_tag = getattr(table.projection, "version_tag", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from feast import Entity, FeatureView, ValueType
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
from feast.infra.online_stores.helpers import _to_naive_utc
from feast.infra.online_stores.helpers import _to_naive_utc, compute_table_id
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.online_stores.vector_store import VectorStoreConfig
from feast.infra.utils.postgres.connection_utils import (
Expand Down Expand Up @@ -152,7 +152,15 @@ def online_write_batch(
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
"""
).format(sql.Identifier(_table_id(config.project, table)))
).format(
sql.Identifier(
_table_id(
config.project,
table,
config.registry.enable_online_feature_view_versioning,
)
)
)

# Push data into the online store
with self._get_conn(config) as conn, conn.cursor() as cur:
Expand Down Expand Up @@ -214,7 +222,13 @@ def _construct_query_and_params(
FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s);
"""
).format(
sql.Identifier(_table_id(config.project, table)),
sql.Identifier(
_table_id(
config.project,
table,
config.registry.enable_online_feature_view_versioning,
)
),
)
params = (keys, requested_features)
else:
Expand All @@ -224,7 +238,13 @@ def _construct_query_and_params(
FROM {} WHERE entity_key = ANY(%s);
"""
).format(
sql.Identifier(_table_id(config.project, table)),
sql.Identifier(
_table_id(
config.project,
table,
config.registry.enable_online_feature_view_versioning,
)
),
)
params = (keys, [])
return query, params
Expand Down Expand Up @@ -304,12 +324,16 @@ def update(
),
)

versioning = config.registry.enable_online_feature_view_versioning
for table in tables_to_delete:
table_name = _table_id(project, table)
cur.execute(_drop_table_and_index(table_name))
if versioning:
_drop_all_version_tables(cur, project, table, schema_name)
else:
table_name = _table_id(project, table)
cur.execute(_drop_table_and_index(table_name))

for table in tables_to_keep:
table_name = _table_id(project, table)
table_name = _table_id(project, table, versioning)
if config.online_store.vector_enabled:
vector_value_type = "vector"
else:
Expand Down Expand Up @@ -363,11 +387,16 @@ def teardown(
entities: Sequence[Entity],
):
project = config.project
schema_name = config.online_store.db_schema or config.online_store.user
versioning = config.registry.enable_online_feature_view_versioning
try:
with self._get_conn(config) as conn, conn.cursor() as cur:
for table in tables:
table_name = _table_id(project, table)
cur.execute(_drop_table_and_index(table_name))
if versioning:
_drop_all_version_tables(cur, project, table, schema_name)
else:
table_name = _table_id(project, table)
cur.execute(_drop_table_and_index(table_name))
conn.commit()
except Exception:
logging.exception("Teardown failed")
Expand Down Expand Up @@ -432,7 +461,9 @@ def retrieve_online_documents(
]
] = []
with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
table_name = _table_id(project, table)
table_name = _table_id(
project, table, config.registry.enable_online_feature_view_versioning
)

# Search query template to find the top k items that are closest to the given embedding
# SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
Expand Down Expand Up @@ -533,7 +564,11 @@ def retrieve_online_documents_v2(
and feature.name in requested_features
]

table_name = _table_id(config.project, table)
table_name = _table_id(
config.project,
table,
config.registry.enable_online_feature_view_versioning,
)

with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
query = None
Expand Down Expand Up @@ -794,8 +829,8 @@ def retrieve_online_documents_v2(
return result


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"
def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str:
return compute_table_id(project, table, enable_versioning)


def _drop_table_and_index(table_name):
Expand All @@ -808,3 +843,23 @@ def _drop_table_and_index(table_name):
sql.Identifier(table_name),
sql.Identifier(f"{table_name}_ek"),
)


def _drop_all_version_tables(
cur, project: str, table: FeatureView, schema_name: Optional[str] = None
) -> None:
"""Drop the base table and all versioned tables (e.g. _v1, _v2, ...)."""
base = f"{project}_{table.name}"
if schema_name:
cur.execute(
"SELECT tablename FROM pg_tables "
"WHERE schemaname = %s AND (tablename = %s OR tablename ~ %s)",
(schema_name, base, f"^{base}_v[0-9]+$"),
)
else:
cur.execute(
"SELECT tablename FROM pg_tables WHERE tablename = %s OR tablename ~ %s",
(base, f"^{base}_v[0-9]+$"),
)
for (name,) in cur.fetchall():
cur.execute(_drop_table_and_index(name))
Loading
Loading