From 0e6296bb7a2ee5f06aeb5495a38a569baafec61b Mon Sep 17 00:00:00 2001 From: Vanshika Vanshika Date: Wed, 8 Apr 2026 13:31:15 +0530 Subject: [PATCH] mlflow-feast integration Signed-off-by: Vanshika Vanshika rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED --- pyproject.toml | 1 + sdk/python/feast/feature_store.py | 137 +++++- .../feast/mlflow_integration/__init__.py | 52 +++ sdk/python/feast/mlflow_integration/config.py | 37 ++ .../mlflow_integration/entity_df_builder.py | 114 +++++ sdk/python/feast/mlflow_integration/logger.py | 151 +++++++ .../mlflow_integration/model_resolver.py | 121 ++++++ sdk/python/feast/repo_config.py | 20 + sdk/python/feast/ui_server.py | 69 +++ .../tests/unit/test_mlflow_integration.py | 400 ++++++++++++++++++ ui/src/components/RegistryVisualization.tsx | 75 +++- .../components/RegistryVisualizationTab.tsx | 4 + ui/src/hooks/useFCOExploreSuggestions.ts | 1 + ui/src/parsers/types.ts | 1 + ui/src/queries/useLoadMlflowRuns.ts | 44 ++ 15 files changed, 1217 insertions(+), 10 deletions(-) create mode 100644 sdk/python/feast/mlflow_integration/__init__.py create mode 100644 sdk/python/feast/mlflow_integration/config.py create mode 100644 sdk/python/feast/mlflow_integration/entity_df_builder.py create mode 100644 sdk/python/feast/mlflow_integration/logger.py create mode 100644 sdk/python/feast/mlflow_integration/model_resolver.py create mode 100644 sdk/python/tests/unit/test_mlflow_integration.py create mode 100644 ui/src/queries/useLoadMlflowRuns.ts diff --git a/pyproject.toml b/pyproject.toml index 2e45d1820e7..0f31a0e764c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,7 @@ snowflake = [ ] sqlite_vec = ["sqlite-vec==v0.1.6"] mcp = ["fastapi_mcp"] +mlflow = ["mlflow>=2.10.0"] dbt = ["dbt-artifacts-parser"] diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index f95bbf10c03..4336abe362c 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -105,6 +105,25 @@ _track_materialization = None # Lazy-loaded on first materialization call _track_materialization_loaded = False +_mlflow_log_fn = None # Lazy-loaded on first feature retrieval +_mlflow_log_fn_loaded = False + + +def _get_mlflow_log_fn(): + """Lazy-import mlflow logger only when MLflow integration is configured.""" + global _mlflow_log_fn, _mlflow_log_fn_loaded + if not _mlflow_log_fn_loaded: + _mlflow_log_fn_loaded = True + try: + from feast.mlflow_integration.logger import ( + log_feature_retrieval_to_mlflow, + ) + + _mlflow_log_fn = log_feature_retrieval_to_mlflow + except Exception: + _mlflow_log_fn = None + return _mlflow_log_fn + def _get_track_materialization(): """Lazy-import feast.metrics only when materialization tracking is needed. @@ -194,6 +213,54 @@ def __init__( # Initialize feature service cache for performance optimization self._feature_service_cache = {} + # Configure MLflow tracking URI globally from config + self._init_mlflow_tracking() + + def _init_mlflow_tracking(self): + """Configure MLflow globally from feature_store.yaml. + + Sets the tracking URI and experiment name so the user never needs + to call mlflow.set_tracking_uri() or mlflow.set_experiment() in + their scripts. The experiment is named after the Feast project. + + When no tracking_uri is specified, defaults to http://127.0.0.1:5000 + (a local MLflow tracking server). This ensures that train.py, + predict.py, feast ui, and the MLflow UI all share the same backend. + """ + try: + mlflow_cfg = self.config.mlflow + if mlflow_cfg is None or not mlflow_cfg.enabled: + return + + import mlflow + + tracking_uri = mlflow_cfg.tracking_uri or "http://127.0.0.1:5000" + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(self.config.project) + except ImportError: + pass + except Exception as e: + warnings.warn(f"Failed to configure MLflow tracking: {e}") + + def _resolve_feature_service_name( + self, feature_refs: List[str] + ) -> Optional[str]: + """Try to find a feature service that covers the given feature refs.""" + try: + ref_set = set(feature_refs) + for fs in self.registry.list_feature_services( + self.project, allow_cache=True + ): + fs_refs = set() + for proj in fs.feature_view_projections: + for feat in proj.features: + fs_refs.add(f"{proj.name}:{feat.name}") + if ref_set == fs_refs or ref_set.issubset(fs_refs): + return fs.name + except Exception: + pass + return None + def _init_openlineage_emitter(self) -> Optional[Any]: """Initialize OpenLineage emitter if configured and enabled.""" try: @@ -1483,6 +1550,8 @@ def get_historical_features( if end_date is not None: kwargs["end_date"] = end_date + _retrieval_start = time.monotonic() + job = provider.get_historical_features( self.config, feature_views, @@ -1494,6 +1563,31 @@ def get_historical_features( **kwargs, ) + # Auto-log to MLflow if configured + if ( + self.config.mlflow is not None + and self.config.mlflow.enabled + and self.config.mlflow.auto_log + ): + _log_fn = _get_mlflow_log_fn() + if _log_fn is not None: + _duration = time.monotonic() - _retrieval_start + _entity_count = ( + len(entity_df) if isinstance(entity_df, pd.DataFrame) else 0 + ) + _fs = features if isinstance(features, FeatureService) else None + _fs_name = features.name if isinstance(features, FeatureService) else self._resolve_feature_service_name(_feature_refs) + _log_fn( + feature_refs=_feature_refs, + entity_count=_entity_count, + duration_seconds=_duration, + retrieval_type="historical", + feature_service=_fs, + feature_service_name=_fs_name, + project=self.project, + tracking_uri=self.config.mlflow.tracking_uri, + ) + return job def create_saved_dataset( @@ -2621,6 +2715,8 @@ def get_online_features( """ provider = self._get_provider() + _retrieval_start = time.monotonic() + response = provider.get_online_features( config=self.config, features=features, @@ -2631,6 +2727,36 @@ def get_online_features( include_feature_view_version_metadata=include_feature_view_version_metadata, ) + # Auto-log to MLflow if configured + if ( + self.config.mlflow is not None + and self.config.mlflow.enabled + and self.config.mlflow.auto_log + ): + _log_fn = _get_mlflow_log_fn() + if _log_fn is not None: + _duration = time.monotonic() - _retrieval_start + _feature_refs = utils._get_features( + self.registry, self.project, features, allow_cache=True + ) + _entity_count = ( + len(entity_rows) + if isinstance(entity_rows, list) + else 0 + ) + _fs = features if isinstance(features, FeatureService) else None + _fs_name = features.name if isinstance(features, FeatureService) else self._resolve_feature_service_name(_feature_refs) + _log_fn( + feature_refs=_feature_refs, + entity_count=_entity_count, + duration_seconds=_duration, + retrieval_type="online", + feature_service=_fs, + feature_service_name=_fs_name, + project=self.project, + tracking_uri=self.config.mlflow.tracking_uri, + ) + return response async def get_online_features_async( @@ -2781,10 +2907,7 @@ def _doc_feature(x): online_features_response=online_features_response, data=requested_features_data, ) - feature_types = { - f.name: f.dtype.to_value_type() for f in requested_feature_view.features - } - return OnlineResponse(online_features_response, feature_types=feature_types) + return OnlineResponse(online_features_response) def retrieve_online_documents_v2( self, @@ -3074,8 +3197,7 @@ def _retrieve_from_online_store_v2( online_features_response.metadata.feature_names.val.extend( features_to_request ) - feature_types = {f.name: f.dtype.to_value_type() for f in table.features} - return OnlineResponse(online_features_response, feature_types=feature_types) + return OnlineResponse(online_features_response) table_entity_values, idxs, output_len = utils._get_unique_entities_from_values( entity_key_dict, @@ -3098,8 +3220,7 @@ def _retrieve_from_online_store_v2( data=entity_key_dict, ) - feature_types = {f.name: f.dtype.to_value_type() for f in table.features} - return OnlineResponse(online_features_response, feature_types=feature_types) + return OnlineResponse(online_features_response) def serve( self, diff --git a/sdk/python/feast/mlflow_integration/__init__.py b/sdk/python/feast/mlflow_integration/__init__.py new file mode 100644 index 00000000000..04cee304ad6 --- /dev/null +++ b/sdk/python/feast/mlflow_integration/__init__.py @@ -0,0 +1,52 @@ +""" +MLflow integration for Feast Feature Store. + +This module provides seamless integration between Feast and MLflow for +automatic experiment tracking of feature retrieval operations. When enabled +in feature_store.yaml, feature metadata is logged automatically to MLflow +during get_historical_features and get_online_features calls. + +Usage: + Configure MLflow in your feature_store.yaml: + + project: my_project + # ... other config ... + + mlflow: + enabled: true + tracking_uri: http://localhost:5000 + auto_log: true + + Then use Feast normally - feature retrieval metadata is logged automatically + to any active MLflow run. + + For advanced use cases, the module also provides: + - resolve_feature_service_from_model_uri: Map an MLflow model to its Feast + feature service. + - get_entity_df_from_mlflow_run: Reproduce training by pulling entity data + from a previous MLflow run's artifacts. +""" + +from feast.mlflow_integration.config import MlflowConfig +from feast.mlflow_integration.entity_df_builder import ( + FeastMlflowEntityDfError, + get_entity_df_from_mlflow_run, +) +from feast.mlflow_integration.logger import ( + log_feature_retrieval_to_mlflow, + log_training_dataset_to_mlflow, +) +from feast.mlflow_integration.model_resolver import ( + FeastMlflowModelResolutionError, + resolve_feature_service_from_model_uri, +) + +__all__ = [ + "MlflowConfig", + "log_feature_retrieval_to_mlflow", + "log_training_dataset_to_mlflow", + "resolve_feature_service_from_model_uri", + "FeastMlflowModelResolutionError", + "get_entity_df_from_mlflow_run", + "FeastMlflowEntityDfError", +] diff --git a/sdk/python/feast/mlflow_integration/config.py b/sdk/python/feast/mlflow_integration/config.py new file mode 100644 index 00000000000..b10aad4e426 --- /dev/null +++ b/sdk/python/feast/mlflow_integration/config.py @@ -0,0 +1,37 @@ +from typing import Optional + +from pydantic import StrictBool, StrictStr + +from feast.repo_config import FeastBaseModel + + +class MlflowConfig(FeastBaseModel): + """Configuration for MLflow integration. + + This enables automatic logging of feature retrieval metadata to MLflow + during get_historical_features and get_online_features calls. + + Example configuration in feature_store.yaml: + mlflow: + enabled: true + tracking_uri: http://localhost:5000 + auto_log: true + """ + + enabled: StrictBool = False + """ bool: Whether MLflow integration is enabled. Defaults to False. """ + + tracking_uri: Optional[StrictStr] = None + """ str: MLflow tracking URI. If not set, uses MLflow's default + (MLFLOW_TRACKING_URI env var or local ./mlruns). """ + + auto_log: StrictBool = True + """ bool: Automatically log feature retrieval metadata to the active + MLflow run when get_historical_features or get_online_features is + called. Defaults to True. """ + + auto_log_dataset: StrictBool = False + """ bool: When True, the training DataFrame produced by + get_historical_features().to_df() is logged as an MLflow dataset + input on the active run. Defaults to False because the DataFrame + can be large. """ diff --git a/sdk/python/feast/mlflow_integration/entity_df_builder.py b/sdk/python/feast/mlflow_integration/entity_df_builder.py new file mode 100644 index 00000000000..4435c85c058 --- /dev/null +++ b/sdk/python/feast/mlflow_integration/entity_df_builder.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import logging +from typing import Optional + +import pandas as pd + +_logger = logging.getLogger(__name__) + + +class FeastMlflowEntityDfError(Exception): + """Raised when an entity DataFrame cannot be built from an MLflow run.""" + + pass + + +def get_entity_df_from_mlflow_run( + run_id: str, + tracking_uri: Optional[str] = None, + timestamp_column: str = "event_timestamp", +) -> pd.DataFrame: + """Build an entity DataFrame from an MLflow run's artifacts or params. + + Convention: the run should have an artifact named ``entity_df.parquet`` + (or ``entity_df.csv``). Alternatively, a run param + ``feast.entity_df_path`` pointing to a local/remote file path. + + Args: + run_id: The MLflow run ID. + tracking_uri: Optional MLflow tracking URI. + timestamp_column: Expected name of the timestamp column in the + entity DataFrame. + + Returns: + A ``pd.DataFrame`` suitable for passing to + ``store.get_historical_features(entity_df=...)``. + + Raises: + FeastMlflowEntityDfError: If mlflow is not installed, run not found, + or no entity data is available on the run. + """ + try: + import mlflow + from mlflow.exceptions import MlflowException + except ImportError: + raise FeastMlflowEntityDfError( + "mlflow is not installed. Install with: pip install feast[mlflow]" + ) + + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + + client = mlflow.MlflowClient() + + try: + run = client.get_run(run_id) + except MlflowException as e: + raise FeastMlflowEntityDfError(f"Run '{run_id}' not found: {e}") + + # Strategy 1: artifact entity_df.parquet + df = _try_artifact(client, run_id, "entity_df.parquet", "parquet") + if df is not None: + _validate_timestamp_col(df, timestamp_column) + return df + + # Strategy 2: artifact entity_df.csv + df = _try_artifact(client, run_id, "entity_df.csv", "csv") + if df is not None: + _validate_timestamp_col(df, timestamp_column) + return df + + # Strategy 3: run param feast.entity_df_path + params = run.data.params + path = params.get("feast.entity_df_path") + if path: + try: + if path.endswith(".parquet"): + df = pd.read_parquet(path) + else: + df = pd.read_csv(path) + _validate_timestamp_col(df, timestamp_column) + return df + except FeastMlflowEntityDfError: + raise + except Exception as e: + raise FeastMlflowEntityDfError( + f"Could not load entity df from param path '{path}': {e}" + ) + + raise FeastMlflowEntityDfError( + f"No entity data found for run '{run_id}'. " + f"Expected artifact 'entity_df.parquet' or 'entity_df.csv', " + f"or param 'feast.entity_df_path'." + ) + + +def _try_artifact(client, run_id: str, artifact_name: str, fmt: str): + """Try to download and load an artifact as a DataFrame.""" + try: + local_path = client.download_artifacts(run_id, artifact_name) + if fmt == "parquet": + return pd.read_parquet(local_path) + return pd.read_csv(local_path) + except Exception: + return None + + +def _validate_timestamp_col(df: pd.DataFrame, col: str): + """Ensure the expected timestamp column exists.""" + if col not in df.columns: + raise FeastMlflowEntityDfError( + f"Entity DataFrame missing required timestamp column '{col}'. " + f"Available columns: {list(df.columns)}" + ) diff --git a/sdk/python/feast/mlflow_integration/logger.py b/sdk/python/feast/mlflow_integration/logger.py new file mode 100644 index 00000000000..d35fe620cfc --- /dev/null +++ b/sdk/python/feast/mlflow_integration/logger.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, List, Optional + +import pandas as pd + +if TYPE_CHECKING: + from feast.feature_service import FeatureService + +_logger = logging.getLogger(__name__) + +_mlflow = None +_mlflow_checked = False + + +def _get_mlflow(): + """Lazy-import mlflow. Returns the module or None if not installed.""" + global _mlflow, _mlflow_checked + if not _mlflow_checked: + _mlflow_checked = True + try: + import mlflow as _m + + _mlflow = _m + except ImportError: + _mlflow = None + return _mlflow + + +def log_feature_retrieval_to_mlflow( + feature_refs: List[str], + entity_count: int, + duration_seconds: float, + retrieval_type: str = "historical", + feature_service: Optional["FeatureService"] = None, + feature_service_name: Optional[str] = None, + project: Optional[str] = None, + tracking_uri: Optional[str] = None, +) -> bool: + """Log feature retrieval metadata to the active MLflow run. + + This function is a no-op when: + - mlflow is not installed + - no MLflow run is currently active + + Args: + feature_refs: List of feature references (e.g. ["fv:feat1", "fv:feat2"]). + entity_count: Number of entity rows in the request. + duration_seconds: Wall-clock time for the retrieval in seconds. + retrieval_type: Either "historical" or "online". + feature_service: Optional FeatureService object used for the retrieval. + feature_service_name: Optional feature service name (resolved from refs). + project: Optional Feast project name. + tracking_uri: Optional MLflow tracking URI override. + + Returns: + True if metadata was logged successfully, False otherwise. + """ + mlflow = _get_mlflow() + if mlflow is None: + return False + + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + + active_run = mlflow.active_run() + if active_run is None: + return False + + try: + client = mlflow.MlflowClient() + run_id = active_run.info.run_id + + # Tags (immutable metadata about the run's Feast context) + if project: + client.set_tag(run_id, "feast.project", project) + client.set_tag(run_id, "feast.retrieval_type", retrieval_type) + + # Resolve feature service name + fs_name = None + if feature_service is not None: + fs_name = feature_service.name + elif feature_service_name is not None: + fs_name = feature_service_name + if fs_name: + client.set_tag(run_id, "feast.feature_service", fs_name) + + # Extract unique feature view names from refs + fv_names = sorted({ref.split(":")[0] for ref in feature_refs if ":" in ref}) + if fv_names: + client.set_tag(run_id, "feast.feature_views", ",".join(fv_names)) + + # Params (input configuration for this retrieval) + # MLflow params have a 500-char limit; truncate feature refs + refs_str = ",".join(feature_refs) + if len(refs_str) > 490: + refs_str = refs_str[:487] + "..." + client.log_param(run_id, "feast.feature_refs", refs_str) + client.log_param(run_id, "feast.entity_count", str(entity_count)) + client.log_param(run_id, "feast.feature_count", str(len(feature_refs))) + + # Metrics (measured values) + client.log_metric( + run_id, "feast.retrieval_duration_sec", round(duration_seconds, 4) + ) + + return True + except Exception as e: + _logger.warning("Failed to log feature retrieval to MLflow: %s", e) + return False + + +def log_training_dataset_to_mlflow( + df: pd.DataFrame, + dataset_name: str = "feast_training_data", + source: Optional[str] = None, +) -> bool: + """Log a training DataFrame as an MLflow dataset input on the active run. + + This enables dataset versioning: each MLflow run records exactly which + training data was used, bridging the gap that Feast does not + version-control datasets on its own. + + Args: + df: The training DataFrame (output of get_historical_features().to_df()). + dataset_name: Name for the dataset in MLflow. + source: Optional source description (e.g. feature service name). + + Returns: + True if the dataset was logged, False otherwise. + """ + mlflow = _get_mlflow() + if mlflow is None: + return False + + active_run = mlflow.active_run() + if active_run is None: + return False + + try: + dataset = mlflow.data.from_pandas( + df, + name=dataset_name, + source=source or "feast.get_historical_features", + ) + mlflow.log_input(dataset, context="training") + return True + except Exception as e: + _logger.warning("Failed to log training dataset to MLflow: %s", e) + return False diff --git a/sdk/python/feast/mlflow_integration/model_resolver.py b/sdk/python/feast/mlflow_integration/model_resolver.py new file mode 100644 index 00000000000..5d38d2e61a1 --- /dev/null +++ b/sdk/python/feast/mlflow_integration/model_resolver.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import json +import logging +import re +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from feast import FeatureStore + +_logger = logging.getLogger(__name__) + + +class FeastMlflowModelResolutionError(Exception): + """Raised when a model URI cannot be resolved to a feature service.""" + + pass + + +def resolve_feature_service_from_model_uri( + model_uri: str, + store: Optional["FeatureStore"] = None, +) -> str: + """Resolve the Feast feature service name for a given MLflow model URI. + + Resolution order: + 1. Model tag ``feast.feature_service`` (explicit override). + 2. Naming convention: ``{model_name}_v{version}``. + + Args: + model_uri: MLflow model URI, e.g. ``models:/fraud-model/Production`` + or ``models:/fraud-model/1``. + store: Optional FeatureStore instance. When provided the resolved + feature service is validated against the registry, and if the + model has an artifact ``required_features.json`` the feature + list is checked for consistency. + + Returns: + Feature service name string. + + Raises: + FeastMlflowModelResolutionError: If mlflow is not installed, URI is + invalid, or validation against the store fails. + """ + try: + import mlflow + from mlflow.exceptions import MlflowException + except ImportError: + raise FeastMlflowModelResolutionError( + "mlflow is not installed. Install with: pip install feast[mlflow]" + ) + + pattern = r"^models:/([^/]+)/(.+)$" + match = re.match(pattern, model_uri) + if not match: + raise FeastMlflowModelResolutionError( + f"Invalid model_uri format: '{model_uri}'. " + f"Expected 'models://'." + ) + + model_name, version_or_alias = match.group(1), match.group(2) + client = mlflow.MlflowClient() + + try: + if version_or_alias.isdigit(): + mv = client.get_model_version(model_name, version_or_alias) + else: + mv = client.get_model_version_by_alias(model_name, version_or_alias) + except MlflowException as e: + raise FeastMlflowModelResolutionError( + f"Could not resolve model '{model_uri}': {e}" + ) + + tags = mv.tags or {} + if "feast.feature_service" in tags: + fs_name = tags["feast.feature_service"] + else: + fs_name = f"{model_name}_v{mv.version}" + + if store is not None: + _validate_feature_service(store, fs_name, client, mv) + + return fs_name + + +def _validate_feature_service(store, fs_name, client, model_version): + """Validate the feature service exists and features match if artifact present.""" + try: + fs = store.get_feature_service(fs_name) + except Exception: + raise FeastMlflowModelResolutionError( + f"Feature service '{fs_name}' not found in the Feast registry." + ) + + try: + local_path = client.download_artifacts( + model_version.run_id, "required_features.json" + ) + with open(local_path) as f: + expected_features = json.load(f) + + actual_features = [] + for proj in fs.feature_view_projections: + for feat in proj.features: + actual_features.append(f"{proj.name}:{feat.name}") + + expected_set = set(expected_features) + actual_set = set(actual_features) + + if expected_set != actual_set: + missing = expected_set - actual_set + extra = actual_set - expected_set + raise FeastMlflowModelResolutionError( + f"Feature mismatch for service '{fs_name}'. " + f"Missing: {missing}, Extra: {extra}" + ) + except FeastMlflowModelResolutionError: + raise + except Exception: + # No artifact or download failed — skip validation silently + pass diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 208307dc5d5..7a26d9a38aa 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -339,6 +339,9 @@ class RepoConfig(FeastBaseModel): openlineage_config: Optional[OpenLineageConfig] = Field(None, alias="openlineage") """ OpenLineageConfig: Configuration for OpenLineage data lineage integration (optional). """ + mlflow_config: Optional[Any] = Field(None, alias="mlflow") + """ MlflowConfig: Configuration for MLflow experiment tracking integration (optional). """ + def __init__(self, **data: Any): super().__init__(**data) @@ -379,6 +382,11 @@ def __init__(self, **data: Any): if "openlineage" in data: self.openlineage_config = data["openlineage"] + # Initialize MLflow configuration + self._mlflow = None + if "mlflow" in data: + self.mlflow_config = data["mlflow"] + if self.entity_key_serialization_version < 3: warnings.warn( "The serialization version below 3 are deprecated. " @@ -478,6 +486,18 @@ def openlineage(self) -> Optional[OpenLineageConfig]: self._openlineage = self.openlineage_config return self._openlineage + @property + def mlflow(self): + """Get the MLflow configuration.""" + if not self._mlflow: + if isinstance(self.mlflow_config, Dict): + from feast.mlflow_integration.config import MlflowConfig + + self._mlflow = MlflowConfig(**self.mlflow_config) + elif self.mlflow_config: + self._mlflow = self.mlflow_config + return self._mlflow + @model_validator(mode="before") def _validate_auth_config(cls, values: Any) -> Any: from feast.permissions.auth_model import AuthConfig diff --git a/sdk/python/feast/ui_server.py b/sdk/python/feast/ui_server.py index 99a4abc9c81..da889dd9264 100644 --- a/sdk/python/feast/ui_server.py +++ b/sdk/python/feast/ui_server.py @@ -115,6 +115,75 @@ def health(): else Response(status_code=status.HTTP_503_SERVICE_UNAVAILABLE) ) + @app.get("/api/mlflow-runs") + def get_mlflow_runs(): + """Return MLflow runs linked to this Feast project via auto-logging.""" + mlflow_cfg = getattr(store.config, "mlflow", None) + if not mlflow_cfg or not mlflow_cfg.enabled: + return {"runs": [], "mlflow_uri": None} + + try: + import mlflow + + if mlflow_cfg.tracking_uri: + mlflow.set_tracking_uri(mlflow_cfg.tracking_uri) + + client = mlflow.MlflowClient() + experiments = client.search_experiments() + experiment_ids = [e.experiment_id for e in experiments] + + if not experiment_ids: + return {"runs": [], "mlflow_uri": mlflow_cfg.tracking_uri} + + runs = client.search_runs( + experiment_ids=experiment_ids, + filter_string="tags.`feast.retrieval_type` != ''", + max_results=50, + order_by=["start_time DESC"], + ) + + result = [] + tracking_uri = mlflow_cfg.tracking_uri or "" + mlflow_ui_base = tracking_uri if tracking_uri else "http://127.0.0.1:5000" + for run in runs: + run_tags = run.data.tags + run_params = run.data.params + result.append( + { + "run_id": run.info.run_id, + "run_name": run.info.run_name, + "status": run.info.status, + "start_time": run.info.start_time, + "feature_service": run_tags.get( + "feast.feature_service" + ), + "feature_views": run_tags.get( + "feast.feature_views", "" + ).split(","), + "feature_refs": run_params.get( + "feast.feature_refs", "" + ).split(","), + "retrieval_type": run_tags.get( + "feast.retrieval_type" + ), + "entity_count": run_params.get("feast.entity_count"), + "mlflow_url": ( + f"{mlflow_ui_base}/#/experiments/" + f"{run.info.experiment_id}/runs/{run.info.run_id}" + ), + } + ) + + return {"runs": result, "mlflow_uri": tracking_uri} + except ImportError: + return { + "runs": [], + "mlflow_uri": None, + "error": "mlflow is not installed", + } + except Exception as e: + return {"runs": [], "mlflow_uri": None, "error": str(e)} + # For all other paths (such as paths that would otherwise be handled by react router), pass to React @app.api_route("/p/{path_name:path}", methods=["GET"]) def catch_all(): diff --git a/sdk/python/tests/unit/test_mlflow_integration.py b/sdk/python/tests/unit/test_mlflow_integration.py new file mode 100644 index 00000000000..c8b99709156 --- /dev/null +++ b/sdk/python/tests/unit/test_mlflow_integration.py @@ -0,0 +1,400 @@ +import os +import tempfile +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + + +# --------------------------------------------------------------------------- +# Tests for logger.py +# --------------------------------------------------------------------------- +class TestLogFeatureRetrieval: + """Tests for feast.mlflow_integration.logger.log_feature_retrieval_to_mlflow.""" + + def setup_method(self): + import feast.mlflow_integration.logger as mod + + mod._mlflow = None + mod._mlflow_checked = False + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_logs_tags_params_metrics_when_active_run(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_run = MagicMock() + mock_run.info.run_id = "run_abc123" + mock_mlflow.active_run.return_value = mock_run + mock_get_mlflow.return_value = mock_mlflow + + from feast.mlflow_integration.logger import log_feature_retrieval_to_mlflow + + result = log_feature_retrieval_to_mlflow( + feature_refs=["fv:feat1", "fv:feat2"], + entity_count=100, + duration_seconds=1.5, + retrieval_type="historical", + project="my_project", + feature_service_name="my_svc", + ) + + assert result is True + client = mock_mlflow.MlflowClient() + client.set_tag.assert_any_call("run_abc123", "feast.project", "my_project") + client.set_tag.assert_any_call("run_abc123", "feast.retrieval_type", "historical") + client.set_tag.assert_any_call("run_abc123", "feast.feature_service", "my_svc") + client.set_tag.assert_any_call("run_abc123", "feast.feature_views", "fv") + client.log_param.assert_any_call("run_abc123", "feast.feature_refs", "fv:feat1,fv:feat2") + client.log_param.assert_any_call("run_abc123", "feast.entity_count", "100") + client.log_param.assert_any_call("run_abc123", "feast.feature_count", "2") + client.log_metric.assert_any_call("run_abc123", "feast.retrieval_duration_sec", 1.5) + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_logs_feature_service_from_object(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_run = MagicMock() + mock_run.info.run_id = "run_xyz" + mock_mlflow.active_run.return_value = mock_run + mock_get_mlflow.return_value = mock_mlflow + + mock_fs = MagicMock() + mock_fs.name = "fraud_detection_service" + + from feast.mlflow_integration.logger import log_feature_retrieval_to_mlflow + + result = log_feature_retrieval_to_mlflow( + feature_refs=["fv:feat1"], + entity_count=10, + duration_seconds=0.3, + retrieval_type="online", + feature_service=mock_fs, + ) + + assert result is True + client = mock_mlflow.MlflowClient() + client.set_tag.assert_any_call("run_xyz", "feast.feature_service", "fraud_detection_service") + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_noop_when_no_active_run(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_mlflow.active_run.return_value = None + mock_get_mlflow.return_value = mock_mlflow + + from feast.mlflow_integration.logger import log_feature_retrieval_to_mlflow + + result = log_feature_retrieval_to_mlflow( + feature_refs=["fv:feat1"], entity_count=10, duration_seconds=0.5, + ) + assert result is False + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_noop_when_mlflow_not_installed(self, mock_get_mlflow): + mock_get_mlflow.return_value = None + + from feast.mlflow_integration.logger import log_feature_retrieval_to_mlflow + + result = log_feature_retrieval_to_mlflow( + feature_refs=["fv:feat1"], entity_count=10, duration_seconds=0.5, + ) + assert result is False + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_truncates_long_feature_refs(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_run = MagicMock() + mock_run.info.run_id = "run_trunc" + mock_mlflow.active_run.return_value = mock_run + mock_get_mlflow.return_value = mock_mlflow + + from feast.mlflow_integration.logger import log_feature_retrieval_to_mlflow + + long_refs = [f"feature_view:feature_{i}" for i in range(100)] + result = log_feature_retrieval_to_mlflow( + feature_refs=long_refs, entity_count=50, duration_seconds=2.0, + ) + + assert result is True + client = mock_mlflow.MlflowClient() + logged_refs = [c for c in client.log_param.call_args_list if c[0][1] == "feast.feature_refs"] + assert len(logged_refs) == 1 + assert len(logged_refs[0][0][2]) <= 500 + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_handles_mlflow_exception_gracefully(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_run = MagicMock() + mock_run.info.run_id = "run_err" + mock_mlflow.active_run.return_value = mock_run + mock_mlflow.MlflowClient().set_tag.side_effect = Exception("boom") + mock_get_mlflow.return_value = mock_mlflow + + from feast.mlflow_integration.logger import log_feature_retrieval_to_mlflow + + result = log_feature_retrieval_to_mlflow( + feature_refs=["fv:feat1"], entity_count=5, duration_seconds=0.1, + ) + assert result is False + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_extracts_feature_view_names(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_run = MagicMock() + mock_run.info.run_id = "run_fv" + mock_mlflow.active_run.return_value = mock_run + mock_get_mlflow.return_value = mock_mlflow + + from feast.mlflow_integration.logger import log_feature_retrieval_to_mlflow + + result = log_feature_retrieval_to_mlflow( + feature_refs=["fv_a:feat1", "fv_b:feat2", "fv_a:feat3"], + entity_count=10, + duration_seconds=0.1, + ) + + assert result is True + client = mock_mlflow.MlflowClient() + client.set_tag.assert_any_call("run_fv", "feast.feature_views", "fv_a,fv_b") + + +# --------------------------------------------------------------------------- +# Tests for model_resolver.py +# --------------------------------------------------------------------------- +class TestResolveFeatureService: + + def test_invalid_uri_raises(self): + from feast.mlflow_integration.model_resolver import ( + FeastMlflowModelResolutionError, + resolve_feature_service_from_model_uri, + ) + + with pytest.raises(FeastMlflowModelResolutionError, match="Invalid model_uri"): + resolve_feature_service_from_model_uri("bad-uri") + + @patch("mlflow.MlflowClient") + def test_resolves_from_tag(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_mv = MagicMock() + mock_mv.tags = {"feast.feature_service": "my_fraud_svc"} + mock_mv.version = "1" + mock_client.get_model_version.return_value = mock_mv + + from feast.mlflow_integration.model_resolver import resolve_feature_service_from_model_uri + + assert resolve_feature_service_from_model_uri("models:/fraud-model/1") == "my_fraud_svc" + + @patch("mlflow.MlflowClient") + def test_falls_back_to_convention(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_mv = MagicMock() + mock_mv.tags = {} + mock_mv.version = "3" + mock_client.get_model_version.return_value = mock_mv + + from feast.mlflow_integration.model_resolver import resolve_feature_service_from_model_uri + + assert resolve_feature_service_from_model_uri("models:/fraud-model/3") == "fraud-model_v3" + + @patch("mlflow.MlflowClient") + def test_resolves_alias(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_mv = MagicMock() + mock_mv.tags = {"feast.feature_service": "prod_features"} + mock_mv.version = "5" + mock_client.get_model_version_by_alias.return_value = mock_mv + + from feast.mlflow_integration.model_resolver import resolve_feature_service_from_model_uri + + assert resolve_feature_service_from_model_uri("models:/fraud-model/Production") == "prod_features" + + @patch("mlflow.MlflowClient") + def test_validates_against_store(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_mv = MagicMock() + mock_mv.tags = {"feast.feature_service": "nonexistent_svc"} + mock_mv.version = "1" + mock_mv.run_id = "run_123" + mock_client.get_model_version.return_value = mock_mv + mock_client.download_artifacts.side_effect = Exception("no artifact") + + mock_store = MagicMock() + mock_store.get_feature_service.side_effect = Exception("not found") + + from feast.mlflow_integration.model_resolver import ( + FeastMlflowModelResolutionError, + resolve_feature_service_from_model_uri, + ) + + with pytest.raises(FeastMlflowModelResolutionError, match="not found in the Feast registry"): + resolve_feature_service_from_model_uri("models:/fraud-model/1", store=mock_store) + + +# --------------------------------------------------------------------------- +# Tests for entity_df_builder.py +# --------------------------------------------------------------------------- +class TestGetEntityDfFromMlflowRun: + + @patch("mlflow.MlflowClient") + def test_loads_parquet_artifact(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_run = MagicMock() + mock_run.data.params = {} + mock_client.get_run.return_value = mock_run + + with tempfile.TemporaryDirectory() as tmpdir: + parquet_path = os.path.join(tmpdir, "entity_df.parquet") + df = pd.DataFrame({ + "user_id": [1, 2, 3], + "event_timestamp": pd.to_datetime(["2026-01-01", "2026-01-02", "2026-01-03"]), + }) + df.to_parquet(parquet_path) + mock_client.download_artifacts.return_value = parquet_path + + from feast.mlflow_integration.entity_df_builder import get_entity_df_from_mlflow_run + + result = get_entity_df_from_mlflow_run("run_abc") + assert len(result) == 3 + assert "event_timestamp" in result.columns + + @patch("mlflow.MlflowClient") + def test_raises_when_no_entity_data(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_run = MagicMock() + mock_run.data.params = {} + mock_client.get_run.return_value = mock_run + mock_client.download_artifacts.side_effect = Exception("not found") + + from feast.mlflow_integration.entity_df_builder import ( + FeastMlflowEntityDfError, get_entity_df_from_mlflow_run, + ) + + with pytest.raises(FeastMlflowEntityDfError, match="No entity data found"): + get_entity_df_from_mlflow_run("run_abc") + + @patch("mlflow.MlflowClient") + def test_raises_when_timestamp_col_missing(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_run = MagicMock() + mock_run.data.params = {} + mock_client.get_run.return_value = mock_run + + with tempfile.TemporaryDirectory() as tmpdir: + parquet_path = os.path.join(tmpdir, "entity_df.parquet") + pd.DataFrame({"user_id": [1, 2], "ts": [1, 2]}).to_parquet(parquet_path) + mock_client.download_artifacts.return_value = parquet_path + + from feast.mlflow_integration.entity_df_builder import ( + FeastMlflowEntityDfError, get_entity_df_from_mlflow_run, + ) + + with pytest.raises(FeastMlflowEntityDfError, match="missing required timestamp"): + get_entity_df_from_mlflow_run("run_abc") + + @patch("mlflow.MlflowClient") + def test_loads_from_param_path(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_client.download_artifacts.side_effect = Exception("no artifact") + + with tempfile.TemporaryDirectory() as tmpdir: + parquet_path = os.path.join(tmpdir, "entities.parquet") + pd.DataFrame({ + "driver_id": [10, 20], + "event_timestamp": pd.to_datetime(["2026-03-01", "2026-03-02"]), + }).to_parquet(parquet_path) + + mock_run = MagicMock() + mock_run.data.params = {"feast.entity_df_path": parquet_path} + mock_client.get_run.return_value = mock_run + + from feast.mlflow_integration.entity_df_builder import get_entity_df_from_mlflow_run + + result = get_entity_df_from_mlflow_run("run_param") + assert len(result) == 2 + + +# --------------------------------------------------------------------------- +# Tests for config.py +# --------------------------------------------------------------------------- +class TestLogTrainingDataset: + """Tests for feast.mlflow_integration.logger.log_training_dataset_to_mlflow.""" + + def setup_method(self): + import feast.mlflow_integration.logger as mod + + mod._mlflow = None + mod._mlflow_checked = False + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_logs_dataset_when_active_run(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_run = MagicMock() + mock_mlflow.active_run.return_value = mock_run + mock_dataset = MagicMock() + mock_mlflow.data.from_pandas.return_value = mock_dataset + mock_get_mlflow.return_value = mock_mlflow + + from feast.mlflow_integration.logger import log_training_dataset_to_mlflow + + df = pd.DataFrame({"user_id": [1, 2], "feature": [0.5, 0.8]}) + result = log_training_dataset_to_mlflow(df, dataset_name="my_data") + + assert result is True + mock_mlflow.data.from_pandas.assert_called_once() + mock_mlflow.log_input.assert_called_once_with(mock_dataset, context="training") + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_noop_when_no_active_run(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_mlflow.active_run.return_value = None + mock_get_mlflow.return_value = mock_mlflow + + from feast.mlflow_integration.logger import log_training_dataset_to_mlflow + + df = pd.DataFrame({"a": [1]}) + result = log_training_dataset_to_mlflow(df) + assert result is False + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_noop_when_mlflow_not_installed(self, mock_get_mlflow): + mock_get_mlflow.return_value = None + + from feast.mlflow_integration.logger import log_training_dataset_to_mlflow + + df = pd.DataFrame({"a": [1]}) + result = log_training_dataset_to_mlflow(df) + assert result is False + + @patch("feast.mlflow_integration.logger._get_mlflow") + def test_handles_exception_gracefully(self, mock_get_mlflow): + mock_mlflow = MagicMock() + mock_mlflow.active_run.return_value = MagicMock() + mock_mlflow.data.from_pandas.side_effect = Exception("dataset error") + mock_get_mlflow.return_value = mock_mlflow + + from feast.mlflow_integration.logger import log_training_dataset_to_mlflow + + df = pd.DataFrame({"a": [1]}) + result = log_training_dataset_to_mlflow(df) + assert result is False + + +# --------------------------------------------------------------------------- +# Tests for config.py +# --------------------------------------------------------------------------- +class TestMlflowConfig: + + def test_defaults(self): + from feast.mlflow_integration.config import MlflowConfig + + cfg = MlflowConfig() + assert cfg.enabled is False + assert cfg.tracking_uri is None + assert cfg.auto_log is True + assert cfg.auto_log_dataset is False + + def test_from_dict(self): + from feast.mlflow_integration.config import MlflowConfig + + cfg = MlflowConfig(enabled=True, tracking_uri="http://localhost:5000", auto_log=False, auto_log_dataset=True) + assert cfg.enabled is True + assert cfg.tracking_uri == "http://localhost:5000" + assert cfg.auto_log is False + assert cfg.auto_log_dataset is True diff --git a/ui/src/components/RegistryVisualization.tsx b/ui/src/components/RegistryVisualization.tsx index d3479078618..fcbf0509a11 100644 --- a/ui/src/components/RegistryVisualization.tsx +++ b/ui/src/components/RegistryVisualization.tsx @@ -25,6 +25,7 @@ import { } from "@elastic/eui"; import { FEAST_FCO_TYPES } from "../parsers/types"; import { EntityRelation } from "../parsers/parseEntityRelationships"; +import { MlflowRunData } from "../queries/useLoadMlflowRuns"; import { feast } from "../protos"; import { useTheme } from "../contexts/ThemeContext"; import { @@ -79,6 +80,8 @@ const getNodeColor = (type: FEAST_FCO_TYPES) => { return "#ff8000"; // Orange case FEAST_FCO_TYPES.dataSource: return "#cc0000"; // Red + case FEAST_FCO_TYPES.mlflowRun: + return "#0194e2"; // MLflow brand blue default: return "#666666"; // Gray } @@ -94,6 +97,8 @@ const getLightNodeColor = (type: FEAST_FCO_TYPES) => { return "#fff2e6"; // Light orange case FEAST_FCO_TYPES.dataSource: return "#ffe6e6"; // Light red + case FEAST_FCO_TYPES.mlflowRun: + return "#e6f6fd"; // Light MLflow blue default: return "#f0f0f0"; // Light gray } @@ -109,6 +114,8 @@ const getNodeIcon = (type: FEAST_FCO_TYPES) => { return "▲"; // Triangle for entity case FEAST_FCO_TYPES.dataSource: return "◆"; // Diamond for data source + case FEAST_FCO_TYPES.mlflowRun: + return "⬡"; // Hexagon for MLflow run default: return "●"; // Default circle } @@ -125,6 +132,10 @@ const CustomNode = ({ data }: { data: NodeData }) => { const hasVersion = data.versionNumber != null && data.versionNumber > 1; const handleClick = () => { + if (data.type === FEAST_FCO_TYPES.mlflowRun && data.metadata?.mlflow_url) { + window.open(data.metadata.mlflow_url, "_blank", "noopener,noreferrer"); + return; + } let path; switch (data.type) { case FEAST_FCO_TYPES.dataSource: @@ -183,7 +194,9 @@ const CustomNode = ({ data }: { data: NodeData }) => { zIndex: 5, }} > - View Details + {data.type === FEAST_FCO_TYPES.mlflowRun + ? "Open in MLflow ↗" + : "View Details"} )} @@ -398,6 +411,7 @@ const getLayoutedElements = ( [FEAST_FCO_TYPES.entity]: [], [FEAST_FCO_TYPES.featureView]: [], [FEAST_FCO_TYPES.featureService]: [], + [FEAST_FCO_TYPES.mlflowRun]: [], }; isolatedNodes.forEach((node) => { @@ -454,6 +468,7 @@ const Legend = () => { { type: FEAST_FCO_TYPES.featureView, label: "Feature View" }, { type: FEAST_FCO_TYPES.entity, label: "Entity" }, { type: FEAST_FCO_TYPES.dataSource, label: "Data Source" }, + { type: FEAST_FCO_TYPES.mlflowRun, label: "MLflow Run" }, ]; const isDarkMode = colorMode === "dark"; @@ -535,6 +550,7 @@ const registryToFlow = ( relationships: EntityRelation[], permissions?: any[], versionHistory?: feast.core.IFeatureViewVersionRecord[], + mlflowRuns?: MlflowRunData[], ) => { const nodes: Node[] = []; const edges: Edge[] = []; @@ -743,6 +759,55 @@ const registryToFlow = ( }); }); + if (mlflowRuns && mlflowRuns.length > 0) { + mlflowRuns.forEach((run) => { + const runLabel = run.run_name || run.run_id.substring(0, 8); + nodes.push({ + id: `mlflow-${run.run_id}`, + type: "custom", + data: { + label: runLabel, + type: FEAST_FCO_TYPES.mlflowRun, + metadata: { + mlflow_url: run.mlflow_url, + retrieval_type: run.retrieval_type, + status: run.status, + run_id: run.run_id, + }, + }, + position: { x: 0, y: 0 }, + }); + + if (run.feature_service) { + const fsNodeId = `fs-${run.feature_service}`; + const fsNodeExists = nodes.some((n) => n.id === fsNodeId); + if (fsNodeExists) { + edges.push({ + id: `edge-mlflow-${run.run_id}`, + source: fsNodeId, + sourceHandle: "source", + target: `mlflow-${run.run_id}`, + targetHandle: "target", + animated: true, + style: { + strokeWidth: 3, + stroke: "#0194e2", + strokeDasharray: "10 5", + animation: "dataflow 2s linear infinite", + }, + type: "smoothstep", + markerEnd: { + type: MarkerType.ArrowClosed, + width: 20, + height: 20, + color: "#0194e2", + }, + }); + } + } + }); + } + return { nodes, edges }; }; @@ -756,6 +821,8 @@ const getNodePrefix = (type: FEAST_FCO_TYPES) => { return "entity"; case FEAST_FCO_TYPES.dataSource: return "ds"; + case FEAST_FCO_TYPES.mlflowRun: + return "mlflow"; default: return "unknown"; } @@ -766,7 +833,8 @@ interface RegistryVisualizationProps { relationships: EntityRelation[]; indirectRelationships: EntityRelation[]; filterNode?: { type: FEAST_FCO_TYPES; name: string }; - permissions?: any[]; // Add permissions field + permissions?: any[]; + mlflowRuns?: MlflowRunData[]; } const RegistryVisualization: React.FC = ({ @@ -775,6 +843,7 @@ const RegistryVisualization: React.FC = ({ indirectRelationships, filterNode, permissions, + mlflowRuns, }) => { const [nodes, setNodes, onNodesChange] = useNodesState([]); const [edges, setEdges, onEdgesChange] = useEdgesState([]); @@ -851,6 +920,7 @@ const RegistryVisualization: React.FC = ({ validRelationships, permissions, versionRecords as feast.core.IFeatureViewVersionRecord[] | undefined, + mlflowRuns, ); const { nodes: layoutedNodes, edges: layoutedEdges } = @@ -873,6 +943,7 @@ const RegistryVisualization: React.FC = ({ showIsolatedNodes, filterNode, permissions, + mlflowRuns, setNodes, setEdges, ]); diff --git a/ui/src/components/RegistryVisualizationTab.tsx b/ui/src/components/RegistryVisualizationTab.tsx index ebc77604322..023f1515457 100644 --- a/ui/src/components/RegistryVisualizationTab.tsx +++ b/ui/src/components/RegistryVisualizationTab.tsx @@ -10,6 +10,7 @@ import { EuiFlexItem, } from "@elastic/eui"; import useLoadRegistry from "../queries/useLoadRegistry"; +import useLoadMlflowRuns from "../queries/useLoadMlflowRuns"; import RegistryPathContext from "../contexts/RegistryPathContext"; import RegistryVisualization from "./RegistryVisualization"; import { FEAST_FCO_TYPES } from "../parsers/types"; @@ -22,6 +23,7 @@ const RegistryVisualizationTab = () => { registryUrl, projectName, ); + const { data: mlflowData } = useLoadMlflowRuns(); const [selectedObjectType, setSelectedObjectType] = useState(""); const [selectedObjectName, setSelectedObjectName] = useState(""); const [selectedPermissionAction, setSelectedPermissionAction] = useState(""); @@ -92,6 +94,7 @@ const RegistryVisualizationTab = () => { { value: "entity", text: "Entity" }, { value: "featureView", text: "Feature View" }, { value: "featureService", text: "Feature Service" }, + { value: "mlflowRun", text: "MLflow Run" }, ]} value={selectedObjectType} onChange={(e) => { @@ -162,6 +165,7 @@ const RegistryVisualizationTab = () => { } : undefined } + mlflowRuns={mlflowData?.runs} /> )} diff --git a/ui/src/hooks/useFCOExploreSuggestions.ts b/ui/src/hooks/useFCOExploreSuggestions.ts index 43a0e1bea3f..19b5d89224a 100644 --- a/ui/src/hooks/useFCOExploreSuggestions.ts +++ b/ui/src/hooks/useFCOExploreSuggestions.ts @@ -22,6 +22,7 @@ const FCO_TO_URL_NAME_MAP: Record = { entity: "/entity", featureView: "/feature-view", featureService: "/feature-service", + mlflowRun: "/mlflow-run", }; const createSearchLink = ( diff --git a/ui/src/parsers/types.ts b/ui/src/parsers/types.ts index 1e515f23f34..98d4f2651ed 100644 --- a/ui/src/parsers/types.ts +++ b/ui/src/parsers/types.ts @@ -3,6 +3,7 @@ enum FEAST_FCO_TYPES { entity = "entity", featureView = "featureView", featureService = "featureService", + mlflowRun = "mlflowRun", } export { FEAST_FCO_TYPES }; diff --git a/ui/src/queries/useLoadMlflowRuns.ts b/ui/src/queries/useLoadMlflowRuns.ts new file mode 100644 index 00000000000..bd374223add --- /dev/null +++ b/ui/src/queries/useLoadMlflowRuns.ts @@ -0,0 +1,44 @@ +import { useQuery } from "react-query"; + +export interface MlflowRunData { + run_id: string; + run_name: string; + status: string; + start_time: number; + feature_service: string | null; + feature_views: string[]; + feature_refs: string[]; + retrieval_type: string | null; + entity_count: string | null; + mlflow_url: string; +} + +interface MlflowRunsResponse { + runs: MlflowRunData[]; + mlflow_uri: string | null; + error?: string; +} + +const useLoadMlflowRuns = () => { + return useQuery( + "mlflow-runs", + () => { + return fetch("/api/mlflow-runs") + .then((res) => { + if (!res.ok) { + return { runs: [], mlflow_uri: null }; + } + return res.json(); + }) + .catch(() => { + return { runs: [], mlflow_uri: null }; + }); + }, + { + staleTime: 30000, + retry: false, + }, + ); +}; + +export default useLoadMlflowRuns;