diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index b3e51b48162..51c4e9d78ec 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -111,7 +111,9 @@ def update_data_sources_with_inferred_event_timestamp_col( assert ( isinstance(data_source, FileSource) or isinstance(data_source, BigQuerySource) + or isinstance(data_source, RedshiftSource) or isinstance(data_source, SnowflakeSource) + or "SparkSource" == data_source.__class__.__name__ ) # loop through table columns to find singular match diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 61cad5606f1..ef57977fcbe 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -203,6 +203,9 @@ class UniversalDataSources: global_ds: DataSource field_mapping: DataSource + def values(self): + return dataclasses.asdict(self).values() + def construct_universal_data_sources( datasets: UniversalDatasets, data_source_creator: DataSourceCreator diff --git a/sdk/python/tests/integration/registration/test_inference.py b/sdk/python/tests/integration/registration/test_inference.py index 2582e69ea37..0ea6276669d 100644 --- a/sdk/python/tests/integration/registration/test_inference.py +++ b/sdk/python/tests/integration/registration/test_inference.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import pandas as pd import pytest @@ -111,7 +113,7 @@ def test_infer_datasource_names_dwh(): @pytest.mark.integration -def test_update_data_sources_with_inferred_event_timestamp_col(simple_dataset_1): +def test_update_file_data_source_with_inferred_event_timestamp_col(simple_dataset_1): df_with_two_viable_timestamp_cols = simple_dataset_1.copy(deep=True) df_with_two_viable_timestamp_cols["ts_2"] = simple_dataset_1["ts_1"] @@ -138,6 +140,28 @@ def test_update_data_sources_with_inferred_event_timestamp_col(simple_dataset_1) ) +@pytest.mark.integration +@pytest.mark.universal +def test_update_data_sources_with_inferred_event_timestamp_col(universal_data_sources): + (_, _, data_sources) = universal_data_sources + data_sources_copy = deepcopy(data_sources) + + # remove defined event_timestamp_column to allow for inference + for data_source in data_sources_copy.values(): + data_source.event_timestamp_column = None + + update_data_sources_with_inferred_event_timestamp_col( + data_sources_copy.values(), RepoConfig(provider="local", project="test"), + ) + actual_event_timestamp_cols = [ + source.event_timestamp_column for source in data_sources_copy.values() + ] + + assert actual_event_timestamp_cols == ["event_timestamp"] * len( + data_sources_copy.values() + ) + + def test_on_demand_features_type_inference(): # Create Feature Views date_request = RequestDataSource(