This project is a PyTorch implementation of "Self-classifying MNIST Digits: Achieving Distributed Coordination with Neural Cellular Automata" (2020) by Randazzo, Mordvintsev, Niklasson, Levin, and Greydanus.
This repo replicates the functionality of the official Jupyter notebook TensorFlow implementation while adding comments and refactored structure to be more easily digestible and extensible.
- The model is a 3 × 3 convolution followed by two 1 × 1 convolutions with ReLU nonlinearities.
- The model has a total of 22,499 parameters.
- Cells have a stochastic firing rate. This is a form of dropout and acts as a regularizer.
- Cells that are not part of the digit are masked out and never updated.
- There are 20 channels per cell: 1 channel to represent the liveness of cells, 9 general purpose communication channels, and 10 output channels for digit classification.
Channel layout
--------------
|liveness|c1|c2|c3|c4|c5|c6|c7|c8|c9|o1|o2|o3|o4|o5|o6|o7|o8|o9|o10|- Twenty update ticks are simulated before evaluating the loss at each training step.
- A sample pool is maintained. Some samples persist between training steps, which encourages stability of the digits over time and acts as a learning curriculum.
- Some samples in the pool are replaced with fresh training samples each training step. This avoids catastrophic forgetting from earlier in the curriculum.
- Each training step, some samples in the pool are mutated by replacing the digit with a new digit. Where the two digits overlap, cell states are copied from the first. This encourages "dynamic homeostasis" where a digit can reclassify itself after significant changes to its form.
- The additive/residual RNN avoids the vanishing/exploding gradient problem and works with Adam to speed up training convergence.
- Gaussian noise is added to residuals as a form of regularization.
- A pixel-wise L2 loss on one-hot targets worked better than cross-entropy loss. Cross-entropy rewards a growing margin between the true logit vs others, whereas L2 allows the gradient to decay quadratically once the true class has the largest logit, enabling convergence to a fixed point without flickering.
- The learning rate schedule can significantly help later stages of training converge more quickly.
Install using uv
uv syncRun the train command to start training and see train.py for configuration options. You should see loss history figures appear periodically in a directory named outputs_train.
ca trainRun the evaluate command to generate the video below and see evaluate.py for configuration options. You can specify a custom model checkpoint or use a pre-trained checkpoint included in this repo. Outputs will be saved in a directory named outputs_evaluate.
ca evaluate checkpoints/model-steps-100000Here's another command to generate a slowed down animation of a single digit, which helps to visualize the local message passing dynamics. You can see how a region of a digit can start off "confused" about its identity and then be "convinced" by neighboring cells to converge on the correct classification.
ca evaluate checkpoints/model-steps-100000 --batch-size 1 --no-slider --zoom-scale 20 --fps 3 --n-stages 1 --n-steps-per-stage 30 --seed 84 --annotations| Framework | Hardware | Training speed | Est. total train time * |
|---|---|---|---|
| TensorFlow | MacBook M2 CPU | ~400 steps/min | ~4.2 hrs |
| PyTorch | MacBook M2 CPU | ~150 steps/min | ~11 hrs |
| PyTorch | MacBook M2 w/ MPS | ~720 steps/min | ~2.3 hrs |
| PyTorch | H100 GPU | ~2000 steps/min | ~0.83 hrs |
* Assuming 100,000 training steps total













