From 693e9dfec5e9fcbc0b574e040805a2b093aa6661 Mon Sep 17 00:00:00 2001 From: toping4445 Date: Sun, 24 Jul 2022 16:29:44 +0900 Subject: [PATCH 01/11] fixed bugs, cleaned code, added AthenaDataSourceCreator Signed-off-by: Youngkyu OH --- protos/feast/core/DataSource.proto | 18 + protos/feast/core/FeatureService.proto | 6 + protos/feast/core/SavedDataset.proto | 1 + sdk/python/feast/__init__.py | 4 + sdk/python/feast/data_source.py | 1 + .../contrib/athena_offline_store/__init__.py | 0 .../contrib/athena_offline_store/athena.py | 690 ++++++++++++++++++ .../athena_offline_store/athena_source.py | 347 +++++++++ .../infra/offline_stores/offline_utils.py | 7 + sdk/python/feast/infra/utils/aws_utils.py | 326 ++++++++- sdk/python/feast/repo_config.py | 1 + sdk/python/feast/templates/athena/__init__.py | 0 sdk/python/feast/templates/athena/example.py | 107 +++ .../feast/templates/athena/feature_store.yaml | 12 + sdk/python/feast/type_map.py | 58 ++ .../universal/data_sources/athena.py | 112 +++ 16 files changed, 1688 insertions(+), 2 deletions(-) create mode 100644 sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/__init__.py create mode 100644 sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py create mode 100644 sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py create mode 100644 sdk/python/feast/templates/athena/__init__.py create mode 100644 sdk/python/feast/templates/athena/example.py create mode 100644 sdk/python/feast/templates/athena/feature_store.yaml create mode 100644 sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py diff --git a/protos/feast/core/DataSource.proto b/protos/feast/core/DataSource.proto index 62f5859ee8e..5258618f3bd 100644 --- a/protos/feast/core/DataSource.proto +++ b/protos/feast/core/DataSource.proto @@ -49,6 +49,7 @@ message DataSource { PUSH_SOURCE = 9; BATCH_TRINO = 10; BATCH_SPARK = 11; + BATCH_ATHENA = 12; } // Unique name of data source within the project @@ -171,6 +172,22 @@ message DataSource { string database = 4; } + // Defines options for DataSource that sources features from a Athena Query + message AthenaOptions { + // Athena table name + string table = 1; + + // SQL query that returns a table containing feature data. Must contain an event_timestamp column, and respective + // entity columns + string query = 2; + + // Athena database name + string database = 3; + + // Athena schema name + string data_source = 4; + } + // Defines options for DataSource that sources features from a Snowflake Query message SnowflakeOptions { // Snowflake table name @@ -242,5 +259,6 @@ message DataSource { PushOptions push_options = 22; SparkOptions spark_options = 27; TrinoOptions trino_options = 30; + AthenaOptions athena_options = 35; } } diff --git a/protos/feast/core/FeatureService.proto b/protos/feast/core/FeatureService.proto index 51b9c6c02a2..80d32eb4dec 100644 --- a/protos/feast/core/FeatureService.proto +++ b/protos/feast/core/FeatureService.proto @@ -60,6 +60,7 @@ message LoggingConfig { RedshiftDestination redshift_destination = 5; SnowflakeDestination snowflake_destination = 6; CustomDestination custom_destination = 7; + AthenaDestination athena_destination = 8; } message FileDestination { @@ -80,6 +81,11 @@ message LoggingConfig { string table_name = 1; } + message AthenaDestination { + // Destination table name. data_source and database will be taken from an offline store config + string table_name = 1; + } + message SnowflakeDestination { // Destination table name. Schema and database will be taken from an offline store config string table_name = 1; diff --git a/protos/feast/core/SavedDataset.proto b/protos/feast/core/SavedDataset.proto index 53f06f73a98..111548aa480 100644 --- a/protos/feast/core/SavedDataset.proto +++ b/protos/feast/core/SavedDataset.proto @@ -59,6 +59,7 @@ message SavedDatasetStorage { DataSource.TrinoOptions trino_storage = 8; DataSource.SparkOptions spark_storage = 9; DataSource.CustomSourceOptions custom_storage = 10; + DataSource.AthenaOptions athena_storage = 11; } } diff --git a/sdk/python/feast/__init__.py b/sdk/python/feast/__init__.py index 5d1663f7cbc..d592c35bbee 100644 --- a/sdk/python/feast/__init__.py +++ b/sdk/python/feast/__init__.py @@ -8,6 +8,9 @@ from feast.infra.offline_stores.file_source import FileSource from feast.infra.offline_stores.redshift_source import RedshiftSource from feast.infra.offline_stores.snowflake_source import SnowflakeSource +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaSource, +) from .batch_feature_view import BatchFeatureView from .data_source import KafkaSource, KinesisSource, PushSource, RequestSource @@ -50,4 +53,5 @@ "SnowflakeSource", "PushSource", "RequestSource", + "AthenaSource", ] diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index 931568f4e2e..89136d2eeed 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -160,6 +160,7 @@ def to_proto(self) -> DataSourceProto.KinesisOptions: DataSourceProto.SourceType.STREAM_KINESIS: "feast.data_source.KinesisSource", DataSourceProto.SourceType.REQUEST_SOURCE: "feast.data_source.RequestSource", DataSourceProto.SourceType.PUSH_SOURCE: "feast.data_source.PushSource", + DataSourceProto.SourceType.BATCH_ATHENA: "feast.infra.offline_stores.contrib.athena_offline_store.athena_source.AthenaSource", } diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py new file mode 100644 index 00000000000..7959322a30a --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -0,0 +1,690 @@ +import contextlib +import uuid +from datetime import datetime +from pathlib import Path +from typing import ( + Callable, + ContextManager, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, +) + +import numpy as np +import pandas as pd +import pyarrow +import pyarrow as pa +from dateutil import parser +from pydantic import StrictStr +from pydantic.typing import Literal +from pytz import utc + +from feast import OnDemandFeatureView +from feast.data_source import DataSource +from feast.errors import InvalidEntityType +from feast.feature_logging import LoggingConfig, LoggingSource, LoggingDestination +from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView +from feast.infra.offline_stores.offline_store import ( + OfflineStore, + RetrievalJob, + RetrievalMetadata, +) + +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaSource, + AthenaLoggingDestination, + SavedDatasetAthenaStorage, +) +from feast.infra.utils import aws_utils +from feast.infra.offline_stores import offline_utils + +from feast.registry import Registry +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.saved_dataset import SavedDatasetStorage +from feast.usage import log_exceptions_and_usage + + +class AthenaOfflineStoreConfig(FeastConfigBaseModel): + """Offline store config for AWS Athena""" + + type: Literal["athena"] = "athena" + """ Offline store type selector""" + + data_source: StrictStr + """ athena data source ex) AwsDataCatalog """ + + region: StrictStr + """ Athena's AWS region """ + + database: StrictStr + """ Athena database name """ + + s3_staging_location: StrictStr + """ S3 path for importing & exporting data to Athena """ + + +class AthenaOfflineStore(OfflineStore): + @staticmethod + @log_exceptions_and_usage(offline_store="athena") + def pull_latest_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + assert isinstance(data_source, AthenaSource) + assert isinstance(config.offline_store, AthenaOfflineStoreConfig) + + from_expression = data_source.get_table_query_string() + + partition_by_join_key_string = ", ".join(join_key_columns) + if partition_by_join_key_string != "": + partition_by_join_key_string = ( + "PARTITION BY " + partition_by_join_key_string + ) + timestamp_columns = [timestamp_field] + if created_timestamp_column: + timestamp_columns.append(created_timestamp_column) + timestamp_desc_string = " DESC, ".join(timestamp_columns) + " DESC" + field_string = ", ".join( + join_key_columns + feature_name_columns + timestamp_columns + ) + + date_partition_column = data_source.date_partition_column + + athena_client = aws_utils.get_athena_data_client( + config.offline_store.region + ) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + start_date = start_date.astimezone(tz=utc) + end_date = end_date.astimezone(tz=utc) + + query = f""" + + SELECT + {field_string} + {f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""} + FROM ( + SELECT {field_string}, + ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row + FROM {from_expression} + WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date.strftime('%Y-%m-%d %H:%M:%S')}' AND TIMESTAMP '{end_date.strftime('%Y-%m-%d %H:%M:%S')}' + {"AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''} + ) + WHERE _feast_row = 1 + """ + # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized + return AthenaRetrievalJob( + query=query, + athena_client=athena_client, + s3_resource=s3_resource, + config=config, + full_feature_names=False, + ) + + @staticmethod + @log_exceptions_and_usage(offline_store="athena") + def pull_all_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + assert isinstance(data_source, AthenaSource) + from_expression = data_source.get_table_query_string() + + field_string = ", ".join( + join_key_columns + feature_name_columns + [timestamp_field] + ) + + athena_client = aws_utils.get_athena_data_client( + config.offline_store.region + ) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + date_partition_column = data_source.date_partition_column + + start_date = start_date.astimezone(tz=utc) + end_date = end_date.astimezone(tz=utc) + + query = f""" + SELECT {field_string} + FROM {from_expression} + WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' + {"AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''} + """ + + return AthenaRetrievalJob( + query=query, + athena_client=athena_client, + s3_resource=s3_resource, + config=config, + full_feature_names=False, + ) + + @staticmethod + @log_exceptions_and_usage(offline_store="athena") + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: Registry, + project: str, + full_feature_names: bool = False, + ) -> RetrievalJob: + assert isinstance(config.offline_store, AthenaOfflineStoreConfig) + + athena_client = aws_utils.get_athena_data_client( + config.offline_store.region + ) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + # get pandas dataframe consisting of 1 row (LIMIT 1) and generate the schema out of it + entity_schema = _get_entity_schema( + entity_df, athena_client, config, s3_resource + ) + + # find timestamp column of entity df.(default = "event_timestamp"). Exception occurs if there are more than two timestamp columns. + entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema + ) + + # get min,max of event_timestamp. + entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( + entity_df, entity_df_event_timestamp_col, athena_client, config, + ) + + @contextlib.contextmanager + def query_generator() -> Iterator[str]: + + table_name = offline_utils.get_temp_entity_table_name() + + _upload_entity_df( + entity_df, athena_client, config, s3_resource, table_name + ) + + expected_join_keys = offline_utils.get_expected_join_keys( + project, feature_views, registry + ) + + offline_utils.assert_expected_columns_in_entity_df( + entity_schema, expected_join_keys, entity_df_event_timestamp_col + ) + + # Build a query context containing all information required to template the Athena SQL query + query_context = offline_utils.get_feature_view_query_context( + feature_refs, + feature_views, + registry, + project, + entity_df_event_timestamp_range, + ) + + + # Generate the Athena SQL query from the query context + query = offline_utils.build_point_in_time_query( + query_context, + left_table_query_string=table_name, + entity_df_event_timestamp_col=entity_df_event_timestamp_col, + entity_df_columns=entity_schema.keys(), + query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, + full_feature_names=full_feature_names, + ) + + try: + yield query + finally: + + #Always clean up the temp Athena table + aws_utils.execute_athena_query( + athena_client, + config.offline_store.data_source, + config.offline_store.database, + f"DROP TABLE IF EXISTS {config.offline_store.database}.{table_name}", + ) + + bucket = config.offline_store.s3_staging_location.replace("s3://", "").split("/", 1)[0] + aws_utils.delete_s3_directory(s3_resource,bucket, "entity_df/"+table_name+"/") + + + return AthenaRetrievalJob( + query=query_generator, + athena_client=athena_client, + s3_resource=s3_resource, + config=config, + full_feature_names=full_feature_names, + on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs( + feature_refs, project, registry + ), + metadata=RetrievalMetadata( + features=feature_refs, + keys=list(entity_schema.keys() - {entity_df_event_timestamp_col}), + min_event_timestamp=entity_df_event_timestamp_range[0], + max_event_timestamp=entity_df_event_timestamp_range[1], + ), + ) + + + @staticmethod + def write_logged_features( + config: RepoConfig, + data: Union[pyarrow.Table, Path], + source: LoggingSource, + logging_config: LoggingConfig, + registry: Registry, + ): + destination = logging_config.destination + assert isinstance(destination, AthenaLoggingDestination) + + athena_client = aws_utils.get_athena_data_client( + config.offline_store.region + ) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + if isinstance(data, Path): + s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}" + else: + s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}.parquet" + + aws_utils.upload_arrow_table_to_athena( + table=data, + athena_data_client=athena_client, + data_source=config.offline_store.data_source, + database=config.offline_store.database, + s3_resource=s3_resource, + s3_path=s3_path, + table_name=destination.table_name, + schema=source.get_schema(registry), + fail_if_exists=False, + ) + + +class AthenaRetrievalJob(RetrievalJob): + def __init__( + self, + query: Union[str, Callable[[], ContextManager[str]]], + athena_client, + s3_resource, + config: RepoConfig, + full_feature_names: bool, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, + ): + """Initialize AthenaRetrievalJob object. + + Args: + query: Athena SQL query to execute. Either a string, or a generator function that handles the artifact cleanup. + athena_client: boto3 athena client + s3_resource: boto3 s3 resource object + config: Feast repo config + full_feature_names: Whether to add the feature view prefixes to the feature names + on_demand_feature_views (optional): A list of on demand transforms to apply at retrieval time + """ + + + if not isinstance(query, str): + self._query_generator = query + else: + + @contextlib.contextmanager + def query_generator() -> Iterator[str]: + assert isinstance(query, str) + yield query + + self._query_generator = query_generator + self._athena_client = athena_client + self._s3_resource = s3_resource + self._config = config + self._full_feature_names = full_feature_names + self._on_demand_feature_views = ( + on_demand_feature_views if on_demand_feature_views else [] + ) + self._metadata = metadata + + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: + return self._on_demand_feature_views + + def get_temp_s3_path(self) -> str: + return self._config.offline_store.s3_staging_location + "/unload/" + str(uuid.uuid4()) + + def get_temp_table_dml_header(self, temp_table_name:str, temp_external_location:str) -> str: + temp_table_dml_header = f""" + CREATE TABLE {temp_table_name} + WITH ( + external_location = '{temp_external_location}', + format = 'parquet', + write_compression = 'snappy' + ) + as + """ + return temp_table_dml_header + + @log_exceptions_and_usage + def _to_df_internal(self) -> pd.DataFrame: + with self._query_generator() as query: + temp_table_name = "_" + str(uuid.uuid4()).replace("-", "") + temp_external_location = self.get_temp_s3_path() + return aws_utils.unload_athena_query_to_df( + self._athena_client, + self._config.offline_store.data_source, + self._config.offline_store.database, + self._s3_resource, + temp_external_location, + self.get_temp_table_dml_header(temp_table_name, temp_external_location) + query, + temp_table_name, + ) + + @log_exceptions_and_usage + def _to_arrow_internal(self) -> pa.Table: + with self._query_generator() as query: + temp_table_name = "_" + str(uuid.uuid4()).replace("-", "") + temp_external_location = self.get_temp_s3_path() + return aws_utils.unload_athena_query_to_pa( + self._athena_client, + self._config.offline_store.data_source, + self._config.offline_store.database, + self._s3_resource, + temp_external_location, + self.get_temp_table_dml_header(temp_table_name, temp_external_location) + query, + temp_table_name, + ) + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + def persist(self, storage: SavedDatasetStorage): + assert isinstance(storage, SavedDatasetAthenaStorage) + # self.to_athena(table_name=storage.athena_options.table) + + +def _upload_entity_df( + entity_df: Union[pd.DataFrame, str], + athena_client, + config: RepoConfig, + s3_resource, + table_name: str, +): + if isinstance(entity_df, pd.DataFrame): + # If the entity_df is a pandas dataframe, upload it to Athena + aws_utils.upload_df_to_athena( + athena_client, + config.offline_store.data_source, + config.offline_store.database, + s3_resource, + f"{config.offline_store.s3_staging_location}/entity_df/{table_name}/{table_name}.parquet", + table_name, + entity_df, + ) + elif isinstance(entity_df, str): + # If the entity_df is a string (SQL query), create a Athena table out of it + aws_utils.execute_athena_query( + athena_client, + config.offline_store.data_source, + config.offline_store.database, + f"CREATE TABLE {table_name} AS ({entity_df})", + ) + else: + raise InvalidEntityType(type(entity_df)) + + +def _get_entity_schema( + entity_df: Union[pd.DataFrame, str], + athena_client, + config: RepoConfig, + s3_resource, +) -> Dict[str, np.dtype]: + if isinstance(entity_df, pd.DataFrame): + return dict(zip(entity_df.columns, entity_df.dtypes)) + + elif isinstance(entity_df, str): + # get pandas dataframe consisting of 1 row (LIMIT 1) and generate the schema out of it + entity_df_sample = AthenaRetrievalJob( + f"SELECT * FROM ({entity_df}) LIMIT 1", + athena_client, + s3_resource, + config, + full_feature_names=False, + ).to_df() + return dict(zip(entity_df_sample.columns, entity_df_sample.dtypes)) + else: + raise InvalidEntityType(type(entity_df)) + + +def _get_entity_df_event_timestamp_range( + entity_df: Union[pd.DataFrame, str], + entity_df_event_timestamp_col: str, + athena_client, + config: RepoConfig, +) -> Tuple[datetime, datetime]: + if isinstance(entity_df, pd.DataFrame): + entity_df_event_timestamp = entity_df.loc[ + :, entity_df_event_timestamp_col + ].infer_objects() + if pd.api.types.is_string_dtype(entity_df_event_timestamp): + entity_df_event_timestamp = pd.to_datetime( + entity_df_event_timestamp, utc=True + ) + entity_df_event_timestamp_range = ( + entity_df_event_timestamp.min().to_pydatetime(), + entity_df_event_timestamp.max().to_pydatetime(), + ) + elif isinstance(entity_df, str): + # If the entity_df is a string (SQL query), determine range + # from table + statement_id = aws_utils.execute_athena_query( + athena_client, + config.offline_store.data_source, + config.offline_store.database, + f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max " + f"FROM ({entity_df})", + ) + res = aws_utils.get_athena_query_result(athena_client, statement_id)[ + "Records" + ][0] + entity_df_event_timestamp_range = ( + res.parse(res[0]["stringValue"]), + res.parse(res[1]["stringValue"]), + ) + else: + raise InvalidEntityType(type(entity_df)) + + return entity_df_event_timestamp_range + + +MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ +/* + Compute a deterministic hash for the `left_table_query_string` that will be used throughout + all the logic as the field to GROUP BY the data +*/ +WITH entity_dataframe AS ( + SELECT *, + {{entity_df_event_timestamp_col}} AS entity_timestamp + {% for featureview in featureviews %} + {% if featureview.entities %} + ,( + {% for entity in featureview.entities %} + CAST({{entity}} as VARCHAR) || + {% endfor %} + CAST({{entity_df_event_timestamp_col}} AS VARCHAR) + ) AS {{featureview.name}}__entity_row_unique_id + {% else %} + ,CAST({{entity_df_event_timestamp_col}} AS VARCHAR) AS {{featureview.name}}__entity_row_unique_id + {% endif %} + {% endfor %} + FROM {{ left_table_query_string }} +), + +{% for featureview in featureviews %} + +{{ featureview.name }}__entity_dataframe AS ( + SELECT + {{ featureview.entities | join(', ')}}{% if featureview.entities %},{% else %}{% endif %} + entity_timestamp, + {{featureview.name}}__entity_row_unique_id + FROM entity_dataframe + GROUP BY + {{ featureview.entities | join(', ')}}{% if featureview.entities %},{% else %}{% endif %} + entity_timestamp, + {{featureview.name}}__entity_row_unique_id +), + +/* + This query template performs the point-in-time correctness join for a single feature set table + to the provided entity table. + + 1. We first join the current feature_view to the entity dataframe that has been passed. + This JOIN has the following logic: + - For each row of the entity dataframe, only keep the rows where the `timestamp_field` + is less than the one provided in the entity dataframe + - If there a TTL for the current feature_view, also keep the rows where the `timestamp_field` + is higher the the one provided minus the TTL + - For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been + computed previously + + The output of this CTE will contain all the necessary information and already filtered out most + of the data that is not relevant. +*/ + +{{ featureview.name }}__subquery AS ( + SELECT + {{ featureview.timestamp_field }} as event_timestamp, + {{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }} + {{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %} + {% for feature in featureview.features %} + {{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}{% if loop.last %}{% else %}, {% endif %} + {% endfor %} + FROM {{ featureview.table_subquery }} + WHERE {{ featureview.timestamp_field }} <= from_iso8601_timestamp('{{ featureview.max_event_timestamp }}') + {% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %} + AND {{ featureview.date_partition_column }} <= '{{ featureview.max_event_timestamp[:10] }}' + {% endif %} + + {% if featureview.ttl == 0 %}{% else %} + AND {{ featureview.timestamp_field }} >= from_iso8601_timestamp('{{ featureview.min_event_timestamp }}') + {% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %} + AND {{ featureview.date_partition_column }} >= '{{ featureview.min_event_timestamp[:10] }}' + {% endif %} + {% endif %} + +), + +{{ featureview.name }}__base AS ( + SELECT + subquery.*, + entity_dataframe.entity_timestamp, + entity_dataframe.{{featureview.name}}__entity_row_unique_id + FROM {{ featureview.name }}__subquery AS subquery + INNER JOIN {{ featureview.name }}__entity_dataframe AS entity_dataframe + ON TRUE + AND subquery.event_timestamp <= entity_dataframe.entity_timestamp + + {% if featureview.ttl == 0 %}{% else %} + AND subquery.event_timestamp >= entity_dataframe.entity_timestamp - {{ featureview.ttl }} * interval '1' second + {% endif %} + + {% for entity in featureview.entities %} + AND subquery.{{ entity }} = entity_dataframe.{{ entity }} + {% endfor %} +), + +/* + 2. If the `created_timestamp_column` has been set, we need to + deduplicate the data first. This is done by calculating the + `MAX(created_at_timestamp)` for each event_timestamp. + We then join the data on the next CTE +*/ +{% if featureview.created_timestamp_column %} +{{ featureview.name }}__dedup AS ( + SELECT + {{featureview.name}}__entity_row_unique_id, + event_timestamp, + MAX(created_timestamp) as created_timestamp + FROM {{ featureview.name }}__base + GROUP BY {{featureview.name}}__entity_row_unique_id, event_timestamp +), +{% endif %} + +/* + 3. The data has been filtered during the first CTE "*__base" + Thus we only need to compute the latest timestamp of each feature. +*/ +{{ featureview.name }}__latest AS ( + SELECT + event_timestamp, + {% if featureview.created_timestamp_column %}created_timestamp,{% endif %} + {{featureview.name}}__entity_row_unique_id + FROM + ( + SELECT base.*, + ROW_NUMBER() OVER( + PARTITION BY base.{{featureview.name}}__entity_row_unique_id + ORDER BY base.event_timestamp DESC{% if featureview.created_timestamp_column %},base.created_timestamp DESC{% endif %} + ) AS row_number + FROM {{ featureview.name }}__base as base + {% if featureview.created_timestamp_column %} + INNER JOIN {{ featureview.name }}__dedup as dedup + ON TRUE + AND base.{{featureview.name}}__entity_row_unique_id = dedup.{{featureview.name}}__entity_row_unique_id + AND base.event_timestamp = dedup.event_timestamp + AND base.created_timestamp = dedup.created_timestamp + {% endif %} + ) + WHERE row_number = 1 +), + +/* + 4. Once we know the latest value of each feature for a given timestamp, + we can join again the data back to the original "base" dataset +*/ +{{ featureview.name }}__cleaned AS ( + SELECT base.* + FROM {{ featureview.name }}__base as base + INNER JOIN {{ featureview.name }}__latest as latest + ON TRUE + AND base.{{featureview.name}}__entity_row_unique_id = latest.{{featureview.name}}__entity_row_unique_id + AND base.event_timestamp = latest.event_timestamp + {% if featureview.created_timestamp_column %} + AND base.created_timestamp = latest.created_timestamp + {% endif %} +){% if loop.last %}{% else %}, {% endif %} + + +{% endfor %} +/* + Joins the outputs of multiple time travel joins to a single table. + The entity_dataframe dataset being our source of truth here. + */ + +SELECT {{ final_output_feature_names | join(', ')}} +FROM entity_dataframe as entity_df +{% for featureview in featureviews %} +LEFT JOIN ( + SELECT + {{featureview.name}}__entity_row_unique_id + {% for feature in featureview.features %} + ,{% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %} + {% endfor %} + FROM {{ featureview.name }}__cleaned +) as cleaned +ON TRUE +AND entity_df.{{featureview.name}}__entity_row_unique_id = cleaned.{{featureview.name}}__entity_row_unique_id +{% endfor %} +""" diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py new file mode 100644 index 00000000000..facd8ed80c0 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py @@ -0,0 +1,347 @@ +import warnings +from typing import Callable, Dict, Iterable, Optional, Tuple + +#from feast import type_map +from feast import type_map + +from feast.data_source import DataSource +from feast.errors import DataSourceNotFoundException, RedshiftCredentialsError +from feast.feature_logging import LoggingDestination +from feast.protos.feast.core.FeatureService_pb2 import ( + LoggingConfig as LoggingConfigProto, +) +#from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto + + +''' +from feast.protos.feast.core.SavedDataset_pb2 import ( + SavedDatasetStorage as SavedDatasetStorageProto, +) +''' +from feast.protos.feast.core.SavedDataset_pb2 import ( + SavedDatasetStorage as SavedDatasetStorageProto, +) +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDatasetStorage +from feast.value_type import ValueType + + +class AthenaSource(DataSource): + def __init__( + self, + *, + timestamp_field: Optional[str] = "", + table: Optional[str] = None, + database: Optional[str] = None, + data_source: Optional[str] = None, + created_timestamp_column: Optional[str] = None, + field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = None, + query: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = "", + tags: Optional[Dict[str, str]] = None, + owner: Optional[str] = "", + ): + """ + Creates a AthenaSource object. + + Args: + timestamp_field : event timestamp column. + table (optional): Athena table where the features are stored. + database: Athena Database Name + data_source (optional): Athena data source + created_timestamp_column (optional): Timestamp column indicating when the + row was created, used for deduplicating rows. + field_mapping (optional): A dictionary mapping of column names in this data + source to column names in a feature table or view. + date_partition_column : Timestamp column used for partitioning. + query (optional): The query to be executed to obtain the features. + name (optional): Name for the source. Defaults to the table_ref if not specified. + description (optional): A human-readable description. + tags (optional): A dictionary of key-value pairs to store arbitrary metadata. + owner (optional): The owner of the athena source, typically the email of the primary + maintainer. + + + """ + # The default Athena schema is named "public". + _database = "default" if table and not database else database + self.athena_options = AthenaOptions( + table=table, query=query, database=_database, data_source=data_source + ) + + if table is None and query is None: + raise ValueError('No "table" argument provided.') + _name = name + if not _name: + if table: + _name = table + else: + warnings.warn( + ( + f"Starting in Feast 0.21, Feast will require either a name for a data source (if using query) " + f"or `table`: {self.query}" + ), + DeprecationWarning, + ) + + super().__init__( + name=_name if _name else "", + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping, + date_partition_column = date_partition_column, + description=description, + tags=tags, + owner=owner, + ) + + @staticmethod + def from_proto(data_source: DataSourceProto): + """ + Creates a AthenaSource from a protobuf representation of a AthenaSource. + + Args: + data_source: A protobuf representation of a AthenaSource + + Returns: + A AthenaSource object based on the data_source protobuf. + """ + return AthenaSource( + name=data_source.name, + timestamp_field=data_source.timestamp_field, + table=data_source.athena_options.table, + database=data_source.athena_options.database, + data_source=data_source.athena_options.data_source, + created_timestamp_column=data_source.created_timestamp_column, + field_mapping=dict(data_source.field_mapping), + date_partition_column=data_source.date_partition_column, + query=data_source.athena_options.query, + description=data_source.description, + tags=dict(data_source.tags), + ) + + # Note: Python requires redefining hash in child classes that override __eq__ + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + if not isinstance(other, AthenaSource): + raise TypeError( + "Comparisons should only involve AthenaSource class objects." + ) + + return ( + super().__eq__(other) + and self.athena_options.table == other.athena_options.table + and self.athena_options.query == other.athena_options.query + and self.athena_options.database == other.athena_options.database + and self.athena_options.data_source == other.athena_options.data_source + ) + + @property + def table(self): + """Returns the table of this Athena source.""" + return self.athena_options.table + + @property + def database(self): + """Returns the database of this Athena source.""" + return self.athena_options.database + + @property + def query(self): + """Returns the Athena query of this Athena source.""" + return self.athena_options.query + + @property + def data_source(self): + """Returns the Athena data_source of this Athena source.""" + return self.athena_options.data_source + + def to_proto(self) -> DataSourceProto: + """ + Converts a RedshiftSource object to its protobuf representation. + + Returns: + A DataSourceProto object. + """ + data_source_proto = DataSourceProto( + type=DataSourceProto.BATCH_ATHENA, + name=self.name, + timestamp_field=self.timestamp_field, + created_timestamp_column=self.created_timestamp_column, + field_mapping=self.field_mapping, + date_partition_column=self.date_partition_column, + description=self.description, + tags=self.tags, + athena_options=self.athena_options.to_proto(), + ) + + return data_source_proto + + def validate(self, config: RepoConfig): + # As long as the query gets successfully executed, or the table exists, + # the data source is validated. We don't need the results though. + self.get_table_column_names_and_types(config) + + def get_table_query_string(self) -> str: + """Returns a string that can directly be used to reference this table in SQL.""" + if self.table: + return f'"{self.data_source}"."{self.database}"."{self.table}"' + else: + return f"({self.query})" + + @staticmethod + def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: + return type_map.athena_to_feast_value_type + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + """ + Returns a mapping of column names to types for this Athena source. + + Args: + config: A RepoConfig describing the feature repo + """ + from botocore.exceptions import ClientError + + from feast.infra.offline_stores.contrib.athena_offline_store.athena import AthenaOfflineStoreConfig + from feast.infra.utils import aws_utils + + assert isinstance(config.offline_store, AthenaOfflineStoreConfig) + + client = aws_utils.get_athena_data_client(config.offline_store.region) + if self.table: + try: + table = client.get_table_metadata( + CatalogName=self.data_source, + DatabaseName=self.database, + TableName=self.table, + ) + except ClientError as e: + raise aws_utils.AthenaError(e) + + # The API returns valid JSON with empty column list when the table doesn't exist + if len(table["TableMetadata"]["Columns"]) == 0: + raise DataSourceNotFoundException(self.table) + + columns = table["TableMetadata"]["Columns"] + else: + statement_id = aws_utils.execute_athena_query( + client, + config.offline_store.data_source, + config.offline_store.database, + f"SELECT * FROM ({self.query}) LIMIT 1", + ) + columns = aws_utils.get_athena_query_result(client, statement_id)["ResultSetMetadata"]["ColumnInfo"] + + return [(column["Name"], column["Type"].upper()) for column in columns] + + +class AthenaOptions: + """ + Configuration options for a Athena data source. + """ + + def __init__( + self, + table: Optional[str], + query: Optional[str], + database: Optional[str], + data_source: Optional[str], + ): + self.table = table or "" + self.query = query or "" + self.database = database or "" + self.data_source = data_source or "" + + @classmethod + def from_proto(cls, athena_options_proto: DataSourceProto.AthenaOptions): + """ + Creates a AthenaOptions from a protobuf representation of a Athena option. + + Args: + athena_options_proto: A protobuf representation of a DataSource + + Returns: + A AthenaOptions object based on the athena_options protobuf. + """ + athena_options = cls( + table=athena_options_proto.table, + query=athena_options_proto.query, + database=athena_options_proto.database, + data_source=athena_options_proto.data_source, + ) + + return athena_options + + def to_proto(self) -> DataSourceProto.AthenaOptions: + """ + Converts an AthenaOptionsProto object to its protobuf representation. + + Returns: + A AthenaOptionsProto protobuf. + """ + athena_options_proto = DataSourceProto.AthenaOptions( + table=self.table, + query=self.query, + database=self.database, + data_source=self.data_source, + ) + + return athena_options_proto + + +class SavedDatasetAthenaStorage(SavedDatasetStorage): + _proto_attr_name = "athena_storage" + + athena_options: AthenaOptions + + def __init__(self, table_ref: str): + self.athena_options = AthenaOptions( + table=table_ref, query=None, database=None, data_source=None + ) + + @staticmethod + def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage: + + return SavedDatasetAthenaStorage( + table_ref=AthenaOptions.from_proto(storage_proto.athena_storage).table + ) + + def to_proto(self) -> SavedDatasetStorageProto: + return SavedDatasetStorageProto( + athena_storage=self.athena_options.to_proto() + ) + + def to_data_source(self) -> DataSource: + return AthenaSource(table=self.athena_options.table) + + +class AthenaLoggingDestination(LoggingDestination): + _proto_kind = "athena_destination" + + table_name: str + + def __init__(self, *, table_name: str): + self.table_name = table_name + + @classmethod + def from_proto(cls, config_proto: LoggingConfigProto) -> "LoggingDestination": + return AthenaLoggingDestination( + table_name=config_proto.athena_destination.table_name, + ) + + def to_proto(self) -> LoggingConfigProto: + return LoggingConfigProto( + athena_destination=LoggingConfigProto.AthenaDestination( + table_name=self.table_name + ) + ) + + def to_data_source(self) -> DataSource: + return AthenaSource(table=self.table_name) diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py index 8b963a864bc..a1dc117b35b 100644 --- a/sdk/python/feast/infra/offline_stores/offline_utils.py +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -93,6 +93,7 @@ class FeatureViewQueryContext: entity_selections: List[str] min_event_timestamp: Optional[str] max_event_timestamp: str + date_partition_column: Optional[str] # this attribute is added because partition pruning affects Athena's query performance. def get_feature_view_query_context( @@ -142,6 +143,11 @@ def get_feature_view_query_context( feature_view.batch_source.created_timestamp_column, ) + date_partition_column = reverse_field_mapping.get( + feature_view.batch_source.date_partition_column, + feature_view.batch_source.date_partition_column, + ) + max_event_timestamp = to_naive_utc(entity_df_timestamp_range[1]).isoformat() min_event_timestamp = None if feature_view.ttl: @@ -162,6 +168,7 @@ def get_feature_view_query_context( entity_selections=entity_selections, min_event_timestamp=min_event_timestamp, max_event_timestamp=max_event_timestamp, + date_partition_column=date_partition_column ) query_context.append(context) diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index 3c8ad9d71b0..a3da377025c 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -21,7 +21,7 @@ RedshiftQueryError, RedshiftTableNameTooLong, ) -from feast.type_map import pa_to_redshift_value_type +from feast.type_map import pa_to_redshift_value_type,pa_to_athena_value_type try: import boto3 @@ -32,7 +32,6 @@ raise FeastExtrasDependencyImportError("aws", str(e)) - REDSHIFT_TABLE_NAME_MAX_LENGTH = 127 @@ -672,3 +671,326 @@ def list_s3_files(aws_region: str, path: str) -> List[str]: contents = objects["Contents"] files = [f"s3://{bucket}/{content['Key']}" for content in contents] return files + + +# Athena + +def get_athena_data_client(aws_region: str): + """ + Get the athena Data API Service client for the given AWS region. + """ + return boto3.client("athena", config=Config(region_name=aws_region)) + + +@retry( + wait=wait_exponential(multiplier=1, max=4), + retry=retry_if_exception_type(ConnectionClosedError), + stop=stop_after_attempt(5), + reraise=True, +) +def execute_athena_query_async( + athena_data_client, data_source: str, database: str, query: str +) -> dict: + """Execute Athena statement asynchronously. Does not wait for the query to finish. + + Raises AthenaCredentialsError if the statement couldn't be executed due to the validation error. + + Args: + athena_data_client: athena Data API Service client + data_source: athena Cluster Identifier + database: athena Database Name + query: The SQL query to execute + + Returns: JSON response + + """ + try: + # return athena_data_client.execute_statement( + return athena_data_client.start_query_execution( + QueryString=query, + QueryExecutionContext={ + 'Database': database + }, + WorkGroup='primary' + ) + + except ClientError as e: + raise AthenaQueryError() + + +class AthenaStatementNotFinishedError(Exception): + pass + + +@retry( + wait=wait_exponential(multiplier=1, max=30), + retry=retry_if_exception_type(AthenaStatementNotFinishedError), + reraise=True, +) +def wait_for_athena_execution(athena_data_client, execution: dict) -> None: + """Waits for the Athena statement to finish. Raises AthenaQueryError if the statement didn't succeed. + + We use exponential backoff for checking the query state until it's not running. The backoff starts with + 0.1 seconds and doubles exponentially until reaching 30 seconds, at which point the backoff is fixed. + + Args: + athena_data_client: athena Service boto3 client + execution: The athena execution to wait for (result of execute_athena_statement) + + Returns: None + + """ + response = athena_data_client.get_query_execution(QueryExecutionId=execution["QueryExecutionId"]) + if response["QueryExecution"]["Status"]["State"] in ("QUEUED", "RUNNING"): + raise AthenaStatementNotFinishedError # Retry + if response["QueryExecution"]["Status"]["State"] != "SUCCEEDED": + raise AthenaQueryError(response) # Don't retry. Raise exception. + + +def drop_temp_table(athena_data_client, data_source: str, database: str, temp_table: str): + query = f'DROP TABLE `{database}.{temp_table}`' + execute_athena_query_async( + athena_data_client, data_source, database, query + ) + + +def execute_athena_query( + athena_data_client, data_source: str, database: str, query: str, temp_table: str = None +) -> str: + """Execute athena statement synchronously. Waits for the query to finish. + + Raises athenaCredentialsError if the statement couldn't be executed due to the validation error. + Raises athenaQueryError if the query runs but finishes with errors. + + + Args: + athena_data_client: athena Data API Service client + data_source: athena data source Name + database: athena Database Name + query: The SQL query to execute + temp_table: temp table name to be deleted after query execution. + + Returns: Statement ID + + """ + + execution = execute_athena_query_async( + athena_data_client, data_source, database, query + ) + wait_for_athena_execution(athena_data_client, execution) + if temp_table is not None: + drop_temp_table(athena_data_client, data_source, database, temp_table) + + return execution["QueryExecutionId"] + + +def get_athena_query_result(athena_data_client, query_execution_id: str) -> dict: + """Get the athena query result""" + response = athena_data_client.get_query_results(QueryExecutionId=query_execution_id) + return response["ResultSet"] + + +class AthenaError(Exception): + def __init__(self, details): + super().__init__(f"Athena API failed. Details: {details}") + + +class AthenaQueryError(Exception): + def __init__(self, details): + super().__init__(f"Athena SQL Query failed to finish. Details: {details}") + + +class AthenaTableNameTooLong(Exception): + def __init__(self, table_name: str): + super().__init__( + f"Athena table(Data catalog) names have a maximum length of 255 characters, but the table name {table_name} has length {len(table_name)} characters." + ) + + +def unload_athena_query_to_pa( + athena_data_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + query: str, + temp_table: str, +) -> pa.Table: + """Unload Athena Query results to S3 and get the results in PyArrow Table format""" + bucket, key = get_bucket_and_key(s3_path) + + execute_athena_query_and_unload_to_s3( + athena_data_client, data_source, database, query, temp_table + ) + + with tempfile.TemporaryDirectory() as temp_dir: + download_s3_directory(s3_resource, bucket, key, temp_dir) + delete_s3_directory(s3_resource, bucket, key) + return pq.read_table(temp_dir) + + +def unload_athena_query_to_df( + athena_data_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + query: str, + temp_table: str, +) -> pd.DataFrame: + """Unload Athena Query results to S3 and get the results in Pandas DataFrame format""" + table = unload_athena_query_to_pa( + athena_data_client, + data_source, + database, + s3_resource, + s3_path, + query, + temp_table + ) + return table.to_pandas() + + +def execute_athena_query_and_unload_to_s3( + athena_data_client, + data_source: str, + database: str, + query: str, + temp_table: str, +) -> None: + """Unload Athena Query results to S3 + + Args: + athena_data_client: Athena Data API Service client + data_source: Athena data source + database: Redshift Database Name + query: The SQL query to execute + temp_table: temp table name to be deleted after query execution. + + """ + + execute_athena_query(athena_data_client, data_source, database, query, temp_table) + + +def upload_df_to_athena( + athena_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + table_name: str, + df: pd.DataFrame, +): + """Uploads a Pandas DataFrame to S3(Athena) as a new table. + + The caller is responsible for deleting the table when no longer necessary. + + Args: + athena_client: Athena API Service client + data_source: Athena Data Source + database: Athena Database Name + s3_resource: S3 Resource object + s3_path: S3 path where the Parquet file is temporarily uploaded + table_name: The name of the new Data Catalog table where we copy the dataframe + df: The Pandas DataFrame to upload + + Raises: + AthenaTableNameTooLong: The specified table name is too long. + """ + + # Drop the index so that we dont have unnecessary columns + df.reset_index(drop=True, inplace=True) + + # Convert Pandas DataFrame into PyArrow table and compile the Athena table schema. + # Note, if the underlying data has missing values, + # pandas will convert those values to np.nan if the dtypes are numerical (floats, ints, etc.) or boolean. + # If the dtype is 'object', then missing values are inferred as python `None`s. + # More details at: + # https://pandas.pydata.org/pandas-docs/stable/user_guide/missing_data.html#values-considered-missing + table = pa.Table.from_pandas(df) + upload_arrow_table_to_athena( + table, + athena_client, + data_source=data_source, + database=database, + s3_resource=s3_resource, + s3_path=s3_path, + table_name=table_name, + ) + + +def upload_arrow_table_to_athena( + table: Union[pyarrow.Table, Path], + athena_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + table_name: str, + schema: Optional[pyarrow.Schema] = None, + fail_if_exists: bool = True, +): + """Uploads an Arrow Table to S3(Athena). + + Here's how the upload process works: + 1. PyArrow Table is serialized into a Parquet format on local disk + 2. The Parquet file is uploaded to S3 + 3. an Athena(data catalog) table is created. the S3 directory(in number 2) will be set as an external location. + 4. The local disk & s3 paths are cleaned up + + Args: + table: The Arrow Table or Path to parquet dataset to upload + athena_client: Athena API Service client + data_source: Athena data source + database: Athena Database Name + s3_resource: S3 Resource object + s3_path: S3 path where the Parquet file is temporarily uploaded + table_name: The name of the new Athena table where we copy the dataframe + schema: (Optionally) client may provide arrow Schema which will be converted into Athena table schema + fail_if_exists: fail if table with such name exists or append data to existing table + + Raises: + AthenaTableNameTooLong: The specified table name is too long. + """ + DATA_CATALOG_TABLE_NAME_MAX_LENGTH = 255 + + if len(table_name) > DATA_CATALOG_TABLE_NAME_MAX_LENGTH: + raise AthenaTableNameTooLong(table_name) + + if isinstance(table, pyarrow.Table) and not schema: + schema = table.schema + + if not schema: + raise ValueError("Schema must be specified when data is passed as a Path") + + bucket, key = get_bucket_and_key(s3_path) + + column_query_list = ", ".join( + [f"{field.name} {pa_to_athena_value_type(field.type)}" for field in schema] + ) + + with tempfile.TemporaryFile(suffix=".parquet") as parquet_temp_file: + pq.write_table(table, parquet_temp_file) + parquet_temp_file.seek(0) + s3_resource.Object(bucket, key).put(Body=parquet_temp_file) + + create_query = ( + f"CREATE EXTERNAL TABLE {database}.{table_name} " + f"({column_query_list}) " + f"STORED AS PARQUET " + f"LOCATION '{s3_path[:s3_path.rfind('/')]}' " + f"TBLPROPERTIES('parquet.compress' = 'SNAPPY') " + ) + + try: + execute_athena_query( + athena_client, + data_source, + database, + f"{create_query}", + ) + finally: + pass + # Clean up S3 temporary data + # for file_path in uploaded_files: + # s3_resource.Object(bucket, file_path).delete() diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 278bc7da690..e475ccda1ff 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -57,6 +57,7 @@ "spark": "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore", "trino": "feast.infra.offline_stores.contrib.trino_offline_store.trino.TrinoOfflineStore", "postgres": "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.PostgreSQLOfflineStore", + "athena": "feast.infra.offline_stores.contrib.athena_offline_store.athena.AthenaOfflineStore", } FEATURE_SERVER_CONFIG_CLASS_FOR_TYPE = { diff --git a/sdk/python/feast/templates/athena/__init__.py b/sdk/python/feast/templates/athena/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/templates/athena/example.py b/sdk/python/feast/templates/athena/example.py new file mode 100644 index 00000000000..d5fec1b8ea2 --- /dev/null +++ b/sdk/python/feast/templates/athena/example.py @@ -0,0 +1,107 @@ +import importlib +import os +from datetime import datetime, timedelta + +import pandas as pd + +from feast import Entity, Feature, FeatureStore, FeatureView, ValueType +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaSource, +) + + +def test_end_to_end(): + + try: + fs = FeatureStore("feature_repo/") + + driver_hourly_stats = AthenaSource( + timestamp_field="event_timestamp", + table="driver_stats", + # table="driver_stats_partitioned", + database="sampledb", + data_source="AwsDataCatalog", + created_timestamp_column="created", + # date_partition_column="std_date" + ) + + driver = Entity( + name="driver_id", value_type=ValueType.INT64, description="driver id", + ) + + driver_hourly_stats_view = FeatureView( + name="driver_hourly_stats", + entities=["driver_id"], + ttl=timedelta(days=365), + features=[ + Feature(name="conv_rate", dtype=ValueType.FLOAT), + Feature(name="acc_rate", dtype=ValueType.FLOAT), + Feature(name="avg_daily_trips", dtype=ValueType.INT64), + ], + online=True, + batch_source=driver_hourly_stats, + ) + + # apply repository + fs.apply([driver_hourly_stats, driver, driver_hourly_stats_view]) + + print(fs.list_data_sources()) + print(fs.list_feature_views()) + + entity_df = pd.DataFrame( + {"driver_id": [1001], "event_timestamp": [datetime.now()]} + ) + + # Read features from offline store + + feature_vector = ( + fs.get_historical_features( + features=["driver_hourly_stats:conv_rate"], entity_df=entity_df + ) + .to_df() + .to_dict() + ) + conv_rate = feature_vector["conv_rate"][0] + print(conv_rate) + assert conv_rate > 0 + + # load data into online store + fs.materialize_incremental(end_date=datetime.now()) + + online_response = fs.get_online_features( + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "driver_hourly_stats:avg_daily_trips", + ], + entity_rows=[{"driver_id": 1002}], + ) + online_response_dict = online_response.to_dict() + print(online_response_dict) + + except Exception as e: + print(e) + finally: + # tear down feature store + fs.teardown() + + +def test_cli(): + os.system("PYTHONPATH=$PYTHONPATH:/$(pwd) feast -c feature_repo apply") + try: + os.system("PYTHONPATH=$PYTHONPATH:/$(pwd) ") + with open("output", "r") as f: + output = f.read() + + if "Pulling latest features from my offline store" not in output: + raise Exception( + 'Failed to successfully use provider from CLI. See "output" for more details.' + ) + finally: + os.system("PYTHONPATH=$PYTHONPATH:/$(pwd) feast -c feature_repo teardown") + + +if __name__ == "__main__": + # pass + test_end_to_end() + test_cli() \ No newline at end of file diff --git a/sdk/python/feast/templates/athena/feature_store.yaml b/sdk/python/feast/templates/athena/feature_store.yaml new file mode 100644 index 00000000000..ee88bda72a1 --- /dev/null +++ b/sdk/python/feast/templates/athena/feature_store.yaml @@ -0,0 +1,12 @@ +project: repo +registry: registry.db +provider: aws +online_store: + type: sqlite + path: online_store.db +offline_store: + type: athena + region: ap-northeast-2 + database: sampledb + data_source: AwsDataCatalog + s3_staging_location: s3://sagemaker-yelo-test \ No newline at end of file diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index ed4b7cba594..bb0fb360d42 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -791,3 +791,61 @@ def pg_type_code_to_arrow(code: int) -> str: return feast_value_type_to_pa( pg_type_to_feast_value_type(pg_type_code_to_pg_type(code)) ) + + + +def athena_to_feast_value_type(athena_type_as_str: str) -> ValueType: + # Type names from https://docs.aws.amazon.com/athena/latest/ug/data-types.html + type_map = { + "null": ValueType.UNKNOWN, + "boolean": ValueType.BOOL, + "tinyint": ValueType.INT32, + "smallint": ValueType.INT32, + "int": ValueType.INT32, + "bigint": ValueType.INT64, + "double": ValueType.DOUBLE, + "float": ValueType.FLOAT, + "binary": ValueType.BYTES, + "char": ValueType.STRING, + "varchar": ValueType.STRING, + "string": ValueType.STRING, + "timestamp": ValueType.UNIX_TIMESTAMP, + # skip date,decimal,array,map,struct + } + return type_map[athena_type_as_str.lower()] + + +def pa_to_athena_value_type(pa_type: pyarrow.DataType) -> str: + # PyArrow types: https://arrow.apache.org/docs/python/api/datatypes.html + # Type names from https://docs.aws.amazon.com/athena/latest/ug/data-types.html + pa_type_as_str = str(pa_type).lower() + if pa_type_as_str.startswith("timestamp"): + return "timestamp" + + if pa_type_as_str.startswith("date"): + return "date" + + if pa_type_as_str.startswith("decimal"): + return pa_type_as_str + + # We have to take into account how arrow types map to parquet types as well. + # For example, null type maps to int32 in parquet, so we have to use int4 in Redshift. + # Other mappings have also been adjusted accordingly. + type_map = { + "null": "null", + "bool": "bool", + "int8": "tinyint", + "int16": "smallint", + "int32": "int", + "int64": "bigint", + "uint8": "tinyint", + "uint16": "tinyint", + "uint32": "tinyint", + "uint64": "tinyint", + "float": "float", + "double": "double", + "binary": "binary", + "string": "string", + } + + return type_map[pa_type_as_str] diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py new file mode 100644 index 00000000000..d34af4f373d --- /dev/null +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py @@ -0,0 +1,112 @@ +import uuid +from typing import Any, Dict, List, Optional + +import pandas as pd + +from feast import AthenaSource +from feast.data_source import DataSource +from feast.feature_logging import LoggingDestination +from feast.infra.offline_stores.contrib.athena_offline_store.athena import ( + AthenaOfflineStoreConfig, +) +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaLoggingDestination, + SavedDatasetAthenaStorage, +) +from feast.infra.utils import aws_utils +from feast.repo_config import FeastConfigBaseModel +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.data_source_creator import ( + DataSourceCreator, +) + + +class AthenaDataSourceCreator(DataSourceCreator): + + tables: List[str] = [] + + def __init__(self, project_name: str): + super().__init__(project_name) + self.client = aws_utils.get_athena_data_client("ap-northeast-2") + self.s3 = aws_utils.get_s3_resource("ap-northeast-2") + + self.offline_store_config = AthenaOfflineStoreConfig( + data_source="AwsDataCatalog", + region="ap-northeast-2", + database="sampledb", + s3_staging_location="s3://sagemaker-yelo-test", + ) + + def create_data_source( + self, + df: pd.DataFrame, + destination_name: str, + suffix: Optional[str] = None, + timestamp_field="ts", + created_timestamp_column="created_ts", + field_mapping: Dict[str, str] = None, + ) -> DataSource: + + destination_name = self.get_prefixed_table_name(destination_name) + + aws_utils.upload_df_to_athena( + self.client, + self.offline_store_config.data_source, + self.offline_store_config.database, + self.s3, + self.offline_store_config.s3_staging_location, + destination_name, + df, + ) + + self.tables.append(destination_name) + + return AthenaSource( + table=destination_name, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping or {"ts_1": "ts"}, + database=self.offline_store_config.database, + data_source=self.offline_store_config.data_source, + ) + + def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: + table = self.get_prefixed_table_name( + f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" + ) + self.tables.append(table) + + return SavedDatasetAthenaStorage(table_ref=table) + + def create_logged_features_destination(self) -> LoggingDestination: + table = self.get_prefixed_table_name( + f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" + ) + self.tables.append(table) + + return AthenaLoggingDestination(table_name=table) + + def create_offline_store_config(self) -> FeastConfigBaseModel: + return self.offline_store_config + + def get_prefixed_table_name(self, suffix: str) -> str: + return f"{self.project_name}.{suffix}" + + def teardown(self): + for table in self.tables: + aws_utils.execute_athena_query( + self.client, + self.offline_store_config.data_source, + self.offline_store_config.database, + f"DROP TABLE IF EXISTS {table}", + ) + + +FULL_REPO_CONFIGS = [ + IntegrationTestRepoConfig(), + IntegrationTestRepoConfig( + provider="aws", offline_store_creator=AthenaDataSourceCreator, + ), +] \ No newline at end of file From 7cbd23293e9a73abdf273e42043cb58ec56add9d Mon Sep 17 00:00:00 2001 From: toping4445 Date: Sun, 31 Jul 2022 20:48:38 +0900 Subject: [PATCH 02/11] fixed bugs, cleaned code, added some methods. test_universal_historical_retrieval - 100% passed Signed-off-by: Youngkyu OH --- sdk/python/feast/__init__.py | 6 +- sdk/python/feast/batch_feature_view.py | 1 + .../contrib/athena_offline_store/athena.py | 126 +++++++++++------- .../athena_offline_store/athena_source.py | 48 ++++--- .../athena_offline_store/tests/__init__.py | 0 .../athena_offline_store/tests/data_source.py | 104 +++++++++++++++ .../contrib/athena_repo_configuration.py | 18 +++ .../infra/offline_stores/offline_utils.py | 6 +- sdk/python/feast/infra/utils/aws_utils.py | 107 ++++++++------- sdk/python/feast/templates/athena/example.py | 6 +- sdk/python/feast/type_map.py | 1 - .../feature_repos/repo_configuration.py | 6 + .../universal/data_sources/athena.py | 42 ++++-- 13 files changed, 333 insertions(+), 138 deletions(-) create mode 100644 sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/__init__.py create mode 100644 sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py create mode 100644 sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py diff --git a/sdk/python/feast/__init__.py b/sdk/python/feast/__init__.py index d592c35bbee..d043f1a9738 100644 --- a/sdk/python/feast/__init__.py +++ b/sdk/python/feast/__init__.py @@ -5,12 +5,12 @@ from importlib_metadata import PackageNotFoundError, version as _version # type: ignore from feast.infra.offline_stores.bigquery_source import BigQuerySource -from feast.infra.offline_stores.file_source import FileSource -from feast.infra.offline_stores.redshift_source import RedshiftSource -from feast.infra.offline_stores.snowflake_source import SnowflakeSource from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( AthenaSource, ) +from feast.infra.offline_stores.file_source import FileSource +from feast.infra.offline_stores.redshift_source import RedshiftSource +from feast.infra.offline_stores.snowflake_source import SnowflakeSource from .batch_feature_view import BatchFeatureView from .data_source import KafkaSource, KinesisSource, PushSource, RequestSource diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index db967046404..f714573810b 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -14,6 +14,7 @@ "SnowflakeSource", "SparkSource", "TrinoSource", + "AthenaSource", } diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py index 7959322a30a..79d4461aa7d 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -25,23 +25,21 @@ from feast import OnDemandFeatureView from feast.data_source import DataSource from feast.errors import InvalidEntityType -from feast.feature_logging import LoggingConfig, LoggingSource, LoggingDestination +from feast.feature_logging import LoggingConfig, LoggingDestination, LoggingSource from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView +from feast.infra.offline_stores import offline_utils +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaLoggingDestination, + AthenaSource, + SavedDatasetAthenaStorage, +) from feast.infra.offline_stores.offline_store import ( OfflineStore, RetrievalJob, RetrievalMetadata, ) - -from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( - AthenaSource, - AthenaLoggingDestination, - SavedDatasetAthenaStorage, -) from feast.infra.utils import aws_utils -from feast.infra.offline_stores import offline_utils - -from feast.registry import Registry +from feast.registry import Registry, BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.usage import log_exceptions_and_usage @@ -82,7 +80,7 @@ def pull_latest_from_table_or_query( assert isinstance(data_source, AthenaSource) assert isinstance(config.offline_store, AthenaOfflineStoreConfig) - from_expression = data_source.get_table_query_string() + from_expression = data_source.get_table_query_string(config) partition_by_join_key_string = ", ".join(join_key_columns) if partition_by_join_key_string != "": @@ -99,9 +97,7 @@ def pull_latest_from_table_or_query( date_partition_column = data_source.date_partition_column - athena_client = aws_utils.get_athena_data_client( - config.offline_store.region - ) + athena_client = aws_utils.get_athena_data_client(config.offline_store.region) s3_resource = aws_utils.get_s3_resource(config.offline_store.region) start_date = start_date.astimezone(tz=utc) @@ -142,15 +138,13 @@ def pull_all_from_table_or_query( end_date: datetime, ) -> RetrievalJob: assert isinstance(data_source, AthenaSource) - from_expression = data_source.get_table_query_string() + from_expression = data_source.get_table_query_string(config) field_string = ", ".join( join_key_columns + feature_name_columns + [timestamp_field] ) - athena_client = aws_utils.get_athena_data_client( - config.offline_store.region - ) + athena_client = aws_utils.get_athena_data_client(config.offline_store.region) s3_resource = aws_utils.get_s3_resource(config.offline_store.region) date_partition_column = data_source.date_partition_column @@ -186,9 +180,7 @@ def get_historical_features( ) -> RetrievalJob: assert isinstance(config.offline_store, AthenaOfflineStoreConfig) - athena_client = aws_utils.get_athena_data_client( - config.offline_store.region - ) + athena_client = aws_utils.get_athena_data_client(config.offline_store.region) s3_resource = aws_utils.get_s3_resource(config.offline_store.region) # get pandas dataframe consisting of 1 row (LIMIT 1) and generate the schema out of it @@ -197,13 +189,16 @@ def get_historical_features( ) # find timestamp column of entity df.(default = "event_timestamp"). Exception occurs if there are more than two timestamp columns. - entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema + entity_df_event_timestamp_col = ( + offline_utils.infer_event_timestamp_from_entity_df(entity_schema) ) # get min,max of event_timestamp. entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, entity_df_event_timestamp_col, athena_client, config, + entity_df, + entity_df_event_timestamp_col, + athena_client, + config, ) @contextlib.contextmanager @@ -211,9 +206,7 @@ def query_generator() -> Iterator[str]: table_name = offline_utils.get_temp_entity_table_name() - _upload_entity_df( - entity_df, athena_client, config, s3_resource, table_name - ) + _upload_entity_df(entity_df, athena_client, config, s3_resource, table_name) expected_join_keys = offline_utils.get_expected_join_keys( project, feature_views, registry @@ -232,7 +225,6 @@ def query_generator() -> Iterator[str]: entity_df_event_timestamp_range, ) - # Generate the Athena SQL query from the query context query = offline_utils.build_point_in_time_query( query_context, @@ -247,7 +239,7 @@ def query_generator() -> Iterator[str]: yield query finally: - #Always clean up the temp Athena table + # Always clean up the temp Athena table aws_utils.execute_athena_query( athena_client, config.offline_store.data_source, @@ -255,9 +247,12 @@ def query_generator() -> Iterator[str]: f"DROP TABLE IF EXISTS {config.offline_store.database}.{table_name}", ) - bucket = config.offline_store.s3_staging_location.replace("s3://", "").split("/", 1)[0] - aws_utils.delete_s3_directory(s3_resource,bucket, "entity_df/"+table_name+"/") - + bucket = config.offline_store.s3_staging_location.replace( + "s3://", "" + ).split("/", 1)[0] + aws_utils.delete_s3_directory( + s3_resource, bucket, "entity_df/" + table_name + "/" + ) return AthenaRetrievalJob( query=query_generator, @@ -276,21 +271,18 @@ def query_generator() -> Iterator[str]: ), ) - @staticmethod def write_logged_features( config: RepoConfig, data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, - registry: Registry, + registry: BaseRegistry, ): destination = logging_config.destination assert isinstance(destination, AthenaLoggingDestination) - athena_client = aws_utils.get_athena_data_client( - config.offline_store.region - ) + athena_client = aws_utils.get_athena_data_client(config.offline_store.region) s3_resource = aws_utils.get_s3_resource(config.offline_store.region) if isinstance(data, Path): s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}" @@ -299,7 +291,7 @@ def write_logged_features( aws_utils.upload_arrow_table_to_athena( table=data, - athena_data_client=athena_client, + athena_client=athena_client, data_source=config.offline_store.data_source, database=config.offline_store.database, s3_resource=s3_resource, @@ -332,7 +324,6 @@ def __init__( on_demand_feature_views (optional): A list of on demand transforms to apply at retrieval time """ - if not isinstance(query, str): self._query_generator = query else: @@ -352,7 +343,6 @@ def query_generator() -> Iterator[str]: ) self._metadata = metadata - @property def full_feature_names(self) -> bool: return self._full_feature_names @@ -362,9 +352,15 @@ def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: return self._on_demand_feature_views def get_temp_s3_path(self) -> str: - return self._config.offline_store.s3_staging_location + "/unload/" + str(uuid.uuid4()) + return ( + self._config.offline_store.s3_staging_location + + "/unload/" + + str(uuid.uuid4()) + ) - def get_temp_table_dml_header(self, temp_table_name:str, temp_external_location:str) -> str: + def get_temp_table_dml_header( + self, temp_table_name: str, temp_external_location: str + ) -> str: temp_table_dml_header = f""" CREATE TABLE {temp_table_name} WITH ( @@ -387,7 +383,8 @@ def _to_df_internal(self) -> pd.DataFrame: self._config.offline_store.database, self._s3_resource, temp_external_location, - self.get_temp_table_dml_header(temp_table_name, temp_external_location) + query, + self.get_temp_table_dml_header(temp_table_name, temp_external_location) + + query, temp_table_name, ) @@ -402,7 +399,8 @@ def _to_arrow_internal(self) -> pa.Table: self._config.offline_store.database, self._s3_resource, temp_external_location, - self.get_temp_table_dml_header(temp_table_name, temp_external_location) + query, + self.get_temp_table_dml_header(temp_table_name, temp_external_location) + + query, temp_table_name, ) @@ -412,7 +410,33 @@ def metadata(self) -> Optional[RetrievalMetadata]: def persist(self, storage: SavedDatasetStorage): assert isinstance(storage, SavedDatasetAthenaStorage) - # self.to_athena(table_name=storage.athena_options.table) + self.to_athena(table_name=storage.athena_options.table) + + @log_exceptions_and_usage + def to_athena(self, table_name: str) -> None: + + if self.on_demand_feature_views: + transformed_df = self.to_df() + + _upload_entity_df( + transformed_df, + self._athena_client, + self._config, + self._s3_resource, + table_name, + ) + + return + + with self._query_generator() as query: + query = f'CREATE TABLE "{table_name}" AS ({query});\n' + + aws_utils.execute_athena_query( + self._athena_client, + self._config.offline_store.data_source, + self._config.offline_store.database, + query, + ) def _upload_entity_df( @@ -496,12 +520,14 @@ def _get_entity_df_event_timestamp_range( f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max " f"FROM ({entity_df})", ) - res = aws_utils.get_athena_query_result(athena_client, statement_id)[ - "Records" - ][0] + res = aws_utils.get_athena_query_result(athena_client, statement_id) entity_df_event_timestamp_range = ( - res.parse(res[0]["stringValue"]), - res.parse(res[1]["stringValue"]), + datetime.strptime( + res["Rows"][1]["Data"][0]["VarCharValue"], "%Y-%m-%d %H:%M:%S.%f" + ), + datetime.strptime( + res["Rows"][1]["Data"][1]["VarCharValue"], "%Y-%m-%d %H:%M:%S.%f" + ), ) else: raise InvalidEntityType(type(entity_df)) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py index facd8ed80c0..1a34de73599 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py @@ -1,24 +1,23 @@ import warnings from typing import Callable, Dict, Iterable, Optional, Tuple -#from feast import type_map +# from feast import type_map from feast import type_map - from feast.data_source import DataSource from feast.errors import DataSourceNotFoundException, RedshiftCredentialsError from feast.feature_logging import LoggingDestination + +# from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.FeatureService_pb2 import ( LoggingConfig as LoggingConfigProto, ) -#from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto -from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto - -''' +""" from feast.protos.feast.core.SavedDataset_pb2 import ( SavedDatasetStorage as SavedDatasetStorageProto, ) -''' +""" from feast.protos.feast.core.SavedDataset_pb2 import ( SavedDatasetStorage as SavedDatasetStorageProto, ) @@ -69,7 +68,7 @@ def __init__( # The default Athena schema is named "public". _database = "default" if table and not database else database self.athena_options = AthenaOptions( - table=table, query=query, database=_database, data_source=data_source + table=table, query=query, database=_database, data_source=data_source ) if table is None and query is None: @@ -92,7 +91,7 @@ def __init__( timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping, - date_partition_column = date_partition_column, + date_partition_column=date_partition_column, description=description, tags=tags, owner=owner, @@ -187,10 +186,15 @@ def validate(self, config: RepoConfig): # the data source is validated. We don't need the results though. self.get_table_column_names_and_types(config) - def get_table_query_string(self) -> str: + def get_table_query_string(self, config: Optional[RepoConfig] = None) -> str: """Returns a string that can directly be used to reference this table in SQL.""" if self.table: - return f'"{self.data_source}"."{self.database}"."{self.table}"' + data_source = self.data_source + database = self.database + if config: + data_source = config.offline_store.data_source + database = config.offline_store.database + return f'"{data_source}"."{database}"."{self.table}"' else: return f"({self.query})" @@ -209,7 +213,9 @@ def get_table_column_names_and_types( """ from botocore.exceptions import ClientError - from feast.infra.offline_stores.contrib.athena_offline_store.athena import AthenaOfflineStoreConfig + from feast.infra.offline_stores.contrib.athena_offline_store.athena import ( + AthenaOfflineStoreConfig, + ) from feast.infra.utils import aws_utils assert isinstance(config.offline_store, AthenaOfflineStoreConfig) @@ -237,7 +243,9 @@ def get_table_column_names_and_types( config.offline_store.database, f"SELECT * FROM ({self.query}) LIMIT 1", ) - columns = aws_utils.get_athena_query_result(client, statement_id)["ResultSetMetadata"]["ColumnInfo"] + columns = aws_utils.get_athena_query_result(client, statement_id)[ + "ResultSetMetadata" + ]["ColumnInfo"] return [(column["Name"], column["Type"].upper()) for column in columns] @@ -301,9 +309,15 @@ class SavedDatasetAthenaStorage(SavedDatasetStorage): athena_options: AthenaOptions - def __init__(self, table_ref: str): + def __init__( + self, + table_ref: str, + query: str = None, + database: str = None, + data_source: str = None, + ): self.athena_options = AthenaOptions( - table=table_ref, query=None, database=None, data_source=None + table=table_ref, query=query, database=database, data_source=data_source ) @staticmethod @@ -314,9 +328,7 @@ def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage: ) def to_proto(self) -> SavedDatasetStorageProto: - return SavedDatasetStorageProto( - athena_storage=self.athena_options.to_proto() - ) + return SavedDatasetStorageProto(athena_storage=self.athena_options.to_proto()) def to_data_source(self) -> DataSource: return AthenaSource(table=self.athena_options.table) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py new file mode 100644 index 00000000000..b582e2f92ed --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -0,0 +1,104 @@ +import uuid +from typing import Any, Dict, List, Optional + +import pandas as pd + +from feast import AthenaSource +from feast.data_source import DataSource +from feast.feature_logging import LoggingDestination +from feast.infra.offline_stores.contrib.athena_offline_store.athena import ( + AthenaOfflineStoreConfig, +) +from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( + AthenaLoggingDestination, + SavedDatasetAthenaStorage, +) +from feast.infra.utils import aws_utils +from feast.repo_config import FeastConfigBaseModel +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.data_source_creator import ( + DataSourceCreator, +) + + +class AthenaDataSourceCreator(DataSourceCreator): + + tables: List[str] = [] + + def __init__(self, project_name: str, *args, **kwargs): + super().__init__(project_name) + self.client = aws_utils.get_athena_data_client("ap-northeast-2") + self.s3 = aws_utils.get_s3_resource("ap-northeast-2") + + self.offline_store_config = AthenaOfflineStoreConfig( + data_source="AwsDataCatalog", + region="ap-northeast-2", + database="sampledb", + s3_staging_location="s3://sagemaker-yelo-test/test_dir", + ) + + def create_data_source( + self, + df: pd.DataFrame, + destination_name: str, + suffix: Optional[str] = None, + timestamp_field="ts", + created_timestamp_column="created_ts", + field_mapping: Dict[str, str] = None, + ) -> DataSource: + + destination_name = self.get_prefixed_table_name(destination_name) + + aws_utils.upload_df_to_athena( + self.client, + self.offline_store_config.data_source, + self.offline_store_config.database, + self.s3, + self.offline_store_config.s3_staging_location, + destination_name, + df, + ) + + self.tables.append(destination_name) + + return AthenaSource( + table=destination_name, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping or {"ts_1": "ts"}, + database=self.offline_store_config.database, + data_source=self.offline_store_config.data_source, + ) + + def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: + table = self.get_prefixed_table_name( + f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" + ) + self.tables.append(table) + + return SavedDatasetAthenaStorage(table_ref=table) + + def create_logged_features_destination(self) -> LoggingDestination: + table = self.get_prefixed_table_name( + f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" + ) + self.tables.append(table) + + return AthenaLoggingDestination(table_name=table) + + def create_offline_store_config(self) -> FeastConfigBaseModel: + return self.offline_store_config + + def get_prefixed_table_name(self, suffix: str) -> str: + return f"{self.project_name}.{suffix}" + + def teardown(self): + for table in self.tables: + aws_utils.execute_athena_query( + self.client, + self.offline_store_config.data_source, + self.offline_store_config.database, + f"DROP TABLE IF EXISTS {table}", + ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py new file mode 100644 index 00000000000..cd74e00aafe --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py @@ -0,0 +1,18 @@ +# from feast.infra.offline_stores.contrib.athena_offline_store.tests.data_source import AthenaDataSourceCreator + +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.data_sources.athena import ( + AthenaDataSourceCreator, +) + +FULL_REPO_CONFIGS = [ + IntegrationTestRepoConfig(), + IntegrationTestRepoConfig( + provider="aws", + offline_store_creator=AthenaDataSourceCreator, + ), +] + +AVAILABLE_OFFLINE_STORES = [("aws", AthenaDataSourceCreator)] diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py index a1dc117b35b..829d46c5ca6 100644 --- a/sdk/python/feast/infra/offline_stores/offline_utils.py +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -93,7 +93,9 @@ class FeatureViewQueryContext: entity_selections: List[str] min_event_timestamp: Optional[str] max_event_timestamp: str - date_partition_column: Optional[str] # this attribute is added because partition pruning affects Athena's query performance. + date_partition_column: Optional[ + str + ] # this attribute is added because partition pruning affects Athena's query performance. def get_feature_view_query_context( @@ -168,7 +170,7 @@ def get_feature_view_query_context( entity_selections=entity_selections, min_event_timestamp=min_event_timestamp, max_event_timestamp=max_event_timestamp, - date_partition_column=date_partition_column + date_partition_column=date_partition_column, ) query_context.append(context) diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index a3da377025c..c212bd09268 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -21,7 +21,7 @@ RedshiftQueryError, RedshiftTableNameTooLong, ) -from feast.type_map import pa_to_redshift_value_type,pa_to_athena_value_type +from feast.type_map import pa_to_athena_value_type, pa_to_redshift_value_type try: import boto3 @@ -675,6 +675,7 @@ def list_s3_files(aws_region: str, path: str) -> List[str]: # Athena + def get_athena_data_client(aws_region: str): """ Get the athena Data API Service client for the given AWS region. @@ -689,7 +690,7 @@ def get_athena_data_client(aws_region: str): reraise=True, ) def execute_athena_query_async( - athena_data_client, data_source: str, database: str, query: str + athena_data_client, data_source: str, database: str, query: str ) -> dict: """Execute Athena statement asynchronously. Does not wait for the query to finish. @@ -708,14 +709,12 @@ def execute_athena_query_async( # return athena_data_client.execute_statement( return athena_data_client.start_query_execution( QueryString=query, - QueryExecutionContext={ - 'Database': database - }, - WorkGroup='primary' + QueryExecutionContext={"Database": database}, + WorkGroup="primary", ) except ClientError as e: - raise AthenaQueryError() + raise AthenaQueryError(e) class AthenaStatementNotFinishedError(Exception): @@ -740,22 +739,28 @@ def wait_for_athena_execution(athena_data_client, execution: dict) -> None: Returns: None """ - response = athena_data_client.get_query_execution(QueryExecutionId=execution["QueryExecutionId"]) + response = athena_data_client.get_query_execution( + QueryExecutionId=execution["QueryExecutionId"] + ) if response["QueryExecution"]["Status"]["State"] in ("QUEUED", "RUNNING"): raise AthenaStatementNotFinishedError # Retry if response["QueryExecution"]["Status"]["State"] != "SUCCEEDED": raise AthenaQueryError(response) # Don't retry. Raise exception. -def drop_temp_table(athena_data_client, data_source: str, database: str, temp_table: str): - query = f'DROP TABLE `{database}.{temp_table}`' - execute_athena_query_async( - athena_data_client, data_source, database, query - ) +def drop_temp_table( + athena_data_client, data_source: str, database: str, temp_table: str +): + query = f"DROP TABLE `{database}.{temp_table}`" + execute_athena_query_async(athena_data_client, data_source, database, query) def execute_athena_query( - athena_data_client, data_source: str, database: str, query: str, temp_table: str = None + athena_data_client, + data_source: str, + database: str, + query: str, + temp_table: str = None, ) -> str: """Execute athena statement synchronously. Waits for the query to finish. @@ -808,13 +813,13 @@ def __init__(self, table_name: str): def unload_athena_query_to_pa( - athena_data_client, - data_source: str, - database: str, - s3_resource, - s3_path: str, - query: str, - temp_table: str, + athena_data_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + query: str, + temp_table: str, ) -> pa.Table: """Unload Athena Query results to S3 and get the results in PyArrow Table format""" bucket, key = get_bucket_and_key(s3_path) @@ -830,13 +835,13 @@ def unload_athena_query_to_pa( def unload_athena_query_to_df( - athena_data_client, - data_source: str, - database: str, - s3_resource, - s3_path: str, - query: str, - temp_table: str, + athena_data_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + query: str, + temp_table: str, ) -> pd.DataFrame: """Unload Athena Query results to S3 and get the results in Pandas DataFrame format""" table = unload_athena_query_to_pa( @@ -846,17 +851,17 @@ def unload_athena_query_to_df( s3_resource, s3_path, query, - temp_table + temp_table, ) return table.to_pandas() def execute_athena_query_and_unload_to_s3( - athena_data_client, - data_source: str, - database: str, - query: str, - temp_table: str, + athena_data_client, + data_source: str, + database: str, + query: str, + temp_table: str, ) -> None: """Unload Athena Query results to S3 @@ -873,13 +878,13 @@ def execute_athena_query_and_unload_to_s3( def upload_df_to_athena( - athena_client, - data_source: str, - database: str, - s3_resource, - s3_path: str, - table_name: str, - df: pd.DataFrame, + athena_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + table_name: str, + df: pd.DataFrame, ): """Uploads a Pandas DataFrame to S3(Athena) as a new table. @@ -920,15 +925,15 @@ def upload_df_to_athena( def upload_arrow_table_to_athena( - table: Union[pyarrow.Table, Path], - athena_client, - data_source: str, - database: str, - s3_resource, - s3_path: str, - table_name: str, - schema: Optional[pyarrow.Schema] = None, - fail_if_exists: bool = True, + table: Union[pyarrow.Table, Path], + athena_client, + data_source: str, + database: str, + s3_resource, + s3_path: str, + table_name: str, + schema: Optional[pyarrow.Schema] = None, + fail_if_exists: bool = True, ): """Uploads an Arrow Table to S3(Athena). @@ -978,7 +983,7 @@ def upload_arrow_table_to_athena( f"CREATE EXTERNAL TABLE {database}.{table_name} " f"({column_query_list}) " f"STORED AS PARQUET " - f"LOCATION '{s3_path[:s3_path.rfind('/')]}' " + f"LOCATION '{s3_path[:s3_path.rfind('/')]}' " f"TBLPROPERTIES('parquet.compress' = 'SNAPPY') " ) diff --git a/sdk/python/feast/templates/athena/example.py b/sdk/python/feast/templates/athena/example.py index d5fec1b8ea2..7e8c2eb6f05 100644 --- a/sdk/python/feast/templates/athena/example.py +++ b/sdk/python/feast/templates/athena/example.py @@ -26,7 +26,9 @@ def test_end_to_end(): ) driver = Entity( - name="driver_id", value_type=ValueType.INT64, description="driver id", + name="driver_id", + value_type=ValueType.INT64, + description="driver id", ) driver_hourly_stats_view = FeatureView( @@ -104,4 +106,4 @@ def test_cli(): if __name__ == "__main__": # pass test_end_to_end() - test_cli() \ No newline at end of file + test_cli() diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index bb0fb360d42..3f4c047229a 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -793,7 +793,6 @@ def pg_type_code_to_arrow(code: int) -> str: ) - def athena_to_feast_value_type(athena_type_as_str: str) -> ValueType: # Type names from https://docs.aws.amazon.com/athena/latest/ug/data-types.html type_map = { diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index c2cf286fdc4..9767ec24f02 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -39,6 +39,10 @@ from tests.integration.feature_repos.universal.data_sources.snowflake import ( SnowflakeDataSourceCreator, ) +from tests.integration.feature_repos.universal.data_sources.athena import ( + AthenaDataSourceCreator, +) + from tests.integration.feature_repos.universal.feature_views import ( conv_rate_plus_100_feature_view, create_conv_rate_request_source, @@ -89,6 +93,7 @@ "bigquery": ("gcp", BigQueryDataSourceCreator), "redshift": ("aws", RedshiftDataSourceCreator), "snowflake": ("aws", SnowflakeDataSourceCreator), + "athena": ("aws", AthenaDataSourceCreator), } AVAILABLE_OFFLINE_STORES: List[Tuple[str, Type[DataSourceCreator]]] = [ @@ -108,6 +113,7 @@ ("gcp", BigQueryDataSourceCreator), ("aws", RedshiftDataSourceCreator), ("aws", SnowflakeDataSourceCreator), + ("aws", AthenaDataSourceCreator), ] ) diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py index d34af4f373d..20add87dfdf 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py @@ -27,7 +27,7 @@ class AthenaDataSourceCreator(DataSourceCreator): tables: List[str] = [] - def __init__(self, project_name: str): + def __init__(self, project_name: str, *args, **kwargs): super().__init__(project_name) self.client = aws_utils.get_athena_data_client("ap-northeast-2") self.s3 = aws_utils.get_s3_resource("ap-northeast-2") @@ -36,7 +36,7 @@ def __init__(self, project_name: str): data_source="AwsDataCatalog", region="ap-northeast-2", database="sampledb", - s3_staging_location="s3://sagemaker-yelo-test", + s3_staging_location="s3://sagemaker-yelo-test/test_dir", ) def create_data_source( @@ -49,22 +49,35 @@ def create_data_source( field_mapping: Dict[str, str] = None, ) -> DataSource: - destination_name = self.get_prefixed_table_name(destination_name) + # destination_name = self.get_prefixed_table_name(destination_name) + # test_name, table_name = destination_name.split('.') + + table_name = destination_name + s3_target = ( + self.offline_store_config.s3_staging_location + + "/" + + self.project_name + + "/" + + table_name + + "/" + + table_name + + ".parquet" + ) aws_utils.upload_df_to_athena( self.client, self.offline_store_config.data_source, self.offline_store_config.database, self.s3, - self.offline_store_config.s3_staging_location, - destination_name, + s3_target, + table_name, df, ) - self.tables.append(destination_name) + self.tables.append(table_name) return AthenaSource( - table=destination_name, + table=table_name, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping or {"ts_1": "ts"}, @@ -78,7 +91,11 @@ def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: ) self.tables.append(table) - return SavedDatasetAthenaStorage(table_ref=table) + return SavedDatasetAthenaStorage( + table_ref=table, + database=self.offline_store_config.database, + data_source=self.offline_store_config.data_source, + ) def create_logged_features_destination(self) -> LoggingDestination: table = self.get_prefixed_table_name( @@ -92,7 +109,7 @@ def create_offline_store_config(self) -> FeastConfigBaseModel: return self.offline_store_config def get_prefixed_table_name(self, suffix: str) -> str: - return f"{self.project_name}.{suffix}" + return f"{self.project_name}_{suffix}" def teardown(self): for table in self.tables: @@ -107,6 +124,9 @@ def teardown(self): FULL_REPO_CONFIGS = [ IntegrationTestRepoConfig(), IntegrationTestRepoConfig( - provider="aws", offline_store_creator=AthenaDataSourceCreator, + provider="aws", + offline_store_creator=AthenaDataSourceCreator, ), -] \ No newline at end of file +] + +AVAILABLE_OFFLINE_STORES = [("aws", AthenaDataSourceCreator)] From 901103633b69e1d6b48b9c52e1ad76ff08702ad6 Mon Sep 17 00:00:00 2001 From: toping4445 Date: Mon, 8 Aug 2022 02:47:34 +0900 Subject: [PATCH 03/11] fixed bugs to pass test_validation Signed-off-by: Youngkyu OH --- .../offline_stores/contrib/athena_offline_store/athena.py | 4 ++-- sdk/python/feast/infra/utils/aws_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py index 79d4461aa7d..3145d43970c 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -149,8 +149,8 @@ def pull_all_from_table_or_query( date_partition_column = data_source.date_partition_column - start_date = start_date.astimezone(tz=utc) - end_date = end_date.astimezone(tz=utc) + start_date = start_date.astimezone(tz=utc).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + end_date = end_date.astimezone(tz=utc).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] query = f""" SELECT {field_string} diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index c212bd09268..72c40e4fc2e 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -971,7 +971,7 @@ def upload_arrow_table_to_athena( bucket, key = get_bucket_and_key(s3_path) column_query_list = ", ".join( - [f"{field.name} {pa_to_athena_value_type(field.type)}" for field in schema] + [f"`{field.name}` {pa_to_athena_value_type(field.type)}" for field in schema] ) with tempfile.TemporaryFile(suffix=".parquet") as parquet_temp_file: From 23905a164710df2d4c9c42fa3bf0149c5ce36e5c Mon Sep 17 00:00:00 2001 From: toping4445 Date: Mon, 8 Aug 2022 09:36:35 +0900 Subject: [PATCH 04/11] changed boolean data type mapping Signed-off-by: Youngkyu OH --- sdk/python/feast/type_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index 3f4c047229a..a9dc4e25da1 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -832,7 +832,7 @@ def pa_to_athena_value_type(pa_type: pyarrow.DataType) -> str: # Other mappings have also been adjusted accordingly. type_map = { "null": "null", - "bool": "bool", + "bool": "boolean", "int8": "tinyint", "int16": "smallint", "int32": "int", From 5fc9e6eaa61563273111121aeda868baf1f414ff Mon Sep 17 00:00:00 2001 From: toping4445 Date: Mon, 8 Aug 2022 10:27:53 +0900 Subject: [PATCH 05/11] 1.added test-python-universal-athena in Makefile 2.replaced database,bucket_name hardcoding to variable in AthenaDataSourceCreator Signed-off-by: Youngkyu OH --- Makefile | 23 ++++++++++- .../athena_offline_store/tests/data_source.py | 39 +++++++++++++------ .../universal/data_sources/athena.py | 14 +++---- 3 files changed, 57 insertions(+), 19 deletions(-) diff --git a/Makefile b/Makefile index 67be3ba2486..06e19ab7c68 100644 --- a/Makefile +++ b/Makefile @@ -139,6 +139,27 @@ test-python-universal-trino: not test_universal_types" \ sdk/python/tests +test-python-universal-athena: + PYTHONPATH='.' \ + FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.athena_repo_configuration \ + PYTEST_PLUGINS=feast.infra.offline_stores.contrib.athena_offline_store.tests \ + FEAST_USAGE=False IS_TEST=True \ + S3_DATABASE=sampledb \ + S3_BUCKET_NAME=sagemaker-yelo-test \ + python -m pytest -n 1 --integration \ + -k "not test_go_feature_server and \ + not test_logged_features_validation and \ + not test_lambda and \ + not test_feature_logging and \ + not test_offline_write and \ + not test_push_offline and \ + not test_historical_retrieval_with_validation and \ + not test_historical_features_persisting and \ + not test_historical_retrieval_fails_on_validation" \ + sdk/python/tests + + + test-python-universal-postgres: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.postgres_repo_configuration \ @@ -229,7 +250,7 @@ install-go-ci-dependencies: python -m pip install pybindgen==0.22.0 protobuf==3.20.1 install-protoc-dependencies: - pip install grpcio-tools==1.47.0 mypy-protobuf==3.1.0 + pip install grpcio-tools==1.48.0 mypy-protobuf==3.1.0 compile-protos-go: install-go-proto-dependencies install-protoc-dependencies python setup.py build_go_protos diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py index b582e2f92ed..ea070c4d426 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -1,4 +1,5 @@ import uuid +import os from typing import Any, Dict, List, Optional import pandas as pd @@ -31,12 +32,14 @@ def __init__(self, project_name: str, *args, **kwargs): super().__init__(project_name) self.client = aws_utils.get_athena_data_client("ap-northeast-2") self.s3 = aws_utils.get_s3_resource("ap-northeast-2") - + data_source = os.environ.get("S3_DATA_SOURCE") if os.environ.get("S3_DATA_SOURCE") else "AwsDataCatalog" + database = os.environ.get("S3_DATABASE") if os.environ.get("S3_DATABASE") else "default" + bucket_name = os.environ.get("S3_BUCKET_NAME") if os.environ.get("S3_BUCKET_NAME") else "feast-integration-tests" self.offline_store_config = AthenaOfflineStoreConfig( - data_source="AwsDataCatalog", + data_source=f"{data_source}", region="ap-northeast-2", - database="sampledb", - s3_staging_location="s3://sagemaker-yelo-test/test_dir", + database=f"{database}", + s3_staging_location=f"s3://{bucket_name}/test_dir", ) def create_data_source( @@ -49,22 +52,32 @@ def create_data_source( field_mapping: Dict[str, str] = None, ) -> DataSource: - destination_name = self.get_prefixed_table_name(destination_name) + table_name = destination_name + s3_target = ( + self.offline_store_config.s3_staging_location + + "/" + + self.project_name + + "/" + + table_name + + "/" + + table_name + + ".parquet" + ) aws_utils.upload_df_to_athena( self.client, self.offline_store_config.data_source, self.offline_store_config.database, self.s3, - self.offline_store_config.s3_staging_location, - destination_name, + s3_target, + table_name, df, ) - self.tables.append(destination_name) + self.tables.append(table_name) return AthenaSource( - table=destination_name, + table=table_name, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping or {"ts_1": "ts"}, @@ -78,7 +91,11 @@ def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: ) self.tables.append(table) - return SavedDatasetAthenaStorage(table_ref=table) + return SavedDatasetAthenaStorage( + table_ref=table, + database=self.offline_store_config.database, + data_source=self.offline_store_config.data_source, + ) def create_logged_features_destination(self) -> LoggingDestination: table = self.get_prefixed_table_name( @@ -92,7 +109,7 @@ def create_offline_store_config(self) -> FeastConfigBaseModel: return self.offline_store_config def get_prefixed_table_name(self, suffix: str) -> str: - return f"{self.project_name}.{suffix}" + return f"{self.project_name}_{suffix}" def teardown(self): for table in self.tables: diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py index 20add87dfdf..4cc078fdd78 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py @@ -1,4 +1,5 @@ import uuid +import os from typing import Any, Dict, List, Optional import pandas as pd @@ -31,12 +32,14 @@ def __init__(self, project_name: str, *args, **kwargs): super().__init__(project_name) self.client = aws_utils.get_athena_data_client("ap-northeast-2") self.s3 = aws_utils.get_s3_resource("ap-northeast-2") - + data_source = os.environ.get("S3_DATA_SOURCE") if os.environ.get("S3_DATA_SOURCE") else "AwsDataCatalog" + database = os.environ.get("S3_DATABASE") if os.environ.get("S3_DATABASE") else "sampledb" + bucket_name = os.environ.get("S3_BUCKET_NAME") if os.environ.get("S3_BUCKET_NAME") else "feast-integration-tests" self.offline_store_config = AthenaOfflineStoreConfig( - data_source="AwsDataCatalog", + data_source=f"{data_source}", region="ap-northeast-2", - database="sampledb", - s3_staging_location="s3://sagemaker-yelo-test/test_dir", + database=f"{database}", + s3_staging_location=f"s3://{bucket_name}/test_dir", ) def create_data_source( @@ -49,9 +52,6 @@ def create_data_source( field_mapping: Dict[str, str] = None, ) -> DataSource: - # destination_name = self.get_prefixed_table_name(destination_name) - # test_name, table_name = destination_name.split('.') - table_name = destination_name s3_target = ( self.offline_store_config.s3_staging_location From e8db74874be794d8a34f9f0926a592d009748c4c Mon Sep 17 00:00:00 2001 From: toping4445 Date: Tue, 9 Aug 2022 17:28:22 +0900 Subject: [PATCH 06/11] format,run lint Signed-off-by: Youngkyu OH --- Makefile | 2 +- sdk/python/feast/data_source.py | 365 +++++++++++++++--- .../contrib/athena_offline_store/athena.py | 10 +- .../athena_offline_store/athena_source.py | 11 +- .../athena_offline_store/tests/data_source.py | 20 +- .../integration/e2e/test_go_feature_server.py | 263 ------------- .../feature_repos/repo_configuration.py | 7 +- .../universal/data_sources/athena.py | 20 +- 8 files changed, 342 insertions(+), 356 deletions(-) delete mode 100644 sdk/python/tests/integration/e2e/test_go_feature_server.py diff --git a/Makefile b/Makefile index 06e19ab7c68..7b0746fec98 100644 --- a/Makefile +++ b/Makefile @@ -250,7 +250,7 @@ install-go-ci-dependencies: python -m pip install pybindgen==0.22.0 protobuf==3.20.1 install-protoc-dependencies: - pip install grpcio-tools==1.48.0 mypy-protobuf==3.1.0 + pip install grpcio-tools==1.47.0 mypy-protobuf==3.1.0 compile-protos-go: install-go-proto-dependencies install-protoc-dependencies python setup.py build_go_protos diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index 89136d2eeed..f53e33c7e74 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -16,7 +16,7 @@ import warnings from abc import ABC, abstractmethod from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from google.protobuf.duration_pb2 import Duration from google.protobuf.json_format import MessageToJson @@ -31,6 +31,19 @@ from feast.value_type import ValueType +class SourceType(enum.Enum): + """ + DataSource value type. Used to define source types in DataSource. + """ + + UNKNOWN = 0 + BATCH_FILE = 1 + BATCH_BIGQUERY = 2 + STREAM_KAFKA = 3 + STREAM_KINESIS = 4 + BATCH_TRINO = 5 + + class KafkaOptions: """ DataSource Kafka options used to source features from Kafka messages @@ -171,13 +184,14 @@ class DataSource(ABC): Args: name: Name of data source, which should be unique within a project - timestamp_field (optional): Event timestamp field used for point-in-time joins of - feature values. + event_timestamp_column (optional): (Deprecated in favor of timestamp_field) Event + timestamp column used for point in time joins of feature values. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. + date_partition_column (optional): Timestamp column used for partitioning. description (optional) A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary @@ -190,6 +204,7 @@ class DataSource(ABC): timestamp_field: str created_timestamp_column: str field_mapping: Dict[str, str] + date_partition_column: str description: str tags: Dict[str, str] owner: str @@ -197,37 +212,68 @@ class DataSource(ABC): def __init__( self, *, - name: str, - timestamp_field: Optional[str] = None, + event_timestamp_column: Optional[str] = None, created_timestamp_column: Optional[str] = None, field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = None, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", + name: Optional[str] = None, + timestamp_field: Optional[str] = None, ): """ Creates a DataSource object. - Args: - name: Name of data source, which should be unique within a project. - timestamp_field (optional): Event timestamp field used for point-in-time joins of - feature values. + name: Name of data source, which should be unique within a project + event_timestamp_column (optional): (Deprecated in favor of timestamp_field) Event + timestamp column used for point in time joins of feature values. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. + date_partition_column (optional): Timestamp column used for partitioning. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. + timestamp_field (optional): Event timestamp field used for point + in time joins of feature values. """ - self.name = name - self.timestamp_field = timestamp_field or "" + if not name: + warnings.warn( + ( + "Names for data sources need to be supplied. " + "Data sources without names will not be supported after Feast 0.24." + ), + UserWarning, + ) + self.name = name or "" + if not timestamp_field and event_timestamp_column: + warnings.warn( + ( + "The argument 'event_timestamp_column' is being deprecated. Please use 'timestamp_field' instead. " + "instead. Feast 0.24 and onwards will not support the argument 'event_timestamp_column' for datasources." + ), + DeprecationWarning, + ) + self.timestamp_field = timestamp_field or event_timestamp_column or "" self.created_timestamp_column = ( created_timestamp_column if created_timestamp_column else "" ) self.field_mapping = field_mapping if field_mapping else {} + self.date_partition_column = ( + date_partition_column if date_partition_column else "" + ) + if date_partition_column: + warnings.warn( + ( + "The argument 'date_partition_column' is being deprecated. " + "Feast 0.25 and onwards will not support 'date_timestamp_column' for data sources." + ), + DeprecationWarning, + ) if ( self.timestamp_field and self.timestamp_field == self.created_timestamp_column @@ -257,6 +303,7 @@ def __eq__(self, other): or self.timestamp_field != other.timestamp_field or self.created_timestamp_column != other.created_timestamp_column or self.field_mapping != other.field_mapping + or self.date_partition_column != other.date_partition_column or self.description != other.description or self.tags != other.tags or self.owner != other.owner @@ -341,18 +388,20 @@ def get_table_query_string(self) -> str: class KafkaSource(DataSource): def __init__( self, - *, - name: str, - timestamp_field: str, - message_format: StreamFormat, + *args, + name: Optional[str] = None, + event_timestamp_column: Optional[str] = "", bootstrap_servers: Optional[str] = None, kafka_bootstrap_servers: Optional[str] = None, + message_format: Optional[StreamFormat] = None, topic: Optional[str] = None, created_timestamp_column: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = "", description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", + timestamp_field: Optional[str] = "", batch_source: Optional[DataSource] = None, watermark_delay_threshold: Optional[timedelta] = None, ): @@ -361,24 +410,41 @@ def __init__( Args: name: Name of data source, which should be unique within a project - timestamp_field: Event timestamp field used for point-in-time joins of feature values. - message_format: StreamFormat of serialized messages. + event_timestamp_column (optional): (Deprecated in favor of timestamp_field) Event + timestamp column used for point in time joins of feature values. bootstrap_servers: (Deprecated) The servers of the kafka broker in the form "localhost:9092". - kafka_bootstrap_servers (optional): The servers of the kafka broker in the form "localhost:9092". - topic (optional): The name of the topic to read from in the kafka source. + kafka_bootstrap_servers: The servers of the kafka broker in the form "localhost:9092". + message_format: StreamFormat of serialized messages. + topic: The name of the topic to read from in the kafka source. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. + date_partition_column (optional): Timestamp column used for partitioning. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. - batch_source (optional): The datasource that acts as a batch source. - watermark_delay_threshold (optional): The watermark delay threshold for stream data. - Specifically how late stream data can arrive without being discarded. + timestamp_field (optional): Event timestamp field used for point + in time joins of feature values. + batch_source: The datasource that acts as a batch source. + watermark_delay_threshold: The watermark delay threshold for stream data. Specifically how + late stream data can arrive without being discarded. """ + positional_attributes = [ + "name", + "event_timestamp_column", + "bootstrap_servers", + "message_format", + "topic", + ] + _name = name + _event_timestamp_column = event_timestamp_column + _kafka_bootstrap_servers = kafka_bootstrap_servers or bootstrap_servers or "" + _message_format = message_format + _topic = topic or "" + if bootstrap_servers: warnings.warn( ( @@ -388,24 +454,53 @@ def __init__( DeprecationWarning, ) + if args: + warnings.warn( + ( + "Kafka parameters should be specified as a keyword argument instead of a positional arg." + "Feast 0.24+ will not support positional arguments to construct Kafka sources" + ), + DeprecationWarning, + ) + if len(args) > len(positional_attributes): + raise ValueError( + f"Only {', '.join(positional_attributes)} are allowed as positional args when defining " + f"Kafka sources, for backwards compatibility." + ) + if len(args) >= 1: + _name = args[0] + if len(args) >= 2: + _event_timestamp_column = args[1] + if len(args) >= 3: + _kafka_bootstrap_servers = args[2] + if len(args) >= 4: + _message_format = args[3] + if len(args) >= 5: + _topic = args[4] + + if _message_format is None: + raise ValueError("Message format must be specified for Kafka source") + + if not timestamp_field and not _event_timestamp_column: + raise ValueError("Timestamp field must be specified for Kafka source") + super().__init__( - name=name, - timestamp_field=timestamp_field, + event_timestamp_column=_event_timestamp_column, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping, + date_partition_column=date_partition_column, description=description, tags=tags, owner=owner, + name=_name, + timestamp_field=timestamp_field, ) self.batch_source = batch_source - kafka_bootstrap_servers = kafka_bootstrap_servers or bootstrap_servers or "" - topic = topic or "" - self.kafka_options = KafkaOptions( - kafka_bootstrap_servers=kafka_bootstrap_servers, - message_format=message_format, - topic=topic, + kafka_bootstrap_servers=_kafka_bootstrap_servers, + message_format=_message_format, + topic=_topic, watermark_delay_threshold=watermark_delay_threshold, ) @@ -445,6 +540,7 @@ def from_proto(data_source: DataSourceProto): ) return KafkaSource( name=data_source.name, + event_timestamp_column=data_source.timestamp_field, field_mapping=dict(data_source.field_mapping), kafka_bootstrap_servers=data_source.kafka_options.kafka_bootstrap_servers, message_format=StreamFormat.from_proto( @@ -454,6 +550,7 @@ def from_proto(data_source: DataSourceProto): topic=data_source.kafka_options.topic, created_timestamp_column=data_source.created_timestamp_column, timestamp_field=data_source.timestamp_field, + date_partition_column=data_source.date_partition_column, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, @@ -475,6 +572,7 @@ def to_proto(self) -> DataSourceProto: data_source_proto.timestamp_field = self.timestamp_field data_source_proto.created_timestamp_column = self.created_timestamp_column + data_source_proto.date_partition_column = self.date_partition_column if self.batch_source: data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto()) return data_source_proto @@ -517,16 +615,55 @@ class RequestSource(DataSource): def __init__( self, - *, - name: str, - schema: List[Field], + *args, + name: Optional[str] = None, + schema: Optional[Union[Dict[str, ValueType], List[Field]]] = None, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", ): """Creates a RequestSource object.""" - super().__init__(name=name, description=description, tags=tags, owner=owner) - self.schema = schema + positional_attributes = ["name", "schema"] + _name = name + _schema = schema + if args: + warnings.warn( + ( + "Request source parameters should be specified as a keyword argument instead of a positional arg." + "Feast 0.24+ will not support positional arguments to construct request sources" + ), + DeprecationWarning, + ) + if len(args) > len(positional_attributes): + raise ValueError( + f"Only {', '.join(positional_attributes)} are allowed as positional args when defining " + f"feature views, for backwards compatibility." + ) + if len(args) >= 1: + _name = args[0] + if len(args) >= 2: + _schema = args[1] + + super().__init__(name=_name, description=description, tags=tags, owner=owner) + if not _schema: + raise ValueError("Schema needs to be provided for Request Source") + if isinstance(_schema, Dict): + warnings.warn( + "Schema in RequestSource is changing type. The schema data type Dict[str, ValueType] is being deprecated in Feast 0.24. " + "Please use List[Field] instead for the schema", + DeprecationWarning, + ) + schema_list = [] + for key, value_type in _schema.items(): + schema_list.append(Field(name=key, dtype=from_value_type(value_type))) + self.schema = schema_list + elif isinstance(_schema, List): + self.schema = _schema + else: + raise Exception( + "Schema type must be either dictionary or list, not " + + str(type(_schema)) + ) def validate(self, config: RepoConfig): pass @@ -558,18 +695,38 @@ def __hash__(self): @staticmethod def from_proto(data_source: DataSourceProto): + + deprecated_schema = data_source.request_data_options.deprecated_schema schema_pb = data_source.request_data_options.schema - list_schema = [] - for field_proto in schema_pb: - list_schema.append(Field.from_proto(field_proto)) - return RequestSource( - name=data_source.name, - schema=list_schema, - description=data_source.description, - tags=dict(data_source.tags), - owner=data_source.owner, - ) + if deprecated_schema and not schema_pb: + warnings.warn( + "Schema in RequestSource is changing type. The schema data type Dict[str, ValueType] is being deprecated in Feast 0.24. " + "Please use List[Field] instead for the schema", + DeprecationWarning, + ) + dict_schema = {} + for key, val in deprecated_schema.items(): + dict_schema[key] = ValueType(val) + return RequestSource( + name=data_source.name, + schema=dict_schema, + description=data_source.description, + tags=dict(data_source.tags), + owner=data_source.owner, + ) + else: + list_schema = [] + for field_proto in schema_pb: + list_schema.append(Field.from_proto(field_proto)) + + return RequestSource( + name=data_source.name, + schema=list_schema, + description=data_source.description, + tags=dict(data_source.tags), + owner=data_source.owner, + ) def to_proto(self) -> DataSourceProto: @@ -602,6 +759,16 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: raise NotImplementedError +@typechecked +class RequestDataSource(RequestSource): + def __init__(self, *args, **kwargs): + warnings.warn( + "The 'RequestDataSource' class is deprecated and was renamed to RequestSource. Please use RequestSource instead. This class name will be removed in Feast 0.24.", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) + + @typechecked class KinesisSource(DataSource): def validate(self, config: RepoConfig): @@ -616,7 +783,7 @@ def get_table_column_names_and_types( def from_proto(data_source: DataSourceProto): return KinesisSource( name=data_source.name, - timestamp_field=data_source.timestamp_field, + event_timestamp_column=data_source.timestamp_field, field_mapping=dict(data_source.field_mapping), record_format=StreamFormat.from_proto( data_source.kinesis_options.record_format @@ -624,6 +791,8 @@ def from_proto(data_source: DataSourceProto): region=data_source.kinesis_options.region, stream_name=data_source.kinesis_options.stream_name, created_timestamp_column=data_source.created_timestamp_column, + timestamp_field=data_source.timestamp_field, + date_partition_column=data_source.date_partition_column, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, @@ -641,34 +810,78 @@ def get_table_query_string(self) -> str: def __init__( self, - *, - name: str, - record_format: StreamFormat, - region: str, - stream_name: str, - timestamp_field: Optional[str] = "", + *args, + name: Optional[str] = None, + event_timestamp_column: Optional[str] = "", created_timestamp_column: Optional[str] = "", + record_format: Optional[StreamFormat] = None, + region: Optional[str] = "", + stream_name: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = "", description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", + timestamp_field: Optional[str] = "", batch_source: Optional[DataSource] = None, ): - if record_format is None: + positional_attributes = [ + "name", + "event_timestamp_column", + "created_timestamp_column", + "record_format", + "region", + "stream_name", + ] + _name = name + _event_timestamp_column = event_timestamp_column + _created_timestamp_column = created_timestamp_column + _record_format = record_format + _region = region or "" + _stream_name = stream_name or "" + if args: + warnings.warn( + ( + "Kinesis parameters should be specified as a keyword argument instead of a positional arg." + "Feast 0.24+ will not support positional arguments to construct kinesis sources" + ), + DeprecationWarning, + ) + if len(args) > len(positional_attributes): + raise ValueError( + f"Only {', '.join(positional_attributes)} are allowed as positional args when defining " + f"kinesis sources, for backwards compatibility." + ) + if len(args) >= 1: + _name = args[0] + if len(args) >= 2: + _event_timestamp_column = args[1] + if len(args) >= 3: + _created_timestamp_column = args[2] + if len(args) >= 4: + _record_format = args[3] + if len(args) >= 5: + _region = args[4] + if len(args) >= 6: + _stream_name = args[5] + + if _record_format is None: raise ValueError("Record format must be specified for kinesis source") super().__init__( - name=name, - timestamp_field=timestamp_field, - created_timestamp_column=created_timestamp_column, + name=_name, + event_timestamp_column=_event_timestamp_column, + created_timestamp_column=_created_timestamp_column, field_mapping=field_mapping, + date_partition_column=date_partition_column, description=description, tags=tags, owner=owner, + timestamp_field=timestamp_field, ) self.batch_source = batch_source self.kinesis_options = KinesisOptions( - record_format=record_format, region=region, stream_name=stream_name + record_format=_record_format, region=_region, stream_name=_stream_name ) def __eq__(self, other): @@ -705,6 +918,7 @@ def to_proto(self) -> DataSourceProto: data_source_proto.timestamp_field = self.timestamp_field data_source_proto.created_timestamp_column = self.created_timestamp_column + data_source_proto.date_partition_column = self.date_partition_column if self.batch_source: data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto()) @@ -729,16 +943,15 @@ class PushSource(DataSource): def __init__( self, - *, - name: str, - batch_source: DataSource, + *args, + name: Optional[str] = None, + batch_source: Optional[DataSource] = None, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", ): """ Creates a PushSource object. - Args: name: Name of the push source batch_source: The batch source that backs this push source. It's used when materializing from the offline @@ -747,9 +960,35 @@ def __init__( tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. + """ - super().__init__(name=name, description=description, tags=tags, owner=owner) - self.batch_source = batch_source + positional_attributes = ["name", "batch_source"] + _name = name + _batch_source = batch_source + if args: + warnings.warn( + ( + "Push source parameters should be specified as a keyword argument instead of a positional arg." + "Feast 0.24+ will not support positional arguments to construct push sources" + ), + DeprecationWarning, + ) + if len(args) > len(positional_attributes): + raise ValueError( + f"Only {', '.join(positional_attributes)} are allowed as positional args when defining " + f"push sources, for backwards compatibility." + ) + if len(args) >= 1: + _name = args[0] + if len(args) >= 2: + _batch_source = args[1] + + super().__init__(name=_name, description=description, tags=tags, owner=owner) + if not _batch_source: + raise ValueError( + f"batch_source parameter is needed for push source {self.name}" + ) + self.batch_source = _batch_source def __eq__(self, other): if not isinstance(other, PushSource): diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py index 3145d43970c..5c33efb9a22 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -17,7 +17,6 @@ import pandas as pd import pyarrow import pyarrow as pa -from dateutil import parser from pydantic import StrictStr from pydantic.typing import Literal from pytz import utc @@ -25,7 +24,7 @@ from feast import OnDemandFeatureView from feast.data_source import DataSource from feast.errors import InvalidEntityType -from feast.feature_logging import LoggingConfig, LoggingDestination, LoggingSource +from feast.feature_logging import LoggingConfig, LoggingSource from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView from feast.infra.offline_stores import offline_utils from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( @@ -39,7 +38,7 @@ RetrievalMetadata, ) from feast.infra.utils import aws_utils -from feast.registry import Registry, BaseRegistry +from feast.registry import BaseRegistry, Registry from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.usage import log_exceptions_and_usage @@ -149,13 +148,10 @@ def pull_all_from_table_or_query( date_partition_column = data_source.date_partition_column - start_date = start_date.astimezone(tz=utc).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] - end_date = end_date.astimezone(tz=utc).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] - query = f""" SELECT {field_string} FROM {from_expression} - WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' + WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date.astimezone(tz=utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}' AND TIMESTAMP '{end_date.astimezone(tz=utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}' {"AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''} """ diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py index 1a34de73599..542ee5606b8 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py @@ -1,23 +1,14 @@ import warnings from typing import Callable, Dict, Iterable, Optional, Tuple -# from feast import type_map from feast import type_map from feast.data_source import DataSource -from feast.errors import DataSourceNotFoundException, RedshiftCredentialsError +from feast.errors import DataSourceNotFoundException from feast.feature_logging import LoggingDestination - -# from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.FeatureService_pb2 import ( LoggingConfig as LoggingConfigProto, ) - -""" -from feast.protos.feast.core.SavedDataset_pb2 import ( - SavedDatasetStorage as SavedDatasetStorageProto, -) -""" from feast.protos.feast.core.SavedDataset_pb2 import ( SavedDatasetStorage as SavedDatasetStorageProto, ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py index ea070c4d426..75a148a8aa0 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -1,5 +1,5 @@ -import uuid import os +import uuid from typing import Any, Dict, List, Optional import pandas as pd @@ -32,9 +32,21 @@ def __init__(self, project_name: str, *args, **kwargs): super().__init__(project_name) self.client = aws_utils.get_athena_data_client("ap-northeast-2") self.s3 = aws_utils.get_s3_resource("ap-northeast-2") - data_source = os.environ.get("S3_DATA_SOURCE") if os.environ.get("S3_DATA_SOURCE") else "AwsDataCatalog" - database = os.environ.get("S3_DATABASE") if os.environ.get("S3_DATABASE") else "default" - bucket_name = os.environ.get("S3_BUCKET_NAME") if os.environ.get("S3_BUCKET_NAME") else "feast-integration-tests" + data_source = ( + os.environ.get("S3_DATA_SOURCE") + if os.environ.get("S3_DATA_SOURCE") + else "AwsDataCatalog" + ) + database = ( + os.environ.get("S3_DATABASE") + if os.environ.get("S3_DATABASE") + else "default" + ) + bucket_name = ( + os.environ.get("S3_BUCKET_NAME") + if os.environ.get("S3_BUCKET_NAME") + else "feast-integration-tests" + ) self.offline_store_config = AthenaOfflineStoreConfig( data_source=f"{data_source}", region="ap-northeast-2", diff --git a/sdk/python/tests/integration/e2e/test_go_feature_server.py b/sdk/python/tests/integration/e2e/test_go_feature_server.py deleted file mode 100644 index 0f972e45df5..00000000000 --- a/sdk/python/tests/integration/e2e/test_go_feature_server.py +++ /dev/null @@ -1,263 +0,0 @@ -import threading -import time -from datetime import datetime -from typing import List - -import grpc -import pandas as pd -import pytest -import pytz -import requests - -from feast.embedded_go.online_features_service import EmbeddedOnlineFeatureServer -from feast.feast_object import FeastObject -from feast.feature_logging import LoggingConfig -from feast.feature_service import FeatureService -from feast.infra.feature_servers.base_config import FeatureLoggingConfig -from feast.protos.feast.serving.ServingService_pb2 import ( - FieldStatus, - GetOnlineFeaturesRequest, - GetOnlineFeaturesResponse, -) -from feast.protos.feast.serving.ServingService_pb2_grpc import ServingServiceStub -from feast.protos.feast.types.Value_pb2 import RepeatedValue -from feast.type_map import python_values_to_proto_values -from feast.value_type import ValueType -from feast.wait import wait_retry_backoff -from tests.integration.feature_repos.repo_configuration import ( - construct_universal_feature_views, -) -from tests.integration.feature_repos.universal.entities import ( - customer, - driver, - location, -) -from tests.utils.http_server import check_port_open, free_port -from tests.utils.test_log_creator import generate_expected_logs, get_latest_rows - - -@pytest.mark.integration -@pytest.mark.goserver -def test_go_grpc_server(grpc_client): - resp: GetOnlineFeaturesResponse = grpc_client.GetOnlineFeatures( - GetOnlineFeaturesRequest( - feature_service="driver_features", - entities={ - "driver_id": RepeatedValue( - val=python_values_to_proto_values( - [5001, 5002], feature_type=ValueType.INT64 - ) - ) - }, - full_feature_names=True, - ) - ) - assert list(resp.metadata.feature_names.val) == [ - "driver_id", - "driver_stats__conv_rate", - "driver_stats__acc_rate", - "driver_stats__avg_daily_trips", - ] - for vector in resp.results: - assert all([s == FieldStatus.PRESENT for s in vector.statuses]) - - -@pytest.mark.integration -@pytest.mark.goserver -def test_go_http_server(http_server_port): - response = requests.post( - f"http://localhost:{http_server_port}/get-online-features", - json={ - "feature_service": "driver_features", - "entities": {"driver_id": [5001, 5002]}, - "full_feature_names": True, - }, - ) - assert response.status_code == 200, response.text - response = response.json() - assert set(response.keys()) == {"metadata", "results"} - metadata = response["metadata"] - results = response["results"] - assert response["metadata"] == { - "feature_names": [ - "driver_id", - "driver_stats__conv_rate", - "driver_stats__acc_rate", - "driver_stats__avg_daily_trips", - ] - }, metadata - assert len(results) == 4, results - assert all( - set(result.keys()) == {"event_timestamps", "statuses", "values"} - for result in results - ), results - assert all( - result["statuses"] == ["PRESENT", "PRESENT"] for result in results - ), results - assert results[0]["values"] == [5001, 5002], results - for result in results[1:]: - assert len(result["values"]) == 2, result - assert all(value is not None for value in result["values"]), result - - -@pytest.mark.integration -@pytest.mark.goserver -@pytest.mark.universal_offline_stores -@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) -def test_feature_logging( - grpc_client, environment, universal_data_sources, full_feature_names -): - fs = environment.feature_store - feature_service = fs.get_feature_service("driver_features") - log_start_date = datetime.now().astimezone(pytz.UTC) - driver_ids = list(range(5001, 5011)) - - for driver_id in driver_ids: - # send each driver id in separate request - grpc_client.GetOnlineFeatures( - GetOnlineFeaturesRequest( - feature_service="driver_features", - entities={ - "driver_id": RepeatedValue( - val=python_values_to_proto_values( - [driver_id], feature_type=ValueType.INT64 - ) - ) - }, - full_feature_names=full_feature_names, - ) - ) - # with some pause - time.sleep(0.1) - - _, datasets, _ = universal_data_sources - latest_rows = get_latest_rows(datasets.driver_df, "driver_id", driver_ids) - feature_view = fs.get_feature_view("driver_stats") - features = [ - feature.name - for proj in feature_service.feature_view_projections - for feature in proj.features - ] - expected_logs = generate_expected_logs( - latest_rows, feature_view, features, ["driver_id"], "event_timestamp" - ) - - def retrieve(): - retrieval_job = fs._get_provider().retrieve_feature_service_logs( - feature_service=feature_service, - start_date=log_start_date, - end_date=datetime.now().astimezone(pytz.UTC), - config=fs.config, - registry=fs._registry, - ) - try: - df = retrieval_job.to_df() - except Exception: - # Table or directory was not created yet - return None, False - - return df, df.shape[0] == len(driver_ids) - - persisted_logs = wait_retry_backoff( - retrieve, timeout_secs=60, timeout_msg="Logs retrieval failed" - ) - - persisted_logs = persisted_logs.sort_values(by="driver_id").reset_index(drop=True) - persisted_logs = persisted_logs[expected_logs.columns] - pd.testing.assert_frame_equal(expected_logs, persisted_logs, check_dtype=False) - - -""" -Start go feature server either on http or grpc based on the repo configuration for testing. -""" - - -def _server_port(environment, server_type: str): - if not environment.test_repo_config.go_feature_serving: - pytest.skip("Only for Go path") - - fs = environment.feature_store - - embedded = EmbeddedOnlineFeatureServer( - repo_path=str(fs.repo_path.absolute()), - repo_config=fs.config, - feature_store=fs, - ) - port = free_port() - if server_type == "grpc": - target = embedded.start_grpc_server - elif server_type == "http": - target = embedded.start_http_server - else: - raise ValueError("Server Type must be either 'http' or 'grpc'") - - t = threading.Thread( - target=target, - args=("127.0.0.1", port), - kwargs=dict( - enable_logging=True, - logging_options=FeatureLoggingConfig( - enabled=True, - queue_capacity=100, - write_to_disk_interval_secs=1, - flush_interval_secs=1, - emit_timeout_micro_secs=10000, - ), - ), - ) - t.start() - - wait_retry_backoff( - lambda: (None, check_port_open("127.0.0.1", port)), timeout_secs=15 - ) - - yield port - if server_type == "grpc": - embedded.stop_grpc_server() - else: - embedded.stop_http_server() - - # wait for graceful stop - time.sleep(5) - - -# Go test fixtures - - -@pytest.fixture -def initialized_registry(environment, universal_data_sources): - fs = environment.feature_store - - _, _, data_sources = universal_data_sources - feature_views = construct_universal_feature_views(data_sources) - - feature_service = FeatureService( - name="driver_features", - features=[feature_views.driver], - logging_config=LoggingConfig( - destination=environment.data_source_creator.create_logged_features_destination(), - sample_rate=1.0, - ), - ) - feast_objects: List[FeastObject] = [feature_service] - feast_objects.extend(feature_views.values()) - feast_objects.extend([driver(), customer(), location()]) - - fs.apply(feast_objects) - fs.materialize(environment.start_date, environment.end_date) - - -@pytest.fixture -def grpc_server_port(environment, initialized_registry): - yield from _server_port(environment, "grpc") - - -@pytest.fixture -def http_server_port(environment, initialized_registry): - yield from _server_port(environment, "http") - - -@pytest.fixture -def grpc_client(grpc_server_port): - ch = grpc.insecure_channel(f"localhost:{grpc_server_port}") - yield ServingServiceStub(ch) diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 9767ec24f02..ce6b5fa873e 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -27,6 +27,9 @@ from tests.integration.feature_repos.universal.data_source_creator import ( DataSourceCreator, ) +from tests.integration.feature_repos.universal.data_sources.athena import ( + AthenaDataSourceCreator, +) from tests.integration.feature_repos.universal.data_sources.bigquery import ( BigQueryDataSourceCreator, ) @@ -39,10 +42,6 @@ from tests.integration.feature_repos.universal.data_sources.snowflake import ( SnowflakeDataSourceCreator, ) -from tests.integration.feature_repos.universal.data_sources.athena import ( - AthenaDataSourceCreator, -) - from tests.integration.feature_repos.universal.feature_views import ( conv_rate_plus_100_feature_view, create_conv_rate_request_source, diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py index 4cc078fdd78..3369fc4290b 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py @@ -1,5 +1,5 @@ -import uuid import os +import uuid from typing import Any, Dict, List, Optional import pandas as pd @@ -32,9 +32,21 @@ def __init__(self, project_name: str, *args, **kwargs): super().__init__(project_name) self.client = aws_utils.get_athena_data_client("ap-northeast-2") self.s3 = aws_utils.get_s3_resource("ap-northeast-2") - data_source = os.environ.get("S3_DATA_SOURCE") if os.environ.get("S3_DATA_SOURCE") else "AwsDataCatalog" - database = os.environ.get("S3_DATABASE") if os.environ.get("S3_DATABASE") else "sampledb" - bucket_name = os.environ.get("S3_BUCKET_NAME") if os.environ.get("S3_BUCKET_NAME") else "feast-integration-tests" + data_source = ( + os.environ.get("S3_DATA_SOURCE") + if os.environ.get("S3_DATA_SOURCE") + else "AwsDataCatalog" + ) + database = ( + os.environ.get("S3_DATABASE") + if os.environ.get("S3_DATABASE") + else "sampledb" + ) + bucket_name = ( + os.environ.get("S3_BUCKET_NAME") + if os.environ.get("S3_BUCKET_NAME") + else "feast-integration-tests" + ) self.offline_store_config = AthenaOfflineStoreConfig( data_source=f"{data_source}", region="ap-northeast-2", From 228555c40ba9b93c294b077bd7693a5b6dda5499 Mon Sep 17 00:00:00 2001 From: Danny Chiao Date: Tue, 9 Aug 2022 10:51:07 -0400 Subject: [PATCH 07/11] revert merge changes Signed-off-by: Danny Chiao --- Makefile | 6 +- sdk/python/feast/data_source.py | 373 ++++-------------- .../contrib/athena_offline_store/athena.py | 9 +- .../athena_offline_store/tests/data_source.py | 5 +- .../contrib/athena_repo_configuration.py | 3 - sdk/python/feast/templates/athena/example.py | 1 - .../feature_repos/repo_configuration.py | 5 - .../universal/data_sources/athena.py | 144 ------- 8 files changed, 80 insertions(+), 466 deletions(-) delete mode 100644 sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py diff --git a/Makefile b/Makefile index 7b0746fec98..9ed25e68f38 100644 --- a/Makefile +++ b/Makefile @@ -146,7 +146,7 @@ test-python-universal-athena: FEAST_USAGE=False IS_TEST=True \ S3_DATABASE=sampledb \ S3_BUCKET_NAME=sagemaker-yelo-test \ - python -m pytest -n 1 --integration \ + python -m pytest -n 8 --integration \ -k "not test_go_feature_server and \ not test_logged_features_validation and \ not test_lambda and \ @@ -155,7 +155,9 @@ test-python-universal-athena: not test_push_offline and \ not test_historical_retrieval_with_validation and \ not test_historical_features_persisting and \ - not test_historical_retrieval_fails_on_validation" \ + not test_historical_retrieval_fails_on_validation and \ + not gcs_registry and \ + not s3_registry" \ sdk/python/tests diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index f53e33c7e74..76b012e5856 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -16,7 +16,7 @@ import warnings from abc import ABC, abstractmethod from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from google.protobuf.duration_pb2 import Duration from google.protobuf.json_format import MessageToJson @@ -31,19 +31,6 @@ from feast.value_type import ValueType -class SourceType(enum.Enum): - """ - DataSource value type. Used to define source types in DataSource. - """ - - UNKNOWN = 0 - BATCH_FILE = 1 - BATCH_BIGQUERY = 2 - STREAM_KAFKA = 3 - STREAM_KINESIS = 4 - BATCH_TRINO = 5 - - class KafkaOptions: """ DataSource Kafka options used to source features from Kafka messages @@ -169,11 +156,11 @@ def to_proto(self) -> DataSourceProto.KinesisOptions: DataSourceProto.SourceType.BATCH_SNOWFLAKE: "feast.infra.offline_stores.snowflake_source.SnowflakeSource", DataSourceProto.SourceType.BATCH_TRINO: "feast.infra.offline_stores.contrib.trino_offline_store.trino_source.TrinoSource", DataSourceProto.SourceType.BATCH_SPARK: "feast.infra.offline_stores.contrib.spark_offline_store.spark_source.SparkSource", + DataSourceProto.SourceType.BATCH_ATHENA: "feast.infra.offline_stores.contrib.athena_offline_store.athena_source.AthenaSource", DataSourceProto.SourceType.STREAM_KAFKA: "feast.data_source.KafkaSource", DataSourceProto.SourceType.STREAM_KINESIS: "feast.data_source.KinesisSource", DataSourceProto.SourceType.REQUEST_SOURCE: "feast.data_source.RequestSource", DataSourceProto.SourceType.PUSH_SOURCE: "feast.data_source.PushSource", - DataSourceProto.SourceType.BATCH_ATHENA: "feast.infra.offline_stores.contrib.athena_offline_store.athena_source.AthenaSource", } @@ -184,96 +171,67 @@ class DataSource(ABC): Args: name: Name of data source, which should be unique within a project - event_timestamp_column (optional): (Deprecated in favor of timestamp_field) Event - timestamp column used for point in time joins of feature values. + timestamp_field (optional): Event timestamp field used for point-in-time joins of + feature values. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. - date_partition_column (optional): Timestamp column used for partitioning. description (optional) A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. timestamp_field (optional): Event timestamp field used for point in time joins of feature values. + date_partition_column (optional): Timestamp column used for partitioning. Not supported by all offline stores. """ name: str timestamp_field: str created_timestamp_column: str field_mapping: Dict[str, str] - date_partition_column: str description: str tags: Dict[str, str] owner: str + date_partition_column: str def __init__( self, *, - event_timestamp_column: Optional[str] = None, + name: str, + timestamp_field: Optional[str] = None, created_timestamp_column: Optional[str] = None, field_mapping: Optional[Dict[str, str]] = None, - date_partition_column: Optional[str] = None, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", - name: Optional[str] = None, - timestamp_field: Optional[str] = None, + date_partition_column: Optional[str] = None, ): """ Creates a DataSource object. + Args: - name: Name of data source, which should be unique within a project - event_timestamp_column (optional): (Deprecated in favor of timestamp_field) Event - timestamp column used for point in time joins of feature values. + name: Name of data source, which should be unique within a project. + timestamp_field (optional): Event timestamp field used for point-in-time joins of + feature values. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. - date_partition_column (optional): Timestamp column used for partitioning. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. - timestamp_field (optional): Event timestamp field used for point - in time joins of feature values. + date_partition_column (optional): Timestamp column used for partitioning. Not supported by all stores """ - if not name: - warnings.warn( - ( - "Names for data sources need to be supplied. " - "Data sources without names will not be supported after Feast 0.24." - ), - UserWarning, - ) - self.name = name or "" - if not timestamp_field and event_timestamp_column: - warnings.warn( - ( - "The argument 'event_timestamp_column' is being deprecated. Please use 'timestamp_field' instead. " - "instead. Feast 0.24 and onwards will not support the argument 'event_timestamp_column' for datasources." - ), - DeprecationWarning, - ) - self.timestamp_field = timestamp_field or event_timestamp_column or "" + self.name = name + self.timestamp_field = timestamp_field or "" self.created_timestamp_column = ( created_timestamp_column if created_timestamp_column else "" ) self.field_mapping = field_mapping if field_mapping else {} - self.date_partition_column = ( - date_partition_column if date_partition_column else "" - ) - if date_partition_column: - warnings.warn( - ( - "The argument 'date_partition_column' is being deprecated. " - "Feast 0.25 and onwards will not support 'date_timestamp_column' for data sources." - ), - DeprecationWarning, - ) if ( self.timestamp_field and self.timestamp_field == self.created_timestamp_column @@ -284,6 +242,9 @@ def __init__( self.description = description or "" self.tags = tags or {} self.owner = owner or "" + self.date_partition_column = ( + date_partition_column if date_partition_column else "" + ) def __hash__(self): return hash((self.name, self.timestamp_field)) @@ -388,20 +349,18 @@ def get_table_query_string(self) -> str: class KafkaSource(DataSource): def __init__( self, - *args, - name: Optional[str] = None, - event_timestamp_column: Optional[str] = "", + *, + name: str, + timestamp_field: str, + message_format: StreamFormat, bootstrap_servers: Optional[str] = None, kafka_bootstrap_servers: Optional[str] = None, - message_format: Optional[StreamFormat] = None, topic: Optional[str] = None, created_timestamp_column: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, - date_partition_column: Optional[str] = "", description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", - timestamp_field: Optional[str] = "", batch_source: Optional[DataSource] = None, watermark_delay_threshold: Optional[timedelta] = None, ): @@ -410,41 +369,24 @@ def __init__( Args: name: Name of data source, which should be unique within a project - event_timestamp_column (optional): (Deprecated in favor of timestamp_field) Event - timestamp column used for point in time joins of feature values. - bootstrap_servers: (Deprecated) The servers of the kafka broker in the form "localhost:9092". - kafka_bootstrap_servers: The servers of the kafka broker in the form "localhost:9092". + timestamp_field: Event timestamp field used for point-in-time joins of feature values. message_format: StreamFormat of serialized messages. - topic: The name of the topic to read from in the kafka source. + bootstrap_servers: (Deprecated) The servers of the kafka broker in the form "localhost:9092". + kafka_bootstrap_servers (optional): The servers of the kafka broker in the form "localhost:9092". + topic (optional): The name of the topic to read from in the kafka source. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. - date_partition_column (optional): Timestamp column used for partitioning. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. - timestamp_field (optional): Event timestamp field used for point - in time joins of feature values. - batch_source: The datasource that acts as a batch source. - watermark_delay_threshold: The watermark delay threshold for stream data. Specifically how - late stream data can arrive without being discarded. + batch_source (optional): The datasource that acts as a batch source. + watermark_delay_threshold (optional): The watermark delay threshold for stream data. + Specifically how late stream data can arrive without being discarded. """ - positional_attributes = [ - "name", - "event_timestamp_column", - "bootstrap_servers", - "message_format", - "topic", - ] - _name = name - _event_timestamp_column = event_timestamp_column - _kafka_bootstrap_servers = kafka_bootstrap_servers or bootstrap_servers or "" - _message_format = message_format - _topic = topic or "" - if bootstrap_servers: warnings.warn( ( @@ -454,53 +396,24 @@ def __init__( DeprecationWarning, ) - if args: - warnings.warn( - ( - "Kafka parameters should be specified as a keyword argument instead of a positional arg." - "Feast 0.24+ will not support positional arguments to construct Kafka sources" - ), - DeprecationWarning, - ) - if len(args) > len(positional_attributes): - raise ValueError( - f"Only {', '.join(positional_attributes)} are allowed as positional args when defining " - f"Kafka sources, for backwards compatibility." - ) - if len(args) >= 1: - _name = args[0] - if len(args) >= 2: - _event_timestamp_column = args[1] - if len(args) >= 3: - _kafka_bootstrap_servers = args[2] - if len(args) >= 4: - _message_format = args[3] - if len(args) >= 5: - _topic = args[4] - - if _message_format is None: - raise ValueError("Message format must be specified for Kafka source") - - if not timestamp_field and not _event_timestamp_column: - raise ValueError("Timestamp field must be specified for Kafka source") - super().__init__( - event_timestamp_column=_event_timestamp_column, + name=name, + timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping, - date_partition_column=date_partition_column, description=description, tags=tags, owner=owner, - name=_name, - timestamp_field=timestamp_field, ) self.batch_source = batch_source + kafka_bootstrap_servers = kafka_bootstrap_servers or bootstrap_servers or "" + topic = topic or "" + self.kafka_options = KafkaOptions( - kafka_bootstrap_servers=_kafka_bootstrap_servers, - message_format=_message_format, - topic=_topic, + kafka_bootstrap_servers=kafka_bootstrap_servers, + message_format=message_format, + topic=topic, watermark_delay_threshold=watermark_delay_threshold, ) @@ -540,7 +453,6 @@ def from_proto(data_source: DataSourceProto): ) return KafkaSource( name=data_source.name, - event_timestamp_column=data_source.timestamp_field, field_mapping=dict(data_source.field_mapping), kafka_bootstrap_servers=data_source.kafka_options.kafka_bootstrap_servers, message_format=StreamFormat.from_proto( @@ -550,7 +462,6 @@ def from_proto(data_source: DataSourceProto): topic=data_source.kafka_options.topic, created_timestamp_column=data_source.created_timestamp_column, timestamp_field=data_source.timestamp_field, - date_partition_column=data_source.date_partition_column, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, @@ -572,7 +483,6 @@ def to_proto(self) -> DataSourceProto: data_source_proto.timestamp_field = self.timestamp_field data_source_proto.created_timestamp_column = self.created_timestamp_column - data_source_proto.date_partition_column = self.date_partition_column if self.batch_source: data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto()) return data_source_proto @@ -615,55 +525,16 @@ class RequestSource(DataSource): def __init__( self, - *args, - name: Optional[str] = None, - schema: Optional[Union[Dict[str, ValueType], List[Field]]] = None, + *, + name: str, + schema: List[Field], description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", ): """Creates a RequestSource object.""" - positional_attributes = ["name", "schema"] - _name = name - _schema = schema - if args: - warnings.warn( - ( - "Request source parameters should be specified as a keyword argument instead of a positional arg." - "Feast 0.24+ will not support positional arguments to construct request sources" - ), - DeprecationWarning, - ) - if len(args) > len(positional_attributes): - raise ValueError( - f"Only {', '.join(positional_attributes)} are allowed as positional args when defining " - f"feature views, for backwards compatibility." - ) - if len(args) >= 1: - _name = args[0] - if len(args) >= 2: - _schema = args[1] - - super().__init__(name=_name, description=description, tags=tags, owner=owner) - if not _schema: - raise ValueError("Schema needs to be provided for Request Source") - if isinstance(_schema, Dict): - warnings.warn( - "Schema in RequestSource is changing type. The schema data type Dict[str, ValueType] is being deprecated in Feast 0.24. " - "Please use List[Field] instead for the schema", - DeprecationWarning, - ) - schema_list = [] - for key, value_type in _schema.items(): - schema_list.append(Field(name=key, dtype=from_value_type(value_type))) - self.schema = schema_list - elif isinstance(_schema, List): - self.schema = _schema - else: - raise Exception( - "Schema type must be either dictionary or list, not " - + str(type(_schema)) - ) + super().__init__(name=name, description=description, tags=tags, owner=owner) + self.schema = schema def validate(self, config: RepoConfig): pass @@ -695,38 +566,18 @@ def __hash__(self): @staticmethod def from_proto(data_source: DataSourceProto): - - deprecated_schema = data_source.request_data_options.deprecated_schema schema_pb = data_source.request_data_options.schema + list_schema = [] + for field_proto in schema_pb: + list_schema.append(Field.from_proto(field_proto)) - if deprecated_schema and not schema_pb: - warnings.warn( - "Schema in RequestSource is changing type. The schema data type Dict[str, ValueType] is being deprecated in Feast 0.24. " - "Please use List[Field] instead for the schema", - DeprecationWarning, - ) - dict_schema = {} - for key, val in deprecated_schema.items(): - dict_schema[key] = ValueType(val) - return RequestSource( - name=data_source.name, - schema=dict_schema, - description=data_source.description, - tags=dict(data_source.tags), - owner=data_source.owner, - ) - else: - list_schema = [] - for field_proto in schema_pb: - list_schema.append(Field.from_proto(field_proto)) - - return RequestSource( - name=data_source.name, - schema=list_schema, - description=data_source.description, - tags=dict(data_source.tags), - owner=data_source.owner, - ) + return RequestSource( + name=data_source.name, + schema=list_schema, + description=data_source.description, + tags=dict(data_source.tags), + owner=data_source.owner, + ) def to_proto(self) -> DataSourceProto: @@ -759,16 +610,6 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: raise NotImplementedError -@typechecked -class RequestDataSource(RequestSource): - def __init__(self, *args, **kwargs): - warnings.warn( - "The 'RequestDataSource' class is deprecated and was renamed to RequestSource. Please use RequestSource instead. This class name will be removed in Feast 0.24.", - DeprecationWarning, - ) - super().__init__(*args, **kwargs) - - @typechecked class KinesisSource(DataSource): def validate(self, config: RepoConfig): @@ -783,7 +624,7 @@ def get_table_column_names_and_types( def from_proto(data_source: DataSourceProto): return KinesisSource( name=data_source.name, - event_timestamp_column=data_source.timestamp_field, + timestamp_field=data_source.timestamp_field, field_mapping=dict(data_source.field_mapping), record_format=StreamFormat.from_proto( data_source.kinesis_options.record_format @@ -791,8 +632,6 @@ def from_proto(data_source: DataSourceProto): region=data_source.kinesis_options.region, stream_name=data_source.kinesis_options.stream_name, created_timestamp_column=data_source.created_timestamp_column, - timestamp_field=data_source.timestamp_field, - date_partition_column=data_source.date_partition_column, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, @@ -810,78 +649,34 @@ def get_table_query_string(self) -> str: def __init__( self, - *args, - name: Optional[str] = None, - event_timestamp_column: Optional[str] = "", + *, + name: str, + record_format: StreamFormat, + region: str, + stream_name: str, + timestamp_field: Optional[str] = "", created_timestamp_column: Optional[str] = "", - record_format: Optional[StreamFormat] = None, - region: Optional[str] = "", - stream_name: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, - date_partition_column: Optional[str] = "", description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", - timestamp_field: Optional[str] = "", batch_source: Optional[DataSource] = None, ): - positional_attributes = [ - "name", - "event_timestamp_column", - "created_timestamp_column", - "record_format", - "region", - "stream_name", - ] - _name = name - _event_timestamp_column = event_timestamp_column - _created_timestamp_column = created_timestamp_column - _record_format = record_format - _region = region or "" - _stream_name = stream_name or "" - if args: - warnings.warn( - ( - "Kinesis parameters should be specified as a keyword argument instead of a positional arg." - "Feast 0.24+ will not support positional arguments to construct kinesis sources" - ), - DeprecationWarning, - ) - if len(args) > len(positional_attributes): - raise ValueError( - f"Only {', '.join(positional_attributes)} are allowed as positional args when defining " - f"kinesis sources, for backwards compatibility." - ) - if len(args) >= 1: - _name = args[0] - if len(args) >= 2: - _event_timestamp_column = args[1] - if len(args) >= 3: - _created_timestamp_column = args[2] - if len(args) >= 4: - _record_format = args[3] - if len(args) >= 5: - _region = args[4] - if len(args) >= 6: - _stream_name = args[5] - - if _record_format is None: + if record_format is None: raise ValueError("Record format must be specified for kinesis source") super().__init__( - name=_name, - event_timestamp_column=_event_timestamp_column, - created_timestamp_column=_created_timestamp_column, + name=name, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, field_mapping=field_mapping, - date_partition_column=date_partition_column, description=description, tags=tags, owner=owner, - timestamp_field=timestamp_field, ) self.batch_source = batch_source self.kinesis_options = KinesisOptions( - record_format=_record_format, region=_region, stream_name=_stream_name + record_format=record_format, region=region, stream_name=stream_name ) def __eq__(self, other): @@ -918,7 +713,6 @@ def to_proto(self) -> DataSourceProto: data_source_proto.timestamp_field = self.timestamp_field data_source_proto.created_timestamp_column = self.created_timestamp_column - data_source_proto.date_partition_column = self.date_partition_column if self.batch_source: data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto()) @@ -943,15 +737,16 @@ class PushSource(DataSource): def __init__( self, - *args, - name: Optional[str] = None, - batch_source: Optional[DataSource] = None, + *, + name: str, + batch_source: DataSource, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", ): """ Creates a PushSource object. + Args: name: Name of the push source batch_source: The batch source that backs this push source. It's used when materializing from the offline @@ -960,35 +755,9 @@ def __init__( tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. - """ - positional_attributes = ["name", "batch_source"] - _name = name - _batch_source = batch_source - if args: - warnings.warn( - ( - "Push source parameters should be specified as a keyword argument instead of a positional arg." - "Feast 0.24+ will not support positional arguments to construct push sources" - ), - DeprecationWarning, - ) - if len(args) > len(positional_attributes): - raise ValueError( - f"Only {', '.join(positional_attributes)} are allowed as positional args when defining " - f"push sources, for backwards compatibility." - ) - if len(args) >= 1: - _name = args[0] - if len(args) >= 2: - _batch_source = args[1] - - super().__init__(name=_name, description=description, tags=tags, owner=owner) - if not _batch_source: - raise ValueError( - f"batch_source parameter is needed for push source {self.name}" - ) - self.batch_source = _batch_source + super().__init__(name=name, description=description, tags=tags, owner=owner) + self.batch_source = batch_source def __eq__(self, other): if not isinstance(other, PushSource): diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py index 5c33efb9a22..bbbc6170e1b 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -103,7 +103,6 @@ def pull_latest_from_table_or_query( end_date = end_date.astimezone(tz=utc) query = f""" - SELECT {field_string} {f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""} @@ -358,13 +357,13 @@ def get_temp_table_dml_header( self, temp_table_name: str, temp_external_location: str ) -> str: temp_table_dml_header = f""" - CREATE TABLE {temp_table_name} + CREATE TABLE {temp_table_name} WITH ( - external_location = '{temp_external_location}', + external_location = '{temp_external_location}', format = 'parquet', write_compression = 'snappy' ) - as + as """ return temp_table_dml_header @@ -598,7 +597,7 @@ def _get_entity_df_event_timestamp_range( {% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %} AND {{ featureview.date_partition_column }} <= '{{ featureview.max_event_timestamp[:10] }}' {% endif %} - + {% if featureview.ttl == 0 %}{% else %} AND {{ featureview.timestamp_field }} >= from_iso8601_timestamp('{{ featureview.min_event_timestamp }}') {% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %} diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py index 75a148a8aa0..2020e78d36e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -1,6 +1,6 @@ import os import uuid -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import pandas as pd @@ -16,9 +16,6 @@ ) from feast.infra.utils import aws_utils from feast.repo_config import FeastConfigBaseModel -from tests.integration.feature_repos.integration_test_repo_config import ( - IntegrationTestRepoConfig, -) from tests.integration.feature_repos.universal.data_source_creator import ( DataSourceCreator, ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py index cd74e00aafe..32376eb6527 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_repo_configuration.py @@ -1,5 +1,3 @@ -# from feast.infra.offline_stores.contrib.athena_offline_store.tests.data_source import AthenaDataSourceCreator - from tests.integration.feature_repos.integration_test_repo_config import ( IntegrationTestRepoConfig, ) @@ -8,7 +6,6 @@ ) FULL_REPO_CONFIGS = [ - IntegrationTestRepoConfig(), IntegrationTestRepoConfig( provider="aws", offline_store_creator=AthenaDataSourceCreator, diff --git a/sdk/python/feast/templates/athena/example.py b/sdk/python/feast/templates/athena/example.py index 7e8c2eb6f05..768a2709dc6 100644 --- a/sdk/python/feast/templates/athena/example.py +++ b/sdk/python/feast/templates/athena/example.py @@ -1,4 +1,3 @@ -import importlib import os from datetime import datetime, timedelta diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index ce6b5fa873e..c2cf286fdc4 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -27,9 +27,6 @@ from tests.integration.feature_repos.universal.data_source_creator import ( DataSourceCreator, ) -from tests.integration.feature_repos.universal.data_sources.athena import ( - AthenaDataSourceCreator, -) from tests.integration.feature_repos.universal.data_sources.bigquery import ( BigQueryDataSourceCreator, ) @@ -92,7 +89,6 @@ "bigquery": ("gcp", BigQueryDataSourceCreator), "redshift": ("aws", RedshiftDataSourceCreator), "snowflake": ("aws", SnowflakeDataSourceCreator), - "athena": ("aws", AthenaDataSourceCreator), } AVAILABLE_OFFLINE_STORES: List[Tuple[str, Type[DataSourceCreator]]] = [ @@ -112,7 +108,6 @@ ("gcp", BigQueryDataSourceCreator), ("aws", RedshiftDataSourceCreator), ("aws", SnowflakeDataSourceCreator), - ("aws", AthenaDataSourceCreator), ] ) diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py deleted file mode 100644 index 3369fc4290b..00000000000 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/athena.py +++ /dev/null @@ -1,144 +0,0 @@ -import os -import uuid -from typing import Any, Dict, List, Optional - -import pandas as pd - -from feast import AthenaSource -from feast.data_source import DataSource -from feast.feature_logging import LoggingDestination -from feast.infra.offline_stores.contrib.athena_offline_store.athena import ( - AthenaOfflineStoreConfig, -) -from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import ( - AthenaLoggingDestination, - SavedDatasetAthenaStorage, -) -from feast.infra.utils import aws_utils -from feast.repo_config import FeastConfigBaseModel -from tests.integration.feature_repos.integration_test_repo_config import ( - IntegrationTestRepoConfig, -) -from tests.integration.feature_repos.universal.data_source_creator import ( - DataSourceCreator, -) - - -class AthenaDataSourceCreator(DataSourceCreator): - - tables: List[str] = [] - - def __init__(self, project_name: str, *args, **kwargs): - super().__init__(project_name) - self.client = aws_utils.get_athena_data_client("ap-northeast-2") - self.s3 = aws_utils.get_s3_resource("ap-northeast-2") - data_source = ( - os.environ.get("S3_DATA_SOURCE") - if os.environ.get("S3_DATA_SOURCE") - else "AwsDataCatalog" - ) - database = ( - os.environ.get("S3_DATABASE") - if os.environ.get("S3_DATABASE") - else "sampledb" - ) - bucket_name = ( - os.environ.get("S3_BUCKET_NAME") - if os.environ.get("S3_BUCKET_NAME") - else "feast-integration-tests" - ) - self.offline_store_config = AthenaOfflineStoreConfig( - data_source=f"{data_source}", - region="ap-northeast-2", - database=f"{database}", - s3_staging_location=f"s3://{bucket_name}/test_dir", - ) - - def create_data_source( - self, - df: pd.DataFrame, - destination_name: str, - suffix: Optional[str] = None, - timestamp_field="ts", - created_timestamp_column="created_ts", - field_mapping: Dict[str, str] = None, - ) -> DataSource: - - table_name = destination_name - s3_target = ( - self.offline_store_config.s3_staging_location - + "/" - + self.project_name - + "/" - + table_name - + "/" - + table_name - + ".parquet" - ) - - aws_utils.upload_df_to_athena( - self.client, - self.offline_store_config.data_source, - self.offline_store_config.database, - self.s3, - s3_target, - table_name, - df, - ) - - self.tables.append(table_name) - - return AthenaSource( - table=table_name, - timestamp_field=timestamp_field, - created_timestamp_column=created_timestamp_column, - field_mapping=field_mapping or {"ts_1": "ts"}, - database=self.offline_store_config.database, - data_source=self.offline_store_config.data_source, - ) - - def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: - table = self.get_prefixed_table_name( - f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" - ) - self.tables.append(table) - - return SavedDatasetAthenaStorage( - table_ref=table, - database=self.offline_store_config.database, - data_source=self.offline_store_config.data_source, - ) - - def create_logged_features_destination(self) -> LoggingDestination: - table = self.get_prefixed_table_name( - f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" - ) - self.tables.append(table) - - return AthenaLoggingDestination(table_name=table) - - def create_offline_store_config(self) -> FeastConfigBaseModel: - return self.offline_store_config - - def get_prefixed_table_name(self, suffix: str) -> str: - return f"{self.project_name}_{suffix}" - - def teardown(self): - for table in self.tables: - aws_utils.execute_athena_query( - self.client, - self.offline_store_config.data_source, - self.offline_store_config.database, - f"DROP TABLE IF EXISTS {table}", - ) - - -FULL_REPO_CONFIGS = [ - IntegrationTestRepoConfig(), - IntegrationTestRepoConfig( - provider="aws", - offline_store_creator=AthenaDataSourceCreator, - ), -] - -AVAILABLE_OFFLINE_STORES = [("aws", AthenaDataSourceCreator)] From e9b6180e229964282f719874aabda2af1340473d Mon Sep 17 00:00:00 2001 From: Danny Chiao Date: Tue, 9 Aug 2022 11:20:55 -0400 Subject: [PATCH 08/11] add entity_key_serialization Signed-off-by: Danny Chiao --- sdk/python/feast/templates/athena/feature_store.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/templates/athena/feature_store.yaml b/sdk/python/feast/templates/athena/feature_store.yaml index ee88bda72a1..13e7898e861 100644 --- a/sdk/python/feast/templates/athena/feature_store.yaml +++ b/sdk/python/feast/templates/athena/feature_store.yaml @@ -9,4 +9,5 @@ offline_store: region: ap-northeast-2 database: sampledb data_source: AwsDataCatalog - s3_staging_location: s3://sagemaker-yelo-test \ No newline at end of file + s3_staging_location: s3://sagemaker-yelo-test +entity_key_serialization_version: 2 \ No newline at end of file From c7e59a33a6f59e20c6a9aa88fbf5f7e94a4ebf38 Mon Sep 17 00:00:00 2001 From: Danny Chiao Date: Tue, 9 Aug 2022 11:24:08 -0400 Subject: [PATCH 09/11] restore deleted file Signed-off-by: Danny Chiao --- .../integration/e2e/test_go_feature_server.py | 263 ++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 sdk/python/tests/integration/e2e/test_go_feature_server.py diff --git a/sdk/python/tests/integration/e2e/test_go_feature_server.py b/sdk/python/tests/integration/e2e/test_go_feature_server.py new file mode 100644 index 00000000000..0f972e45df5 --- /dev/null +++ b/sdk/python/tests/integration/e2e/test_go_feature_server.py @@ -0,0 +1,263 @@ +import threading +import time +from datetime import datetime +from typing import List + +import grpc +import pandas as pd +import pytest +import pytz +import requests + +from feast.embedded_go.online_features_service import EmbeddedOnlineFeatureServer +from feast.feast_object import FeastObject +from feast.feature_logging import LoggingConfig +from feast.feature_service import FeatureService +from feast.infra.feature_servers.base_config import FeatureLoggingConfig +from feast.protos.feast.serving.ServingService_pb2 import ( + FieldStatus, + GetOnlineFeaturesRequest, + GetOnlineFeaturesResponse, +) +from feast.protos.feast.serving.ServingService_pb2_grpc import ServingServiceStub +from feast.protos.feast.types.Value_pb2 import RepeatedValue +from feast.type_map import python_values_to_proto_values +from feast.value_type import ValueType +from feast.wait import wait_retry_backoff +from tests.integration.feature_repos.repo_configuration import ( + construct_universal_feature_views, +) +from tests.integration.feature_repos.universal.entities import ( + customer, + driver, + location, +) +from tests.utils.http_server import check_port_open, free_port +from tests.utils.test_log_creator import generate_expected_logs, get_latest_rows + + +@pytest.mark.integration +@pytest.mark.goserver +def test_go_grpc_server(grpc_client): + resp: GetOnlineFeaturesResponse = grpc_client.GetOnlineFeatures( + GetOnlineFeaturesRequest( + feature_service="driver_features", + entities={ + "driver_id": RepeatedValue( + val=python_values_to_proto_values( + [5001, 5002], feature_type=ValueType.INT64 + ) + ) + }, + full_feature_names=True, + ) + ) + assert list(resp.metadata.feature_names.val) == [ + "driver_id", + "driver_stats__conv_rate", + "driver_stats__acc_rate", + "driver_stats__avg_daily_trips", + ] + for vector in resp.results: + assert all([s == FieldStatus.PRESENT for s in vector.statuses]) + + +@pytest.mark.integration +@pytest.mark.goserver +def test_go_http_server(http_server_port): + response = requests.post( + f"http://localhost:{http_server_port}/get-online-features", + json={ + "feature_service": "driver_features", + "entities": {"driver_id": [5001, 5002]}, + "full_feature_names": True, + }, + ) + assert response.status_code == 200, response.text + response = response.json() + assert set(response.keys()) == {"metadata", "results"} + metadata = response["metadata"] + results = response["results"] + assert response["metadata"] == { + "feature_names": [ + "driver_id", + "driver_stats__conv_rate", + "driver_stats__acc_rate", + "driver_stats__avg_daily_trips", + ] + }, metadata + assert len(results) == 4, results + assert all( + set(result.keys()) == {"event_timestamps", "statuses", "values"} + for result in results + ), results + assert all( + result["statuses"] == ["PRESENT", "PRESENT"] for result in results + ), results + assert results[0]["values"] == [5001, 5002], results + for result in results[1:]: + assert len(result["values"]) == 2, result + assert all(value is not None for value in result["values"]), result + + +@pytest.mark.integration +@pytest.mark.goserver +@pytest.mark.universal_offline_stores +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) +def test_feature_logging( + grpc_client, environment, universal_data_sources, full_feature_names +): + fs = environment.feature_store + feature_service = fs.get_feature_service("driver_features") + log_start_date = datetime.now().astimezone(pytz.UTC) + driver_ids = list(range(5001, 5011)) + + for driver_id in driver_ids: + # send each driver id in separate request + grpc_client.GetOnlineFeatures( + GetOnlineFeaturesRequest( + feature_service="driver_features", + entities={ + "driver_id": RepeatedValue( + val=python_values_to_proto_values( + [driver_id], feature_type=ValueType.INT64 + ) + ) + }, + full_feature_names=full_feature_names, + ) + ) + # with some pause + time.sleep(0.1) + + _, datasets, _ = universal_data_sources + latest_rows = get_latest_rows(datasets.driver_df, "driver_id", driver_ids) + feature_view = fs.get_feature_view("driver_stats") + features = [ + feature.name + for proj in feature_service.feature_view_projections + for feature in proj.features + ] + expected_logs = generate_expected_logs( + latest_rows, feature_view, features, ["driver_id"], "event_timestamp" + ) + + def retrieve(): + retrieval_job = fs._get_provider().retrieve_feature_service_logs( + feature_service=feature_service, + start_date=log_start_date, + end_date=datetime.now().astimezone(pytz.UTC), + config=fs.config, + registry=fs._registry, + ) + try: + df = retrieval_job.to_df() + except Exception: + # Table or directory was not created yet + return None, False + + return df, df.shape[0] == len(driver_ids) + + persisted_logs = wait_retry_backoff( + retrieve, timeout_secs=60, timeout_msg="Logs retrieval failed" + ) + + persisted_logs = persisted_logs.sort_values(by="driver_id").reset_index(drop=True) + persisted_logs = persisted_logs[expected_logs.columns] + pd.testing.assert_frame_equal(expected_logs, persisted_logs, check_dtype=False) + + +""" +Start go feature server either on http or grpc based on the repo configuration for testing. +""" + + +def _server_port(environment, server_type: str): + if not environment.test_repo_config.go_feature_serving: + pytest.skip("Only for Go path") + + fs = environment.feature_store + + embedded = EmbeddedOnlineFeatureServer( + repo_path=str(fs.repo_path.absolute()), + repo_config=fs.config, + feature_store=fs, + ) + port = free_port() + if server_type == "grpc": + target = embedded.start_grpc_server + elif server_type == "http": + target = embedded.start_http_server + else: + raise ValueError("Server Type must be either 'http' or 'grpc'") + + t = threading.Thread( + target=target, + args=("127.0.0.1", port), + kwargs=dict( + enable_logging=True, + logging_options=FeatureLoggingConfig( + enabled=True, + queue_capacity=100, + write_to_disk_interval_secs=1, + flush_interval_secs=1, + emit_timeout_micro_secs=10000, + ), + ), + ) + t.start() + + wait_retry_backoff( + lambda: (None, check_port_open("127.0.0.1", port)), timeout_secs=15 + ) + + yield port + if server_type == "grpc": + embedded.stop_grpc_server() + else: + embedded.stop_http_server() + + # wait for graceful stop + time.sleep(5) + + +# Go test fixtures + + +@pytest.fixture +def initialized_registry(environment, universal_data_sources): + fs = environment.feature_store + + _, _, data_sources = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + feature_service = FeatureService( + name="driver_features", + features=[feature_views.driver], + logging_config=LoggingConfig( + destination=environment.data_source_creator.create_logged_features_destination(), + sample_rate=1.0, + ), + ) + feast_objects: List[FeastObject] = [feature_service] + feast_objects.extend(feature_views.values()) + feast_objects.extend([driver(), customer(), location()]) + + fs.apply(feast_objects) + fs.materialize(environment.start_date, environment.end_date) + + +@pytest.fixture +def grpc_server_port(environment, initialized_registry): + yield from _server_port(environment, "grpc") + + +@pytest.fixture +def http_server_port(environment, initialized_registry): + yield from _server_port(environment, "http") + + +@pytest.fixture +def grpc_client(grpc_server_port): + ch = grpc.insecure_channel(f"localhost:{grpc_server_port}") + yield ServingServiceStub(ch) From 327289edb60ba3a8a17ec59b6bd02e22926e4645 Mon Sep 17 00:00:00 2001 From: Youngkyu OH Date: Wed, 10 Aug 2022 03:05:46 +0900 Subject: [PATCH 10/11] modified confusing environment variable names, added how to use Athena Signed-off-by: Youngkyu OH --- Makefile | 8 ++++++-- .../athena_offline_store/tests/data_source.py | 12 ++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 9ed25e68f38..d5f1519a8b2 100644 --- a/Makefile +++ b/Makefile @@ -139,13 +139,17 @@ test-python-universal-trino: not test_universal_types" \ sdk/python/tests +#To use Athena as an offline store, you need to create an Athena database and an S3 bucket on AWS. https://docs.aws.amazon.com/athena/latest/ug/getting-started.html +#Modify environment variables ATHENA_DATA_SOURCE, ATHENA_DATABASE, ATHENA_S3_BUCKET_NAME if you want to change the data source, database, and bucket name of S3 to use. +#If tests fail with the pytest -n 8 option, change the number to 1. test-python-universal-athena: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.athena_repo_configuration \ PYTEST_PLUGINS=feast.infra.offline_stores.contrib.athena_offline_store.tests \ FEAST_USAGE=False IS_TEST=True \ - S3_DATABASE=sampledb \ - S3_BUCKET_NAME=sagemaker-yelo-test \ + ATHENA_DATA_SOURCE=AwsDataCatalog \ + ATHENA_DATABASE=default \ + ATHENA_S3_BUCKET_NAME=feast-integration-tests \ python -m pytest -n 8 --integration \ -k "not test_go_feature_server and \ not test_logged_features_validation and \ diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py index 2020e78d36e..92e0d6e5f60 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -30,18 +30,18 @@ def __init__(self, project_name: str, *args, **kwargs): self.client = aws_utils.get_athena_data_client("ap-northeast-2") self.s3 = aws_utils.get_s3_resource("ap-northeast-2") data_source = ( - os.environ.get("S3_DATA_SOURCE") - if os.environ.get("S3_DATA_SOURCE") + os.environ.get("ATHENA_DATA_SOURCE") + if os.environ.get("ATHENA_DATA_SOURCE") else "AwsDataCatalog" ) database = ( - os.environ.get("S3_DATABASE") - if os.environ.get("S3_DATABASE") + os.environ.get("ATHENA_DATABASE") + if os.environ.get("ATHENA_DATABASE") else "default" ) bucket_name = ( - os.environ.get("S3_BUCKET_NAME") - if os.environ.get("S3_BUCKET_NAME") + os.environ.get("ATHENA_S3_BUCKET_NAME") + if os.environ.get("ATHENA_S3_BUCKET_NAME") else "feast-integration-tests" ) self.offline_store_config = AthenaOfflineStoreConfig( From ce1710c9f08d9fb69a0380e329f6aa1ae92dacdd Mon Sep 17 00:00:00 2001 From: Youngkyu OH Date: Wed, 10 Aug 2022 10:19:23 +0900 Subject: [PATCH 11/11] enforce AthenaSource to have a name Signed-off-by: Youngkyu OH --- .../athena_offline_store/athena_source.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py index 542ee5606b8..f96dc0d048f 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py @@ -1,9 +1,8 @@ -import warnings from typing import Callable, Dict, Iterable, Optional, Tuple from feast import type_map from feast.data_source import DataSource -from feast.errors import DataSourceNotFoundException +from feast.errors import DataSourceNoNameException, DataSourceNotFoundException from feast.feature_logging import LoggingDestination from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.FeatureService_pb2 import ( @@ -56,7 +55,7 @@ def __init__( """ - # The default Athena schema is named "public". + _database = "default" if table and not database else database self.athena_options = AthenaOptions( table=table, query=query, database=_database, data_source=data_source @@ -64,18 +63,12 @@ def __init__( if table is None and query is None: raise ValueError('No "table" argument provided.') - _name = name - if not _name: - if table: - _name = table - else: - warnings.warn( - ( - f"Starting in Feast 0.21, Feast will require either a name for a data source (if using query) " - f"or `table`: {self.query}" - ), - DeprecationWarning, - ) + + # If no name, use the table as the default name. + if name is None and table is None: + raise DataSourceNoNameException() + _name = name or table + assert _name super().__init__( name=_name if _name else "",