feat: KGSpectralConv — physics-constrained spectral layer for wave-type PDEs#714
feat: KGSpectralConv — physics-constrained spectral layer for wave-type PDEs#714gpartin wants to merge 3 commits intoneuraloperator:mainfrom
Conversation
A physics-constrained spectral convolution that applies the KG dispersion filter H(k) = cos(T * sqrt(c^2 |k|^2 + chi^2)). 3 learnable parameters (T, c, chi) + channel mixing matrix, giving 16-500x fewer parameters than standard SpectralConv for wave-type PDEs while encoding the exact solution operator. Mathematical connection: the KG Green's function is the Matern kernel family (Whittle 1954); the RBF kernel is the nu->inf limit. Includes: - neuralop/layers/kg_spectral_conv.py: KGSpectralConv layer - neuralop/layers/tests/test_kg_spectral_conv.py: 14 tests - examples/layers/plot_kg_spectral_filter.py: visualization example
Rewrote KGSpectralConv with: - Complex propagator exp(-iT*omega) instead of real cos - Proper mode truncation (only first n_modes, zeros elsewhere) - Per-channel dispersion (T, c, chi) + per-mode complex amplitudes - Drop-in FNO compatibility via conv_module parameter - 2.3x fewer parameters than standard SpectralConv Benchmark results on 1D Klein-Gordon equation (FNO vs KG-FNO): - m=0 (wave): 8.08% -> 6.02% (-25% error) - m=5 (KG): 7.86% -> 6.06% (-23% error) - m=15 (heavy): 12.09% -> 6.75% (-44% error) Updated tests (13/13 pass), black-formatted.
There was a problem hiding this comment.
Pull request overview
Introduces KGSpectralConv, a new spectral convolution layer intended as a physics-constrained, parameter-efficient drop-in alternative to SpectralConv for wave-type PDEs, along with tests and example/benchmark scripts demonstrating usage and benefits.
Changes:
- Added
neuralop/layers/kg_spectral_conv.pyimplementingKGSpectralConvwith KG-dispersion-based spectral filtering. - Added unit tests covering shapes, gradients, complex inputs, and basic properties.
- Added Sphinx-gallery example and a standalone benchmark script comparing FNO vs KG-FNO on a 1D Klein–Gordon task.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
neuralop/layers/kg_spectral_conv.py |
Implements the KG-constrained spectral layer, FFT pipeline, and parameterization. |
neuralop/layers/tests/test_kg_spectral_conv.py |
Adds unit tests for the new layer across dimensions and configurations. |
examples/layers/plot_kg_spectral_filter.py |
Visualization example for KG filter behavior and parameter efficiency. |
examples/models/benchmark_fno_vs_kg_fno.py |
Benchmark script comparing standard FNO vs KG-FNO on synthetic KG data. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| # Determine how many modes to keep per dimension | ||
| kept_sizes = [min(nm, fs) for nm, fs in zip(self._n_modes, fft_sizes)] | ||
| mode_slices = tuple(slice(0, k) for k in kept_sizes) | ||
| full_slices = (slice(None), slice(None)) + mode_slices | ||
|
|
||
| # Extract low-frequency modes from input | ||
| x_low = x_hat[full_slices] # (B, C_in, *kept_sizes) | ||
|
|
| if isinstance(n_modes, int): | ||
| n_modes = [n_modes] | ||
| self._n_modes = list(n_modes) | ||
| self.order = len(self._n_modes) | ||
|
|
||
| # Per-channel KG dispersion parameters (log-space for positivity) | ||
| self.log_T = nn.Parameter( | ||
| torch.full((out_channels,), float(np.log(max(init_T, 1e-8))), device=device) | ||
| ) | ||
| self.log_c = nn.Parameter( | ||
| torch.full((out_channels,), float(np.log(max(init_c, 1e-8))), device=device) | ||
| ) | ||
| self.log_chi = nn.Parameter( | ||
| torch.full( | ||
| (out_channels,), float(np.log(max(init_chi, 1e-8))), device=device | ||
| ) | ||
| ) | ||
|
|
||
| # Per-mode learnable complex amplitude: (out_channels, *n_modes) | ||
| alpha_shape = (out_channels, *self._n_modes) | ||
| self.alpha_real = nn.Parameter(torch.ones(alpha_shape, device=device)) | ||
| self.alpha_imag = nn.Parameter(torch.zeros(alpha_shape, device=device)) | ||
|
|
| @n_modes.setter | ||
| def n_modes(self, n_modes): | ||
| if isinstance(n_modes, int): | ||
| n_modes = [n_modes] | ||
| self._n_modes = list(n_modes) | ||
|
|
||
| def _compute_kg_filter(self, kept_sizes, spatial_sizes): | ||
| """Compute the KG spectral filter on the truncated frequency grid. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| kept_sizes : list of int | ||
| Number of modes kept per dimension (after truncation). | ||
| spatial_sizes : list of int | ||
| Full spatial dimensions of the input. | ||
|
|
||
| Returns | ||
| ------- | ||
| H : torch.Tensor | ||
| Complex spectral filter of shape ``(out_channels, *kept_sizes)``. | ||
| """ | ||
| T = torch.exp(self.log_T) | ||
| c = torch.exp(self.log_c) | ||
| chi = torch.exp(self.log_chi) | ||
|
|
||
| # Build wavenumber grid only for the kept low-frequency modes | ||
| freq_components = [] | ||
| for i, (kept, full_size) in enumerate(zip(kept_sizes, spatial_sizes)): | ||
| # Frequencies: 0, 2π/N, 2·2π/N, ..., (kept-1)·2π/N | ||
| freqs = torch.arange(kept, device=T.device, dtype=torch.float32) | ||
| freqs = freqs * (2 * np.pi / full_size) | ||
| freq_components.append(freqs) | ||
|
|
||
| grids = torch.meshgrid(*freq_components, indexing="ij") | ||
| k_squared = sum(g**2 for g in grids) # |k|^2 | ||
|
|
||
| # Reshape (out_channels,) -> (out_channels, 1, 1, ...) | ||
| ndim = len(kept_sizes) | ||
| shape = (-1,) + (1,) * ndim | ||
| omega = torch.sqrt( | ||
| c.view(shape) ** 2 * k_squared.unsqueeze(0) + chi.view(shape) ** 2 | ||
| ) | ||
|
|
||
| # Complex propagator: exp(-i T omega) | ||
| phase = -T.view(shape) * omega | ||
| H = torch.complex(torch.cos(phase), torch.sin(phase)) | ||
|
|
||
| # Modulate by per-mode learnable amplitude | ||
| alpha = torch.complex(self.alpha_real, self.alpha_imag) | ||
| # Truncate alpha to match actual kept sizes (may differ from n_modes) | ||
| slices = tuple(slice(0, k) for k in kept_sizes) | ||
| alpha_trunc = alpha[(slice(None),) + slices] | ||
| H = alpha_trunc * H |
| def transform(self, x, output_shape=None): | ||
| """Transform input for skip connections (identity or resample). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : torch.Tensor | ||
| Input tensor. | ||
| output_shape : tuple of int or None | ||
| Target spatial shape. If None or same as input, returns identity. | ||
| """ | ||
| in_shape = list(x.shape[2:]) | ||
| if output_shape is None or list(output_shape) == in_shape: | ||
| return x | ||
| from .resample import resample | ||
|
|
||
| return resample(x, 1.0, list(range(2, x.ndim)), output_shape=list(output_shape)) | ||
|
|
| # Accept (and ignore) FNOBlocks kwargs for drop-in compatibility | ||
| max_n_modes=None, | ||
| rank=None, | ||
| factorization=None, | ||
| implementation=None, | ||
| separable=None, | ||
| resolution_scaling_factor=None, | ||
| fno_block_precision=None, | ||
| fixed_rank_modes=None, | ||
| decomposition_kwargs=None, | ||
| init_std=None, | ||
| **kwargs, |
| # High-frequency modes should be near zero | ||
| assert y_hat[:, :, 16:].abs().max() < 1e-5 | ||
|
|
||
|
|
…lity - Fix mode truncation for 2D/3D: add fftshift/ifftshift with centered frequency windows matching SpectralConv behavior (Comment 1) - Fix n_modes rFFT adjustment: auto-adjust last dim to n//2+1 for real-valued data via n_modes setter (Comment 2) - Implement max_n_modes: allocate alpha at max_n_modes, symmetric cropping at runtime for incremental training (Comment 3) - Implement resolution_scaling_factor via validate_scaling_factor in transform() and forward() (Comment 4) - Warn on unsupported fno_block_precision values (Comment 5) - Add 4 new tests: 2D mode truncation, max_n_modes, precision warning, rFFT adjustment across 1D/2D/3D/complex (Comment 6) - Update 4 existing tests for rFFT-adjusted mode counts Note: The n_modes rFFT fix changes internal mode count from 16 to 9 for n_modes=(16,), making the FNO vs KG-FNO comparison fair (both models now use the same number of spectral modes). Benchmark results change accordingly - KG-FNO now wins only at m=15 (-14%) where the physics prior matches, which is the scientifically expected behavior. Tests: 17/17 pass
Addressing all 6 Copilot Review CommentsThanks for the thorough automated review! All 6 issues have been addressed in commit 39d5a01. What was fixed
Test results: 17/17 passingBenchmark impactThe n_modes rFFT fix (Comment 2) corrects an apples-to-oranges comparison in the original benchmark. Previously KGSpectralConv retained 16 modes while SpectralConv retained 9 for the same |
Summary
Adds
KGSpectralConv, a physics-constrained drop-in replacement forSpectralConvthat parameterizes the spectral filter via the Klein-Gordon dispersion relation. Particularly effective for wave-type (hyperbolic) PDEs where the physics prior matches.Motivation
The standard
SpectralConvlearns arbitrary complex weights for each Fourier mode — expressive but parameter-heavy. For wave-type PDEs, the exact solution operator has a known structure:\
H(k) = exp(-i T sqrt(c^2 |k|^2 + chi^2))
\\
By encoding this structure directly, we can:
The mathematical connection: the Green's function of the Klein-Gordon equation is the Matérn kernel family (Whittle 1954). The standard RBF kernel (and by analogy, unconstrained spectral weights) is the ν→∞ diffusion limit. Using finite ν encodes wave-like (hyperbolic) rather than diffusion-like (parabolic) behavior.
What's included
Layer:
neuralop/layers/kg_spectral_conv.pyexp(-i*omega*T)with per-channel dispersion(T, c, chi)alpha(k)for expressivenessSpectralConvbehavior:n_modesadjustment (last dim:n//2+1)fftshiftfor 2D/3Dmax_n_modessupport for incremental trainingresolution_scaling_factorsupport for super-resolutionFNOBlockskwargs for seamless integrationTests:
neuralop/layers/tests/test_kg_spectral_conv.pyExample:
examples/layers/plot_kg_spectral_filter.pyBenchmark:
examples/models/benchmark_fno_vs_kg_fno.pyBenchmark results
Both models use
n_modes=(16,)which internally retains 9 Fourier modes (rFFT:16//2+1). Training: 80 epochs, hidden=32, 4 layers. Data: 500 train / 100 test, nx=64.Parameters: FNO 49,953 → KG-FNO 19,873 (2.5x fewer)
Interpretation: The KG physics prior helps exactly where it should — at high mass where the dispersion relation is the dominant structure. At m=0 (pure wave equation, no mass term), the rigid KG structure is actively unhelpful. This is the expected behavior for a physics-informed layer: it outperforms where its assumptions match and underperforms where they don't.
Note on earlier results: An initial version of this layer did not apply the rFFT
n//2+1adjustment ton_modes, inadvertently retaining 16 modes vs SpectralConv's 9. Those inflated results have been corrected. The current benchmark is an apples-to-apples comparison.Usage
\\python
from neuralop.models import FNO
from neuralop.layers.kg_spectral_conv import KGSpectralConv
Drop-in replacement — just pass conv_module
model = FNO(
n_modes=(16,),
in_channels=1,
out_channels=1,
hidden_channels=32,
n_layers=4,
conv_module=KGSpectralConv, # only change needed
)
\\
When to use KGSpectralConv
Checklist
black-formattedHappy to address any feedback. Thanks for building such a well-designed library — the
conv_moduleabstraction made this a pleasure to implement!