Skip to content

Latest commit

 

History

History

README.md

FlowLet: Conditional 3D Brain MRI Synthesis using Wavelet Flow Matching

Overview

FlowLet is a conditional generative framework that synthesizes high-fidelity, age-conditioned 3D brain Magnetic Resonance Images (MRIs). It is designed to address key challenges in neuroimaging, particularly for applications like Brain Age Prediction (BAP), which require large, diverse, and age-balanced datasets.

Existing generative methods often struggle with the high dimensionality of MRI data, leading to slow inference, compression artifacts, or insufficient conditioning. FlowLet mitigates the "generative modeling trilemma" (the trade-off between sample quality, diversity, and speed) by integrating Flow Matching (FM) within an invertible 3D Haar wavelet domain. This approach preserves fine anatomical details and supports diversity in age-specific synthesis without the artifacts of learned compression, while allowing the model to generate samples in very few steps.

This repository contains the official PyTorch implementation for the paper: "FlowLet: Conditional 3D Brain MRI Synthesis using Wavelet Flow Matching".

FlowLet Architecture Diagram

Key Features

  • Wavelet-Based Flow Matching: Implements multiple Flow Matching formulations (RFM, CFM, VP, Trigonometric) in the 3D Haar wavelet domain for efficient and stable generative modeling.
  • Advanced Conditional Synthesis: Generates 3D brain MRIs conditioned on specified variables (e.g., Age) using a dual mechanism:
    • FiLM (Feature-wise Linear Modulation) in residual blocks for global, feature-wise control.
    • Spatial Conditioning in transformer blocks for spatially-aware, fine-grained anatomical conditioning with Cross-Attention.
  • High-Fidelity 3D Output: Designed to generate high-resolution volumetric NIfTI (.nii.gz) brain images that preserve anatomical coherence.
  • Efficient and Modular U-Net: Employs a robust 3D U-Net architecture built with modular blocks, with optional support for xformers for memory-efficient attention.
  • Comprehensive & Reproducible Workflow: Includes scripts for the entire pipeline: data preparation, training, generation, and quantitative evaluation (FID, MMD, MS-SSIM).

Flow Matching Formulations

FlowLet supports several flow matching strategies, allowing for a systematic evaluation of how trajectory curvature impacts training stability and synthesis quality. Each formulation defines a different path and target velocity field between noise and data.

  • Rectified Flow Matching (RFM): Performs a simple linear interpolation between noise and data. The straight-line path and constant velocity field promote stable training and produce high-quality, coherent anatomical structures.
  • Conditional Flow Matching (CFM): Also uses a linear path, but defines a time-dependent target velocity that points from the current state x_t to the data x_1.
  • Variance-Preserving (VP) Diffusion Matching: Defines a non-linear, curved path inspired by Denoising Diffusion Probabilistic Models (DDPMs), governed by a variance schedule.
  • Trigonometric Flow: Uses a circular interpolation path on a unit half-circle, introducing smooth, curved trajectories with a constant norm.

How to Use

You can select the desired formulation during training by using the --flow_type command-line argument. The available choices are rectified, cfm, vp_diffusion, and trigonometric. See the training_ablation.sh script for an example of how to launch training for all variants.

# Example of training with Conditional Flow Matching (CFM)
python scripts/train.py \
    --flow_type cfm \
    ... # other arguments

Installation

1. Data Availability

Due to patient privacy regulations and strict data use agreements, we cannot directly redistribute the 3D MRI scans. To use the same training data, researchers must apply for access to the original datasets, which are openly available to the research community upon request.

The datasets used are:

To ensure precise replication, we provide the exact list of subjects and scans used for training in the file: Dataset_preparation/metadata/main_dataset_catalog.csv. After obtaining the data, you can use this catalog to construct the identical dataset cohort.

  1. Prerequisites: The framework was developed using Python 3.11 and CUDA 12.0.

  2. Create a Conda Environment:

    conda create -n flowlet_env python=3.11
    conda activate flowlet_env
  3. Install Dependencies:

    pip install -r requirements.txt
  4. Install xFormers (Optional, Recommended for efficiency): Follow the official instructions at the xFormers GitHub repository.

    # This command often works, but may vary based on your setup
    pip install -U xformers

Data Preparation

IMPORTANT: Raw T1-weighted MRI volumes must be preprocessed using the standardized pipeline provided in the MRI_preprocessing folder before running any experiments or evaluations.

The model can be trained using two methods for data loading. The metadata CSV approach is recommended as it is more robust and was used for the paper's experiments.

Method 1: Metadata CSV (Recommended)

  1. Structure: Prepare a single CSV file containing metadata for your entire preprocessed dataset. This CSV must include a column with the absolute path to each NIfTI file and columns for all conditions you wish to use (e.g., Age, Condition).
  2. Create the CSV: You can use the provided script to generate this file from a directory of NIfTI files that have been pre-named to include age information.
    # Example from Dataset_preparation/create_metadata_csv.sh
    PYTHONPATH=. python3 Dataset_preparation/create_metadata_csv.py \
        --input_dirs /path/to/your/nifti/data \
        --output_csv ./metadata/main_dataset_catalog.csv \
        --condition_label CN
  3. Training with CSV: During training, point to this file using the --metadata_csv argument. You can also filter the dataset using --csv_filter_col and --csv_filter_value.

Method 2: Filename Parsing

As a simpler alternative, you can point the trainer directly to a folder of NIfTI files.

  1. Dataset Structure: Place all .nii.gz files in a single directory.
  2. Filename Convention: The script extracts conditions from filenames. Ensure your filenames include the condition variables, like _AGE_. The default regex looks for [_-]AGE[_-]([0-9.]+).
    • Examples: subject001_AGE_65.3.nii.gz, sub-002-age-22.0.nii.gz
  3. Training with Folder: Use the --data_folder argument in scripts/train.py.

Training

The main training script is scripts/train.py.

Example Command (using Metadata CSV):

PYTHONPATH=. nohup python3 -u scripts/train.py \
    --metadata_csv ./metadata/main_dataset_catalog.csv \
    --run_name "FlowLet_RFM_Training" \
    --flow_type rfm \
    --condition_vars Age \
    --csv_filter_col Condition \
    --csv_filter_value CN \
    --epochs 200 \
    --batch_size 4 \
    --lr 3e-6 \
    --model_input_size 112 112 112 \
    --save_size 91 109 91 \
    --unet_model_channels 128 \
    --unet_channel_mult "1,2,4,8" \
    --unet_attention_res "4,8" \
    --wandb \
    --wandb_project "FlowLet_Project" > logs/training_rfm.log 2>&1 &

Key Training Arguments

  • --metadata_csv: Path to your metadata CSV file (recommended).
  • --data_folder: Path to your NIfTI dataset folder (alternative).
  • --run_name: A unique name for this experiment. Checkpoints and logs will be saved in checkpoints_flowlet/<run_name>.
  • --flow_type: The Flow Matching formulation to use. Choices: rectified, cfm, vp_diffusion, trigonometric.
  • --condition_vars: List of conditions to use (must be columns in the CSV or parseable from filenames).
  • --model_input_size: The spatial size images are padded before the DWT. Must be divisible by 2^(num_downsampling_layers).
  • --save_size: The final spatial size to crop generated images to after the IDWT.
  • --unet_*: Arguments to configure the U-Net architecture (channels, attention resolutions, etc.).
  • --wandb: Enable Weights & Biases logging.

Generating Samples

1. Generating with Linearly Interpolated Ages

Use scripts/generate_linear.py to generate a sequence of samples where age is varied smoothly. This is ideal for visualization and for creating large, age-balanced datasets for downstream tasks.

PYTHONPATH=. python3 -u scripts/generate_linear.py \
    --checkpoint_path checkpoints_flowlet/<run_name>/fmw_best.pth \
    --output_dir ./generated_samples/<run_name>/linear_age \
    --condition_ranges_path checkpoints_flowlet/<run_name>/condition_ranges.json \
    --min_age 5.9 \
    --max_age 95.5 \
    --num_total_samples 3000 \
    --num_flow_steps 10 \
    --save_size 91 109 91

2. Generating for Specific Conditions

Use scripts/generate.py to create a fixed number of samples for specific, discrete conditions.

PYTHONPATH=. python3 -u scripts/generate.py \
    --checkpoint_path checkpoints_flowlet/<run_name>/fmw_best.pth \
    --output_dir ./generated_samples/<run_name>/specific \
    --condition_ranges_path checkpoints_flowlet/<run_name>/condition_ranges.json \
    --generation_conditions "Age=45" "Age=70.5" \
    --num_synthetic 50 \
    --save_size 91 109 91

Advanced: Generating Ablation Samples

The scripts/generate_linear.py script contains a generation_modes dictionary that allows you to generate samples with specific conditioning mechanisms turned off such as FiLM only, Spatial Condition only (cross-attention) or complete unconditional. This is the exact tool used to produce the samples for the ablation study in the paper.


Full Evaluation Experiments



Reproducing the Quantitative Evaluation

This section details how to perform the quantitative evaluation of generated samples, as reported in the paper. The evaluation is designed to assess image fidelity, distributional similarity, and sample diversity.

1. The Evaluation Script

The primary script for this task is Evaluation_metrics/Evaluation_FID_MMD_MSSIM.py. This script handles the calculation of all key metrics, including age-stratified analysis and statistical significance testing.

2. How it Works

The script compares one or more directories of generated 3D NIfTI files against a directory of real NIfTI files. It calculates three main metrics:

  1. Fréchet Inception Distance (FID): Measures the similarity between the distributions of real and generated images in a deep feature space. A lower FID indicates that the generated samples are more realistic and diverse.
  2. Maximum Mean Discrepancy (MMD): An alternative metric to FID for comparing distributions in a feature space, using a Gaussian kernel. A lower MMD indicates better similarity.
  3. Multi-Scale Structural Similarity (MS-SSIM): Measures the structural similarity between pairs of images within the generated set. A lower MS-SSIM score indicates higher diversity among the generated samples (i.e., the model is not collapsing to a single mode).

The evaluation pipeline follows these key steps:

  1. Feature Extractor: A pre-trained 3D Medical ResNet-50 is loaded. This network acts as a feature extractor, converting each 3D MRI into a high-dimensional feature vector (an "activation").
  2. Data Normalization: A global intensity normalization is calculated by sampling a subset of the real data. This ensures that both real and generated images are processed with a consistent intensity range.
  3. Activation Calculation: The script processes all real images and all generated images (from each provided directory) through the feature extractor to get their corresponding sets of feature vectors.
  4. Metric Calculation:
    • FID/MMD: The script compares the set of real activations against the set of generated activations. To ensure stable and reliable results, it uses a bootstrapping procedure: it repeatedly takes random subsamples from both activation sets, calculates FID and MMD on these subsamples, and then reports the mean and standard deviation of these scores over all iterations.
    • MS-SSIM: The script calculates the pairwise MS-SSIM between a large number of randomly sampled pairs from the generated dataset to assess intra-set diversity.
  5. Age-Stratified Analysis: If the filenames contain _AGE_ tags, the script automatically groups both real and generated samples into predefined age bands (e.g., 15-30, 40-55, 65-80) and repeats the metric calculations for each band. This provides a fine-grained understanding of model performance across different demographics.
  6. Statistical Testing: If more than one directory of generated samples is provided, the script automatically treats the first directory as the baseline. It performs a Wilcoxon rank-sum test on the distributions of bootstrapped FID/MMD scores to determine if the differences between models are statistically significant, applying a Bonferroni correction to account for multiple comparisons.

3. How to Run the Evaluation

You can run the full evaluation pipeline using the Evaluation_metrics/Evaluation_FID_MMD_MSSSIM.sh script.

Steps:

  1. Open the Evaluation_metrics/Evaluation_FID_MMD_MSSSIM.sh script.
  2. Set the --real_dir argument to the path of your directory containing the real NIfTI dataset.
  3. Set the --gen_dirs argument to a space-separated list of paths to the directories containing your generated NIfTI samples. The first path in this list will be treated as the baseline for statistical comparisons.
  4. Ensure the --medical_resnet_path points to the pre-trained weights file (resnet_50_epoch_110_batch_0.pth).
  5. Specify the --output_csv path where the final results table will be saved.
  6. (Optional) Adjust parameters like --max_fid_samples, --max_ssim_samples, and --num_bootstraps to balance computational cost and precision.
  7. Execute the script from the project's root directory:
    bash Evaluation_metrics/Evaluation_FID_MMD_MSSSIM.sh

Reproducing the Wavelet Selection Ablation Study

This section details how to reproduce the wavelet selection analysis reported in the paper. The goal of this analysis is to empirically determine which wavelet basis is most suitable for our task by measuring its reconstruction fidelity.

1. The Analysis Script

This analysis is performed by the script located at Evaluation_wavelet_ablations/calculate_wavelet_errors.py.

2. How it Works

The script systematically evaluates the "round-trip" reconstruction error for a list of different wavelet families. For each real 3D MRI in the dataset and for each specified wavelet (e.g., 'haar', 'db4', 'sym4'):

  1. Load Data: A 3D NIfTI file is loaded into a NumPy array.
  2. Forward DWT: The script performs a 3D Discrete Wavelet Transform on the image data using the selected wavelet.
  3. Inverse DWT: It immediately performs an Inverse DWT on the resulting coefficients to reconstruct the image.
  4. Calculate Error: It computes the Mean Absolute Error (MAE) between the original image and the reconstructed image. This MAE value quantifies how much information was lost or altered during the DWT/IDWT round-trip process. A lower MAE indicates higher fidelity.
  5. Aggregate Results: After processing all files in the dataset, the script calculates the mean and standard deviation of the MAE scores for each wavelet family.
  6. Report: The final results are sorted by the mean MAE (lowest error first), printed to the console, and saved to a CSV file. The wavelet with the lowest mean MAE is considered the best choice for preserving anatomical information.

3. How to Run the Analysis

You can run the analysis using the Evaluation_wavelet_ablations/calculate_wavelet_errors.sh script, which provides a ready-to-use command.

Steps:

  1. Open the Evaluation_wavelet_ablations/calculate_wavelet_errors.sh script.
  2. Modify the --input_dir to point to the directory containing your real NIfTI dataset.
  3. (Optional) Modify the --wavelets list to test different wavelet families. The default list in the script matches the one used for the paper's ablation.
  4. Specify the desired --output_csv path for the results.
  5. (Optional) For a quick test, you can use the --num_files argument to limit the analysis to a small subset of your data.
  6. Execute the script from the project's root directory:
    bash Evaluation_wavelet_ablations/calculate_wavelet_errors.sh

Reproducing the Sampling Time Benchmark

This section details how to reproduce the sampling time benchmarking for FlowLet as presented in the paper. The benchmark measures the wall-clock time required to generate a single 3D brain MRI sample for various numbers of ODE integration steps.

1. The Benchmarking Script

The entire process is managed by the script located at Benchmarking_time/benchmarking_time.py. This script is designed to provide accurate and stable timing measurements.

2. How it Works

The script systematically evaluates the model's sampling performance by following these steps for a predefined list of step counts (e.g., 1, 2, 5, 10, 100, 200):

  1. Model Loading: It loads a pre-trained model checkpoint and its associated configuration file. The model is set to evaluation mode (model.eval()) to disable operations like dropout.
  2. GPU Warm-up: Before any timing begins, the script generates a few "warm-up" samples. This is a crucial step to ensure that the GPU is fully initialized and any one-time memory allocations are completed. This prevents first-run overhead from skewing the timing results.
  3. Timed Measurement Loop: After the warm-up, the script generates a sequence of samples (e.g., 15) one by one. The generation time for each individual sample is measured.
  4. Accurate Timing: To ensure timing accuracy, especially on GPUs, the script uses the following precautions:
    • torch.cuda.synchronize(): This command is called immediately before starting the timer and after the generation is complete. This forces the CPU to wait for all pending GPU operations to finish, ensuring that the measured time accurately reflects the full duration of the GPU workload, not just the time it took to launch the kernels.
    • time.perf_counter(): This is used for measuring the time intervals, which is more suitable for short-duration benchmarks than other timing functions.
  5. Aggregation and Reporting: After timing all samples for a given step count, the script calculates the mean and standard deviation of the timings. This provides a stable average generation time and a measure of its variability.
  6. Results Export: The final results (Steps, Mean Time, Standard Deviation) are printed to the console and saved to a CSV file for easy analysis and plotting.

All Brain Age Prediction experiments, regional plausibility evaluations, and their associated preprocessing steps are provided in the corresponding folders within the repository.