A lightweight implementation of neural networks from 'scratch' in Rust. It supports both training and inference. For now, it focuses on the MNIST dataset where the task is digit recognition on 28x28 black-and-white images.
Dependencies:
- Core dependency is ndarray for efficient operations on ( n )-dimensional arrays.
Layers available:
- Fully-connected layer
- Convolution, max pooling, flatten layers
- ReLU and softmax activations
Training:
- Stochastic gradient descent (SGD) with cross-entropy loss.
- Persistence: save and load models via JSON checkpoints (courtesy of serde).
ONNX export:
- export models to the ONNX format for cross-platform deployment
To implement the convolution operation efficiently, the img2col method is used. The idea is to turn this operation into one single matrix multipication which benefits from optimized kernels (see GEMM).
Ressource:
To train the model on MNIST, use the train. You can pass a bunch of standard
hyper-parameters as well as specify paths for check-pointing:
cargo run --release -- train \
--learning-rate 0.001 \
--batch-size 64 \
--nb-epochs 30 \
--train-data-dir data/ \
--checkpoint-folder ckpt/ \
--checkpoint-stride 5 \
--loss-csv-path loss.csv \Note: use the --help flag to get more info (cargo run -- train --help)
You can load a saved model and run inference on a bitmap image using the run
command:
cargo run --release -- run \
--checkpoint <path-to-model-checkpoint> \
--image-path <path-to-image.bin>Models defined with this Rust engine can be exported to the ONNX format. Use the
export command:
cargo run -- export \
--checkpoint-path <path-to-checkpoint> \
--onnx-path <output-onnx-file-path>Resources:
- MMAP blog post: ONNX introduction
- Doc:
We currently only have a single test. It tries to train a small CNN on a single batch of random data using SGD with momentum. The goal is to overfit the dataset within a given optimization 'step budget'.
This test ensure that we don't break the core training mechanics - at least with this optimizer.
To run it, you can use cargo test. However I recommend testing on release mode
for faster execution and using the --nocapture flag to see debug prints:
cargo test --release -- --nocapture
Note: the test was slightly flaky due to potential bad luck in the network initialization. A simple retry strategy is implemented, which makes it reliable.
