InvFusion: Bridging Supervised and Zero-shot Diffusion for Inverse Problems
Official PyTorch implementation
Noam Elata*, Hyungjin Chung*, Jong Chul Ye, Tomer Michali, Michael Elad
https://arxiv.org/pdf/2504.01689
Diffusion Models have demonstrated remarkable capabilities in handling inverse problems, offering high-quality posterior-sampling-based solutions. Despite significant advances, a fundamental trade-off persists regarding the way the conditioned synthesis is employed: Zero-shot approaches can accommodate any linear degradation but rely on approximations that reduce accuracy. In contrast, training-based methods model the posterior correctly, but cannot adapt to the degradation at test-time. Here we introduce InvFusion, the first training-based degradation-aware posterior sampler. InvFusion combines the best of both worlds -- the strong performance of supervised approaches and the flexibility of zero-shot methods. This is achieved through a novel architectural design that seamlessly integrates the degradation operator directly into the diffusion denoiser. We compare InvFusion against existing general-purpose posterior samplers, both degradation-aware zero-shot techniques and blind training-based methods. Experiments on the FFHQ and ImageNet datasets demonstrate state-of-the-art performance. Beyond posterior sampling, we further demonstrate the applicability of our architecture, operating as a general Minimum Mean Square Error predictor, and as a Neural Posterior Principal Component estimator.
* Equal contribution
Please complete the following steps
-
Install pytorch with the appropriate hardware and the following packages:
click Pillow psutil requests scipy tqdm diffusers accelerateInstall
nattenfitting the appropriate hardware and pytorch version.wandbis optional and is used for logging. -
Download the relevant datasets and place them in the
data/folder. -
pre-compute the reference statistics for your dataset.
torchrun --standalone --nproc_per_node=<num gpus> calculate_metrics.py ref --data datautils.<dataset name>.<dataset class> --dest dataset-refs/<ref name>.pklTo calculate metrics for a trained model, please run:
torchrun --standalone --nproc_per_node=<num gpus> calculate_metrics.py gen \
--net <path to pretrained model> \
--data datautils.<dataset name>.<dataset class> \
--degradation degradation.<degradation to test> \
--ref dataset-refs/<ref name>.pkl \
[--outdir <optional path for saving outputs>]list of available degradations:
degradation.RandomDegradationdegradation.MotionBlurdegradation.MissingPatchesdegradation.MatrixDegradationdegradation.Box
For example, the following evaluates FID and PSNR for Motion Blur restoration on ImageNet
torchrun --standalone --nproc_per_node=8 calculate_metrics.py gen \
--net checkpoints/imagenet/checkpoint.pt \
--data datautils.imagenet.ImageNet64 \
--degradation degradation.MotionBlur \
--ref dataset-refs/imagenet-10k-64.pkl \
--cfg 2.0 \
--metrics fid,psnrtorchrun --standalone --nproc_per_node=<num gpus> train.py \
--data-class datautils.<dataset name>.<dataset class> \
--ref-path dataset-refs/<ref name>.pkl \
--degradation degradation.RandomDegradation \
--name invfusionRemove the --degradation argument to train an unconditional model.
Add --blind-train argument to train a degradation-blind model.
Use python train.py -h to get the list of available arguments.
@article{elata2025invfusion,
title = {InvFusion: Bridging Supervised and Zero-shot Diffusion for Inverse Problems},
author = {Noam Elata and Hyungjin Chung and Jong Chul Ye and Tomer Michaeli and Michael Elad},
journal = {arXiv preprint arXiv:2504.01689},
year = {2025}
}This code uses implementations from EDM2, EDM, HDiT and MotionBlur. Please consider also citing:
@inproceedings{Karras2024edm2,
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
booktitle = {Proc. CVPR},
year = {2024},
}
@inproceedings{Karras2022edm,
author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
title = {Elucidating the Design Space of Diffusion-Based Generative Models},
booktitle = {Proc. NeurIPS},
year = {2022}
}
@InProceedings{crowson2024hourglass,
title = {Scalable High-Resolution Pixel-Space Image Synthesis with Hourglass Diffusion Transformers},
author = {Crowson, Katherine and Baumann, Stefan Andreas and Birch, Alex and Abraham, Tanishq Mathew and Kaplan, Daniel Z and Shippole, Enrico},
booktitle = {Proceedings of the 41st International Conference on Machine Learning},
pages = {9550--9575},
year = {2024},
editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
volume = {235},
series = {Proceedings of Machine Learning Research},
month = {21--27 Jul},
publisher = {PMLR},
pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/crowson24a/crowson24a.pdf},
url = {https://proceedings.mlr.press/v235/crowson24a.html},
abstract = {We present the Hourglass Diffusion Transformer (HDiT), an image-generative model that exhibits linear scaling with pixel count, supporting training at high resolution (e.g. $1024 \times 1024$) directly in pixel-space. Building on the Transformer architecture, which is known to scale to billions of parameters, it bridges the gap between the efficiency of convolutional U-Nets and the scalability of Transformers. HDiT trains successfully without typical high-resolution training techniques such as multiscale architectures, latent autoencoders or self-conditioning. We demonstrate that HDiT performs competitively with existing models on ImageNet $256^2$, and sets a new state-of-the-art for diffusion models on FFHQ-$1024^2$. Code is available at https://github.com/crowsonkb/k-diffusion.}
}