|
1 | | -from functools import wraps |
2 | | -import numpy as np |
3 | | -from typing import Union |
4 | | - |
5 | | -from autoarray.structures.grids.uniform_1d import Grid1D |
6 | | -from autoarray.structures.grids.irregular_2d import Grid2DIrregular |
7 | | -from autoarray.structures.grids.uniform_2d import Grid2D |
8 | | - |
9 | | - |
10 | | -def transform(func): |
11 | | - """ |
12 | | - Checks whether the input Grid2D of (y,x) coordinates have previously been transformed. If they have not |
13 | | - been transformed then they are transformed. |
14 | | -
|
15 | | - Parameters |
16 | | - ---------- |
17 | | - func |
18 | | - A function where the input grid is the grid whose coordinates are transformed. |
19 | | -
|
20 | | - Returns |
21 | | - ------- |
22 | | - A function that can accept cartesian or transformed coordinates |
23 | | - """ |
24 | | - |
25 | | - @wraps(func) |
26 | | - def wrapper( |
27 | | - obj: object, |
28 | | - grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D], |
29 | | - xp=np, |
30 | | - *args, |
31 | | - **kwargs, |
32 | | - ) -> Union[np.ndarray, Grid2D, Grid2DIrregular]: |
33 | | - """ |
34 | | - This decorator checks whether the input grid has been transformed to the reference frame of the class |
35 | | - that owns the function. If it has not been transformed, it is transformed. |
36 | | -
|
37 | | - A function call which uses this decorator often has many subsequent function calls which also use the |
38 | | - decorator. To ensure the grid is only transformed once, the `is_transformed` keyword is used to track |
39 | | - whether the grid has been transformed. |
40 | | -
|
41 | | - Parameters |
42 | | - ---------- |
43 | | - obj |
44 | | - An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid. |
45 | | - grid |
46 | | - The (y, x) coordinates in the original reference frame of the grid. |
47 | | -
|
48 | | - Returns |
49 | | - ------- |
50 | | - A grid_like object whose coordinates may be transformed. |
51 | | - """ |
52 | | - |
53 | | - if not kwargs.get("is_transformed"): |
54 | | - kwargs["is_transformed"] = True |
55 | | - |
56 | | - transformed_grid = obj.transformed_to_reference_frame_grid_from( |
57 | | - grid, xp, **kwargs |
58 | | - ) |
59 | | - |
60 | | - result = func(obj, transformed_grid, xp, *args, **kwargs) |
61 | | - |
62 | | - else: |
63 | | - result = func(obj, grid, xp, *args, **kwargs) |
64 | | - |
65 | | - return result |
66 | | - |
67 | | - return wrapper |
| 1 | +from functools import wraps |
| 2 | +import numpy as np |
| 3 | +from typing import Union |
| 4 | + |
| 5 | +from autoarray.structures.grids.uniform_1d import Grid1D |
| 6 | +from autoarray.structures.grids.irregular_2d import Grid2DIrregular |
| 7 | +from autoarray.structures.grids.uniform_2d import Grid2D |
| 8 | + |
| 9 | + |
| 10 | +def transform(func=None, *, rotate_back=False): |
| 11 | + """ |
| 12 | + Checks whether the input Grid2D of (y,x) coordinates have previously been transformed. If they have not |
| 13 | + been transformed then they are transformed. |
| 14 | +
|
| 15 | + Can be used with or without arguments:: |
| 16 | +
|
| 17 | + @transform |
| 18 | + def convergence_2d_from(self, grid, xp=np, **kwargs): ... |
| 19 | +
|
| 20 | + @transform(rotate_back=True) |
| 21 | + def deflections_yx_2d_from(self, grid, xp=np, **kwargs): ... |
| 22 | +
|
| 23 | + When ``rotate_back=True``, after the decorated function returns its result the decorator automatically |
| 24 | + rotates the output vector back from the profile's reference frame to the original observer frame. |
| 25 | + This eliminates the need for deflection methods to manually call |
| 26 | + ``self.rotated_grid_from_reference_frame_from``. |
| 27 | +
|
| 28 | + Parameters |
| 29 | + ---------- |
| 30 | + func |
| 31 | + A function where the input grid is the grid whose coordinates are transformed. |
| 32 | + rotate_back |
| 33 | + If ``True``, the result is rotated back from the profile's reference frame after evaluation. |
| 34 | + Use this for functions that return vector quantities (e.g. deflection angles) computed in the |
| 35 | + profile's rotated frame. |
| 36 | +
|
| 37 | + Returns |
| 38 | + ------- |
| 39 | + A function that can accept cartesian or transformed coordinates |
| 40 | + """ |
| 41 | + |
| 42 | + def decorator(func): |
| 43 | + @wraps(func) |
| 44 | + def wrapper( |
| 45 | + obj: object, |
| 46 | + grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D], |
| 47 | + xp=np, |
| 48 | + *args, |
| 49 | + **kwargs, |
| 50 | + ) -> Union[np.ndarray, Grid2D, Grid2DIrregular]: |
| 51 | + """ |
| 52 | + This decorator checks whether the input grid has been transformed to the reference frame of the class |
| 53 | + that owns the function. If it has not been transformed, it is transformed. |
| 54 | +
|
| 55 | + The transform state is tracked via the ``is_transformed`` property on the grid object itself. |
| 56 | + When a decorated function calls another decorated function with the same (already-transformed) |
| 57 | + grid, the flag prevents the grid from being transformed a second time. |
| 58 | +
|
| 59 | + Parameters |
| 60 | + ---------- |
| 61 | + obj |
| 62 | + An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid. |
| 63 | + grid |
| 64 | + The (y, x) coordinates in the original reference frame of the grid. |
| 65 | +
|
| 66 | + Returns |
| 67 | + ------- |
| 68 | + A grid_like object whose coordinates may be transformed. |
| 69 | + """ |
| 70 | + |
| 71 | + if not getattr(grid, "is_transformed", False): |
| 72 | + transformed_grid = obj.transformed_to_reference_frame_grid_from( |
| 73 | + grid, xp, **kwargs |
| 74 | + ) |
| 75 | + transformed_grid.is_transformed = True |
| 76 | + |
| 77 | + result = func(obj, transformed_grid, xp, *args, **kwargs) |
| 78 | + |
| 79 | + else: |
| 80 | + result = func(obj, grid, xp, *args, **kwargs) |
| 81 | + |
| 82 | + if rotate_back: |
| 83 | + result = obj.rotated_grid_from_reference_frame_from( |
| 84 | + grid=result, xp=xp |
| 85 | + ) |
| 86 | + |
| 87 | + return result |
| 88 | + |
| 89 | + return wrapper |
| 90 | + |
| 91 | + if func is not None: |
| 92 | + # Called without arguments: @transform |
| 93 | + return decorator(func) |
| 94 | + |
| 95 | + # Called with arguments: @transform(rotate_back=True) |
| 96 | + return decorator |
0 commit comments