diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index a093795c42a..a4b50481e38 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -38,7 +38,7 @@ ) from fastapi.concurrency import run_in_threadpool from fastapi.logger import logger -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, ORJSONResponse from fastapi.staticfiles import StaticFiles from google.protobuf.json_format import MessageToDict from pydantic import BaseModel, field_validator @@ -50,9 +50,11 @@ from feast.data_source import PushMode from feast.errors import ( FeastError, + FeatureViewNotFoundException, ) from feast.feast_object import FeastObject from feast.feature_view_utils import get_feature_view_from_feature_store +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.permissions.action import WRITE, AuthzedAction from feast.permissions.security_manager import assert_permissions from feast.permissions.server.rest import inject_user_details @@ -108,7 +110,42 @@ class GetOnlineDocumentsRequest(BaseModel): top_k: Optional[int] = None query: Optional[List[float]] = None query_string: Optional[str] = None + distance_metric: Optional[str] = None api_version: Optional[int] = 1 + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None + + +class OpenAISearchMetadata(BaseModel): + features_to_retrieve: Optional[List[str]] = None + content_field: Optional[str] = None + + +class OpenAIComparisonFilter(BaseModel): + key: str + type: str + value: Union[str, int, float, bool, List[Union[str, int]]] + + +class OpenAICompoundFilter(BaseModel): + type: str + filters: List[Union[OpenAIComparisonFilter, "OpenAICompoundFilter"]] + + +OpenAICompoundFilter.model_rebuild() + + +class OpenAIRankingOptions(BaseModel): + ranker: Optional[str] = None + score_threshold: Optional[float] = None + + +class OpenAISearchRequest(BaseModel): + query: Union[str, List[str]] + filters: Optional[Union[OpenAIComparisonFilter, OpenAICompoundFilter]] = None + max_num_results: Optional[int] = 10 + ranking_options: Optional[OpenAIRankingOptions] = None + rewrite_query: Optional[bool] = None + metadata: Optional[OpenAISearchMetadata] = None class FeatureVectorResponse(BaseModel): @@ -418,6 +455,10 @@ async def retrieve_online_documents( ) if request.api_version == 2 and request.query_string is not None: read_params["query_string"] = request.query_string + if request.api_version == 2 and request.distance_metric is not None: + read_params["distance_metric"] = request.distance_metric + if request.api_version == 2 and request.filters is not None: + read_params["filters"] = request.filters if request.api_version == 2: response = await run_in_threadpool( @@ -436,6 +477,51 @@ async def retrieve_online_documents( ) return response_dict + @app.post( + "/v1/vector_stores/{vector_store_id}/search", + dependencies=[Depends(inject_user_details)], + ) + async def openai_vector_store_search( + vector_store_id: str, + request: OpenAISearchRequest, + ) -> ORJSONResponse: + with feast_metrics.track_request_latency( + "/v1/vector_stores/{vector_store_id}/search" + ): + try: + result = await run_in_threadpool( + lambda: store.retrieve_online_documents_openai( + vector_store_id=vector_store_id, + query=request.query, + max_num_results=request.max_num_results or 10, + filters=( + request.filters.model_dump() if request.filters else None + ), + ranking_options=( + request.ranking_options.model_dump() + if request.ranking_options + else None + ), + rewrite_query=request.rewrite_query, + features_to_retrieve=( + request.metadata.features_to_retrieve + if request.metadata + else None + ), + ) + ) + except FeatureViewNotFoundException: + return ORJSONResponse( + status_code=404, + content={ + "error": { + "message": f"No vector store found with id '{vector_store_id}'", + "type": "not_found_error", + } + }, + ) + return ORJSONResponse(content=result) + @app.post("/push", dependencies=[Depends(inject_user_details)]) async def push(request: PushFeaturesRequest) -> Response: with feast_metrics.track_request_latency("/push"): diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index fe0e7967345..41d29d3a7ec 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -69,6 +69,7 @@ from feast.feast_object import FeastObject from feast.feature_service import FeatureService from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_NAME, FeatureView +from feast.filter_models import ComparisonFilter, CompoundFilter, convert_dict_to_filter from feast.inference import ( update_data_sources_with_inferred_event_timestamp_col, update_feature_views_with_inferred_features_and_entities, @@ -2677,6 +2678,7 @@ def retrieve_online_documents_v2( text_weight: float = 0.5, image_weight: float = 0.5, combine_strategy: str = "weighted_sum", + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> OnlineResponse: """ Retrieves the top k closest document features. Note, embeddings are a subset of features. @@ -2829,8 +2831,164 @@ def retrieve_online_documents_v2( top_k, distance_metric, query_string, + filters, ) + def retrieve_online_documents_openai( + self, + vector_store_id: str, + query: Union[str, List[str]], + max_num_results: int = 10, + filters: Optional[Dict[str, Any]] = None, + ranking_options: Optional[Dict[str, Any]] = None, + rewrite_query: Optional[bool] = None, + features_to_retrieve: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + OpenAI-compatible vector store search. + + Accepts a raw query string, optionally embeds it via LiteLLM + (when ``query_embedding_model`` is configured in feature_store.yaml), + and returns results in OpenAI's ``vector_store.search_results.page`` + format. + + Args: + vector_store_id: Feature view name (maps to the OpenAI + ``vector_store_id`` path parameter). + query: Natural language query string, or list of strings. + max_num_results: Maximum number of results to return. + filters: OpenAI-compatible filters (accepted but not yet + applied). + ranking_options: OpenAI-compatible ranking options (accepted + but not yet applied). + rewrite_query: Whether to rewrite the query (accepted but + not yet applied). + features_to_retrieve: Specific feature names to return. + If None, all features from the feature view are used. + + Returns: + Dict matching the OpenAI ``vector_store.search_results.page`` + schema. + + Examples: + Keyword search (no embedding model configured):: + + result = store.retrieve_online_documents_openai( + vector_store_id="city_embeddings", + query="cities in California", + max_num_results=5, + ) + + Vector search (embedding model configured in YAML):: + + # feature_store.yaml has: + # feature_server: + # query_embedding_model: text-embedding-3-small + result = store.retrieve_online_documents_openai( + vector_store_id="product_embeddings", + query="wireless audio device", + max_num_results=3, + features_to_retrieve=["name", "description"], + ) + """ + feature_view = self.get_feature_view(vector_store_id) + + if features_to_retrieve: + feature_names = features_to_retrieve + else: + feature_names = [f.name for f in feature_view.features] + + features = [f"{feature_view.name}:{name}" for name in feature_names] + query_text = query if isinstance(query, str) else " ".join(query) + + embed_cfg = self.config.embedding_model + if embed_cfg is None: + raise ValueError( + "embedding_model is not configured in feature_store.yaml. " + "Add an 'embedding_model' section with at least a 'model' " + "field to use retrieve_online_documents_openai.\n" + "Example:\n" + " embedding_model:\n" + " model: text-embedding-3-small\n" + " api_key: sk-..." + ) + + try: + from litellm import embedding as litellm_embedding + except ImportError: + raise ImportError( + "litellm is required for query embedding. " + "Install with: pip install litellm" + ) + + litellm_kwargs: Dict[str, Any] = { + "model": embed_cfg.model, + "input": [query_text], + } + if embed_cfg.api_key: + litellm_kwargs["api_key"] = embed_cfg.api_key + if embed_cfg.api_base: + litellm_kwargs["api_base"] = embed_cfg.api_base + if embed_cfg.api_version: + litellm_kwargs["api_version"] = embed_cfg.api_version + if embed_cfg.dimensions: + litellm_kwargs["dimensions"] = embed_cfg.dimensions + + embed_response = litellm_embedding(**litellm_kwargs) + query_embedding = embed_response.data[0]["embedding"] + + typed_filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None + if filters is not None: + typed_filters = convert_dict_to_filter(filters) + + response = self.retrieve_online_documents_v2( + features=features, + query=query_embedding, + top_k=max_num_results, + filters=typed_filters, + ) + + response_dict = response.to_dict() + + result_data = [] + if response_dict: + first_key = next(iter(response_dict)) + num_rows = len(response_dict.get(first_key, [])) + for i in range(num_rows): + score = 0.0 + attributes: Dict[str, Any] = {} + content_parts: List[Dict[str, str]] = [] + + for key, values in response_dict.items(): + val = values[i] if i < len(values) else None + if key == "distance": + score = float(val) if val is not None else 0.0 + else: + attributes[key] = val + if isinstance(val, str): + content_parts.append({"type": "text", "text": val}) + + result_data.append( + { + "file_id": f"{vector_store_id}_{i}", + "filename": vector_store_id, + "score": score, + "attributes": attributes, + "content": content_parts + if content_parts + else [{"type": "text", "text": str(attributes)}], + } + ) + + search_query = query if isinstance(query, list) else [query] + return { + "object": "vector_store.search_results.page", + "search_query": search_query, + "data": result_data, + "has_more": False, + "next_page": None, + } + def _retrieve_from_online_store( self, provider: Provider, @@ -2893,6 +3051,7 @@ def _retrieve_from_online_store_v2( top_k: int, distance_metric: Optional[str], query_string: Optional[str], + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> OnlineResponse: """ Search and return document features from the online document store. @@ -2909,6 +3068,7 @@ def _retrieve_from_online_store_v2( top_k=top_k, distance_metric=distance_metric, query_string=query_string, + filters=filters, ) entity_key_dict: Dict[str, List[ValueProto]] = {} diff --git a/sdk/python/feast/filter_models.py b/sdk/python/feast/filter_models.py new file mode 100644 index 00000000000..7b5c94bff39 --- /dev/null +++ b/sdk/python/feast/filter_models.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel + + +class ComparisonFilter(BaseModel): + """A filter that compares a metadata field against a value. + + :param type: The comparison operator to apply + :param key: The metadata field name to filter on + :param value: The value to compare against + """ + + type: Literal["eq", "ne", "gt", "gte", "lt", "lte", "in", "nin"] + key: str + value: Any + + +class CompoundFilter(BaseModel): + """A filter that combines multiple filters with a logical operator. + + :param type: The logical operator ("and" requires all filters match, + "or" requires any filter matches) + :param filters: The list of filters to combine + """ + + type: Literal["and", "or"] + filters: List[Union[ComparisonFilter, "CompoundFilter"]] + + +CompoundFilter.model_rebuild() + +FilterType = Optional[Union[ComparisonFilter, CompoundFilter]] + + +def convert_dict_to_filter( + filter_dict: Dict[str, Any], +) -> Union[ComparisonFilter, CompoundFilter]: + """Convert a raw dict (e.g. from OpenAI-compatible JSON) into a typed filter object.""" + filter_type = filter_dict.get("type") + if filter_type in ("and", "or"): + return CompoundFilter( + type=filter_type, + filters=[convert_dict_to_filter(f) for f in filter_dict["filters"]], + ) + return ComparisonFilter( + type=filter_dict["type"], + key=filter_dict["key"], + value=filter_dict["value"], + ) diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index 7e8e533281d..ad292ce6da7 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -5,11 +5,12 @@ import logging from collections import defaultdict from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from elasticsearch import Elasticsearch, helpers from feast import Entity, FeatureView, RepoConfig +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.key_encoding_utils import ( deserialize_entity_key, get_list_val_str, @@ -340,6 +341,79 @@ def retrieve_online_documents( ) return result + def _translate_filters( + self, + filters: Optional[Union[ComparisonFilter, CompoundFilter]], + ) -> List[Dict[str, Any]]: + """Translate filter objects into Elasticsearch Query DSL filter clauses. + + Returns a list of ES filter clause dicts suitable for insertion into + a ``bool.filter`` array. Returns an empty list when no filters are + provided. + """ + if filters is None: + return [] + return [self._translate_single_filter(filters)] + + def _translate_single_filter( + self, + filter_obj: Union[ComparisonFilter, CompoundFilter], + ) -> Dict[str, Any]: + if isinstance(filter_obj, ComparisonFilter): + return self._translate_comparison_filter(filter_obj) + elif isinstance(filter_obj, CompoundFilter): + return self._translate_compound_filter(filter_obj) + raise ValueError(f"Unknown filter type: {type(filter_obj)}") + + def _translate_comparison_filter( + self, + f: ComparisonFilter, + ) -> Dict[str, Any]: + """Translate a ComparisonFilter to an ES Query DSL clause. + + Feature values in Elasticsearch are stored under + ``.value_text``, so filters target that nested path. + """ + field = f"{f.key}.value_text" + + if f.type == "eq": + return {"term": {field: str(f.value)}} + elif f.type == "ne": + return {"bool": {"must_not": [{"term": {field: str(f.value)}}]}} + elif f.type == "gt": + return {"range": {field: {"gt": f.value}}} + elif f.type == "gte": + return {"range": {field: {"gte": f.value}}} + elif f.type == "lt": + return {"range": {field: {"lt": f.value}}} + elif f.type == "lte": + return {"range": {field: {"lte": f.value}}} + elif f.type == "in": + if not isinstance(f.value, list): + raise ValueError( + f"'in' filter requires a list value, got {type(f.value)}" + ) + return {"terms": {field: [str(v) for v in f.value]}} + elif f.type == "nin": + if not isinstance(f.value, list): + raise ValueError( + f"'nin' filter requires a list value, got {type(f.value)}" + ) + return { + "bool": {"must_not": [{"terms": {field: [str(v) for v in f.value]}}]} + } + raise ValueError(f"Unsupported comparison operator: {f.type}") + + def _translate_compound_filter( + self, + f: CompoundFilter, + ) -> Dict[str, Any]: + clauses = [self._translate_single_filter(sub) for sub in f.filters] + if f.type == "and": + return {"bool": {"must": clauses}} + else: + return {"bool": {"should": clauses, "minimum_should_match": 1}} + def retrieve_online_documents_v2( self, config: RepoConfig, @@ -349,6 +423,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List[ Tuple[ Optional[datetime], @@ -383,6 +458,8 @@ def retrieve_online_documents_v2( source_fields += composite_key_name body["_source"] = source_fields + metadata_filters = self._translate_filters(filters) + if embedding: similarity = (distance_metric or config.online_store.similarity).lower() vector_field_path = ( @@ -399,37 +476,47 @@ def retrieve_online_documents_v2( f"Unsupported similarity/distance_metric: {similarity}" ) - # Hybrid search if embedding and query_string: + bool_clause: Dict[str, Any] = { + "must": [ + {"query_string": {"query": f'"{query_string}"'}}, + {"exists": {"field": vector_field_path}}, + ] + } + if metadata_filters: + bool_clause["filter"] = metadata_filters body["query"] = { "script_score": { - "query": { - "bool": { - "must": [ - {"query_string": {"query": f'"{query_string}"'}}, - {"exists": {"field": vector_field_path}}, - ] - } - }, + "query": {"bool": bool_clause}, "script": { "source": script, "params": {"query_vector": embedding}, }, } } - # Vector search only elif embedding: + filter_clauses: List[Dict[str, Any]] = [ + {"exists": {"field": vector_field_path}} + ] + filter_clauses.extend(metadata_filters) body["query"] = { "script_score": { - "query": { - "bool": {"filter": [{"exists": {"field": vector_field_path}}]} - }, + "query": {"bool": {"filter": filter_clauses}}, "script": {"source": script, "params": {"query_vector": embedding}}, } } - # Keyword search only elif query_string: - body["query"] = {"query_string": {"query": f'"{query_string}"'}} + if metadata_filters: + body["query"] = { + "bool": { + "must": [ + {"query_string": {"query": f'"{query_string}"'}}, + ], + "filter": metadata_filters, + } + } + else: + body["query"] = {"query_string": {"query": f'"{query_string}"'}} response = self._get_client(config).search(index=es_index, body=body) @@ -443,7 +530,6 @@ def retrieve_online_documents_v2( timestamp = row["_source"]["timestamp"] timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") - # Create feature dict with all requested features feature_dict = {"distance": _to_value_proto(float(row["_score"]))} if query_string is not None: feature_dict["text_rank"] = _to_value_proto(float(row["_score"])) diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index fb812f82b7b..0572f0dc683 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -13,6 +13,7 @@ from feast import Entity from feast.feature_view import FeatureView +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.infra_object import InfraObject from feast.infra.key_encoding_utils import ( deserialize_entity_key, @@ -93,6 +94,17 @@ FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = milvus_type +def _milvus_fmt(value: Any) -> str: + """Format a Python value for use in a Milvus boolean expression. + + Feast's Milvus store maps all non-vector features to VARCHAR, so + every value is formatted as a quoted string. + """ + s = str(value) + escaped = s.replace("'", "\\'") + return f"'{escaped}'" + + class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): """ Configuration for the Milvus online store. @@ -513,6 +525,74 @@ def teardown( self.client.drop_collection(collection_name) self._collections.pop(collection_name, None) + def _translate_filters( + self, + filters: Optional[Union[ComparisonFilter, CompoundFilter]], + ) -> Optional[str]: + """Translate filter objects into a Milvus expression string. + + Returns a Milvus boolean expression or ``None`` when no filters are + provided, so callers can pass the result directly to ``filter=``. + """ + if filters is None: + return None + return self._translate_single_filter(filters) + + def _translate_single_filter( + self, + filter_obj: Union[ComparisonFilter, CompoundFilter], + ) -> str: + if isinstance(filter_obj, ComparisonFilter): + return self._translate_comparison_filter(filter_obj) + elif isinstance(filter_obj, CompoundFilter): + return self._translate_compound_filter(filter_obj) + raise ValueError(f"Unknown filter type: {type(filter_obj)}") + + def _translate_comparison_filter( + self, + filter_obj: ComparisonFilter, + ) -> str: + """Translate a comparison filter to a Milvus boolean expression. + + Feast's Milvus store maps all non-vector features to VARCHAR + columns, so values are always formatted as quoted strings. + """ + key, value, op_type = filter_obj.key, filter_obj.value, filter_obj.type + + milvus_ops = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="} + + if op_type == "eq": + return f"{key} == {_milvus_fmt(value)}" + elif op_type == "ne": + return f"{key} != {_milvus_fmt(value)}" + elif op_type in milvus_ops: + return f"{key} {milvus_ops[op_type]} {_milvus_fmt(value)}" + elif op_type in ("in", "nin"): + if not isinstance(value, list): + raise ValueError( + f"'{op_type}' filter requires a list value, got {type(value)}" + ) + formatted = [_milvus_fmt(v) for v in value] + kw = "not in" if op_type == "nin" else "in" + return f"{key} {kw} [{', '.join(formatted)}]" + raise ValueError(f"Unsupported comparison operator: {op_type}") + + def _translate_compound_filter( + self, + filter_obj: CompoundFilter, + ) -> str: + if not filter_obj.filters: + return "" + clauses = [] + for sub_filter in filter_obj.filters: + clause = self._translate_single_filter(sub_filter) + if clause: + clauses.append(f"({clause})") + if not clauses: + return "" + operator = " and " if filter_obj.type == "and" else " or " + return operator.join(clauses) + def retrieve_online_documents_v2( self, config: RepoConfig, @@ -522,6 +602,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List[ Tuple[ Optional[datetime], @@ -581,6 +662,17 @@ def retrieve_online_documents_v2( self.client.load_collection(collection_name) + metadata_filter_expr = self._translate_filters(filters) + + def _combine_exprs(*parts: Optional[str]) -> Optional[str]: + """Combine non-empty Milvus boolean expressions with AND.""" + active = [p for p in parts if p] + if not active: + return None + if len(active) == 1: + return active[0] + return " and ".join(f"({p})" for p in active) + if ( embedding is not None and query_string is not None @@ -598,22 +690,19 @@ def retrieve_online_documents_v2( "No string fields found in the feature view for text search in hybrid mode" ) - # Create a filter expression for text search filter_expressions = [] for field in string_field_list: if field in output_fields: filter_expressions.append(f"{field} LIKE '%{query_string}%'") - # Combine filter expressions with OR - filter_expr = " OR ".join(filter_expressions) if filter_expressions else "" + text_filter = " OR ".join(filter_expressions) if filter_expressions else "" + combined_filter = _combine_exprs(text_filter, metadata_filter_expr) - # Vector search with text filter search_params = { "metric_type": distance_metric or config.online_store.metric_type, "params": {"nprobe": 10}, } - # For hybrid search, use filter parameter instead of expr results = self.client.search( collection_name=collection_name, data=[embedding], @@ -621,7 +710,7 @@ def retrieve_online_documents_v2( search_params=search_params, limit=top_k, output_fields=output_fields, - filter=filter_expr if filter_expr else None, + filter=combined_filter, ) elif embedding is not None and config.online_store.vector_enabled: @@ -638,6 +727,7 @@ def retrieve_online_documents_v2( search_params=search_params, limit=top_k, output_fields=output_fields, + filter=metadata_filter_expr, ) elif query_string is not None: @@ -658,16 +748,18 @@ def retrieve_online_documents_v2( if field in output_fields: filter_expressions.append(f"{field} LIKE '%{query_string}%'") - filter_expr = " OR ".join(filter_expressions) + text_filter = " OR ".join(filter_expressions) - if not filter_expr: + if not text_filter: raise ValueError( "No text fields found in requested features for search" ) + combined_filter = _combine_exprs(text_filter, metadata_filter_expr) + query_results = self.client.query( collection_name=collection_name, - filter=filter_expr, + filter=combined_filter or text_filter, output_fields=output_fields, limit=top_k, ) diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index b77185229d5..0e2c7302fec 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -20,6 +20,7 @@ from feast.batch_feature_view import BatchFeatureView from feast.feature_service import FeatureService from feast.feature_view import FeatureView +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.infra_object import InfraObject from feast.infra.registry.base_registry import BaseRegistry from feast.infra.supported_async_methods import SupportedAsyncMethods @@ -436,6 +437,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List[ Tuple[ Optional[datetime], @@ -454,6 +456,8 @@ def retrieve_online_documents_v2( embedding: The embeddings to use for retrieval (optional) top_k: The number of documents to retrieve. query_string: The query string to search for using keyword search (bm25) (optional) + filters: Optional metadata filters (ComparisonFilter or CompoundFilter) + to narrow results before ranking. Returns: object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index e252280285e..0daa753c94c 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -21,6 +21,7 @@ from psycopg_pool import AsyncConnectionPool, ConnectionPool from feast import Entity, FeatureView, ValueType +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key from feast.infra.online_stores.helpers import _to_naive_utc from feast.infra.online_stores.online_store import OnlineStore @@ -44,6 +45,15 @@ "inner_product": "<#>", } +_PG_COMPARISON_OPS: Dict[str, str] = { + "eq": "=", + "ne": "!=", + "gt": ">", + "gte": ">=", + "lt": "<", + "lte": "<=", +} + class PostgreSQLOnlineStoreConfig(PostgreSQLConfig, VectorStoreConfig): type: Literal["postgres"] = "postgres" @@ -119,10 +129,21 @@ def online_write_batch( for feature_name, val in values.items(): vector_val = None value_text = None + value_num = None - # Check if the feature type is STRING - if val.WhichOneof("val") == "string_val": + val_type = val.WhichOneof("val") + if val_type == "string_val": value_text = val.string_val + elif val_type == "int64_val": + value_num = float(val.int64_val) + elif val_type == "int32_val": + value_num = float(val.int32_val) + elif val_type == "double_val": + value_num = val.double_val + elif val_type == "float_val": + value_num = float(val.float_val) + elif val_type == "bool_val": + value_num = 1.0 if val.bool_val else 0.0 if config.online_store.vector_enabled: vector_val = get_list_val_str(val) @@ -132,6 +153,7 @@ def online_write_batch( feature_name, val.SerializeToString(), value_text, + value_num, vector_val, timestamp, created_ts, @@ -142,12 +164,13 @@ def online_write_batch( sql_query = sql.SQL( """ INSERT INTO {} - (entity_key, feature_name, value, value_text, vector_value, event_ts, created_ts) - VALUES (%s, %s, %s, %s, %s, %s, %s) + (entity_key, feature_name, value, value_text, value_num, vector_value, event_ts, created_ts) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (entity_key, feature_name) DO UPDATE SET value = EXCLUDED.value, value_text = EXCLUDED.value_text, + value_num = EXCLUDED.value_num, vector_value = EXCLUDED.vector_value, event_ts = EXCLUDED.event_ts, created_ts = EXCLUDED.created_ts; @@ -328,7 +351,8 @@ def update( entity_key BYTEA, feature_name TEXT, value BYTEA, - value_text TEXT NULL, -- Added for FTS + value_text TEXT NULL, + value_num DOUBLE PRECISION NULL, vector_value {} NULL, event_ts TIMESTAMPTZ, created_ts TIMESTAMPTZ, @@ -344,6 +368,12 @@ def update( ) ) + cur.execute( + sql.SQL( + """ALTER TABLE {} ADD COLUMN IF NOT EXISTS value_num DOUBLE PRECISION NULL;""" + ).format(sql.Identifier(table_name)) + ) + if has_string_features: cur.execute( sql.SQL( @@ -479,6 +509,130 @@ def retrieve_online_documents( return result + def _translate_filters( + self, + filters: Optional[Union[ComparisonFilter, CompoundFilter]], + table_name: str, + alias: Optional[str] = None, + ) -> Tuple[sql.Composable, List[Any]]: + """Translate filter objects into a SQL WHERE clause fragment and params. + + Returns a ``(clause, params)`` pair. When *filters* is ``None`` the + clause is empty and params is ``[]``, so callers can always append it + unconditionally. + + When *alias* is given (e.g. ``"t1"``), the generated clause references + ``t1.entity_key`` instead of the bare ``entity_key`` column, which is + necessary when the outer query joins multiple relations. + """ + if filters is None: + return sql.SQL(""), [] + return self._translate_single_filter(filters, table_name, alias) + + def _translate_single_filter( + self, + filter_obj: Union[ComparisonFilter, CompoundFilter], + table_name: str, + alias: Optional[str] = None, + ) -> Tuple[sql.Composable, List[Any]]: + if isinstance(filter_obj, ComparisonFilter): + return self._translate_comparison_filter(filter_obj, table_name, alias) + elif isinstance(filter_obj, CompoundFilter): + return self._translate_compound_filter(filter_obj, table_name, alias) + raise ValueError(f"Unknown filter type: {type(filter_obj)}") + + @staticmethod + def _filter_col_and_val( + value: Any, + ) -> Tuple[str, Any]: + """Return the appropriate column name and DB-ready value for a filter value.""" + if isinstance(value, bool): + return "value_num", 1.0 if value else 0.0 + if isinstance(value, (int, float)): + return "value_num", float(value) + return "value_text", str(value) + + def _translate_comparison_filter( + self, + filter_obj: ComparisonFilter, + table_name: str, + alias: Optional[str] = None, + ) -> Tuple[sql.Composable, List[Any]]: + key, value, op_type = filter_obj.key, filter_obj.value, filter_obj.type + ek_col = f"{alias}.entity_key" if alias else "entity_key" + + if op_type in _PG_COMPARISON_OPS: + col, db_value = self._filter_col_and_val(value) + clause = sql.SQL( + "{ek_col} IN (SELECT entity_key FROM {tbl} WHERE feature_name = %s AND {col} {op} %s)" + ).format( + ek_col=sql.SQL(ek_col), + tbl=sql.Identifier(table_name), + col=sql.Identifier(col), + op=sql.SQL(_PG_COMPARISON_OPS[op_type]), + ) + return clause, [key, db_value] + + if op_type == "in": + if not isinstance(value, list): + raise ValueError( + f"'in' filter requires a list value, got {type(value)}" + ) + placeholders = sql.SQL(", ").join([sql.Placeholder()] * len(value)) + col, _ = ( + self._filter_col_and_val(value[0]) if value else ("value_text", None) + ) + db_values = [self._filter_col_and_val(v)[1] for v in value] + clause = sql.SQL( + "{ek_col} IN (SELECT entity_key FROM {tbl} WHERE feature_name = %s AND {col} IN ({phs}))" + ).format( + ek_col=sql.SQL(ek_col), + tbl=sql.Identifier(table_name), + col=sql.Identifier(col), + phs=placeholders, + ) + return clause, [key] + db_values + + if op_type == "nin": + if not isinstance(value, list): + raise ValueError( + f"'nin' filter requires a list value, got {type(value)}" + ) + placeholders = sql.SQL(", ").join([sql.Placeholder()] * len(value)) + col, _ = ( + self._filter_col_and_val(value[0]) if value else ("value_text", None) + ) + db_values = [self._filter_col_and_val(v)[1] for v in value] + clause = sql.SQL( + "{ek_col} IN (SELECT entity_key FROM {tbl} WHERE feature_name = %s AND {col} NOT IN ({phs}))" + ).format( + ek_col=sql.SQL(ek_col), + tbl=sql.Identifier(table_name), + col=sql.Identifier(col), + phs=placeholders, + ) + return clause, [key] + db_values + + raise ValueError(f"Unknown comparison operator: {op_type}") + + def _translate_compound_filter( + self, + filter_obj: CompoundFilter, + table_name: str, + alias: Optional[str] = None, + ) -> Tuple[sql.Composable, List[Any]]: + parts: List[sql.Composable] = [] + all_params: List[Any] = [] + for sub in filter_obj.filters: + sub_clause, sub_params = self._translate_single_filter( + sub, table_name, alias + ) + parts.append(sub_clause) + all_params.extend(sub_params) + joiner = sql.SQL(" AND " if filter_obj.type == "and" else " OR ") + combined = sql.SQL("(") + joiner.join(parts) + sql.SQL(")") + return combined, all_params + def retrieve_online_documents_v2( self, config: RepoConfig, @@ -488,6 +642,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List[ Tuple[ Optional[datetime], @@ -538,9 +693,24 @@ def retrieve_online_documents_v2( query = None params: Any = None + filter_clause, filter_params = self._translate_filters( + filters, table_name, alias="t1" + ) + has_filters = bool(filter_params) or ( + filters is not None and not filter_params + ) + if embedding is not None and query_string is not None and string_fields: # Case 1: Hybrid Search (vector + text) tsquery_str = " & ".join(query_string.split()) + + outer_where_parts: list[sql.Composable] = [ + sql.SQL("t1.feature_name = ANY(%s)") + ] + if has_filters: + outer_where_parts.append(filter_clause) + outer_where = sql.SQL(" AND ").join(outer_where_parts) + query = sql.SQL( """ WITH vector_candidates AS ( @@ -600,15 +770,16 @@ def retrieve_online_documents_v2( t1.created_ts FROM {table_name} t1 INNER JOIN scored s ON t1.entity_key = s.entity_key - WHERE t1.feature_name = ANY(%s) + WHERE {outer_where} ORDER BY s.text_rank DESC, s.distance """ ).format( distance_metric_sql=sql.SQL(distance_metric_sql), table_name=sql.Identifier(table_name), + outer_where=outer_where, top_k=sql.Literal(top_k), ) - params = ( + base_params: list[Any] = [ embedding, tsquery_str, string_fields, @@ -617,9 +788,18 @@ def retrieve_online_documents_v2( tsquery_str, string_fields, requested_features, - ) + ] + if has_filters: + base_params.extend(filter_params) + params = tuple(base_params) + elif embedding is not None: # Case 2: Vector Search Only + outer_where_parts = [sql.SQL("t1.feature_name = ANY(%s)")] + if has_filters: + outer_where_parts.append(filter_clause) + outer_where = sql.SQL(" AND ").join(outer_where_parts) + query = sql.SQL( """ WITH vector_matches AS ( @@ -642,25 +822,36 @@ def retrieve_online_documents_v2( t1.created_ts FROM {table_name} t1 INNER JOIN vector_matches t2 ON t1.entity_key = t2.entity_key - WHERE t1.feature_name = ANY(%s) + WHERE {outer_where} ORDER BY t2.distance """ ).format( distance_metric_sql=sql.SQL(distance_metric_sql), table_name=sql.Identifier(table_name), + outer_where=outer_where, top_k=sql.Literal(top_k), ) - params = (embedding, requested_features) + base_params = [embedding, requested_features] + if has_filters: + base_params.extend(filter_params) + params = tuple(base_params) elif query_string is not None and string_fields: # Case 3: Text Search Only tsquery_str = " & ".join(query_string.split()) + + outer_where_parts = [sql.SQL("t1.feature_name = ANY(%s)")] + if has_filters: + outer_where_parts.append(filter_clause) + outer_where = sql.SQL(" AND ").join(outer_where_parts) + query = sql.SQL( """ WITH text_matches AS ( SELECT DISTINCT entity_key, ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank FROM {table_name} - WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s) + WHERE feature_name = ANY(%s) + AND to_tsvector('english', value_text) @@ to_tsquery('english', %s) ORDER BY text_rank DESC LIMIT {top_k} ) @@ -675,14 +866,23 @@ def retrieve_online_documents_v2( t1.created_ts FROM {table_name} t1 INNER JOIN text_matches t2 ON t1.entity_key = t2.entity_key - WHERE t1.feature_name = ANY(%s) + WHERE {outer_where} ORDER BY t2.text_rank DESC """ ).format( table_name=sql.Identifier(table_name), + outer_where=outer_where, top_k=sql.Literal(top_k), ) - params = (tsquery_str, string_fields, tsquery_str, requested_features) + base_params = [ + tsquery_str, + string_fields, + tsquery_str, + requested_features, + ] + if has_filters: + base_params.extend(filter_params) + params = tuple(base_params) else: raise ValueError( diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index 5b5b04c362d..ad60476abb2 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -15,12 +15,13 @@ import logging from collections import defaultdict from datetime import datetime -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union import requests from pydantic import StrictStr from feast import Entity, FeatureView, RepoConfig +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.online_stores.helpers import _to_naive_utc from feast.infra.online_stores.online_store import OnlineStore from feast.permissions.client.http_auth_requests_wrapper import HttpSessionManager @@ -301,6 +302,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List[ Tuple[ Optional[datetime], @@ -318,6 +320,7 @@ def retrieve_online_documents_v2( top_k, distance_metric, query_string, + filters=filters, api_version=2, ) response = get_remote_online_documents(config=config, req_body=req_body) @@ -326,20 +329,29 @@ def retrieve_online_documents_v2( response_json = json.loads(response.text) event_ts: Optional[datetime] = self._get_event_ts(response_json) + metadata = response_json.get("metadata") or {} + feature_names = metadata.get("feature_names") or [] + results = response_json.get("results") or [] + + if not feature_names or not results: + logger.debug( + "Empty metadata or results in retrieve_online_documents_v2 response." + ) + return [] + # Create feature name to index mapping for efficient lookup feature_name_to_index = { - name: idx - for idx, name in enumerate(response_json["metadata"]["feature_names"]) + name: idx for idx, name in enumerate(feature_names) } # Process each result row - num_results = ( - len(response_json["results"][0]["values"]) - if response_json["results"] - else 0 - ) + first_result_values = results[0].get("values") or [] + num_results = len(first_result_values) result_tuples = [] - + if requested_features is None: + requested_features = [] + if "distance" not in requested_features: + requested_features.append("distance") for row_idx in range(num_results): # Build feature values dictionary for requested features feature_values_dict = {} @@ -501,6 +513,7 @@ def _construct_online_documents_v2_api_json_request( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, api_version: Optional[int] = 2, ) -> dict: api_requested_features = [] @@ -508,7 +521,7 @@ def _construct_online_documents_v2_api_json_request( for requested_feature in requested_features: api_requested_features.append(f"{table.name}:{requested_feature}") - return { + body: Dict[str, Any] = { "features": api_requested_features, "query": embedding, "top_k": top_k, @@ -516,12 +529,19 @@ def _construct_online_documents_v2_api_json_request( "query_string": query_string, "api_version": api_version, } - - def _get_event_ts(self, response_json) -> datetime: - event_ts = "" - if len(response_json["results"]) > 1: - event_ts = response_json["results"][1]["event_timestamps"][0] - return datetime.fromisoformat(event_ts.replace("Z", "+00:00")) + if filters is not None: + body["filters"] = filters.model_dump() + return body + + def _get_event_ts(self, response_json) -> Optional[datetime]: + results = response_json.get("results") or [] + if len(results) > 1: + event_timestamps = results[1].get("event_timestamps") or [] + if event_timestamps: + return datetime.fromisoformat( + event_timestamps[0].replace("Z", "+00:00") + ) + return None def _construct_entity_key_from_response( self, diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 1be4141c650..eb18391b9d3 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -36,6 +36,7 @@ from feast import Entity from feast.feature_view import FeatureView from feast.field import Field +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject from feast.infra.key_encoding_utils import ( deserialize_entity_key, @@ -100,6 +101,16 @@ def convert_timestamp(val: bytes): sqlite3.register_converter("timestamp", convert_timestamp) +_SQLITE_COMPARISON_OPS: Dict[str, str] = { + "eq": "=", + "ne": "!=", + "gt": ">", + "gte": ">=", + "lt": "<", + "lte": "<=", +} + + class SqliteOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): """Online store config for local (SQLite-based) store""" @@ -176,6 +187,21 @@ def online_write_batch( table_name = _table_id(project, table) for feature_name, val in values.items(): + value_text = None + value_num = None + val_type = val.WhichOneof("val") + if val_type == "string_val": + value_text = val.string_val + elif val_type in ( + "int64_val", + "int32_val", + "double_val", + "float_val", + ): + value_num = float(getattr(val, val_type)) + elif val_type == "bool_val": + value_num = 1.0 if val.bool_val else 0.0 + if config.online_store.vector_enabled: if ( feature_type_dict.get(feature_name, None) @@ -193,39 +219,47 @@ def online_write_batch( val_bin = feast_value_type_to_python_type(val) conn.execute( f""" - INSERT INTO {table_name} (entity_key, feature_name, value, vector_value, event_ts, created_ts) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO {table_name} (entity_key, feature_name, value, value_text, value_num, vector_value, event_ts, created_ts) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(entity_key, feature_name) DO UPDATE SET value = excluded.value, + value_text = excluded.value_text, + value_num = excluded.value_num, vector_value = excluded.vector_value, event_ts = excluded.event_ts, created_ts = excluded.created_ts; """, ( - entity_key_bin, # entity_key - feature_name, # feature_name - val.SerializeToString(), # value - val_bin, # vector_value - timestamp, # event_ts - created_ts, # created_ts + entity_key_bin, + feature_name, + val.SerializeToString(), + value_text, + value_num, + val_bin, + timestamp, + created_ts, ), ) else: conn.execute( f""" - INSERT INTO {table_name} (entity_key, feature_name, value, event_ts, created_ts) - VALUES (?, ?, ?, ?, ?) + INSERT INTO {table_name} (entity_key, feature_name, value, value_text, value_num, event_ts, created_ts) + VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT(entity_key, feature_name) DO UPDATE SET value = excluded.value, + value_text = excluded.value_text, + value_num = excluded.value_num, event_ts = excluded.event_ts, created_ts = excluded.created_ts; """, ( - entity_key_bin, # entity_key - feature_name, # feature_name - val.SerializeToString(), # value - timestamp, # event_ts - created_ts, # created_ts + entity_key_bin, + feature_name, + val.SerializeToString(), + value_text, + value_num, + timestamp, + created_ts, ), ) @@ -296,7 +330,7 @@ def update( for table in tables_to_keep: conn.execute( - f"CREATE TABLE IF NOT EXISTS {_table_id(project, table)} (entity_key BLOB, feature_name TEXT, value BLOB, vector_value BLOB, event_ts timestamp, created_ts timestamp, PRIMARY KEY(entity_key, feature_name))" + f"CREATE TABLE IF NOT EXISTS {_table_id(project, table)} (entity_key BLOB, feature_name TEXT, value BLOB, value_text TEXT, value_num REAL, vector_value BLOB, event_ts timestamp, created_ts timestamp, PRIMARY KEY(entity_key, feature_name))" ) conn.execute( f"CREATE INDEX IF NOT EXISTS {_table_id(project, table)}_ek ON {_table_id(project, table)} (entity_key);" @@ -455,6 +489,111 @@ def retrieve_online_documents( return result + def _translate_filters( + self, + filters: Optional[Union[ComparisonFilter, CompoundFilter]], + table_name: str, + alias: Optional[str] = None, + ) -> Tuple[str, List[Any]]: + """Translate filter objects into a SQL WHERE clause fragment and params. + + Returns a ``(clause, params)`` pair. When *filters* is ``None`` the + clause is empty and params is ``[]``. + """ + if filters is None: + return "", [] + return self._translate_single_filter(filters, table_name, alias) + + def _translate_single_filter( + self, + filter_obj: Union[ComparisonFilter, CompoundFilter], + table_name: str, + alias: Optional[str] = None, + ) -> Tuple[str, List[Any]]: + if isinstance(filter_obj, ComparisonFilter): + return self._translate_comparison_filter(filter_obj, table_name, alias) + elif isinstance(filter_obj, CompoundFilter): + return self._translate_compound_filter(filter_obj, table_name, alias) + raise ValueError(f"Unknown filter type: {type(filter_obj)}") + + @staticmethod + def _filter_col_and_val(value: Any) -> Tuple[str, Any]: + """Return the appropriate column name and DB-ready value for a filter value.""" + if isinstance(value, bool): + return "value_num", 1.0 if value else 0.0 + if isinstance(value, (int, float)): + return "value_num", float(value) + return "value_text", str(value) + + def _translate_comparison_filter( + self, + filter_obj: ComparisonFilter, + table_name: str, + alias: Optional[str] = None, + ) -> Tuple[str, List[Any]]: + key, value, op_type = filter_obj.key, filter_obj.value, filter_obj.type + ek_col = f"{alias}.entity_key" if alias else "entity_key" + + if op_type in _SQLITE_COMPARISON_OPS: + col, db_value = self._filter_col_and_val(value) + clause = ( + f"{ek_col} IN (SELECT entity_key FROM {table_name} " + f"WHERE feature_name = ? AND {col} {_SQLITE_COMPARISON_OPS[op_type]} ?)" + ) + return clause, [key, db_value] + + if op_type == "in": + if not isinstance(value, list): + raise ValueError( + f"'in' filter requires a list value, got {type(value)}" + ) + col, _ = ( + self._filter_col_and_val(value[0]) if value else ("value_text", None) + ) + db_values = [self._filter_col_and_val(v)[1] for v in value] + placeholders = ", ".join(["?"] * len(value)) + clause = ( + f"{ek_col} IN (SELECT entity_key FROM {table_name} " + f"WHERE feature_name = ? AND {col} IN ({placeholders}))" + ) + return clause, [key] + db_values + + if op_type == "nin": + if not isinstance(value, list): + raise ValueError( + f"'nin' filter requires a list value, got {type(value)}" + ) + col, _ = ( + self._filter_col_and_val(value[0]) if value else ("value_text", None) + ) + db_values = [self._filter_col_and_val(v)[1] for v in value] + placeholders = ", ".join(["?"] * len(value)) + clause = ( + f"{ek_col} IN (SELECT entity_key FROM {table_name} " + f"WHERE feature_name = ? AND {col} NOT IN ({placeholders}))" + ) + return clause, [key] + db_values + + raise ValueError(f"Unknown comparison operator: {op_type}") + + def _translate_compound_filter( + self, + filter_obj: CompoundFilter, + table_name: str, + alias: Optional[str] = None, + ) -> Tuple[str, List[Any]]: + parts: List[str] = [] + all_params: List[Any] = [] + for sub in filter_obj.filters: + sub_clause, sub_params = self._translate_single_filter( + sub, table_name, alias + ) + parts.append(sub_clause) + all_params.extend(sub_params) + joiner = " AND " if filter_obj.type == "and" else " OR " + combined = "(" + joiner.join(parts) + ")" + return combined, all_params + def retrieve_online_documents_v2( self, config: RepoConfig, @@ -464,6 +603,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List[ Tuple[ Optional[datetime], @@ -502,6 +642,10 @@ def retrieve_online_documents_v2( table_name = _table_id(config.project, table) vector_field = _get_vector_field(table) + filter_clause, filter_params = self._translate_filters( + filters, table_name, alias="fv2" + ) + if online_store.vector_enabled: query_embedding_bin = serialize_f32(query, vector_field_length) # type: ignore cur.execute( @@ -523,7 +667,6 @@ def retrieve_online_documents_v2( f.name for f in table.features if f.dtype == PrimitiveFeastType.STRING ] string_fields = ", ".join(string_field_list) - # TODO: swap this for a value configurable in each Field() BM25_DEFAULT_WEIGHTS = ", ".join( [ str(1.0) @@ -542,6 +685,9 @@ def retrieve_online_documents_v2( table_name, string_field_list ) cur.execute(insert_query) + filter_clause, filter_params = self._translate_filters( + filters, table_name, alias="fv" + ) else: raise ValueError( @@ -549,6 +695,11 @@ def retrieve_online_documents_v2( ) if online_store.vector_enabled: + where_parts = [f'fv2.feature_name != "{vector_field}"'] + if filter_clause: + where_parts.append(filter_clause) + where_sql = " AND ".join(where_parts) + cur.execute( f""" select @@ -573,24 +724,28 @@ def retrieve_online_documents_v2( on f.rowid = fv.rowid left join {table_name} fv2 on fv.entity_key = fv2.entity_key - where fv2.feature_name != "{vector_field}" + where {where_sql} """, - ( - query_embedding_bin, - top_k, - ), + [query_embedding_bin, top_k] + filter_params, ) elif online_store.text_search_enabled: + where_parts_text = [] + if filter_clause: + where_parts_text.append(filter_clause) + where_sql_text = ( + "where " + " AND ".join(where_parts_text) if where_parts_text else "" + ) + cur.execute( f""" - select - fv.entity_key, - fv.feature_name, - fv.value, - fv.vector_value, - f.distance, - fv.event_ts, - fv.created_ts + select + fv.entity_key, + fv.feature_name, + fv.value, + fv.vector_value, + f.distance, + fv.event_ts, + fv.created_ts from {table_name} fv inner join ( select @@ -602,8 +757,9 @@ def retrieve_online_documents_v2( where search_table match ? order by distance limit ? ) f on f.entity_key = fv.entity_key + {where_sql_text} """, - (query_string, top_k), + [query_string, top_k] + filter_params, ) else: @@ -763,6 +919,8 @@ def update(self): entity_key BLOB, feature_name TEXT, value BLOB, + value_text TEXT, + value_num REAL, vector_value BLOB, event_ts timestamp, created_ts timestamp, @@ -813,7 +971,7 @@ def _generate_bm25_search_insert_query( from_query = f"\nFROM (select rowid, * from {table_name} where feature_name = '{string_field_list[0]}') fv0" for i, string_field in enumerate(string_field_list): - query += f"\n\t,fv{i}.value as {string_field}" + query += f"\n\t,fv{i}.value_text as {string_field}" if i > 0: from_query += ( f"\nLEFT JOIN (select rowid, * from {table_name} where feature_name = '{string_field}') fv{i}" diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 6830929e776..868c27e9486 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -24,6 +24,7 @@ from feast.feature_logging import FeatureServiceLoggingSource from feast.feature_service import FeatureService from feast.feature_view import FeatureView +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.common.materialization_job import ( MaterializationJobStatus, MaterializationTask, @@ -322,6 +323,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List: result = [] if self.online_store: @@ -333,6 +335,7 @@ def retrieve_online_documents_v2( top_k, distance_metric, query_string, + filters, ) return result diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index c2879c1e2db..fd98bbc5df3 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -23,6 +23,7 @@ from feast.data_source import DataSource from feast.entity import Entity from feast.feature_view import FeatureView +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.importer import import_class from feast.infra.infra_object import Infra from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob @@ -468,6 +469,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List[ Tuple[ Optional[datetime], @@ -486,6 +488,8 @@ def retrieve_online_documents_v2( query: The query embedding to search for (optional). top_k: The number of documents to return. query_string: The query string to search for using keyword search (bm25) (optional) + filters: Optional metadata filters (ComparisonFilter or CompoundFilter) + to narrow results before ranking. Returns: A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 02a0f13c733..fc7bb7a1034 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -250,6 +250,40 @@ def to_openlineage_config(self): ) +class EmbeddingModelConfig(FeastConfigBaseModel): + """Configuration for the LiteLLM-based query embedding model. + + Required when using ``retrieve_online_documents_openai`` or the + ``/v1/vector_stores/{vector_store_id}/search`` endpoint. + + Example in ``feature_store.yaml``:: + + embedding_model: + model: text-embedding-3-small + api_key: sk-... + api_base: https://api.openai.com/v1 + """ + + model: str + """LiteLLM model identifier (e.g. 'text-embedding-3-small', + 'cohere/embed-english-v3.0', 'azure/my-deployment').""" + + api_key: Optional[str] = None + """API key for the embedding provider. If not set, LiteLLM falls back + to the relevant environment variable (e.g. OPENAI_API_KEY).""" + + api_base: Optional[str] = None + """Custom API base URL for the embedding provider + (e.g. 'https://my-azure-deployment.openai.azure.com/').""" + + api_version: Optional[str] = None + """API version for the embedding provider (used by Azure OpenAI).""" + + dimensions: Optional[StrictInt] = None + """Output embedding dimensionality. Supported by text-embedding-3 and + newer models. If not set, the model's default dimension is used.""" + + class RepoConfig(FeastBaseModel): """Repo config. Typically loaded from `feature_store.yaml`""" @@ -289,6 +323,13 @@ class RepoConfig(FeastBaseModel): feature_server: Optional[Any] = None """ FeatureServerConfig: Feature server configuration (optional depending on provider) """ + embedding_model: Optional[EmbeddingModelConfig] = Field( + None, alias="embedding_model" + ) + """ EmbeddingModelConfig: LiteLLM embedding model configuration. + Required when using retrieve_online_documents_openai or the + OpenAI-compatible vector store search endpoint. """ + flags: Any = None """ Flags (deprecated field): Feature flags for experimental features """ diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index a04ff3cc456..fb0659cf834 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -19,6 +19,7 @@ from feast import Entity, FeatureService, FeatureView, RepoConfig from feast.data_source import DataSource +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.offline_stores.offline_store import RetrievalJob from feast.infra.provider import Provider from feast.infra.registry.base_registry import BaseRegistry @@ -174,6 +175,7 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + filters: Optional[Union[ComparisonFilter, CompoundFilter]] = None, ) -> List[ Tuple[ Optional[datetime], diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 0c27585139e..ead6ecc62c4 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -14,13 +14,15 @@ from feast import FeatureStore from feast.entity import Entity -from feast.errors import FeatureNameCollisionError +from feast.errors import FeatureNameCollisionError, FeatureViewNotFoundException from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.field import Field +from feast.filter_models import ComparisonFilter, CompoundFilter from feast.infra.offline_stores.file_source import FileSource from feast.infra.utils.postgres.postgres_config import ConnectionType from feast.online_response import TIMESTAMP_POSTFIX +from feast.repo_config import EmbeddingModelConfig from feast.types import ( Array, Float32, @@ -1265,3 +1267,313 @@ def test_retrieve_online_documents_v2(environment, fake_document_data): assert len(no_match_results["text_field"]) == 0 assert "text_rank" in no_match_results assert len(no_match_results["text_rank"]) == 0 + + +def _setup_documents_with_categories(fs): + """Shared helper that creates and populates a feature view with embeddings, + text, and category fields. Returns (feature_view, entity, dataframe).""" + n_rows = 20 + vector_dim = 2 + random.seed(42) + + df = pd.DataFrame( + { + "item_id": list(range(n_rows)), + "embedding": [list(np.random.random(vector_dim)) for _ in range(n_rows)], + "text_field": [ + f"Document text content {i} with searchable keywords" + for i in range(n_rows) + ], + "category": [f"Category-{i % 5}" for i in range(n_rows)], + "event_timestamp": [datetime.now() for _ in range(n_rows)], + } + ) + + data_source = FileSource( + path="dummy_path.parquet", timestamp_field="event_timestamp" + ) + + item_entity = Entity( + name="item_id", + join_keys=["item_id"], + value_type=ValueType.INT64, + ) + + item_embeddings_fv = FeatureView( + name="item_embeddings", + entities=[item_entity], + schema=[ + Field(name="embedding", dtype=Array(Float32), vector_index=True), + Field(name="text_field", dtype=String), + Field(name="category", dtype=String), + Field(name="item_id", dtype=Int64), + ], + source=data_source, + ) + + fs.apply([item_embeddings_fv, item_entity]) + fs.write_to_online_store("item_embeddings", df) + return item_embeddings_fv, item_entity, df + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["pgvector", "elasticsearch"]) +def test_retrieve_online_documents_v2_with_filters(environment, fake_document_data): + """Test that metadata filters narrow down vector/text search results.""" + fs = environment.feature_store + fs.config.online_store.vector_enabled = True + + _, _, df = _setup_documents_with_categories(fs) + vector_dim = 2 + query_embedding = list(np.random.random(vector_dim)) + + # --- eq filter: only Category-0 rows --- + eq_filter = ComparisonFilter(type="eq", key="category", value="Category-0") + results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=10, + distance_metric="L2", + filters=eq_filter, + ).to_dict() + + assert len(results["category"]) > 0 + assert len(results["category"]) <= 4 # 20 rows / 5 categories + assert all(c == "Category-0" for c in results["category"]) + + # --- ne filter: exclude Category-0 --- + ne_filter = ComparisonFilter(type="ne", key="category", value="Category-0") + results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=10, + distance_metric="L2", + filters=ne_filter, + ).to_dict() + + assert len(results["category"]) > 0 + assert all(c != "Category-0" for c in results["category"]) + + # --- in filter: Category-0 or Category-1 --- + in_filter = ComparisonFilter( + type="in", key="category", value=["Category-0", "Category-1"] + ) + results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=10, + distance_metric="L2", + filters=in_filter, + ).to_dict() + + assert len(results["category"]) > 0 + assert all(c in ("Category-0", "Category-1") for c in results["category"]) + + # --- compound AND filter: category == Category-0 AND item_id >= 5 --- + and_filter = CompoundFilter( + type="and", + filters=[ + ComparisonFilter(type="eq", key="category", value="Category-0"), + ComparisonFilter(type="gte", key="item_id", value=5), + ], + ) + results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=10, + distance_metric="L2", + filters=and_filter, + ).to_dict() + + assert len(results["category"]) > 0 + assert all(c == "Category-0" for c in results["category"]) + assert all(i >= 5 for i in results["item_id"]) + + # --- text search + filter --- + text_filter = ComparisonFilter(type="eq", key="category", value="Category-2") + text_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query_string="searchable keywords", + top_k=10, + filters=text_filter, + ).to_dict() + + assert len(text_results["category"]) > 0 + assert all(c == "Category-2" for c in text_results["category"]) + + # --- filter with no matches --- + empty_filter = ComparisonFilter( + type="eq", key="category", value="NonexistentCategory" + ) + empty_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=10, + distance_metric="L2", + filters=empty_filter, + ).to_dict() + + assert len(empty_results.get("category", [])) == 0 + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["pgvector", "elasticsearch"]) +def test_retrieve_online_documents_openai(environment, fake_document_data): + """Test OpenAI-compatible vector store search returns the correct response shape.""" + fs = environment.feature_store + fs.config.online_store.vector_enabled = True + + fv, _, df = _setup_documents_with_categories(fs) + vector_dim = 2 + + fs.config.embedding_model = EmbeddingModelConfig(model="text-embedding-3-small") + + fake_embedding = list(np.random.random(vector_dim)) + mock_embed_response = unittest.mock.MagicMock() + mock_embed_response.data = [{"embedding": fake_embedding}] + + with unittest.mock.patch( + "feast.feature_store.litellm_embedding", create=True + ) as mock_litellm: + mock_litellm.return_value = mock_embed_response + + with unittest.mock.patch( + "feast.feature_store.FeatureStore.retrieve_online_documents_openai", + wraps=fs.retrieve_online_documents_openai, + ): + # Patch the litellm import inside the method + with unittest.mock.patch.dict( + "sys.modules", + {"litellm": unittest.mock.MagicMock(embedding=mock_litellm)}, + ): + result = fs.retrieve_online_documents_openai( + vector_store_id="item_embeddings", + query="test query", + max_num_results=5, + ) + + # Validate top-level OpenAI response shape + assert result["object"] == "vector_store.search_results.page" + assert isinstance(result["search_query"], list) + assert result["search_query"] == ["test query"] + assert result["has_more"] is False + assert result["next_page"] is None + + assert isinstance(result["data"], list) + assert len(result["data"]) > 0 + assert len(result["data"]) <= 5 + + for item_result in result["data"]: + assert "file_id" in item_result + assert "filename" in item_result + assert item_result["filename"] == "item_embeddings" + assert "score" in item_result + assert isinstance(item_result["score"], float) + assert "attributes" in item_result + assert isinstance(item_result["attributes"], dict) + assert "content" in item_result + assert isinstance(item_result["content"], list) + for part in item_result["content"]: + assert "type" in part + assert part["type"] == "text" + assert "text" in part + + # --- Test with features_to_retrieve --- + with unittest.mock.patch.dict( + "sys.modules", + { + "litellm": unittest.mock.MagicMock( + embedding=unittest.mock.MagicMock(return_value=mock_embed_response) + ), + }, + ): + result_subset = fs.retrieve_online_documents_openai( + vector_store_id="item_embeddings", + query="test query", + max_num_results=5, + features_to_retrieve=["text_field", "category"], + ) + + assert len(result_subset["data"]) > 0 + for item_result in result_subset["data"]: + attr_keys = set(item_result["attributes"].keys()) + assert "embedding" not in attr_keys + + # --- Test with list query --- + with unittest.mock.patch.dict( + "sys.modules", + { + "litellm": unittest.mock.MagicMock( + embedding=unittest.mock.MagicMock(return_value=mock_embed_response) + ), + }, + ): + result_list = fs.retrieve_online_documents_openai( + vector_store_id="item_embeddings", + query=["term1", "term2"], + max_num_results=5, + ) + + assert result_list["search_query"] == ["term1", "term2"] + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["pgvector", "elasticsearch"]) +def test_retrieve_online_documents_openai_no_embedding_config( + environment, fake_document_data +): + """Test that retrieve_online_documents_openai raises ValueError + when embedding_model is not configured.""" + fs = environment.feature_store + fs.config.embedding_model = None + + with pytest.raises(ValueError, match="embedding_model is not configured"): + fs.retrieve_online_documents_openai( + vector_store_id="item_embeddings", + query="test query", + ) + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["pgvector", "elasticsearch"]) +def test_retrieve_online_documents_openai_not_found(environment, fake_document_data): + """Test that retrieve_online_documents_openai raises FeatureViewNotFoundException + for a non-existent feature view.""" + fs = environment.feature_store + fs.config.embedding_model = EmbeddingModelConfig(model="text-embedding-3-small") + + with pytest.raises(FeatureViewNotFoundException): + fs.retrieve_online_documents_openai( + vector_store_id="nonexistent_feature_view", + query="test query", + ) diff --git a/sdk/python/tests/unit/test_filter_models.py b/sdk/python/tests/unit/test_filter_models.py new file mode 100644 index 00000000000..d5db20a1330 --- /dev/null +++ b/sdk/python/tests/unit/test_filter_models.py @@ -0,0 +1,160 @@ +import pytest +from pydantic import ValidationError + +from feast.filter_models import ( + ComparisonFilter, + CompoundFilter, + convert_dict_to_filter, +) + + +class TestComparisonFilter: + @pytest.mark.parametrize("op", ["eq", "ne", "gt", "gte", "lt", "lte", "in", "nin"]) + def test_valid_operators(self, op): + f = ComparisonFilter(type=op, key="field", value="x") + assert f.type == op + assert f.key == "field" + + def test_rejects_invalid_operator(self): + with pytest.raises(ValidationError): + ComparisonFilter(type="like", key="field", value="x") + + def test_accepts_string_value(self): + f = ComparisonFilter(type="eq", key="city", value="LA") + assert f.value == "LA" + + def test_accepts_int_value(self): + f = ComparisonFilter(type="gt", key="age", value=25) + assert f.value == 25 + + def test_accepts_float_value(self): + f = ComparisonFilter(type="lte", key="score", value=0.95) + assert f.value == 0.95 + + def test_accepts_bool_value(self): + f = ComparisonFilter(type="eq", key="active", value=True) + assert f.value is True + + def test_accepts_list_value(self): + f = ComparisonFilter(type="in", key="status", value=["a", "b"]) + assert f.value == ["a", "b"] + + +class TestCompoundFilter: + def test_and_filter(self): + f = CompoundFilter( + type="and", + filters=[ + ComparisonFilter(type="eq", key="a", value=1), + ComparisonFilter(type="gt", key="b", value=2), + ], + ) + assert f.type == "and" + assert len(f.filters) == 2 + + def test_or_filter(self): + f = CompoundFilter( + type="or", + filters=[ + ComparisonFilter(type="eq", key="a", value=1), + ComparisonFilter(type="eq", key="b", value=2), + ], + ) + assert f.type == "or" + assert len(f.filters) == 2 + + def test_rejects_invalid_type(self): + with pytest.raises(ValidationError): + CompoundFilter( + type="xor", + filters=[ComparisonFilter(type="eq", key="a", value=1)], + ) + + def test_nested_compound(self): + f = CompoundFilter( + type="and", + filters=[ + ComparisonFilter(type="eq", key="x", value=1), + CompoundFilter( + type="or", + filters=[ + ComparisonFilter(type="gt", key="y", value=5), + ComparisonFilter(type="lt", key="z", value=10), + ], + ), + ], + ) + assert f.type == "and" + assert len(f.filters) == 2 + inner = f.filters[1] + assert isinstance(inner, CompoundFilter) + assert inner.type == "or" + assert len(inner.filters) == 2 + + +class TestConvertDictToFilter: + def test_comparison_dict(self): + result = convert_dict_to_filter({"type": "eq", "key": "city", "value": "LA"}) + assert isinstance(result, ComparisonFilter) + assert result.type == "eq" + assert result.key == "city" + assert result.value == "LA" + + def test_compound_dict(self): + result = convert_dict_to_filter( + { + "type": "and", + "filters": [ + {"type": "eq", "key": "a", "value": 1}, + {"type": "gt", "key": "b", "value": 2}, + ], + } + ) + assert isinstance(result, CompoundFilter) + assert result.type == "and" + assert len(result.filters) == 2 + assert all(isinstance(f, ComparisonFilter) for f in result.filters) + + def test_nested_compound_dict(self): + result = convert_dict_to_filter( + { + "type": "or", + "filters": [ + {"type": "eq", "key": "x", "value": "a"}, + { + "type": "and", + "filters": [ + {"type": "gt", "key": "y", "value": 5}, + {"type": "lt", "key": "z", "value": 10}, + ], + }, + ], + } + ) + assert isinstance(result, CompoundFilter) + assert result.type == "or" + inner = result.filters[1] + assert isinstance(inner, CompoundFilter) + assert inner.type == "and" + assert len(inner.filters) == 2 + + def test_or_compound_dict(self): + result = convert_dict_to_filter( + { + "type": "or", + "filters": [ + {"type": "eq", "key": "a", "value": 1}, + {"type": "eq", "key": "b", "value": 2}, + ], + } + ) + assert isinstance(result, CompoundFilter) + assert result.type == "or" + + def test_in_operator_dict(self): + result = convert_dict_to_filter( + {"type": "in", "key": "status", "value": ["active", "pending"]} + ) + assert isinstance(result, ComparisonFilter) + assert result.type == "in" + assert result.value == ["active", "pending"] diff --git a/temp_remote/start_server.py b/temp_remote/start_server.py new file mode 100644 index 00000000000..54713fe4ab3 --- /dev/null +++ b/temp_remote/start_server.py @@ -0,0 +1,146 @@ +""" +Feast Feature Server — standalone launcher +============================================= +Sets up data, materializes into PostgreSQL (pgvector), and starts +the feature server with uvicorn (single-process, no fork). + +Usage: + python start_server.py # full setup + serve + python start_server.py --serve-only # skip setup, just start serving + +Prerequisites: + docker run -d --name feast-pg -p 5432:5432 \ + -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=feast_demo \ + pgvector/pgvector:pg16 +""" + +import hashlib +import importlib +import os +import sys +from datetime import datetime, timedelta + +import numpy as np +import pandas as pd + +from feast import FeatureStore + +EMBEDDING_DIM = 384 +HOST = "0.0.0.0" +PORT = 6566 + +DOCUMENTS = { + "Introduction to Machine Learning": [ + "Machine learning is a subset of artificial intelligence that enables systems to learn from data. " + "Instead of being explicitly programmed, these systems improve their performance through experience.", + "Supervised learning uses labeled training data to learn a mapping from inputs to outputs. " + "Common algorithms include linear regression, decision trees, and neural networks.", + "Unsupervised learning finds hidden patterns in data without labeled responses. " + "Clustering and dimensionality reduction are key techniques in this category.", + "Reinforcement learning trains agents to make sequences of decisions by rewarding desired behaviors. " + "It has been successfully applied to game playing, robotics, and autonomous driving.", + ], + "Natural Language Processing": [ + "Natural language processing (NLP) combines linguistics and computer science to help machines " + "understand, interpret, and generate human language.", + "Tokenization breaks text into smaller units like words or subwords. Modern models use byte-pair " + "encoding or WordPiece tokenization for better handling of rare words.", + "Transformer architectures revolutionized NLP by introducing self-attention mechanisms that capture " + "long-range dependencies in text more effectively than recurrent networks.", + "Large language models like GPT and BERT are pre-trained on massive text corpora and can be " + "fine-tuned for downstream tasks such as classification, translation, and question answering.", + ], + "Vector Databases": [ + "Vector databases are specialized systems designed to store and query high-dimensional vector " + "embeddings efficiently, enabling similarity search at scale.", + "Approximate nearest neighbor (ANN) algorithms like HNSW and IVF trade a small amount of accuracy " + "for dramatic speed improvements over brute-force search.", + "Common distance metrics include cosine similarity, Euclidean distance, and inner product. " + "The choice depends on how embeddings were trained and what notion of similarity is needed.", + "Remote online stores allow Feast clients to query a centralized feature server without direct " + "database access, useful for multi-tenant or security-sensitive deployments.", + ], +} + + +def simple_embedding(text: str) -> list[float]: + h = hashlib.sha512(text.encode()).digest() + rng = np.random.RandomState(int.from_bytes(h[:4], "big")) + vec = rng.randn(EMBEDDING_DIM).astype(np.float32) + vec /= np.linalg.norm(vec) + return vec.tolist() + + +def build_parquet(): + rows = [] + for doc_title, chunks in DOCUMENTS.items(): + for i, text in enumerate(chunks): + chunk_id = f"{doc_title.lower().replace(' ', '_')}_chunk_{i}" + rows.append( + { + "chunk_id": chunk_id, + "chunk_text": text, + "doc_title": doc_title, + "chunk_index": i, + "embedding": simple_embedding(text), + "event_timestamp": datetime.now() - timedelta(hours=1), + "created_timestamp": datetime.now() - timedelta(hours=1), + } + ) + + os.makedirs(os.path.join("server", "data"), exist_ok=True) + path = os.path.join("server", "data", "doc_chunks.parquet") + pd.DataFrame(rows).to_parquet(path, index=False) + print(f"[1/3] Wrote {len(rows)} chunks → {path}") + + +def setup_and_materialize(): + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "server")) + doc_chunks_repo = importlib.import_module("doc_chunks_repo") + + store = FeatureStore(repo_path="server") + + pg_store = store._get_provider().online_store + table_name = f"{store.project}_{doc_chunks_repo.doc_chunks_fv.name}" + with pg_store._get_conn(store.config) as conn: + cur = conn.execute( + "SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename = %s", + (table_name,), + ) + if cur.fetchone(): + conn.execute(f'DROP TABLE IF EXISTS "{table_name}" CASCADE') + conn.commit() + print(f" Dropped old table '{table_name}'") + pg_store._conn = None + + store.apply([doc_chunks_repo.chunk_entity, doc_chunks_repo.doc_chunks_fv]) + store.materialize( + start_date=datetime.now() - timedelta(days=1), + end_date=datetime.now(), + ) + print("[2/3] Applied definitions & materialized into PostgreSQL") + + +def serve(): + from feast.feature_server import get_app + + import uvicorn + + store = FeatureStore(repo_path="server") + app = get_app(store) + print(f"[3/3] Starting uvicorn on {HOST}:{PORT} ...") + uvicorn.run(app, host=HOST, port=PORT) + + +if __name__ == "__main__": + os.chdir(os.path.dirname(os.path.abspath(__file__))) + + serve_only = "--serve-only" in sys.argv + + if serve_only: + print("Skipping setup (--serve-only), starting server directly.") + else: + build_parquet() + setup_and_materialize() + + serve()