implemented:
- llama (gguf parsing and dequantization (dropped my own llamacpp dequantization implementation and used candle's instead), tokenizer, rope, kv-cache, rmsnorm, attention)
- cnn (26ms for im2col vs 173ms for naive) /vit/llava/autoencoder inference with onnx (parsing and graph executor but only needed ops)
- training - autograd for dynamic and graph static compiling for static
- also implemented mini-cubecl to study how to properly use proc macros and transpose rust code to wgsl (dynamic, can replace with any other runtime like cuda)
ndarray_lite_studyndarray reimplementation to get a full intuition- wrote some gpu kernels to practice using original cubecl - attention, flash_attention, sigmoid, softmax, blelloch scan, ssm (simple mamba)
- and cli's to test things out (
cargo run -p inference-clietc.)
there is no Backend god-trait to reuse training and inference logic like in burn cuz didnt plan to connect these parts
tested models (hugging face):
gguf llama - tinyllama-15M-stories-Q3_K_M.gguf
gguf moe - Tiny-Moe.Q2_K.gguf
onnx vision - mnist-12-int8.onnx
onnx vis transformer - vit_base_patch16_224.augreg2_in21k_ft_in1k-ONNX___model_fp16.onnx
onnx llava - onevision-qwen2-0.5b (vision_encoder + embed_tokens + decoder_model_merged)
also got a funny schrodinger's cat effect - monitoring llama_cpp_inference_ref.rs llama output using callbacks altered the output itself
mini cubecl macros works kinda like this
let x = a + 5; ->>> let x_var = builder.add(a_var, builder.constant_u32(5));
some logs (debug build):
CPU (ndarray): 316.39 ms
GPU (CubeCL): 0.81 ms
GPU (Burn native): 0.35 ms
CubeCL vs CPU: 391.14x speedup
Burn vs CPU: 894.01x speedup
CubeCL vs Burn: 0.44x
4096 seq len
naive attention with 2d distribution - 75ms
flash attention with global mem - 341ms
- tiling with shared mem - 278ms
- with causal - 257ms
now getting "Speedup: 0.30x | Memory saved: 99%" (1mb memory for flash, 135mb memory for naive)
sooo to support 2d parallel stuff instead of 1d we have to use "flash attention v2" with merge online softmax and gemm instead of dot etc
or maybe add multihead batches, mixed precision support (is that possible in cubecl?), tensor cores, line vectorization, subgroup sync ("planes" in cubecl terms) (wgpu experimental feat but no problems on vulkan/hip/cuda backend), double buffering (or warp specialization in v3, etc.) first
loading onnx model from: mnist-12-int8.onnx
[PARSER] found Input3_Convolution28_QuantizeLinear node, op_type QuantizeLinear
[PARSER] found Convolution28_quant node, op_type QLinearConv
[PARSER] auto-calculated pads for Convolution28_quant: [2, 2, 2, 2]
[PARSER] found Pooling66_quant node, op_type MaxPool
[PARSER] found Convolution110_quant node, op_type QLinearConv
[PARSER] auto-calculated pads for Convolution110_quant: [2, 2, 2, 2]
[PARSER] found Convolution110_DequantizeLinear_0 node, op_type DequantizeLinear
[PARSER] found Pooling160 node, op_type MaxPool
[PARSER] found Times212_reshape0 node, op_type Reshape
[PARSER] found Pooling160_Output_0_reshape0_gemm_MatMul_QuantizeLinear node, op_type QuantizeLinear
[PARSER] found gemm_MatMul_quant node, op_type QLinearMatMul
[PARSER] found gemm_Add_quant node, op_type QLinearAdd
[PARSER] found Plus214_Output_0_DequantizeLinear_0 node, op_type DequantizeLinear
successfully built vision model with 11 nodes
running inference on mnist_3_28x28.png...
[EXECUTOR] running node 'Input3_Convolution28_QuantizeLinear', op: QuantizeLinear
[EXECUTOR] running node 'Convolution28_quant', op: Conv(Conv { stride: (1, 1), padding: (2, 2) })
[EXECUTOR] running node 'Pooling66_quant', op: MaxPool(MaxPool2d { kernel_shape: (2, 2), stride: (2, 2), padding: (0, 0) })
[EXECUTOR] running node 'Convolution110_quant', op: Conv(Conv { stride: (1, 1), padding: (2, 2) })
[EXECUTOR] running node 'Convolution110_DequantizeLinear_0', op: DequantizeLinear
[EXECUTOR] running node 'Pooling160', op: MaxPool(MaxPool2d { kernel_shape: (3, 3), stride: (3, 3), padding: (0, 0) })
[EXECUTOR] running node 'Times212_reshape0', op: Reshape
[EXECUTOR] running node 'Pooling160_Output_0_reshape0_gemm_MatMul_QuantizeLinear', op: QuantizeLinear
[EXECUTOR] running node 'gemm_MatMul_quant', op: MatMul
[EXECUTOR] running node 'gemm_Add_quant', op: Add
[EXECUTOR] running node 'Plus214_Output_0_DequantizeLinear_0', op: DequantizeLinear
model prediction: 3
takes ~30 SECONDS to naively run parsed onnx vision transformer (vit_base_patch16_224.augreg2_in21k_ft_in1k-ONNX___model_fp32)
so have to replace eager exec with a static one (graph-compiler crate) and add simd/gpu usage
in comparison python onnxruntime takes... like a second
lets ignore how slow llava is hahaha, and have to fix [EXECUTOR] missing tensor 'past_key_values.0.key' for node 'graph_input_cast2'but kinda tired of the project