We're fine-tuning Stable Diffusion 3 Medium on our PixelProse dataset with the research goal of enhancing text rendering capabilities in large diffusion models. This project extends the text context window length of SD3-M, allowing for longer and more detailed text prompts in image generation.
Specifically, we're aiming for high GenEval scores across categories pertaining to text rendering, counts, shape, color, and object relations.

- Full fine-tuning with custom dataset & dataloader
- Extended context window for SD3-Medium
- Optimizations for memory & runtime (caching text latents for faster I/O, distributed data sharding, and offloading model weights to CPU)
- Validation & checkpointing
- Gradient accumulation and mixed precision training
- Cosine warm-up schedule for trainable parameters
- Model: Stable Diffusion 3 Medium
- GPU: NVIDIA A100s
- Cluster: Polaris & Nexus
- Framework: PyTorch with DeepSpeed
- Dataset: PixelProse 16M with cached latents
Clone this repository:
git clone https://github.com/sukritipaul5/PixelProse-PowerUp.git
cd PixelProse-PowerUp/scripts/SD3-Medium-ContextWindowInstall dependencies:
Will add soonYou can visit the scripts/ directory to find the scripts we used to fine-tune SD3-M on our dataset. The two folders correpond to full fine-tune and context window extension respectively.
You can qsub the job scripts as follows:
qsub SD3-Medium-ContextWindow/launch.sh
qsub SD3-Medium-FFT/launch.shpython train_sd3m_cw.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-3-base" \
--output_dir="./sd3m_context_window" \
--train_batch_size=1 \
--num_train_epochs=1 \
--learning_rate=1e-4 \
--max_grad_norm=1.0 \
--validation_step=500 \
--validation_prompts_file="/path/to/validation_prompts.txt" \
--mixed_precision="bf16" \
--gradient_accumulation_steps=4 \
--seed=42train_sd3m_cw.py: Main training scriptutils_sd3.py: Utility functions for SD3datapipe_new.py: Custom data pipeline for PixelProse dataset with cached latents
- Extended Context Window: We use a custom MLP to process extended context beyond the standard 77 tokens.
- Custom Dataloader: Our
WebDatasetwithCachedLatentsclass handles large-scale data with pre-computed latents. - Training Loop: The main training loop in
train_sd3m_cw.pyhandles gradient accumulation, mixed precision training, and periodic validation. - Validation: During training, we generate images using the current model state to track progress visually on wandb.
Key parameters in train_sd3m_cw.py:
--pretrained_model_name_or_path: Path to the pretrained SD3-M model--output_dir: Directory to save checkpoints and logs--train_batch_size: Batch size for training--learning_rate: Initial learning rate--max_grad_norm: Maximum gradient norm for clipping--validation_step: Number of steps between validations--mixed_precision: Precision for mixed precision training (fp16, bf16, or no)--gradient_accumulation_steps: Number of steps to accumulate gradients before updating model weights--seed: Random seed for reproducibility--validation_prompts_file: Path to the file containing validation prompts--num_train_epochs: Number of training epochs--cosine_warmup_steps: Number of steps for cosine warm-up schedule--epsilon: Small constant for numerical stability in optimizer--weight_decay: Constant for L2 regularization--lr_scheduler_type: Type of learning rate scheduler (cosine, linear, etc.)--checkpoint_dir: Directory to save checkpoints--checkpoint_steps: Number of steps between checkpoints
- Stability AI for the original SD3-M model.
- The Hugging Face team for their Diffusers library and their dreambooth finetuning repo.
- Databricks for MosaicML and efficient data sharding strategies.
- ALCF for compute.
copyright (c) 2024 TomLab