diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 83aebe7ef9c..e126b059342 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -6,6 +6,7 @@ from functools import reduce from pathlib import Path from typing import ( + TYPE_CHECKING, Any, Callable, ContextManager, @@ -63,12 +64,8 @@ raise FeastExtrasDependencyImportError("snowflake", str(e)) -try: +if TYPE_CHECKING: from pyspark.sql import DataFrame, SparkSession -except ImportError as e: - from feast.errors import FeastExtrasDependencyImportError - - raise FeastExtrasDependencyImportError("spark", str(e)) warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -462,7 +459,7 @@ def to_sql(self) -> str: with self._query_generator() as query: return query - def to_spark_df(self, spark_session: SparkSession) -> DataFrame: + def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame": """ Method to convert snowflake query results to pyspark data frame. @@ -473,6 +470,13 @@ def to_spark_df(self, spark_session: SparkSession) -> DataFrame: spark_df: A pyspark dataframe. """ + try: + from pyspark.sql import DataFrame, SparkSession + except ImportError as e: + from feast.errors import FeastExtrasDependencyImportError + + raise FeastExtrasDependencyImportError("spark", str(e)) + if isinstance(spark_session, SparkSession): with self._query_generator() as query: