Skip to content

feat: KGSpectralConv — physics-constrained spectral layer for wave-type PDEs#714

Open
gpartin wants to merge 3 commits intoneuraloperator:mainfrom
gpartin:feature/kg-spectral-conv
Open

feat: KGSpectralConv — physics-constrained spectral layer for wave-type PDEs#714
gpartin wants to merge 3 commits intoneuraloperator:mainfrom
gpartin:feature/kg-spectral-conv

Conversation

@gpartin
Copy link
Copy Markdown

@gpartin gpartin commented Mar 15, 2026

Summary

Adds KGSpectralConv, a physics-constrained drop-in replacement for SpectralConv that parameterizes the spectral filter via the Klein-Gordon dispersion relation. Particularly effective for wave-type (hyperbolic) PDEs where the physics prior matches.

Motivation

The standard SpectralConv learns arbitrary complex weights for each Fourier mode — expressive but parameter-heavy. For wave-type PDEs, the exact solution operator has a known structure:

\
H(k) = exp(-i T sqrt(c^2 |k|^2 + chi^2))
\\

By encoding this structure directly, we can:

  • Reduce parameters by 2.5x while maintaining competitive accuracy
  • Embed the correct dispersion relation as an inductive bias
  • Maintain full drop-in compatibility with existing FNO models

The mathematical connection: the Green's function of the Klein-Gordon equation is the Matérn kernel family (Whittle 1954). The standard RBF kernel (and by analogy, unconstrained spectral weights) is the ν→∞ diffusion limit. Using finite ν encodes wave-like (hyperbolic) rather than diffusion-like (parabolic) behavior.

What's included

Layer: neuralop/layers/kg_spectral_conv.py

  • Complex propagator exp(-i*omega*T) with per-channel dispersion (T, c, chi)
  • Per-mode learnable complex amplitudes alpha(k) for expressiveness
  • Proper mode truncation matching SpectralConv behavior:
    • rFFT-aware n_modes adjustment (last dim: n//2+1)
    • Centered frequency windows via fftshift for 2D/3D
    • max_n_modes support for incremental training
    • resolution_scaling_factor support for super-resolution
  • Accepts all FNOBlocks kwargs for seamless integration
  • Full NumPy-style docstrings with references

Tests: neuralop/layers/tests/test_kg_spectral_conv.py

  • 17 tests covering: 1D/2D/3D shapes, gradient flow, identity at T=0, complex data, parameter efficiency, output resize, 2D mode truncation, max_n_modes, precision warnings, rFFT adjustment, repr, and more
  • All passing

Example: examples/layers/plot_kg_spectral_filter.py

  • Sphinx-compatible gallery example visualizing the KG filter
  • Compares against GCN low-pass, Gaussian diffusion, and FNO truncation filters
  • Shows parameter efficiency table and signal filtering demo

Benchmark: examples/models/benchmark_fno_vs_kg_fno.py

  • Compares standard FNO vs KG-FNO on 1D Klein-Gordon time evolution
  • Tests across 3 mass regimes (wave, moderate KG, heavy KG)

Benchmark results

Both models use n_modes=(16,) which internally retains 9 Fourier modes (rFFT: 16//2+1). Training: 80 epochs, hidden=32, 4 layers. Data: 500 train / 100 test, nx=64.

Mass regime FNO (test L2) KG-FNO (test L2) Delta Params
m=0 (wave eq) 9.26% 20.75% +124% 2.5x fewer
m=5 (moderate KG) 7.11% 9.56% +34% 2.5x fewer
m=15 (heavy KG) 8.87% 7.67% -14% 2.5x fewer

Parameters: FNO 49,953 → KG-FNO 19,873 (2.5x fewer)

Interpretation: The KG physics prior helps exactly where it should — at high mass where the dispersion relation is the dominant structure. At m=0 (pure wave equation, no mass term), the rigid KG structure is actively unhelpful. This is the expected behavior for a physics-informed layer: it outperforms where its assumptions match and underperforms where they don't.

Note on earlier results: An initial version of this layer did not apply the rFFT n//2+1 adjustment to n_modes, inadvertently retaining 16 modes vs SpectralConv's 9. Those inflated results have been corrected. The current benchmark is an apples-to-apples comparison.

Usage

\\python
from neuralop.models import FNO
from neuralop.layers.kg_spectral_conv import KGSpectralConv

Drop-in replacement — just pass conv_module

model = FNO(
n_modes=(16,),
in_channels=1,
out_channels=1,
hidden_channels=32,
n_layers=4,
conv_module=KGSpectralConv, # only change needed
)
\\

When to use KGSpectralConv

Scenario Recommendation
Hyperbolic PDEs with mass (KG, Dirac) Use KG-FNO — physics prior matches
Pure wave equation (c only, no mass) Use standard FNO
Parameter-constrained settings Use KG-FNO — 2.5x fewer params
General-purpose PDE operator learning Use standard FNO

Checklist

  • Code follows PEP8 and is black-formatted
  • NumPy-style docstrings on all public methods
  • Comprehensive unit tests (17 tests)
  • SpectralConv compatibility (fftshift, rFFT, max_n_modes, resolution_scaling_factor)
  • Sphinx gallery example with visualization
  • Training benchmark with quantitative results
  • Type hints on public API
  • References to relevant papers in docstrings

Happy to address any feedback. Thanks for building such a well-designed library — the conv_module abstraction made this a pleasure to implement!

gpartin added 2 commits March 14, 2026 18:53
A physics-constrained spectral convolution that applies the KG
dispersion filter H(k) = cos(T * sqrt(c^2 |k|^2 + chi^2)).

3 learnable parameters (T, c, chi) + channel mixing matrix,
giving 16-500x fewer parameters than standard SpectralConv for
wave-type PDEs while encoding the exact solution operator.

Mathematical connection: the KG Green's function is the Matern
kernel family (Whittle 1954); the RBF kernel is the nu->inf limit.

Includes:
- neuralop/layers/kg_spectral_conv.py: KGSpectralConv layer
- neuralop/layers/tests/test_kg_spectral_conv.py: 14 tests
- examples/layers/plot_kg_spectral_filter.py: visualization example
Rewrote KGSpectralConv with:
- Complex propagator exp(-iT*omega) instead of real cos
- Proper mode truncation (only first n_modes, zeros elsewhere)
- Per-channel dispersion (T, c, chi) + per-mode complex amplitudes
- Drop-in FNO compatibility via conv_module parameter
- 2.3x fewer parameters than standard SpectralConv

Benchmark results on 1D Klein-Gordon equation (FNO vs KG-FNO):
- m=0  (wave):  8.08% -> 6.02% (-25% error)
- m=5  (KG):    7.86% -> 6.06% (-23% error)
- m=15 (heavy): 12.09% -> 6.75% (-44% error)

Updated tests (13/13 pass), black-formatted.
Copilot AI review requested due to automatic review settings March 15, 2026 02:32
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Introduces KGSpectralConv, a new spectral convolution layer intended as a physics-constrained, parameter-efficient drop-in alternative to SpectralConv for wave-type PDEs, along with tests and example/benchmark scripts demonstrating usage and benefits.

Changes:

  • Added neuralop/layers/kg_spectral_conv.py implementing KGSpectralConv with KG-dispersion-based spectral filtering.
  • Added unit tests covering shapes, gradients, complex inputs, and basic properties.
  • Added Sphinx-gallery example and a standalone benchmark script comparing FNO vs KG-FNO on a 1D Klein–Gordon task.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.

File Description
neuralop/layers/kg_spectral_conv.py Implements the KG-constrained spectral layer, FFT pipeline, and parameterization.
neuralop/layers/tests/test_kg_spectral_conv.py Adds unit tests for the new layer across dimensions and configurations.
examples/layers/plot_kg_spectral_filter.py Visualization example for KG filter behavior and parameter efficiency.
examples/models/benchmark_fno_vs_kg_fno.py Benchmark script comparing standard FNO vs KG-FNO on synthetic KG data.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread neuralop/layers/kg_spectral_conv.py Outdated
Comment on lines +318 to +325
# Determine how many modes to keep per dimension
kept_sizes = [min(nm, fs) for nm, fs in zip(self._n_modes, fft_sizes)]
mode_slices = tuple(slice(0, k) for k in kept_sizes)
full_slices = (slice(None), slice(None)) + mode_slices

# Extract low-frequency modes from input
x_low = x_hat[full_slices] # (B, C_in, *kept_sizes)

Comment on lines +146 to +168
if isinstance(n_modes, int):
n_modes = [n_modes]
self._n_modes = list(n_modes)
self.order = len(self._n_modes)

# Per-channel KG dispersion parameters (log-space for positivity)
self.log_T = nn.Parameter(
torch.full((out_channels,), float(np.log(max(init_T, 1e-8))), device=device)
)
self.log_c = nn.Parameter(
torch.full((out_channels,), float(np.log(max(init_c, 1e-8))), device=device)
)
self.log_chi = nn.Parameter(
torch.full(
(out_channels,), float(np.log(max(init_chi, 1e-8))), device=device
)
)

# Per-mode learnable complex amplitude: (out_channels, *n_modes)
alpha_shape = (out_channels, *self._n_modes)
self.alpha_real = nn.Parameter(torch.ones(alpha_shape, device=device))
self.alpha_imag = nn.Parameter(torch.zeros(alpha_shape, device=device))

Comment on lines +204 to +256
@n_modes.setter
def n_modes(self, n_modes):
if isinstance(n_modes, int):
n_modes = [n_modes]
self._n_modes = list(n_modes)

def _compute_kg_filter(self, kept_sizes, spatial_sizes):
"""Compute the KG spectral filter on the truncated frequency grid.

Parameters
----------
kept_sizes : list of int
Number of modes kept per dimension (after truncation).
spatial_sizes : list of int
Full spatial dimensions of the input.

Returns
-------
H : torch.Tensor
Complex spectral filter of shape ``(out_channels, *kept_sizes)``.
"""
T = torch.exp(self.log_T)
c = torch.exp(self.log_c)
chi = torch.exp(self.log_chi)

# Build wavenumber grid only for the kept low-frequency modes
freq_components = []
for i, (kept, full_size) in enumerate(zip(kept_sizes, spatial_sizes)):
# Frequencies: 0, 2π/N, 2·2π/N, ..., (kept-1)·2π/N
freqs = torch.arange(kept, device=T.device, dtype=torch.float32)
freqs = freqs * (2 * np.pi / full_size)
freq_components.append(freqs)

grids = torch.meshgrid(*freq_components, indexing="ij")
k_squared = sum(g**2 for g in grids) # |k|^2

# Reshape (out_channels,) -> (out_channels, 1, 1, ...)
ndim = len(kept_sizes)
shape = (-1,) + (1,) * ndim
omega = torch.sqrt(
c.view(shape) ** 2 * k_squared.unsqueeze(0) + chi.view(shape) ** 2
)

# Complex propagator: exp(-i T omega)
phase = -T.view(shape) * omega
H = torch.complex(torch.cos(phase), torch.sin(phase))

# Modulate by per-mode learnable amplitude
alpha = torch.complex(self.alpha_real, self.alpha_imag)
# Truncate alpha to match actual kept sizes (may differ from n_modes)
slices = tuple(slice(0, k) for k in kept_sizes)
alpha_trunc = alpha[(slice(None),) + slices]
H = alpha_trunc * H
Comment on lines +183 to +199
def transform(self, x, output_shape=None):
"""Transform input for skip connections (identity or resample).

Parameters
----------
x : torch.Tensor
Input tensor.
output_shape : tuple of int or None
Target spatial shape. If None or same as input, returns identity.
"""
in_shape = list(x.shape[2:])
if output_shape is None or list(output_shape) == in_shape:
return x
from .resample import resample

return resample(x, 1.0, list(range(2, x.ndim)), output_shape=list(output_shape))

Comment thread neuralop/layers/kg_spectral_conv.py Outdated
Comment on lines +126 to +137
# Accept (and ignore) FNOBlocks kwargs for drop-in compatibility
max_n_modes=None,
rank=None,
factorization=None,
implementation=None,
separable=None,
resolution_scaling_factor=None,
fno_block_precision=None,
fixed_rank_modes=None,
decomposition_kwargs=None,
init_std=None,
**kwargs,
# High-frequency modes should be near zero
assert y_hat[:, :, 16:].abs().max() < 1e-5


…lity

- Fix mode truncation for 2D/3D: add fftshift/ifftshift with centered
  frequency windows matching SpectralConv behavior (Comment 1)
- Fix n_modes rFFT adjustment: auto-adjust last dim to n//2+1 for
  real-valued data via n_modes setter (Comment 2)
- Implement max_n_modes: allocate alpha at max_n_modes, symmetric
  cropping at runtime for incremental training (Comment 3)
- Implement resolution_scaling_factor via validate_scaling_factor
  in transform() and forward() (Comment 4)
- Warn on unsupported fno_block_precision values (Comment 5)
- Add 4 new tests: 2D mode truncation, max_n_modes, precision
  warning, rFFT adjustment across 1D/2D/3D/complex (Comment 6)
- Update 4 existing tests for rFFT-adjusted mode counts

Note: The n_modes rFFT fix changes internal mode count from 16 to 9
for n_modes=(16,), making the FNO vs KG-FNO comparison fair (both
models now use the same number of spectral modes). Benchmark results
change accordingly - KG-FNO now wins only at m=15 (-14%) where the
physics prior matches, which is the scientifically expected behavior.

Tests: 17/17 pass
@gpartin
Copy link
Copy Markdown
Author

gpartin commented Mar 15, 2026

Addressing all 6 Copilot Review Comments

Thanks for the thorough automated review! All 6 issues have been addressed in commit 39d5a01.

What was fixed

  1. Mode truncation for 2D/3D (Comment 1): Added ftshift/ifftshift with centered frequency windows, matching SpectralConv's multi-D behavior. The frequency window is symmetric around DC for all dims except the last (rFFT dim).

  2. n_modes rFFT adjustment (Comment 2): The
    _modes setter now auto-adjusts the last dimension to
    //2+1 for real-valued data, matching SpectralConv's semantics.

  3. max_n_modes support (Comment 3): Alpha parameters are now allocated at max_n_modes size, with symmetric cropping at runtime when n_modes is smaller. This enables incremental training workflows.

  4. resolution_scaling_factor (Comment 4): Now stored via �alidate_scaling_factor() (from
    euralop.utils) and applied in both ransform() and orward() for super-resolution workflows.

  5. fno_block_precision warning (Comment 5): Non-default values now emit a UserWarning with a clear message explaining that only full precision is currently supported.

  6. 2D/3D mode-truncation tests (Comment 6): Added 4 new tests:

    • \ est_kg_spectral_conv_2d_mode_truncation\ — verifies centered frequency window for 2D
    • \ est_kg_spectral_conv_max_n_modes\ — verifies max_n_modes allocation and dynamic mode changes
    • \ est_kg_spectral_conv_precision_warning\ — verifies warning for non-default precision
    • \ est_kg_spectral_conv_n_modes_rfft_adjustment\ — verifies rFFT adjustment for 1D/2D/3D/complex

Test results: 17/17 passing

Benchmark impact

The n_modes rFFT fix (Comment 2) corrects an apples-to-oranges comparison in the original benchmark. Previously KGSpectralConv retained 16 modes while SpectralConv retained 9 for the same
_modes=(16,). With both using 9 modes, KG-FNO now wins only at m=15 (heavy KG regime: -14% error with 2.5x fewer params), which is the scientifically expected behavior — the physics prior helps exactly where the physics matches. The PR description has been updated with corrected numbers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants