diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index a27065fb5ed..1083cc56278 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -185,6 +185,18 @@ def get_table_query_string(self) -> str: return f"`{tmp_table_name}`" + # Note: Python requires redefining hash in child classes that override __eq__ + def __hash__(self): + + return super().__hash__() + + def __eq__(self, other): + if not isinstance(other, SparkSource): + raise TypeError( + "Comparisons should only involve SparkSource class objects." + ) + return super().__eq__(other) and self.spark_options == other.spark_options + class SparkOptions: allowed_formats = [format.value for format in SparkSourceFormat] @@ -282,6 +294,19 @@ def to_proto(self) -> DataSourceProto.SparkOptions: return spark_options_proto + def __eq__(self, other: object) -> bool: + if not isinstance(other, SparkOptions): + raise TypeError( + "Comparisons should only involve SparkOptions class objects." + ) + + return ( + self.table == other.table + and self.query == other.query + and self.path == other.path + and self.file_format == other.file_format + ) + class SavedDatasetSparkStorage(SavedDatasetStorage): _proto_attr_name = "spark_storage" diff --git a/sdk/python/tests/unit/test_data_sources.py b/sdk/python/tests/unit/test_data_sources.py index 990c5d3b698..990752621b0 100644 --- a/sdk/python/tests/unit/test_data_sources.py +++ b/sdk/python/tests/unit/test_data_sources.py @@ -10,6 +10,9 @@ ) from feast.field import Field from feast.infra.offline_stores.bigquery_source import BigQuerySource +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) from feast.infra.offline_stores.file_source import FileSource from feast.infra.offline_stores.redshift_source import RedshiftSource from feast.infra.offline_stores.snowflake_source import SnowflakeSource @@ -233,3 +236,38 @@ def test_redshift_fully_qualified_table_name(source_kwargs, expected_name): ) assert redshift_source.redshift_options.fully_qualified_table_name == expected_name + + +@pytest.mark.parameterize( + "test_data,are_equal", + [ + ( + SparkSource( + name="name", table="table", query="query", file_format="file_format" + ), + True, + ), + (SparkSource(table="table", query="query", file_format="file_format"), False), + ( + SparkSource( + name="name", table="table", query="query", file_format="file_format1" + ), + False, + ), + ( + SparkSource( + name="name", table="table", query="query1", file_format="file_format" + ), + True, + ), + ], +) +def test_spark_source_equality(test_data, are_equal): + default = SparkSource( + name="name", table="table1", query="query", file_format="file_format" + ) + if are_equal: + assert default == test_data + else: + assert default != test_data +