API Reference#

spd_learn:

SPD Learn: Deep Learning on Riemannian Manifolds.

A pure PyTorch library for Symmetric Positive Definite (SPD) matrix learning.

SPD Learn follows a functional-first design philosophy: low-level operations are implemented as pure functions in spd_learn.functional, which are then wrapped into stateful layers in spd_learn.modules, and finally composed into complete architectures in spd_learn.models.

Functional#

spd_learn.functional:

Functional API for SPD matrix operations.

This module provides differentiable operations for Symmetric Positive Definite (SPD) matrices, organized into:

  • Core operations: Matrix logarithm, exponential, power, square root

  • Metrics: Riemannian metrics (AIRM, Log-Euclidean, Bures-Wasserstein, Log-Cholesky)

  • Transport: Parallel transport operations

  • Covariance: Covariance estimation functions

  • Regularization: Covariance regularization utilities

  • Vectorization: Batch vectorization and (un)vectorization helpers

  • Numerical: Numerical stability configuration

This module provides low-level functions for operations on tensors, particularly those representing SPD matrices or elements in related manifolds. These functions form the core computational backend for the layers in spd_learn.modules and models in spd_learn.models.

Core Matrix Operations#

Basic matrix operations commonly used in Riemannian geometry, such as matrix logarithm, exponential, power functions, and utilities for ensuring symmetry or clamping eigenvalues.

matrix_log(*args, **kwargs)

Matrix logarithm of a symmetric matrix.

matrix_exp(*args, **kwargs)

Matrix exponential of a symmetric matrix.

matrix_softplus(*args, **kwargs)

Matrix (scaled) SoftPlus of a symmetric matrix.

matrix_inv_softplus(*args, **kwargs)

Matrix inverse (scaled) SoftPlus of a symmetric matrix.

softplus(s)

Scaled SoftPlus function.

inv_softplus(s)

Inverse of the scaled SoftPlus function

matrix_power(*args, **kwargs)

Computes the matrix power.

matrix_sqrt(*args, **kwargs)

Matrix square root.

matrix_inv_sqrt(*args, **kwargs)

Inverse matrix square root.

matrix_sqrt_inv(*args, **kwargs)

Matrix square root and inverse matrix square root.

clamp_eigvals(*args, **kwargs)

Rectification of the eigenvalues of a symmetric matrix.

abs_eigvals(*args, **kwargs)

Absolute value of the eigenvalues of a symmetric matrix.

ensure_sym(matrix)

Ensures that a matrix is symmetric.

orthogonal_polar_factor(W)

Compute the orthogonal polar factor of a matrix.

Covariance Estimation#

Functions for computing various types of covariance matrices from input data.

covariance(input)

Computes the covariance matrix of multivariate data.

sample_covariance(input)

Computes the sample covariance matrix of multivariate data.

real_covariance(X)

Computes the real-valued covariance matrix of time series data.

cross_covariance(X)

Computes the real-valued cross-frequency covariance matrix.

Regularization#

Functional regularization utilities for covariance matrices.

trace_normalization(covariances[, epsilon])

Performs trace normalization on a batch of covariance matrices.

ledoit_wolf(covariances, shrinkage, ...)

Applies Ledoit-Wolf shrinkage to a batch of covariance matrices.

shrinkage_covariance(X, alpha, n_chans[, ...])

Apply shrinkage regularization to covariance matrices.

Riemannian Metrics#

SPD Learn implements four Riemannian metrics for SPD manifolds. Each metric provides distance computation, geodesic interpolation, and exponential/logarithmic maps. Choose based on your application’s needs:

Metric

Properties

Best For

AIRM

Affine-invariant, curvature-aware

Theoretical correctness, domain adaptation

Log-Euclidean

Bi-invariant, closed-form mean

Fast computation, deep learning

Log-Cholesky

Avoids eigendecomposition

Numerical stability, large matrices

Bures-Wasserstein

Optimal transport connection

Covariance interpolation, statistics

AIRM (Affine-Invariant Riemannian Metric)#

The AIRM is the canonical Riemannian metric on SPD manifolds with affine invariance properties. It provides geodesic distances and interpolations that are invariant under congruence transformations.

airm_distance(A, B)

Compute the geodesic distance under the Affine-Invariant Riemannian Metric (AIRM).

airm_geodesic(A, B, t)

Geodesic interpolation on the SPD manifold under the Affine-Invariant Riemannian Metric (AIRM).

exp_map_airm(P, V[, t])

Riemannian exponential map under the Affine-Invariant metric.

log_map_airm(P, Q)

Riemannian logarithmic map under the Affine-Invariant metric.

Log-Euclidean Metric#

The Log-Euclidean Metric (LEM) maps SPD matrices to a flat (Euclidean) space via matrix logarithm, enabling efficient closed-form computations.

log_euclidean_distance(A, B)

Computes the Log-Euclidean distance between SPD matrices.

log_euclidean_geodesic(A, B, t)

Geodesic interpolation under the Log-Euclidean metric.

log_euclidean_mean(weights, V)

Computes the weighted Log-Euclidean mean of a batch of SPD matrices.

log_euclidean_multiply(x, y)

Logarithmic multiplication of SPD matrices under the Log-Euclidean metric.

log_euclidean_scalar_multiply(alpha, x)

Logarithmic scalar multiplication of an SPD matrix.

exp_map_lem(P, V)

Riemannian exponential map under the Log-Euclidean metric.

log_map_lem(P, Q)

Riemannian logarithmic map under the Log-Euclidean metric.

Log-Cholesky Metric#

The Log-Cholesky metric uses the Cholesky decomposition to parameterize SPD matrices, avoiding expensive eigendecompositions while maintaining numerical stability.

cholesky_log(*args, **kwargs)

Matrix logarithm via Cholesky decomposition (Log-Cholesky map).

cholesky_exp(*args, **kwargs)

Inverse of the Log-Cholesky map (Cholesky exponential).

log_cholesky_distance(A, B)

Compute the distance in the Log-Cholesky metric.

log_cholesky_mean(matrices[, weights])

Compute the weighted mean in the Log-Cholesky space.

log_cholesky_geodesic(A, B, t)

Geodesic interpolation in the Log-Cholesky metric.

Bures-Wasserstein Metric#

The Bures-Wasserstein (or Procrustes) metric has connections to optimal transport theory and is particularly useful for covariance interpolation.

bures_wasserstein_distance(A, B)

Compute the Bures-Wasserstein distance between SPD matrices.

bures_wasserstein_geodesic(A, B, t)

Compute the geodesic interpolation under the Bures-Wasserstein metric.

bures_wasserstein_mean(matrices[, weights, ...])

Compute the Bures-Wasserstein barycenter of SPD matrices.

bures_wasserstein_transport(A, B, X)

Compute the optimal transport map from A to B applied to X.

Parallel Transport#

Functions for parallel transport of tangent vectors along geodesics on the SPD manifold, essential for operations like domain adaptation.

parallel_transport_airm(v, p, q)

Parallel transport of tangent vector under the Affine-Invariant metric.

parallel_transport_lem(v, p, q)

Parallel transport of tangent vector under the Log-Euclidean metric.

parallel_transport_log_cholesky(v, p, q)

Parallel transport of tangent vector under the Log-Cholesky metric.

schild_ladder(v, p, q[, n_steps])

Parallel transport via Schild's ladder approximation.

pole_ladder(v, p, q)

Parallel transport via pole ladder approximation.

transport_tangent_vector(v, p, q[, metric])

Parallel transport of tangent vector with metric selection.

Vectorization Utilities#

Vectorization helpers for batching, (un)vectorizing matrices, and symmetric matrix encodings.

vec_batch(X)

Vectorizes a batch of tensors along the last two dimensions.

unvec_batch(X_vec, n)

Unvectorizes a batch of tensors along the last dimension.

sym_to_upper(X[, preserve_norm, upper])

Vectorizes symmetric matrices by extracting triangular elements.

vec_to_sym(x_vec[, preserve_norm, upper])

Reconstructs symmetric matrices from vectorization.

Dropout#

Functional implementation of dropout specifically designed for SPD tensors.

dropout_spd(input_mat[, p, use_scaling, ...])

Applies dropout to a batch of SPD matrices.

Autograd Helpers#

Custom forward and backward functions for operations like matrix eigen-decomposition, enabling gradient computation through these potentially complex steps.

modeig_backward(grad_output, s, U, ...)

Backward pass for the modified eigenvalue of a symmetric matrix.

modeig_forward(X, applied_fct, *args)

Forward pass for the modified eigenvalue of a symmetric matrix.

Batch Normalization Operations#

Functions for Riemannian batch normalization computations on SPD manifolds.

karcher_mean_iteration(X, current_mean[, detach])

Perform one iteration of the Karcher mean algorithm.

spd_centering(X, mean_invsqrt)

Center SPD matrices around a mean via congruence transformation.

spd_rebiasing(X, bias_sqrt)

Apply learnable rebiasing to centered SPD matrices.

tangent_space_variance(X_tangent, mean_tangent)

Compute scalar dispersion in the tangent space.

Bilinear Operations#

Bilinear transformations that preserve SPD properties.

bimap_transform(X, W)

Apply bilinear transformation to SPD matrices.

bimap_increase_dim(X, projection_matrix, ...)

Increase the dimension of SPD matrices via embedding.

Wavelet Operations#

Time-frequency analysis using Gabor wavelets.

compute_gabor_wavelet(tt, foi, fwhm[, ...])

Compute a complex Gabor (Morlet) wavelet filterbank.

Numerical Stability#

Utilities for ensuring numerical stability when working with SPD matrices, including epsilon handling and eigenvalue clamping. See Numerical Stability for detailed guidance.

numerical_config

Global configuration for numerical stability thresholds.

NumericalConfig(eigval_clamp_scale, ...)

Global configuration for numerical stability thresholds.

NumericalContext(**kwargs)

Context manager for temporarily modifying numerical configuration.

get_epsilon(dtype[, name, config])

Get a dtype-aware epsilon value for numerical stability.

get_epsilon_tensor(dtype[, name, device, config])

Get a dtype-aware epsilon value as a tensor.

get_loewner_threshold(eigenvalues, *[, config])

Get threshold for detecting equal eigenvalues in Loewner matrix.

safe_clamp_eigenvalues(eigenvalues[, name, ...])

Safely clamp eigenvalues with dtype-aware threshold.

check_spd_eigenvalues(eigenvalues[, name, ...])

Check if eigenvalues satisfy SPD requirements.

is_half_precision(dtype)

Check if dtype is half precision (float16 or bfloat16).

recommend_dtype_for_spd(condition_number, *)

Recommend a dtype based on expected matrix condition number.

Modules#

spd_learn.modules:

This module provides neural network layers specifically designed for deep learning on Riemannian manifolds, particularly SPD matrices. These layers wrap the functional operations from spd_learn.functional into stateful torch.nn.Module components.

Covariance Layer#

Modules for computing covariance matrices, often used as the first step in SPD-based pipelines.

CovLayer(method[, device, dtype])

Covariance Estimation Layer for Neuroimaging Data.

Matrix Eigen-Operations#

Modules performing operations based on matrix eigenvalue decomposition (LogEig, ReEig, ExpEig), essential for mapping between the SPD manifold and Euclidean/tangent spaces or applying non-linearities.

LogEig([upper, flatten, autograd, device, dtype])

Logarithmic Eigenvalue Layer (LogEig).

ReEig([threshold, autograd, device, dtype])

Rectified Eigenvalue Layer (ReEig).

ExpEig([upper, flatten, autograd, device, dtype])

Exponential Eigenvalue Layer (ExpEig).

Manifold Parametrization#

Modules for parametrizing learnable SPD matrices, ensuring parameters remain on the SPD manifold during optimization. Supports both matrix exponential and softplus mappings from symmetric matrices to SPD matrices.

SymmetricPositiveDefinite([mapping, device, ...])

Symmetric Positive Definite Manifold parametrization.

PositiveDefiniteScalar([mapping, device, dtype])

Positive definite scalars parametrization.

Bilinear Mappings#

Layers implementing learnable bilinear transformations suitable for SPD matrices, acting as analogous operations to linear layers in Euclidean space.

BiMap(in_features, out_features[, ...])

Bilinear Mapping Layer for SPD Matrices.

BiMapIncreaseDim(in_features, out_features)

Bilinear Mapping Layer for SPD Matrix Dimensionality Expansion.

Batch Normalization#

Batch normalization layers specifically adapted for data on the SPD manifold or related representations.

SPDBatchNormMean(num_features[, momentum, ...])

Riemannian Batch Normalization for SPD Matrices (Mean-only).

SPDBatchNormMeanVar(num_features[, ...])

SPD Batch Normalization (Mean and Variance).

BatchReNorm(num_features[, momentum, ...])

Batch Re-Normalization.

Regularization#

Modules implementing regularization covariance methods, such as Ledoit-Wolf shrinkage for covariance estimation or trace normalization.

TraceNorm([epsilon, device, dtype])

Trace Normalization Layer for Scale-Invariant Covariance.

Shrinkage(n_chans[, init_shrinkage, ...])

Learnable Shrinkage Regularization for Covariance Matrices.

Dropout#

Dropout mechanisms designed for SPD matrix inputs or features derived from them.

SPDDropout([p, use_scaling, epsilon, ...])

Structured Dropout for SPD Matrices.

Residual Connections#

Modules for residual/skip connections on Riemannian manifolds, enabling deeper SPD networks with improved gradient flow.

LogEuclideanResidual([device, dtype])

Residual/skip connection for SPD networks using the Log-Euclidean metric.

Signal Processing#

Layers for signal processing operations, including learnable wavelet convolutions for time-frequency feature extraction.

WaveletConv(kernel_width_s, foi_init[, ...])

Parametrized Complex Gabor Wavelet Convolution Layer.

Utilities#

Utility layers for preprocessing, feature extraction, or other auxiliary tasks within Riemannian deep learning models.

PatchEmbeddingLayer(n_chans, n_patches[, ...])

Patch Embedding Layer.

Vec([device, dtype])

Vectorization Layer.

Vech([preserve_norm, upper, device, dtype])

Vectorize Triangular Part Layer.

Models#

spd_learn.models:

This module offers pre-built models for working with SPD matrices, using the building blocks from spd_learn.modules.

EEGSPDNet(n_chans, n_outputs[, n_filters, ...])

EE(G) SPDNet.

Green(n_outputs, n_chans, sfreq, ...)

Gabor Riemann EEGNet.

MAtt(n_patches, n_chans, n_outputs, ...[, ...])

Manifold Attention Network for EEG Decoding (MAtt).

PhaseSPDNet([subspacedim, input_type, ...])

Phase SPDNet.

SPDNet([input_type, cov_method, ...])

Symmetric Positive Definite Neural Network (SPDNet).

TensorCSPNet([n_chans, n_outputs, ...])

Tensor-CSPNet.

TSMNet([n_chans, n_temp_filters, ...])

Tangent Space Mapping Network (TSMNet).

Model Selection Guide#

Use this table to choose the right model for your application:

Model

Best For

Input Type

Key Feature

Complexity

SPDNet

General SPD learning

Covariance matrices

Foundational architecture

Low

TensorCSPNet

Multi-band EEG

Filter bank data

Temporal-spectral-spatial

Medium

TSMNet

Domain adaptation, Interpretation

Raw EEG

SPDBatchNorm

Medium

EEGSPDNet

Channel-specific EEG

Raw EEG

Grouped convolution

Medium

MAtt

Attention-based

Raw EEG

Manifold attention

High

Green

Interpretable features

Raw EEG

Learnable wavelets

Medium

PhaseSPDNet

Nonlinear dynamics

Raw EEG

Phase-space embedding

Low

Decision Flowchart#

START
  │
  ▼
┌─────────────────────────────────┐
│ What is your input data type?  │
└─────────────────────────────────┘
  │
  ├─── Pre-computed covariance matrices ──► SPDNet
  │
  ├─── Filter bank EEG (multiple bands) ──► TensorCSPNet
  │
  └─── Raw time series
        │
        ▼
      ┌─────────────────────────────────┐
      │ Do you need domain adaptation? │
      └─────────────────────────────────┘
        │
        ├─── Yes ──► TSMNet (with domain adaptation)
        │
        └─── No
              │
              ▼
            ┌─────────────────────────────────┐
            │ What is your priority?         │
            └─────────────────────────────────┘
              │
              ├─── Interpretability ──► TSMNet, Green
              │
              ├─── Attention mechanism ──► MAtt
              │
              ├─── Conv feature extraction ──► TSMNet, EEGSPDNet
              │
              └─── Nonlinear dynamics ──► PhaseSPDNet

Model Architectures#

SPDNet - Foundational architecture for SPD learning:

[CovLayer] → BiMap → ReEig → LogEig → Linear

TensorCSPNet - Multi-frequency EEG with temporal-spectral-spatial features:

Tensor Stacking → BiMap blocks → SPDBatchNormMean → LogEig → TCN → Linear

TSMNet - Domain adaptation with SPD batch normalization:

Conv2d → CovLayer → BiMap → ReEig → SPDBatchNormMeanVar → LogEig → Linear

EEGSPDNet - Channel-specific temporal filtering:

GroupedConv1d → CovLayer → BiMap/SPDDropout/ReEig blocks → LogEig → Linear

MAtt - Attention-based temporal weighting:

Conv2d → PatchCov → AttentionManifold → ReEig → LogEig → Linear

Green - Interpretable wavelet features:

WaveletConv → CovLayer → Shrinkage → BiMap → LogEig → BatchReNorm → MLP

PhaseSPDNet - Phase-space embedding for nonlinear dynamics:

PhaseDelay → SPDNet

Performance Comparison#

Based on motor imagery classification benchmarks (BNCI2014-001):

Model

Accuracy (%)

Parameters

Training Time

GPU Memory

SPDNet

70-75

~10K

Fast

Low

TensorCSPNet

75-82

~50K

Medium

Medium

TSMNet

72-78

~30K

Medium

Medium

EEGSPDNet

73-79

~40K

Medium

Medium

MAtt

74-80

~60K

Slow

High

Green

72-78

~20K

Medium

Low

Note: Performance varies significantly across subjects and datasets.

To reproduce these results and run your own benchmarks, please refer to the Benchmarking SPD Learn Models with MOABB and Hydra example.

Initialization#

spd_learn.init:

Initialization functions for SPD-aware neural networks.

This module provides functions to initialize tensors with SPD-specific methods, following PyTorch’s torch.nn.init pattern.

All functions operate in-place and return the modified tensor for convenience.

Functions#

stiefel_

Initialize tensor on the Stiefel manifold (orthonormal columns).

spd_identity_

Initialize tensor as identity matrix.

Examples

>>> import torch
>>> from spd_learn import init as spd_init
>>> W = torch.empty(8, 4)
>>> spd_init.stiefel_(W, seed=42)
>>> # W is now on the Stiefel manifold: W^T @ W = I
>>> torch.allclose(W.T @ W, torch.eye(4), atol=1e-5)
True

This module provides functions to initialize tensors with SPD-specific methods, following PyTorch’s torch.nn.init pattern. All functions operate in-place and return the modified tensor for convenience.

stiefel_(tensor[, seed])

Initialize tensor on the Stiefel manifold (in-place).

spd_identity_(tensor)

Initialize tensor as identity matrix (in-place).