This guide covers how to extend numr with custom operations and kernels.
numr separates operation traits (the interface) from kernel implementations (the computation):
┌─────────────────────────────────────────────────────────────────┐
│ Operation Trait │
│ (defines interface, validates inputs) │
└──────────────────────────┬──────────────────────────────────────┘
│
┌────────────────┼────────────────┐
│ │ │
┌────▼────┐ ┌────▼────┐ ┌────▼────┐
│ CPU │ │ CUDA │ │ WebGPU │
│ Impl │ │ Impl │ │ Impl │
└────┬────┘ └────┬────┘ └────┬────┘
│ │ │
┌────▼────┐ ┌────▼────┐ ┌────▼────┐
│ SIMD │ │ PTX │ │ WGSL │
│ Kernel │ │ Kernel │ │ Shader │
└─────────┘ └─────────┘ └─────────┘
numr provides default kernels for all operations. You can:
- Use default kernels - Works out of the box for all supported operations
- Replace specific kernels - Swap in optimized kernels for hot paths
- Add new operations - Define new traits and implement kernels for all backends
All operations work immediately with numr's built-in kernels:
use numr::prelude::*;
let a = Tensor::<CpuRuntime>::randn(&[1024, 1024], &device)?;
let b = Tensor::<CpuRuntime>::randn(&[1024, 1024], &device)?;
// Uses numr's tiled GEMM kernel
let c = client.matmul(&a, &b)?;
// Uses numr's SIMD reduction kernel
let sum = client.sum(&c)?;To replace numr's default kernel with your own optimized version:
use numr::ops::KernelDispatch;
use numr::runtime::cpu::CpuClient;
// Define your custom kernel
fn my_optimized_add_kernel(a: &[f32], b: &[f32], out: &mut [f32]) {
// Your optimized implementation
// e.g., using specific SIMD intrinsics, cache-aware tiling, etc.
for i in 0..a.len() {
out[i] = a[i] + b[i];
}
}
// Register the custom kernel
impl KernelDispatch for MyCustomOps {
fn dispatch_add_f32(&self, a: &[f32], b: &[f32], out: &mut [f32]) {
my_optimized_add_kernel(a, b, out);
}
}numr's CPU kernels use portable SIMD. Here's the pattern:
use std::simd::{f32x8, SimdFloat};
pub fn add_simd_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
let n = a.len();
let simd_width = 8;
let simd_end = n - (n % simd_width);
// SIMD loop
for i in (0..simd_end).step_by(simd_width) {
let va = f32x8::from_slice(&a[i..]);
let vb = f32x8::from_slice(&b[i..]);
let vout = va + vb;
vout.copy_to_slice(&mut out[i..]);
}
// Scalar remainder
for i in simd_end..n {
out[i] = a[i] + b[i];
}
}For architecture-specific intrinsics:
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
pub unsafe fn add_avx512_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
// AVX-512 implementation
let n = a.len();
let mut i = 0;
while i + 16 <= n {
let va = _mm512_loadu_ps(a.as_ptr().add(i));
let vb = _mm512_loadu_ps(b.as_ptr().add(i));
let vout = _mm512_add_ps(va, vb);
_mm512_storeu_ps(out.as_mut_ptr().add(i), vout);
i += 16;
}
// Handle remainder with scalar or smaller SIMD
while i < n {
out[i] = a[i] + b[i];
i += 1;
}
}numr compiles CUDA kernels to PTX at build time. Here's the pattern:
// src/runtime/cuda/kernels/my_op.cu
template<typename T>
__global__ void my_op_kernel(
const T* __restrict__ a,
const T* __restrict__ b,
T* __restrict__ out,
unsigned int n
) {
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
out[idx] = /* your operation */;
}
}
// Explicit instantiations for each dtype
extern "C" __global__ void my_op_f32(const float* a, const float* b, float* out, unsigned int n) {
my_op_kernel<float>(a, b, out, n);
}
extern "C" __global__ void my_op_f64(const double* a, const double* b, double* out, unsigned int n) {
my_op_kernel<double>(a, b, out, n);
}
extern "C" __global__ void my_op_i32(const int* a, const int* b, int* out, unsigned int n) {
my_op_kernel<int>(a, b, out, n);
}use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;
impl MyOps<CudaRuntime> for CudaClient {
fn my_op(
&self,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
let n = a.numel();
let output = self.allocate_like(a)?;
// Select kernel based on dtype
let kernel_name = match a.dtype() {
DType::F32 => "my_op_f32",
DType::F64 => "my_op_f64",
DType::I32 => "my_op_i32",
_ => return Err(Error::UnsupportedDType { dtype: a.dtype(), op: "my_op" }),
};
// Launch kernel
let block_size = 256;
let grid_size = (n + block_size - 1) / block_size;
self.launch_kernel(
kernel_name,
grid_size as u32,
block_size as u32,
&[
a.as_kernel_arg(),
b.as_kernel_arg(),
output.as_kernel_arg(),
&(n as u32),
],
)?;
Ok(output)
}
}For operations that benefit from shared memory:
template<typename T, int TILE_SIZE>
__global__ void tiled_matmul_kernel(
const T* __restrict__ A,
const T* __restrict__ B,
T* __restrict__ C,
int M, int N, int K
) {
__shared__ T As[TILE_SIZE][TILE_SIZE];
__shared__ T Bs[TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
T sum = 0;
for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; tile++) {
// Load tile into shared memory
if (row < M && tile * TILE_SIZE + threadIdx.x < K)
As[threadIdx.y][threadIdx.x] = A[row * K + tile * TILE_SIZE + threadIdx.x];
else
As[threadIdx.y][threadIdx.x] = 0;
if (col < N && tile * TILE_SIZE + threadIdx.y < K)
Bs[threadIdx.y][threadIdx.x] = B[(tile * TILE_SIZE + threadIdx.y) * N + col];
else
Bs[threadIdx.y][threadIdx.x] = 0;
__syncthreads();
// Compute partial dot product
for (int k = 0; k < TILE_SIZE; k++) {
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = sum;
}
}// src/runtime/wgpu/shaders/my_op.wgsl
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> out: array<f32>;
struct Params {
size: u32,
}
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn my_op(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.size) {
return;
}
out[idx] = /* your operation */;
}WebGPU requires separate shaders per dtype. Use shader generation:
fn generate_my_op_shader(dtype: DType) -> String {
let wgsl_type = match dtype {
DType::F32 => "f32",
DType::F16 => "f16",
DType::I32 => "i32",
DType::U32 => "u32",
_ => panic!("Unsupported dtype for WebGPU: {:?}", dtype),
};
format!(r#"
@group(0) @binding(0) var<storage, read> a: array<{wgsl_type}>;
@group(0) @binding(1) var<storage, read> b: array<{wgsl_type}>;
@group(0) @binding(2) var<storage, read_write> out: array<{wgsl_type}>;
@compute @workgroup_size(256)
fn my_op(@builtin(global_invocation_id) gid: vec3<u32>) {{
let idx = gid.x;
if (idx >= arrayLength(&out)) {{
return;
}}
out[idx] = a[idx] + b[idx];
}}
"#)
}use numr::runtime::wgpu::{WgpuClient, WgpuRuntime};
impl MyOps<WgpuRuntime> for WgpuClient {
fn my_op(
&self,
a: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let output = self.allocate_like(a)?;
// Get or compile shader for this dtype
let shader = self.get_or_compile_shader("my_op", a.dtype())?;
// Calculate workgroup dispatch
let workgroup_size = 256;
let workgroup_count = (a.numel() + workgroup_size - 1) / workgroup_size;
// Dispatch compute shader
self.dispatch_compute(
&shader,
&[a.buffer(), b.buffer(), output.buffer()],
workgroup_count as u32,
)?;
Ok(output)
}
}To add a completely new operation to numr:
Create src/ops/traits/my_category.rs:
use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::Tensor;
/// Operations for my category
pub trait MyOps<R: Runtime> {
/// Computes my_op(a, b) element-wise
fn my_op(&self, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>;
/// Computes my_other_op with a scalar parameter
fn my_other_op(&self, a: &Tensor<R>, scalar: f64) -> Result<Tensor<R>>;
}Add to src/ops/traits/mod.rs:
pub mod my_category;
pub use my_category::MyOps;Create src/ops/cpu/my_category.rs:
use crate::ops::traits::MyOps;
use crate::runtime::cpu::{CpuClient, CpuRuntime};
use crate::tensor::Tensor;
use crate::error::Result;
use crate::dtype::dispatch_dtype;
impl MyOps<CpuRuntime> for CpuClient {
fn my_op(&self, a: &Tensor<CpuRuntime>, b: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
// Validate shapes
a.shape().broadcast_with(b.shape())?;
let output = self.allocate_output_shape(&broadcast_shape)?;
// Dispatch based on dtype
dispatch_dtype!(a.dtype(), T => {
let a_data = a.as_slice::<T>()?;
let b_data = b.as_slice::<T>()?;
let out_data = output.as_mut_slice::<T>()?;
// Call kernel
kernels::my_op_kernel::<T>(a_data, b_data, out_data);
});
Ok(output)
}
fn my_other_op(&self, a: &Tensor<CpuRuntime>, scalar: f64) -> Result<Tensor<CpuRuntime>> {
// Implementation
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_my_op() {
// Test implementation
}
}Create src/ops/cuda/my_category.rs with similar structure, launching PTX kernels.
Create src/ops/wgpu/my_category.rs with similar structure, dispatching WGSL shaders.
Add to each backend's mod.rs:
pub mod my_category;numr uses the dispatch_dtype! macro for runtime dtype dispatch:
use crate::dtype::dispatch_dtype;
fn process_tensor<R: Runtime>(tensor: &Tensor<R>) -> Result<()> {
dispatch_dtype!(tensor.dtype(), T => {
// T is now the concrete type (f32, f64, i32, etc.)
let data = tensor.as_slice::<T>()?;
process_data::<T>(data);
});
Ok(())
}For operations that only support specific dtypes:
dispatch_dtype!(tensor.dtype(), T => {
match tensor.dtype() {
DType::F32 | DType::F64 => {
// Float-only operation
}
_ => return Err(Error::UnsupportedDType {
dtype: tensor.dtype(),
op: "my_float_op",
}),
}
});#[test]
fn test_kernel_accuracy() {
let a = Tensor::<CpuRuntime>::randn(&[1024], &device)?;
let b = Tensor::<CpuRuntime>::randn(&[1024], &device)?;
let result = client.my_op(&a, &b)?;
let expected = compute_expected(&a, &b);
assert_allclose(&result, &expected, 1e-6, 1e-7);
}#[test]
fn test_backend_parity() {
let a_cpu = Tensor::<CpuRuntime>::randn(&[1024], &cpu_device)?;
let b_cpu = Tensor::<CpuRuntime>::randn(&[1024], &cpu_device)?;
let result_cpu = cpu_client.my_op(&a_cpu, &b_cpu)?;
#[cfg(feature = "cuda")]
{
let a_cuda = a_cpu.to_device(&cuda_device)?;
let b_cuda = b_cpu.to_device(&cuda_device)?;
let result_cuda = cuda_client.my_op(&a_cuda, &b_cuda)?;
assert_allclose(
&result_cpu.to_vec::<f32>()?,
&result_cuda.to_cpu()?.to_vec::<f32>()?,
1e-5,
1e-6,
);
}
}Always test:
- Empty tensors (
[]) - Scalar tensors (
[1]) - Non-contiguous tensors (from transpose/slice)
- Broadcasting edge cases
- Special values (NaN, Inf, subnormals)
- All supported dtypes
- Coalesced access: Adjacent threads access adjacent memory
- Avoid strided access: Poor cache utilization
- Contiguous tensors: Call
.contiguous()before kernels if needed
- Occupancy: Balance registers vs threads per block
- Shared memory: Use for tile-based algorithms
- Warp divergence: Avoid conditionals within warps
- Workgroup size: 256 is a safe default, tune per GPU
- Buffer alignment: 256-byte alignment for storage buffers
- Uniform buffers: Use for small, read-only parameters
// Print intermediate values
eprintln!("a[0..4] = {:?}", &a.as_slice::<f32>()?[0..4]);Use cuda-gdb or print from device:
if (idx == 0) {
printf("First element: %f\n", a[0]);
}Use browser WebGPU debugging tools or dump buffers:
let data = client.read_buffer(&output)?;
eprintln!("Output: {:?}", &data[0..10]);