Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
9fc5b52
Add support for deprecated argument (without alternative). Remove sel…
romainvo Dec 5, 2025
c8c3112
compressed sensing, phase retrieval
romainvo Dec 8, 2025
9187ae6
define device as a property
romainvo Dec 12, 2025
4ff12fe
Enforce keep buffer.device intact when updating value. Remove self.de…
romainvo Dec 16, 2025
8bdf3ff
Fix mri changes. Propagate y.device to every call to to_static
romainvo Dec 16, 2025
d48ba58
radio.py
romainvo Dec 16, 2025
0f32d46
single_pixel.py
romainvo Dec 16, 2025
27a9182
tomography
romainvo Dec 16, 2025
8f957c3
blur.py
romainvo Dec 16, 2025
a287a5f
Set self.device as propery in PhysicsGenerator
romainvo Dec 16, 2025
36f802b
Add warning on device property
romainvo Dec 16, 2025
425af7d
Fix tests in test_physics.py
romainvo Dec 16, 2025
f1fa33e
black .
romainvo Dec 16, 2025
1469e22
Merge remote-tracking branch 'upstream/main' into remove_device_attri…
romainvo Dec 16, 2025
6813d91
changelog
romainvo Dec 16, 2025
af8f42f
propagate change to examples
romainvo Dec 17, 2025
ae6c59f
update super init: super().__init__(device=device, **kwargs)
romainvo Dec 17, 2025
15d51ab
black .
romainvo Dec 17, 2025
b29d864
fix to_static call in models/dynamic.py
romainvo Dec 18, 2025
5052554
fix TimeAveragingNet
romainvo Dec 18, 2025
6c01721
black .
romainvo Dec 18, 2025
abeeab4
add device setter and complete warning.
romainvo Dec 19, 2025
19cd689
black .
romainvo Dec 19, 2025
28c366e
Remove update_parameters method from BlurFFT which is redundant with …
romainvo Dec 19, 2025
ace81bf
authorize none buffers
romainvo Dec 28, 2025
dfb7e19
remove custom update_parameters from Downsampling physics
romainvo Dec 31, 2025
bdce975
add non tensor parameters update in update_parameters
romainvo Jan 9, 2026
61f6e26
minor fixes
romainvo Jan 9, 2026
3720387
Merge remote-tracking branch 'upstream/main' into remove_device_attri…
romainvo Jan 9, 2026
a2eeb9e
black .
romainvo Jan 9, 2026
139f82c
Revert changes: remove non tensor attribute update in update_parameters
romainvo Jan 19, 2026
f5abd9c
Revert changes: add back custom update_parameters in BlurFFT and Down…
romainvo Jan 19, 2026
5c374fa
Remove last self.device from Downsampling. Fix angle update in BlurFFT
romainvo Jan 19, 2026
5d4c5de
Fix mask check in update_parameters
romainvo Jan 19, 2026
5b46dd0
Merge remote-tracking branch 'upstream/main' into remove_device_attri…
romainvo Jan 20, 2026
34825b0
Add proper warnings to explain device precedence. Add device argument…
romainvo Jan 22, 2026
1600198
Fix PhysicsMultiScaler to ensure proper argument propagation during M…
romainvo Jan 22, 2026
a971c6b
Fix tests
romainvo Jan 22, 2026
17f6147
black .
romainvo Jan 22, 2026
80bff96
Remove comments
romainvo Jan 22, 2026
19cf35a
add from __futre__ import annotations in wrappers.py
romainvo Jan 22, 2026
cb0308f
Fix warning in Physics.update_parameters. Add device argument to Down…
romainvo Jan 22, 2026
31dd972
Merge remote-tracking branch 'upstream/main' into remove_device_attri…
romainvo Jan 24, 2026
a155194
Merge branch 'main' into remove_device_attribute
romainvo Jan 30, 2026
a37cc6c
Detail contrib guidelines for physics
romainvo Feb 17, 2026
cebc9f7
Merge branch 'contrib_guidelines' into remove_device_attribute
romainvo Feb 17, 2026
25be4cd
black .
romainvo Feb 17, 2026
33264fd
Add from future
romainvo Feb 17, 2026
55c63a7
Merge branch 'main' into remove_device_attribute
romainvo Feb 24, 2026
02ed33a
trigger ci
jscanvic Feb 27, 2026
d157299
Include jscanvic comments
romainvo Feb 27, 2026
9a012a9
Fix warnings in doc
romainvo Feb 27, 2026
8f5aa00
Merge remote-tracking branch 'upstream/main' into remove_device_attri…
romainvo Feb 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion deepinv/models/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,15 @@ def forward(self, y, physics: TimeMixin, **kwargs):
:param y: measurements
:parameter physics: forward operator acting on dynamic inputs
"""

static_physics = (
physics.to_static(device=y.device)
if hasattr(physics, "to_static")
else physics
)

return self.backbone_net(
self.average(y, getattr(physics, "mask", None)),
getattr(physics, "to_static", lambda: physics)(),
static_physics,
**kwargs,
)
6 changes: 4 additions & 2 deletions deepinv/models/ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ def forward(
y = y / rescale_val.view([y.shape[0]] + [1] * (y.ndim - 1))

if physics is None:
physics = dinv.physics.Denoising(noise_model=dinv.physics.ZeroNoise())
physics = dinv.physics.Denoising(
noise_model=dinv.physics.ZeroNoise(), device=y.device
)

if img_size is None:
if hasattr(physics, "img_shape") and physics.img_shape is not None:
Expand All @@ -354,7 +356,7 @@ def forward(

use_pad = False
if pad[0] != 0 or pad[1] != 0:
physics = PhysicsCropper(physics, pad)
physics = PhysicsCropper(physics, pad, device=y.device)
use_pad = True

x_in = physics.A_adjoint(y)
Expand Down
410 changes: 291 additions & 119 deletions deepinv/physics/blur.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion deepinv/physics/cassi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
rng: torch.Generator = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(device=device, **kwargs)

if len(img_size) != 3:
raise ValueError("img_size must be (C, H, W)")
Expand Down
5 changes: 2 additions & 3 deletions deepinv/physics/compressed_sensing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,20 @@ def __init__(
rng: torch.Generator = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(device=device, **kwargs)
self.name = f"CS_m{m}"
self.img_size = img_size
self.fast = fast
self.channelwise = channelwise
self.dtype = dtype
self.device = device

if rng is None:
self.rng = torch.Generator(device=device)
else:
# Make sure that the random generator is on the same device as the physic generator
assert rng.device == torch.device(
device
), f"The random generator is not on the same device as the Physics Generator. Got random generator on {rng.device} and the Physics Generator on {self.device}."
), f"The random generator is not on the same device as the Physics Generator. Got random generator on {rng.device} and the Physics Generator on {device}."
self.rng = rng
self.register_buffer("initial_random_state", self.rng.get_state())

Expand Down
74 changes: 69 additions & 5 deletions deepinv/physics/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,20 @@ def update_parameters(self, **kwargs):
and hasattr(self, key)
and isinstance(value, torch.Tensor)
):
self.register_buffer(key, value)

if isinstance(getattr(self, key), torch.Tensor):
if value.device.type != getattr(self, key).device.type:
warnings.warn(
f"The provided tensor for parameter '{key}' is on a different device ({value.device}) than the current parameter device ({getattr(self, key).device}). The current device will be used.",
stacklevel=2,
)

# Move `value` to the buffer's device before updating
# regardless of where the `value` tensor is located.
# Also performs type casting if necessary.
# If getattr(self, key) is None, torch.Tensor.to will
# ignore the call and just return the original tensor.
setattr(self, key, value.to(getattr(self, key)))

# NOTE: Physics instances can hold instances of torch.Generator as
# (possibly nested) attributes and they cannot be copied using deepcopy
Expand Down Expand Up @@ -373,6 +386,7 @@ class LinearPhysics(Physics):
is used for computing it, and this parameter fixes the relative tolerance of the least squares algorithm.
:param str solver: least squares solver to use. Choose between `'CG'`, `'lsqr'`, `'BiCGStab'` and `'minres'`. See :func:`deepinv.optim.linear.least_squares` for more details.
:param bool implicit_backward_solver: If `True`, uses implicit differentiation for computing gradients through the :meth:`deepinv.physics.LinearPhysics.A_dagger` and :meth:`deepinv.physics.LinearPhysics.prox_l2`, using :func:`deepinv.optim.linear.least_squares_implicit_backward` instead of :func:`deepinv.optim.linear.least_squares`. This can significantly reduce memory consumption, especially when using many iterations. If `False`, uses the standard autograd mechanism, which can be memory-intensive. Default is `True`.
:param torch.device, str device: cpu or cuda, every registered buffer and module parameters are recursively pushed onto the device during initialization.

|sep|

Expand Down Expand Up @@ -436,6 +450,7 @@ def __init__(
tol=1e-4,
solver="lsqr",
implicit_backward_solver: bool = True,
device: torch.device | str = "cpu",
**kwargs,
):
super().__init__(
Expand All @@ -457,6 +472,37 @@ def __init__(
"Using implicit_backward_solver with a low number of iterations may produce inaccurate gradients during the backward pass. If you are not doing backpropagation through `A_dagger` or `prox_l2`, ignore this message. If you are training unfolded models, consider increasing max_iter."
)

device_holder = torch.tensor(0.0, device=device)
self.register_buffer("_device_holder", device_holder, persistent=False)
# pushes all parameters/buffers to the specified device, including `noise_model`
self.to(device)

@property
def device(self) -> torch.device | str:
r"""
Returns the device where the physics parameters/buffers are stored.

:return: device of the physics parameters.
"""
warnings.warn(
"Following torch.nn.Module's design, the 'device' attribute is deprecated and will be removed in a future version. To move the module's buffers/parameters to a different device, use the `to()` method."
)

return self._device_holder.device

@device.setter
def device(self, value: torch.device | str):
r"""
Sets the device where the physics parameters/buffers are stored.

:param device: device to which the physics parameters will be moved.
"""
warnings.warn(
"Following torch.nn.Module's design, the 'device' attribute is deprecated and will be removed in a future version, i.e. doing `physics.device = device` will no longer work and throw an `AttributeError`. Use `physics.to(device)` instead."
)

self.to(value)

def A_adjoint(self, y, **kwargs):
r"""
Computes transpose of the forward operator :math:`\tilde{x} = A^{\top}y`.
Expand Down Expand Up @@ -964,7 +1010,8 @@ class DecomposablePhysics(LinearPhysics):
from the `V_adjoint` function and the `img_size` parameter.
This automatic adjoint is computed using automatic differentiation, which is slower than a closed form adjoint, and can
have a larger memory footprint. If you want to use the automatic adjoint, you should set the `img_size` parameter.
:param torch.nn.parameter.Parameter, float params: Singular values of the transform
:param torch.nn.parameter.Parameter, float mask: Singular values of the transform
:param torch.device, str device: cpu or cuda, every registered buffer and module parameters are recursively pushed onto the device during initialization.

|sep|

Expand Down Expand Up @@ -1002,9 +1049,10 @@ def __init__(
U_adjoint=None,
V=None,
mask=1.0,
device: torch.device | str = "cpu",
**kwargs,
):
super().__init__(**kwargs)
super().__init__(device=device, **kwargs)

assert not (
U is None and not (U_adjoint is None)
Expand All @@ -1027,6 +1075,8 @@ def __init__(
self.img_size = img_size
self.register_buffer("mask", mask)

self.to(device)

def A(self, x, mask=None, **kwargs) -> Tensor:
r"""
Applies the forward operator :math:`y = A(x)`.
Expand Down Expand Up @@ -1210,6 +1260,7 @@ class Denoising(DecomposablePhysics):
The linear operator is just the identity mapping :math:`A(x)=x`

:param torch.nn.Module noise: noise distribution, e.g., :class:`deepinv.physics.GaussianNoise`, or a user-defined torch.nn.Module. By default, it is set to Gaussian noise with a standard deviation of 0.1.
:param torch.device, str device: cpu or cuda, every registered buffer and module parameters are recursively pushed onto the device during initialization.

|sep|

Expand All @@ -1229,10 +1280,23 @@ class Denoising(DecomposablePhysics):

"""

def __init__(self, noise_model: NoiseModel | None = None, **kwargs):
def __init__(
self,
noise_model: NoiseModel | None = None,
device: str | torch.device = "cpu",
**kwargs,
):
if noise_model is None:
noise_model = GaussianNoise(sigma=0.1)
super().__init__(noise_model=noise_model, **kwargs)

if noise_model.rng is not None:
if noise_model.rng.device != device:
warnings.warn(
f"argument `device`={device} is different from the random generator device of the noise model, `noise_model.rng.device`={noise_model.rng.device}. This will likely lead to errors during execution. The device argument will be ignored in favor of `noise_model.rng.device`={noise_model.rng.device}."
)
device = noise_model.rng.device

super().__init__(noise_model=noise_model, device=device, **kwargs)


def adjoint_function(A, input_size, device="cpu", dtype=torch.float):
Expand Down
46 changes: 24 additions & 22 deletions deepinv/physics/functional/radon.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ def deg2rad(x: int | float | torch.Tensor) -> torch.Tensor:


class AbstractFilter(nn.Module):
def __init__(self, device="cpu", dtype=torch.float):
def __init__(self, dtype: torch.dtype = torch.float32):
super().__init__()
self.device = device
self.dtype = dtype

def forward(self, x: torch.Tensor, dim: int = -2) -> torch.Tensor:
Expand All @@ -98,7 +97,7 @@ def forward(self, x: torch.Tensor, dim: int = -2) -> torch.Tensor:
)
pad_width = projection_size_padded - input_size

f = self._get_fourier_filter(projection_size_padded).to(x.device)
f = self._get_fourier_filter(projection_size_padded, device=x.device)
fourier_filter = self.create_filter(f)
if dim == 2 or dim == -2:
fourier_filter = fourier_filter.unsqueeze(-1)
Expand Down Expand Up @@ -149,12 +148,12 @@ def filter(
elif dim == 3 or dim == -1:
return result[:, :, :, :input_size]

def _get_fourier_filter(self, size):
def _get_fourier_filter(self, size, device: str | torch.device = "cpu"):
n = torch.cat(
[torch.arange(1, size / 2 + 1, 2), torch.arange(size / 2 - 1, 0, -2)]
)

f = torch.zeros(size, dtype=self.dtype, device=self.device)
f = torch.zeros(size, dtype=self.dtype, device=device)
f[0] = 0.25
f[1::2] = -1 / (torch.pi * n) ** 2

Expand Down Expand Up @@ -216,7 +215,7 @@ def __init__(
):
super().__init__()
self.circle = circle
theta = theta if theta is not None else torch.arange(180).to(self.device)
theta = theta if theta is not None else torch.arange(180, device=device)
self.register_buffer("theta", theta, persistent=False)
self.dtype = dtype
self.parallel_computation = parallel_computation
Expand All @@ -239,7 +238,7 @@ def __init__(
if not "detector_spacing" in self.fan_parameters.keys():
self.fan_parameters["detector_spacing"] = 0.077

all_grids = self._create_grids(self.theta, in_size, circle).to(device)
all_grids = self._create_grids(self.theta, in_size, circle, device=device)
if self.parallel_computation:
self.register_buffer(
"all_grids",
Expand Down Expand Up @@ -308,7 +307,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out

def _create_grids(
self, angles: torch.Tensor, grid_size: int, circle: bool, device: str = "cpu"
self,
angles: torch.Tensor,
grid_size: int,
circle: bool,
device: str | torch.device = "cpu",
) -> torch.Tensor:
if not circle:
grid_size = int((SQRT2 * grid_size).ceil())
Expand Down Expand Up @@ -365,19 +368,18 @@ def __init__(
):
super().__init__()
self.circle = circle
self.device = device
theta = theta if theta is not None else torch.arange(180).to(self.device)
theta = theta if theta is not None else torch.arange(180, device=device)
self.register_buffer("theta", theta, persistent=False)
self.out_size = out_size
self.in_size = in_size
self.parallel_computation = parallel_computation
self.dtype = dtype

ygrid, xgrid = self._create_yxgrid(in_size, circle)
ygrid, xgrid = self._create_yxgrid(in_size, circle, device=device)
self.register_buffer("xgrid", xgrid, persistent=False)
self.register_buffer("ygrid", ygrid, persistent=False)

all_grids = self._create_grids(self.theta, in_size, circle).to(self.device)
all_grids = self._create_grids(self.theta, in_size, circle, device=device)
if self.parallel_computation:
self.register_buffer(
"all_grids",
Expand All @@ -387,11 +389,7 @@ def __init__(
else:
self.register_buffer("all_grids", all_grids, persistent=False)

self.filter = (
RampFilter(dtype=self.dtype, device=self.device)
if use_filter
else lambda x: x
)
self.filter = RampFilter(dtype=self.dtype) if use_filter else lambda x: x

def forward(self, x: torch.Tensor, filtering: bool = True) -> torch.Tensor:
r"""
Expand All @@ -414,7 +412,7 @@ def forward(self, x: torch.Tensor, filtering: bool = True) -> torch.Tensor:
ch_size,
it_size,
it_size,
device=self.device,
device=x.device,
dtype=self.dtype,
)
for i_theta in range(len(self.theta)):
Expand Down Expand Up @@ -450,11 +448,11 @@ def forward(self, x: torch.Tensor, filtering: bool = True) -> torch.Tensor:
return reco

def _create_yxgrid(
self, in_size: int, circle: bool
self, in_size: int, circle: bool, device: str | torch.device = "cpu"
) -> tuple[torch.Tensor, torch.Tensor]:
if not circle:
in_size = int((SQRT2 * in_size).ceil())
unitrange = torch.linspace(-1, 1, in_size, dtype=self.dtype, device=self.device)
unitrange = torch.linspace(-1, 1, in_size, dtype=self.dtype, device=device)
ygrid, xgrid = torch.meshgrid(unitrange, unitrange, indexing="ij")
return ygrid, xgrid

Expand All @@ -463,14 +461,18 @@ def _XYtoT(self, theta: int | float | torch.Tensor) -> torch.Tensor:
return T

def _create_grids(
self, angles: torch.Tensor, grid_size: int, circle: bool
self,
angles: torch.Tensor,
grid_size: int,
circle: bool,
device: str | torch.device = "cpu",
) -> torch.Tensor:
if not circle:
grid_size = int((SQRT2 * grid_size).ceil())
all_grids = []
for i_theta in range(len(angles)):
X = (
torch.ones(grid_size, dtype=self.dtype, device=self.device)
torch.ones(grid_size, dtype=self.dtype, device=device)
.view(-1, 1)
.repeat(1, grid_size)
* i_theta
Expand Down
7 changes: 5 additions & 2 deletions deepinv/physics/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ def __init__(
self.step_func = step
self.kwargs = kwargs
self.factory_kwargs = {"device": device, "dtype": dtype}
self.device = device
if rng is None:
self.rng = torch.Generator(device=device)
else:
# Make sure that the random generator is on the same device as the physics generator
assert rng.device == torch.device(
device
), f"The random generator is not on the same device as the Physics Generator. Got random generator on {rng.device} and the Physics Generator named {self.__class__.__name__} on {self.device}."
), f"The random generator is not on the same device as the Physics Generator. Got random generator on {rng.device} and the Physics Generator named {self.__class__.__name__} on {device}."
self.rng = rng

# NOTE: There is no use in moving RNG states from one device to another
Expand All @@ -88,6 +87,10 @@ def __init__(
for k, v in kwargs.items():
setattr(self, k, v)

@property
def device(self) -> torch.device:
return self.rng.device

def step(self, batch_size: int = 1, seed: int = None, **kwargs) -> dict:
r"""
Generates a batch of parameters for the forward operator.
Expand Down
2 changes: 1 addition & 1 deletion deepinv/physics/inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
rng: torch.Generator = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(device=device, **kwargs)

if isinstance(mask, torch.Tensor):
mask = mask.to(device)
Expand Down
Loading
Loading