This repository implements the Multi-Marginal Schrödinger Bridge (Multi-Marginal-SB) framework for generative modeling tasks, such as face editing, image translation (e.g., face-to-dog), and Brownian bridge generation. The project leverages models like ALAE and StyleGAN2 for latent space manipulations and is inspired by advancements in optimal transport and diffusion processes.
For detailed technical explanations, model architectures, and experimental results, refer to the ICLR Submission Document.
- Multi-Marginal SB Training: Train models for multi-distribution bridging.
- Face Editing: Edit faces based on text prompts.
- Image Translation: Translate between domains like faces and dogs.
- Brownian Bridge Generation: Generate interpolations and animations.
- Pre-trained Checkpoints: Available for quick inference.
- Python 3.10+
- PyTorch 1.10+
- Additional dependencies:
numpy,matplotlib,torchvision,pillow,tqdm
-
Clone the repository:
git clone https://github.com/yourusername/Multi-Marginal-SB.git cd Multi-Marginal-SB -
Install dependencies:
pip install -r requirements.txt(Note: Create a
requirements.txtfile with the listed dependencies if not present.) -
Download pre-trained checkpoints from Google Drive. Extract and place them in the appropriate directories (e.g.,
checkpoints/).
Below is a list of main scripts and their purposes. For full details on each script's functionality, parameters, and outputs, refer to the ICLR Submission Document.
train_unpaired_ALAE.py: Trains an unpaired LightSB model for latent space encoding.encode_dogs_with_stylegan2.py: Encodes dog images into StyleGAN2 latent space using projection.train_face2dog.py: Trains a LightSB model for face-to-dog translation.train_face_editing_sb.py: Trains MultiMarginalSB for face editing tasks using text prompts.train_text2face_paired.py: Trains a paired model for text-to-face generation.train_dog2face.py: Trains a model for dog-to-face translation.train_mapper_only.py: Trains a mapper network independently.train_multi_marginal.py: Trains the core multi-marginal Schrödinger Bridge model.latent_bridge_adult2children.py: Generates latent bridges from adult to children faces.generate_gif_brownian_bridge.py: Generates GIF animations of Brownian bridges.generate_lambda_sweep_demo.py: Demonstrates lambda parameter sweeps for generation.generate_prompt_brownian_bridge.py: Generates Brownian bridges based on prompts.encode_captions.py: Encodes text captions into latent representations.face_edit.py: Performs face editing inference using trained models.generate_from_noise.py: Generates images from noise inputs.generate_from_text.py: Generates images from text descriptions.display_dog2face.py: Displays dog-to-face translation results.distributions.py: Utility for handling data distributions.
-
Train face editing model:
python train_face_editing_sb.py -
Generate from text:
python generate_from_text.py --prompt "a smiling face" -
For custom parameters, refer to each script's configuration section.
Download pre-trained models and checkpoints from this Google Drive folder. Files include:
ALAE.zip: ALAE model checkpoints.checkpoint.zip: Various training checkpoints.clip-vit-base-patch16.zip: CLIP model for text encoding.data.zip: Encoded Image and caption latent.
Extract and place them in the project folder.
Contributions are welcome! Please open an issue or submit a pull request for bug fixes, features, or improvements.
This project is licensed under the MIT License - see the LICENSE file for details.
- Based on ALAE and StyleGAN2 implementations.
- Inspired by Schrödinger Bridge and optimal transport literature.
For questions, contact [[email protected]].