Add Multiwavelet Neural Operator#636
Add Multiwavelet Neural Operator#636shixuanli01 wants to merge 17 commits intoneuraloperator:mainfrom
Conversation
dhpitt
left a comment
There was a problem hiding this comment.
Thank you so much for your improved PR @Shixuan01 . This is great. I have a few stylistic notes to make this library-ready:
- Wherever possible, please add unit tests (e.g. for models, in
neuralop/models/tests/test_xx.pyto ensure that data passes through models in the way that we expect. Where possible, if you can test the correctness of sub-operations like the transform or the basis, these will make great tests to prevent regressions when we change the code. Checkneuralop/layers/testsfor examples. - All functions and classes should have Numpy-formatted docstrings, detailed here in our dev guide. This will make it easier for users to understand and for future maintainers to learn and maintain.
- The individual layers and utils belong in
neuralop/layersto indicate that they are submodules that can be composed into an MWT model. - Let's think of a slightly more descriptive name for the model that indicates that it's a full operator, perhaps MWNO?
Overall this is very solid and will make a great contribution to the library. Thank you!
| from sympy import Poly, legendre, Symbol, chebyshevt | ||
|
|
||
| class WaveletUtils: | ||
| """Unified wavelet utility class containing wavelet transform functions for all dimensions""" |
There was a problem hiding this comment.
Some docs here would be helpful. Please refer to the contributors guide for guidance on how to format in-line docstrings
| return x | ||
|
|
||
|
|
||
| class SparseKernelFT(nn.Module): |
There was a problem hiding this comment.
These layers belong in neuralop/layers
| return x | ||
|
|
||
|
|
||
| class MWT_CZ(nn.Module): |
There was a problem hiding this comment.
same comments re: docstring, and moving layers to the layers module
| return x | ||
|
|
||
|
|
||
| class MWT(nn.Module): |
There was a problem hiding this comment.
This is great. Let's think of a more descriptive name than MWT, perhaps MultiWaveletNO or MWNO?
| return x.squeeze(-1) | ||
| return x | ||
|
|
||
| """Compatible MWT Model""" |
There was a problem hiding this comment.
We don't necessarily need these classes
|
We're glad you gave us such insightful advice. We've modified and supplemented the code according to your suggestions and added the test_mwno.py file. During testing, we discovered some issues with the 3D MWNO implementation, so we've temporarily removed this module and retained the normal 1D and 2D modules. |
|
We have debugged the 3D operator and added it to the latest code base. |
dhpitt
left a comment
There was a problem hiding this comment.
Looks really good, thank you so much for the improvements. Could you please:
- respond to the small nits re: docstrings
- refactor
neuralop.models.mwt-->neuralop.layers.mwt - add corresponding unit tests for the MWT and mwno block
Once these changes are in this will be ready to go.
|
Thanks for your suggestion, we have added more descriptive text as well as a reference. |
dhpitt
left a comment
There was a problem hiding this comment.
These changes look good to me, thank you @Shixuan01 for responding to the feedback!
There was a problem hiding this comment.
This is a great addition to the library @Shixuan01, thank you for the PR!
My main comments are on clarify and readability of the code: right now it is hard for a maintainer/user to directly use the code and/or modify it. It would be great to clarify it by using descriptive names and adding more docstrings and comments!
| - Scaling functions are normalized Chebyshev polynomials | ||
| - Wavelets have compact support on [0,0.5] and [0.5,1] respectively | ||
| """ | ||
| x = Symbol('x') |
There was a problem hiding this comment.
Do we need to go through Sympy for this? Couldn't we directly return the discretized evaluations, instead of generating a symbolic version and then evaluating?
There was a problem hiding this comment.
Yes, we use numpy instead of Sympy to make the calculation faster.
| 2. Applies learnable transformations at each scale | ||
| 3. Reconstructs output from all scales | ||
|
|
||
| Processing flow: |
There was a problem hiding this comment.
These are extremely helpful docstrings. I would suggest putting them down beneath the params heading so that the parameters section renders first in IDEs
There was a problem hiding this comment.
Thank you! I have moved the place for those senteces.
There was a problem hiding this comment.
Let's move the description back to the top of the docstring, I think it's what we have been doing more consistently across the library. Sorry for the back and forth
|
Thank you @Shixuan01 for your hard work on this PR! I'm starting my review now |
| n_dim : int, optional | ||
| Spatial dimensionality (1, 2, or 3). | ||
| Only needed if using alpha parameter. | ||
| Inferred from n_modes if n_modes is a tuple. |
There was a problem hiding this comment.
Maybe we can remove these since they don't seem to provide more functionality than what n_modes does?
There was a problem hiding this comment.
These (alpha and n_dim) are not in the __init__ anymore. If that's intentional, we can remove them from the doctoring so users are not confused
Let's make sure the docstring reflects only the parameters that the user can specify in the __init__
| def _build_lifting_layer( | ||
| self, | ||
| in_channels: int, | ||
| channel_multiplier: int, | ||
| lifting_channels: int | ||
| ) -> nn.Module: | ||
| """ | ||
| Build the lifting layer that embeds inputs into wavelet space. | ||
|
|
||
| The lifting layer transforms from physical input space to the | ||
| higher-dimensional wavelet coefficient space where MWNO operates. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| in_channels : int | ||
| Input feature dimension | ||
| channel_multiplier : int | ||
| Output dimension (c * k^n_dim) | ||
| lifting_channels : int | ||
| Hidden dimension (0 for direct linear projection) | ||
|
|
||
| Returns | ||
| ------- | ||
| nn.Module | ||
| Lifting layer (Linear or Sequential MLP) |
There was a problem hiding this comment.
Could we use the existing ChannelMLP class (link) ?
You can then create the lifting layer as in the init of FNO class
self.lifting = ChannelMLP(
in_channels=lifting_in_channels,
out_channels=self.hidden_channels,
hidden_channels=self.lifting_channels,
n_layers=2,
n_dim=self.n_dim,
non_linearity=non_linearity,
)
if self.complex_data:
self.lifting = ComplexValued(self.lifting)
| def _build_projection_layer( | ||
| self, | ||
| channel_multiplier: int, | ||
| out_channels: int, | ||
| projection_channels: int | ||
| ) -> nn.Module: | ||
| """ | ||
| Build the projection layer that maps from wavelet space to output. | ||
|
|
||
| The projection layer transforms from the wavelet coefficient space | ||
| back to the physical output space. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| channel_multiplier : int | ||
| Input dimension (c * k^n_dim) | ||
| out_channels : int | ||
| Output feature dimension | ||
| projection_channels : int | ||
| Hidden dimension (0 uses default) | ||
|
|
||
| Returns | ||
| ------- | ||
| nn.Module | ||
| Projection layer (two-layer MLP with ReLU) |
There was a problem hiding this comment.
Could we use the existing ChannelMLP class (link) ?
You can then create the projection layer as in the init of FNO class
## Projection layer
self.projection = ChannelMLP(
in_channels=self.hidden_channels,
out_channels=out_channels,
hidden_channels=self.projection_channels,
n_layers=2,
n_dim=self.n_dim,
non_linearity=non_linearity,
)
if self.complex_data:
self.projection = ComplexValued(self.projection)
| self.mwno_layers = nn.ModuleList([ | ||
| MWNO_CZ( | ||
| k=k, | ||
| alpha=alpha, |
There was a problem hiding this comment.
It seems that the current code only allows to use the same number of modes for every dimension. Is there a real constraint leading to this design? Otherwise it would be good to have the flexibility of choosing more or less modes along certain dimensions, since dynamics can have different scales along different dimensions and the resolution could also be different along different dimensions
There was a problem hiding this comment.
Thank you for your question, but currently, the code only allows the same number of modes for every dimension. We will improve this feature later.
* Adds a short doc to the user guide detailing CPU offloading to save GPU memory ----- author: @cwwangcal
4e9e5bb to
b683e52
Compare
| def _validate_input_shape(self, x: torch.Tensor) -> None: | ||
| """ | ||
| Validate input tensor shape and dimensions. |
There was a problem hiding this comment.
I don't think this function is used anywhere
| If > 0: Two-layer MLP (in → lifting_channels → wavelet_space). | ||
| projection_channels : int, optional | ||
| Hidden dimension for projection layer. Default: 128 | ||
| If 0: Uses default two-layer MLP with hidden_dim=128. |
There was a problem hiding this comment.
The projection_channels=0 option does not seem to be implemented. Maybe we can remove that option? Or simply use the ChannelMLP already in the library
There was a problem hiding this comment.
Thanks — we now use ChannelMLP for projection (same pattern as FNO), with projection_channels=0 mapped to hidden width 128 so the documented default is actually applied; forward permutes to channels-first for the MLP and back.
| If int: Creates (alpha,) for 1D, (alpha, alpha) for 2D, etc. | ||
| Must provide either n_modes or alpha (not both). |
There was a problem hiding this comment.
Maybe can remove mention of alpha since we will remove it
There was a problem hiding this comment.
Might be good to remove alpha completely everywhere and just use n_modes to avoid having two variables for the same thing
There was a problem hiding this comment.
Dropped redundant alpha / duplicate self.alpha; the public API is n_modes only (with optional int shorthand for 1D). MWNOBlock now takes n_modes as well; internal kernels still use the first entry as the mode cutoff, documented in Notes.
|
|
||
| Main Parameters | ||
| --------------- | ||
| n_modes : Tuple[int, ...] or int |
There was a problem hiding this comment.
With the current implementation, n_modes cannot really be a tuple. It seems the code only takes n_modes[0] and then uses it everywhere.
Either we allow for the number of modes to be different along different dimensions, or otherwise we update the docstring and an assertion at the beginning of the model definition to make it clear n_modes but be a single integer
There was a problem hiding this comment.
Good point. The public API now requires n_modes to be one int plus explicit n_dim, with validation in init, and the docs no longer imply a tuple. Anisotropic mode counts would need deeper changes in SparseKernelFT; we can leave that for a later PR if desired.
| Input Shapes: | ||
| - 1D: (batch, n_points, in_channels) where n_points must be a power of 2 | ||
| - 2D: (batch, height, width, in_channels) where height, width must be powers of 2 | ||
| - 3D: (batch, height, width, time, in_channels) where height, width must be powers of 2 |
There was a problem hiding this comment.
Definitely need to add an assert and throw an error if one of the spatial resolutions is not a power of 2
| - 3D: (batch, height, width, time, in_channels) | ||
|
|
||
| Requirements: | ||
| - Spatial dimensions must be powers of 2 |
There was a problem hiding this comment.
Definitely need to add an assert and throw an error if one of the spatial resolutions is not a power of 2
There was a problem hiding this comment.
Added _validate_spatial_resolution at the start of forward: we require power-of-two sizes on the wavelet axes (1D: n_points; 2D: H,W; 3D: H,W only) and raise ValueError with an explicit message otherwise, plus a small unit test for the 2D case.
| phi = [partial(cls.phi_, scaling_coeffs[i, :]) for i in range(k)] | ||
| psi1 = [partial(cls.phi_, np.zeros(k), lb=0, ub=0.5) for i in range(k)] | ||
| psi2 = [partial(cls.phi_, np.zeros(k), lb=0.5, ub=1) for i in range(k)] |
There was a problem hiding this comment.
phi does not take these lb and ub arguments anymore. Please update the names and make sure everything runs smoothly
There was a problem hiding this comment.
phi_ uses lower_bound / upper_bound, not lb / ub. Updated the Chebyshev partial(...) calls accordingly so the wavelet callables match the current signature.
| coeffs_2x = poly_stretched.coef * np.sqrt(2) * 2 / np.sqrt(np.pi) | ||
| scaling_2x_coeffs[basis_idx, :len(coeffs_2x)] = coeffs_2x | ||
|
|
||
| # For Chebyshev, wavelets are handled differently (compact support) |
There was a problem hiding this comment.
It does not seem like the Chebyshev version runs correctly. With the below code, it seems that psi1 and psi2 are just zeros. Make sure to run the code with the Chebyshev version and make sure it works properly
There was a problem hiding this comment.
You were right: Chebyshev psi1/psi2 were zero polynomials. We now convert the Chebyshev scaling rows to power basis (so the same Gram–Schmidt + monomial integrals as Legendre are valid), reuse that Gram–Schmidt for Chebyshev wavelet coefficients, and build psi1/psi2 with phi_ and lower_bound/upper_bound on [0,0.5] / [0.5,1]. Added tests to assert non-zero wavelets and a finite Chebyshev filter bank.
| # CORRECT frequency indexing: | ||
| # Positive frequencies in x: [0, 1, ..., num_modes_x-1] | ||
| output_fft[:, :, :num_modes_x, :num_modes_y] = torch.einsum( | ||
| "bixy,ioxy->boxy", | ||
| x_fft[:, :, :num_modes_x, :num_modes_y], | ||
| self.weights1[:, :, :num_modes_x, :num_modes_y] | ||
| ) | ||
|
|
||
| # Negative frequencies in x: [-num_modes_x, ..., -1] | ||
| # In FFT layout, these are at indices [nx-num_modes_x:nx] | ||
| output_fft[:, :, -num_modes_x:, :num_modes_y] = torch.einsum( | ||
| "bixy,ioxy->boxy", | ||
| x_fft[:, :, -num_modes_x:, :num_modes_y], | ||
| self.weights2[:, :, :num_modes_x, :num_modes_y] | ||
| ) |
There was a problem hiding this comment.
These could be overlapping and problematic if alpha is too large. This is why I would recommend using the SpectralConv class directly, or starting from there
There was a problem hiding this comment.
Thanks — we were indeed using m = min(α, nx//2+1) on the full FFT axis, so [:m] and [-m:] could overlap when 2m > nx. We now cap symmetric blocks at nx//2 per side (and treat nx≤1 as DC-only with no negative slice, since [-0:] aliases the whole axis). The last rfft dimension still uses n//2+1. We added a short note pointing to SpectralConv’s handling of FFT redundancy; wiring SparseKernelFT to SpectralConv directly would be a larger API/weight-shape change if we want that refactor later. This would require a complete structural modification.
| L : int, optional | ||
| Number of coarsest decomposition levels to skip. Default: 0 | ||
| Reduces computation by stopping wavelet decomposition early. | ||
| L=0: Full decomposition to coarsest scale. L=1: Stop 1 level before coarsest. |
There was a problem hiding this comment.
Would be good to add an assert that L needs to be smaller than num_scales
There was a problem hiding this comment.
Added a ValueError in MWNO.forward when L >= num_scales (with num_scales = floor(log2(spatial_size)) on the wavelet axis), documented the constraint on L for both MWNO and MWNOBlock, and a small unit test.
| @pytest.mark.parametrize("n_modes", [[16], (12,12),(8,8,8)]) | ||
| @pytest.mark.parametrize("k", [4]) | ||
| @pytest.mark.parametrize("c", [4, 16]) | ||
| @pytest.mark.parametrize("n_layers", [3]) | ||
| @pytest.mark.parametrize("L", [1]) | ||
| @pytest.mark.parametrize("base", ["legendre"]) |
There was a problem hiding this comment.
Need to test more cases, and definitely need to test chebyshev. If it becomes a number of configs that is too large, we can split it in 2-3 test functions
There was a problem hiding this comment.
Thanks — I split the parametrized smoke test into test_mwno_legendre and test_mwno_chebyshev with the same grid (n_dim/n_modes, c, n_layers ∈ {1,3}, L ∈ {0,1,2}, k=4) and factored the forward/backward + grad checks into a small helper so we don’t duplicate code. If 72 parametrized runs is too heavy for CI we can trim (e.g. drop L=2 or one of n_layers).
|
Thank you for the updates @Shixuan01 ! I have added a few more comments. There seems to be a few remaining issues in the code, especially for the Chebyshev option which currently does not work It would also be nice to have an example script for the multiwavelet NO. That could also help make sure the code is working correctly, both for the Chebyshev and Legendre cases. You could add examples of (tuned) MWNO hyperparameters (maybe one for Chebyshev and one for Legendre) in the config files (here) and see if it works when loaded by the training script for NS equations: |
e9bee7e to
549c7e7
Compare
549c7e7 to
e052867
Compare
Thank you very much for your previous reply and suggestions, including refactoring the layers to be dimension-agnostic and follow the modular design of the FNO currently in the library. In this version, we have integrated the previous mwt_utils.py into mwt.py (if it is not very suitable, we can also integrate it into the layer/submodule), and integrated the multidimensional MWT. All updates are now concentrated in /neuraloperator/neuralop/models/mwt.py, which hopefully makes the structure more concise! We will also update test_mwt.py and related explanatory files later, and hope they will be helpful!