Paint app in a world model.
- MiniPaint — Pygame canvas for drawing and data collection
- VAE — Variational autoencoder compressesing images into latent representations
- LatentCNN — Residual CNN predicting latent-space changes using brush actions
Requirements: Python 3.12+
# Install dependencies (using uv)
uv sync
# Or with pip
pip install -e .
# Install pre-trained checkpoints (downloads from GitHub releases)
bash setupinfinipaint.sh
# Download LPIPS model weights (only necessary for train/eval)
bash setupmodels.sh
# Copy and fill in environment variables
cp .env.example .env# View raw episodes from a shard
python scripts/run_viewers.py --shard path/to/shard.tar.gz
# View VAE reconstructions
python scripts/run_viewers.py --recon path/to/shard.tar.gz
# Interactive painting with live model inference
python scripts/run_viewers.py --combined# Generate shards (interactive prompts for count, headless mode, etc.)
python scripts/generate_shards.py -n 20 --headless --no-continueThe training CLI auto-detects model types from configs/ and provides interactive selection when no arguments are given.
# Interactive mode (prompts for model type and config)
python scripts/run_trainers.py
# Train VAE
python scripts/run_trainers.py configs/vae/base_vae_ldm.yaml
# Train LatentCNN (requires a trained VAE checkpoint)
python scripts/run_trainers.py configs/latent_cnn/small_latent_cnn.yaml
# Resume from checkpoint
python scripts/run_trainers.py --resume checkpoints/vae/base_vae_ldm_best.safetensors# Evaluate a LatentCNN checkpoint on the test set
python scripts/run_eval.py checkpoints/latent_cnn/small_latent_cnn_best.safetensors
# With custom settings
python scripts/run_eval.py --change-threshold 0.02 --batch-size 16├── configs/
│ ├── vae/ # VAE model configs (YAML)
│ └── latent_cnn/ # LatentCNN model configs (YAML)
├── src/infinipaint/
│ ├── models/
│ │ ├── vae.py # VAE encoder/decoder
│ │ ├── latent_cnn.py # LatentCNN world model
│ │ ├── blocks.py # ResBlock, AdaGNResBlock
│ │ └── lpips.py # LPIPS perceptual loss
│ ├── training/
│ │ ├── core.py # Abstract trainer, training loop, checkpointing
│ │ ├── vae_trainer.py # VAE loss and step functions
│ │ ├── latent_cnn_trainer.py # LatentCNN loss and step functions
│ │ ├── dataloader.py # Episode streaming, precomputation
│ │ ├── metrics.py # SSIM, PSNR, KL divergence
│ │ └── utils.py # Config loading, checkpoint I/O
│ ├── data/
│ │ ├── generator.py # Synthetic episode generation
│ │ ├── viewer.py # Episode viewer
│ │ └── combined_viewer.py # Full pipeline viewer
│ └── minipaint/
│ └── gui.py # Pygame painting interface
├── scripts/
│ ├── run_trainers.py # Training CLI
│ ├── run_eval.py # Evaluation CLI
│ ├── run_viewers.py # Viewer CLI
│ └── generate_shards.py # Data generation CLI
├── dataset/
│ ├── shards/ # Compressed episode archives
│ └── models/ # Pre-trained weights (LPIPS)
└── checkpoints/
├── vae/ # VAE checkpoints
└── latent_cnn/ # LatentCNN checkpoints
This project is licensed under the MIT License.
