This is Tiny Recursion Model (TRM), a research codebase for recursive reasoning on puzzle-solving tasks (ARC-AGI, Sudoku, Maze). TRM achieves 45% on ARC-AGI-1 using only 7M parameters through iterative answer refinement with cycles of latent reasoning (H_cycles) and answer updates (L_cycles). The core innovation is simplifying recursive reasoning to its essence: recursively update latent state z, then update answer y, repeat.
Key architectural components:
- models/recursive_reasoning/trm.py: Main TRM model with ACT (Adaptive Computation Time) halting
- models/recursive_reasoning/hrm.py: Hierarchical Reasoning Model baseline (more complex, hierarchical approach)
- puzzle_dataset.py: Unified dataset interface for all puzzle types with augmentation support
- pretrain.py: Main training script with distributed training support
All experiments are configured via Hydra with hierarchical YAML configs:
config/cfg_pretrain.yaml: Base training config (data paths, hyperparams, evaluators)config/arch/*.yaml: Model architectures (trm, hrm, transformers_baseline, etc.)- Each defines
name: module.path@ClassNamefor dynamic loading - Example:
recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
- Each defines
Override config from CLI:
python pretrain.py arch=trm data_paths="[data/sudoku]" lr=1e-4 arch.H_cycles=3TRM uses two separate optimizers with different learning rates:
CastedSparseEmbeddingSignSGD_Distributed: For puzzle embeddings (per-puzzle learned vectors)- Operates on
model.model.puzzle_emb.buffers()- note the sparse embedding uses buffers, not parameters - Custom SignSGD with all-reduce for distributed training
- Operates on
AdamATan2: For model weights (AdamW variant with atan2-based adaptive learning)
The puzzle_emb_lr is typically 10-100x higher than base lr (e.g., 1e-2 vs 1e-4).
Each puzzle gets a unique learned embedding via CastedSparseEmbedding:
- Training: Copies relevant embeddings to
local_weightsbuffer, computes gradients, then aggregates back to mainweights - Inference: Direct lookup from
weights, no gradient tracking - Critical: The embedding dimension can differ from model dimension - padded/truncated to fit via
puzzle_emb_len
Model learns when to halt via Q-learning:
q_halt_logits: Binary classifier predicting if current answer is correctq_continue_logits: Bootstrap target for continuing (disabled by default withno_ACT_continue: True)- Loss components:
lm_loss(cross-entropy) +q_halt_loss(binary halting) + optionalq_continue_loss
Uses PyTorch DDP patterns:
- Manually broadcasts parameters from rank 0 after initialization
- Manual
all_reduceon gradients intrain_batch()instead of DDP wrapper - Per-GPU batch size:
global_batch_size // world_size - Launch with
torchrun: SetRANK,WORLD_SIZE,LOCAL_RANKenvironment variables
torch.compile()is default (disable withDISABLE_COMPILE=1)- Checkpoints saved as state_dicts at
checkpoint_path/step_{N} - Checkpoint loading handles puzzle embedding resizing - if shape mismatch, resets to mean of old embeddings
- Compiled model keys prefixed with
_orig_mod.in state_dict
All datasets must be pre-processed to standardized format:
# ARC-AGI with dihedral augmentation (8 symmetries)
python -m dataset.build_arc_dataset \
--input-file-prefix kaggle/combined/arc-agi \
--output-dir data/arc1concept-aug-1000 \
--subsets training evaluation concept \
--test-set-name evaluation \
--num-aug 1000
# Sudoku/Maze
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000Output structure: Each dataset folder contains:
metadata.json: Vocab size, sequence length, padding IDs, number of puzzlesidentifiers.json: Maps puzzle IDs to namestrain.npz,test.npz: Packed sequences withexamples,puzzle_indices,group_indicestest_puzzles.json: Original test examples for evaluation
L_level: "Lower level" - operates on latent statez_L, updates answerH_level: "Higher level" (HRM only) - hierarchical latentz_H- TRM simplifies by using only one latent level with recursive updates
SwiGLU: Gated FFN activation (hidden_size → expansion*hidden_size → hidden_size)CastedLinear/CastedEmbedding: Auto-casting toforward_dtype(bfloat16 default)rms_norm: RMS normalization, post-norm architecture (norm after residual)mlp_t: Optional flag to replace attention with MLP over sequence dimension
Config pos_encodings controls: "rope" (RoPE, default), "learned", "none"
Evaluators are registered in config and run during eval intervals:
evaluators:
- name: arc@ARC # module.path@ClassName
submission_K: 2
pass_Ks: [1, 2, 5, 10]Each evaluator:
- Declares
required_outputs(e.g.,{"preds", "q_halt_logits", "puzzle_identifiers"}) - Implements
begin_eval(),update_batch(),result()lifecycle - ARC evaluator aggregates predictions with inverse augmentation and voting
# Single GPU
python pretrain.py arch=trm data_paths="[data/sudoku]" +run_name=my_experiment
# Multi-GPU (4 GPUs)
torchrun --nproc_per_node=4 pretrain.py arch=trm data_paths="[data/arc1]"
# Resume from checkpoint
python pretrain.py load_checkpoint=checkpoints/my_project/my_run/step_10000- Use
DISABLE_COMPILE=1to disable torch.compile for better stack traces - Check WandB logs for metrics:
train/accuracy,train/exact_accuracy,train/steps,ARC/pass@K - EMA (Exponential Moving Average) can be enabled with
ema=True ema_rate=0.999for smoother convergence
- Create
models/recursive_reasoning/my_model.pywithMyModel_Innerand config class - Add architecture YAML in
config/arch/my_model.yaml:name: recursive_reasoning.my_model@MyModel_ACTV1 loss: name: losses@ACTLossHead loss_type: stablemax_cross_entropy # ... hyperparams
- Models must implement
initial_carry()andforward()returning(new_carry, outputs)
stablemax_cross_entropy: Custom numerically stable softmax (not standard PyTorch)- Non-autoregressive: All models are parallel (no causal masking) - predict full output sequence at once
- Ignore label: Use
IGNORE_LABEL_ID = -100for padding, not just for loss masking - Dihedral transforms: Augmentations use 8 symmetries; evaluator must inverse transform predictions back
- WandB logging: Run
wandb loginbefore training to sync metrics