Skip to content

noamelata/InvFusion

Repository files navigation

InvFusion: Bridging Supervised and Zero-shot Diffusion for Inverse Problems
Official PyTorch implementation

Restored Images

Noam Elata*, Hyungjin Chung*, Jong Chul Ye, Tomer Michali, Michael Elad
https://arxiv.org/pdf/2504.01689

Abstract

Diffusion Models have demonstrated remarkable capabilities in handling inverse problems, offering high-quality posterior-sampling-based solutions. Despite significant advances, a fundamental trade-off persists regarding the way the conditioned synthesis is employed: Zero-shot approaches can accommodate any linear degradation but rely on approximations that reduce accuracy. In contrast, training-based methods model the posterior correctly, but cannot adapt to the degradation at test-time. Here we introduce InvFusion, the first training-based degradation-aware posterior sampler. InvFusion combines the best of both worlds -- the strong performance of supervised approaches and the flexibility of zero-shot methods. This is achieved through a novel architectural design that seamlessly integrates the degradation operator directly into the diffusion denoiser. We compare InvFusion against existing general-purpose posterior samplers, both degradation-aware zero-shot techniques and blind training-based methods. Experiments on the FFHQ and ImageNet datasets demonstrate state-of-the-art performance. Beyond posterior sampling, we further demonstrate the applicability of our architecture, operating as a general Minimum Mean Square Error predictor, and as a Neural Posterior Principal Component estimator.

* Equal contribution

Prerequisites

Please complete the following steps

  1. Install pytorch with the appropriate hardware and the following packages:

    click Pillow psutil requests scipy tqdm diffusers accelerate 
    

    Install natten fitting the appropriate hardware and pytorch version.

    wandb is optional and is used for logging.

  2. Download the relevant datasets and place them in the data/ folder.

  3. pre-compute the reference statistics for your dataset.

torchrun --standalone --nproc_per_node=<num gpus> calculate_metrics.py ref --data datautils.<dataset name>.<dataset class> --dest dataset-refs/<ref name>.pkl

Inference

Model checkpoints coming soon!

To calculate metrics for a trained model, please run:

torchrun --standalone --nproc_per_node=<num gpus> calculate_metrics.py  gen   \
      --net <path to pretrained model>                                        \
      --data datautils.<dataset name>.<dataset class>                         \
      --degradation degradation.<degradation to test>                         \
      --ref dataset-refs/<ref name>.pkl                                       \
      [--outdir <optional path for saving outputs>]

list of available degradations:

  • degradation.RandomDegradation
  • degradation.MotionBlur
  • degradation.MissingPatches
  • degradation.MatrixDegradation
  • degradation.Box

For example, the following evaluates FID and PSNR for Motion Blur restoration on ImageNet $64\times 64$ with 8 GPUs.

torchrun --standalone --nproc_per_node=8 calculate_metrics.py  gen  \
      --net checkpoints/imagenet/checkpoint.pt                      \
      --data datautils.imagenet.ImageNet64                          \
      --degradation degradation.MotionBlur                          \
      --ref dataset-refs/imagenet-10k-64.pkl                        \
      --cfg 2.0                                                     \
      --metrics fid,psnr

Train

torchrun --standalone --nproc_per_node=<num gpus> train.py    \
      --data-class datautils.<dataset name>.<dataset class>   \
      --ref-path dataset-refs/<ref name>.pkl                  \
      --degradation degradation.RandomDegradation             \
      --name invfusion

Remove the --degradation argument to train an unconditional model. Add --blind-train argument to train a degradation-blind model. Use python train.py -h to get the list of available arguments.

Citation

@article{elata2025invfusion,
   title    = {InvFusion: Bridging Supervised and Zero-shot Diffusion for Inverse Problems},
   author   = {Noam Elata and Hyungjin Chung and Jong Chul Ye and Tomer Michaeli and Michael Elad},
   journal  = {arXiv preprint arXiv:2504.01689},
   year     = {2025}
}

Acknowledgments and References

This code uses implementations from EDM2, EDM, HDiT and MotionBlur. Please consider also citing:

@inproceedings{Karras2024edm2,
   title     = {Analyzing and Improving the Training Dynamics of Diffusion Models},
   author    = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
   booktitle = {Proc. CVPR},
   year      = {2024},
}
@inproceedings{Karras2022edm,
   author    = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
   title     = {Elucidating the Design Space of Diffusion-Based Generative Models},
   booktitle = {Proc. NeurIPS},
   year      = {2022}
}
@InProceedings{crowson2024hourglass,
   title     = {Scalable High-Resolution Pixel-Space Image Synthesis with Hourglass Diffusion Transformers},
   author    = {Crowson, Katherine and Baumann, Stefan Andreas and Birch, Alex and Abraham, Tanishq Mathew and Kaplan, Daniel Z and Shippole, Enrico},
   booktitle = {Proceedings of the 41st International Conference on Machine Learning},
   pages     = {9550--9575},
   year      = {2024},
   editor    = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
   volume    = {235},
   series    = {Proceedings of Machine Learning Research},
   month     = {21--27 Jul},
   publisher = {PMLR},
   pdf       = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/crowson24a/crowson24a.pdf},
   url       = {https://proceedings.mlr.press/v235/crowson24a.html},
   abstract  = {We present the Hourglass Diffusion Transformer (HDiT), an image-generative model that exhibits linear scaling with pixel count, supporting training at high resolution (e.g. $1024 \times 1024$) directly in pixel-space. Building on the Transformer architecture, which is known to scale to billions of parameters, it bridges the gap between the efficiency of convolutional U-Nets and the scalability of Transformers. HDiT trains successfully without typical high-resolution training techniques such as multiscale architectures, latent autoencoders or self-conditioning. We demonstrate that HDiT performs competitively with existing models on ImageNet $256^2$, and sets a new state-of-the-art for diffusion models on FFHQ-$1024^2$. Code is available at https://github.com/crowsonkb/k-diffusion.}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages