Skip to content

Commit a4762d7

Browse files
Jammy2211claude
authored andcommitted
refactor: move transform state to grid property and add rotate_back
The transform decorator now tracks is_transformed on the grid object itself instead of passing it through kwargs. Added rotate_back parameter for automatic back-rotation of deflection vectors. Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent 36fc333 commit a4762d7

2 files changed

Lines changed: 104 additions & 67 deletions

File tree

autoarray/abstract_ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ def __init__(self, array, xp=np):
7676

7777
self.use_jax = xp is not np
7878

79+
@property
80+
def is_transformed(self) -> bool:
81+
return self._is_transformed
82+
83+
@is_transformed.setter
84+
def is_transformed(self, value: bool):
85+
self._is_transformed = value
86+
7987
@property
8088
def _xp(self):
8189
if self.use_jax:
Lines changed: 96 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,96 @@
1-
from functools import wraps
2-
import numpy as np
3-
from typing import Union
4-
5-
from autoarray.structures.grids.uniform_1d import Grid1D
6-
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
7-
from autoarray.structures.grids.uniform_2d import Grid2D
8-
9-
10-
def transform(func):
11-
"""
12-
Checks whether the input Grid2D of (y,x) coordinates have previously been transformed. If they have not
13-
been transformed then they are transformed.
14-
15-
Parameters
16-
----------
17-
func
18-
A function where the input grid is the grid whose coordinates are transformed.
19-
20-
Returns
21-
-------
22-
A function that can accept cartesian or transformed coordinates
23-
"""
24-
25-
@wraps(func)
26-
def wrapper(
27-
obj: object,
28-
grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D],
29-
xp=np,
30-
*args,
31-
**kwargs,
32-
) -> Union[np.ndarray, Grid2D, Grid2DIrregular]:
33-
"""
34-
This decorator checks whether the input grid has been transformed to the reference frame of the class
35-
that owns the function. If it has not been transformed, it is transformed.
36-
37-
A function call which uses this decorator often has many subsequent function calls which also use the
38-
decorator. To ensure the grid is only transformed once, the `is_transformed` keyword is used to track
39-
whether the grid has been transformed.
40-
41-
Parameters
42-
----------
43-
obj
44-
An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid.
45-
grid
46-
The (y, x) coordinates in the original reference frame of the grid.
47-
48-
Returns
49-
-------
50-
A grid_like object whose coordinates may be transformed.
51-
"""
52-
53-
if not kwargs.get("is_transformed"):
54-
kwargs["is_transformed"] = True
55-
56-
transformed_grid = obj.transformed_to_reference_frame_grid_from(
57-
grid, xp, **kwargs
58-
)
59-
60-
result = func(obj, transformed_grid, xp, *args, **kwargs)
61-
62-
else:
63-
result = func(obj, grid, xp, *args, **kwargs)
64-
65-
return result
66-
67-
return wrapper
1+
from functools import wraps
2+
import numpy as np
3+
from typing import Union
4+
5+
from autoarray.structures.grids.uniform_1d import Grid1D
6+
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
7+
from autoarray.structures.grids.uniform_2d import Grid2D
8+
9+
10+
def transform(func=None, *, rotate_back=False):
11+
"""
12+
Checks whether the input Grid2D of (y,x) coordinates have previously been transformed. If they have not
13+
been transformed then they are transformed.
14+
15+
Can be used with or without arguments::
16+
17+
@transform
18+
def convergence_2d_from(self, grid, xp=np, **kwargs): ...
19+
20+
@transform(rotate_back=True)
21+
def deflections_yx_2d_from(self, grid, xp=np, **kwargs): ...
22+
23+
When ``rotate_back=True``, after the decorated function returns its result the decorator automatically
24+
rotates the output vector back from the profile's reference frame to the original observer frame.
25+
This eliminates the need for deflection methods to manually call
26+
``self.rotated_grid_from_reference_frame_from``.
27+
28+
Parameters
29+
----------
30+
func
31+
A function where the input grid is the grid whose coordinates are transformed.
32+
rotate_back
33+
If ``True``, the result is rotated back from the profile's reference frame after evaluation.
34+
Use this for functions that return vector quantities (e.g. deflection angles) computed in the
35+
profile's rotated frame.
36+
37+
Returns
38+
-------
39+
A function that can accept cartesian or transformed coordinates
40+
"""
41+
42+
def decorator(func):
43+
@wraps(func)
44+
def wrapper(
45+
obj: object,
46+
grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D],
47+
xp=np,
48+
*args,
49+
**kwargs,
50+
) -> Union[np.ndarray, Grid2D, Grid2DIrregular]:
51+
"""
52+
This decorator checks whether the input grid has been transformed to the reference frame of the class
53+
that owns the function. If it has not been transformed, it is transformed.
54+
55+
The transform state is tracked via the ``is_transformed`` property on the grid object itself.
56+
When a decorated function calls another decorated function with the same (already-transformed)
57+
grid, the flag prevents the grid from being transformed a second time.
58+
59+
Parameters
60+
----------
61+
obj
62+
An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid.
63+
grid
64+
The (y, x) coordinates in the original reference frame of the grid.
65+
66+
Returns
67+
-------
68+
A grid_like object whose coordinates may be transformed.
69+
"""
70+
71+
if not getattr(grid, "is_transformed", False):
72+
transformed_grid = obj.transformed_to_reference_frame_grid_from(
73+
grid, xp, **kwargs
74+
)
75+
transformed_grid.is_transformed = True
76+
77+
result = func(obj, transformed_grid, xp, *args, **kwargs)
78+
79+
else:
80+
result = func(obj, grid, xp, *args, **kwargs)
81+
82+
if rotate_back:
83+
result = obj.rotated_grid_from_reference_frame_from(
84+
grid=result, xp=xp
85+
)
86+
87+
return result
88+
89+
return wrapper
90+
91+
if func is not None:
92+
# Called without arguments: @transform
93+
return decorator(func)
94+
95+
# Called with arguments: @transform(rotate_back=True)
96+
return decorator

0 commit comments

Comments
 (0)