Skip to content

nufft Error #2162

@nadav-nt-tao

Description

@nadav-nt-tao

When trying to plot effective ripple with nufft_eps=0, I still receive errors:

File ~/DESC/Nadav/eff_ripple.py:21, in plot_eff_ripple(eq)
     19 fig, ax = plt.subplots()
     20 theta = Bounce2D.compute_theta(eq, rho=np.unique(grid.nodes[: ,0]))
---> 21 res = eq.compute("effective ripple", grid=grid, theta=theta, nufft_eps=0)
     22 ax.plot(res["rho"], res["effective ripple"], '.')
     23 ax.set_yscale("log")

File ~/DESC/desc/equilibrium/equilibrium.py:1229, in Equilibrium.compute(self, names, grid, params, transforms, profiles, data, override_grid, **kwargs)
   1220     data1dz = {
   1221         key: grid.copy_data_from_other(
   1222             data1dz[key], grid1dz, surface_label="zeta"
   (...)   1225         if key in dep1dz and key not in data
   1226     }
   1227     data.update(data1dz)
-> 1229 data = compute_fun(
   1230     self,
   1231     names,
   1232     params=params,
   1233     transforms=transforms,
   1234     profiles=profiles,
   1235     data=data,
   1236     **kwargs,
   1237 )
   1238 return data

File ~/DESC/desc/compute/utils.py:137, in compute(parameterization, names, params, transforms, profiles, data, **kwargs)
    134 if data is None:
    135     data = {}
--> 137 data = _compute(
    138     p,
    139     names,
    140     params=params,
    141     transforms=transforms,
    142     profiles=profiles,
    143     data=data,
    144     **kwargs,
    145 )
    147 # convert data from default 'rpz' basis to 'xyz' basis, if requested by the user
    148 if basis == "xyz":

File ~/DESC/desc/compute/utils.py:195, in _compute(parameterization, names, params, transforms, profiles, data, **kwargs)
    190     continue
    191 if not has_data_dependencies(
    192     parameterization, name, data, transforms["grid"].axis.size
    193 ):
    194     # then compute the missing dependencies
--> 195     data = _compute(
    196         parameterization,
    197         data_index[parameterization][name]["dependencies"]["data"],
    198         params=params,
    199         transforms=transforms,
    200         profiles=profiles,
    201         data=data,
    202         **kwargs,
    203     )
    204     if transforms["grid"].axis.size:
    205         data = _compute(
    206             parameterization,
    207             data_index[parameterization][name]["dependencies"][
   (...)    214             **kwargs,
    215         )

File ~/DESC/desc/compute/utils.py:217, in _compute(parameterization, names, params, transforms, profiles, data, **kwargs)
    205             data = _compute(
    206                 parameterization,
    207                 data_index[parameterization][name]["dependencies"][
   (...)    214                 **kwargs,
    215             )
    216     # now compute the quantity
--> 217     data = data_index[parameterization][name]["fun"](
    218         params=params, transforms=transforms, profiles=profiles, data=data, **kwargs
    219     )
    220 return data

    [... skipping hidden 47 frame]

File ~/DESC/desc-env/lib/python3.14/site-packages/jax/_src/interpreters/mlir.py:2391, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs)
   2389   rule = platform_rules.get(platforms[0], default_rule)
   2390   if rule is None:
-> 2391     raise NotImplementedError(
   2392       f"MLIR translation rule for primitive '{description}' not "
   2393       f"found for platform {platforms[0]}")
   2395 # Multi-platform lowering
   2396 kept_rules: list[LoweringRule] = []  # Only the rules for the platforms of interest

NotImplementedError: MLIR translation rule for primitive 'nufft2' not found for platform cuda

Seems like the code still tries to use nufft even with a zero eps.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions