-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregression.py
More file actions
210 lines (152 loc) · 8.91 KB
/
regression.py
File metadata and controls
210 lines (152 loc) · 8.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import re
import sys
from typing import Optional, Union
import numpy as np
from shunting_yard import MismatchedBracketsError, shunting_yard
from tqdm import tqdm
from chplot.functions import FUNCTIONS
from chplot.plot.plot_parameters import PlotParameters
from chplot.plot.utils import _round as round
from chplot.plot.utils import Graph, GraphType
from chplot.plot.utils import LOGGER
from chplot.rpn import compute_rpn_list, get_rpn_errors
# match anything like _rX either at the beginning/end of a string or surrounded by spaces, where X is a letter or underscore possibly followed by more letters/underscores or digits
REGRESSION_PARAMETERS_REGEX = r'(^| )(_r[a-zA-Z_][0-9a-zA-Z_]*)( |$)'
def _get_unique_regression_parameters(rpn: str) -> list[str]:
"""Get unique regression parameters without changing their order, and without missing overlapping parameters."""
parameters_names = []#[param_name for _, param_name, _ in re.findall(REGRESSION_PARAMETERS_REGEX, rpn)]
while rpn:
match = re.search(REGRESSION_PARAMETERS_REGEX, rpn)
if match:
rpn = rpn[match.start() + 1:]
if (param_name := match.group(2)) not in parameters_names:
parameters_names.append(param_name)
else:
break
return parameters_names
def _check_regression_expression(parameters: PlotParameters) -> Optional[str]:
"""Check if the regression expression is valid. Return its RPN if yes, None if no."""
try:
rpn = shunting_yard(
expression=parameters.regression_expression,
case_sensitive=True,
variable=parameters.variable,
convert_scientific_notation=not parameters.disable_scientific_notation
)
except MismatchedBracketsError:
LOGGER.error("mismatched brackets in the regression expression '%s'.", parameters.regression_expression)
return None
except Exception:
LOGGER.error("unknown error in the regression expression '%s'.", parameters.regression_expression)
return None
# Check that there are regression parameters in the expression
if not re.findall(REGRESSION_PARAMETERS_REGEX, rpn):
LOGGER.error("error: no regression parameters (string starting with '_r' in the regression expression)")
return None
# Replace every regression parameters with 0 to check if the rest of the RPN is valid
# If parameters are overlapping (ie: the rpn contains something like "_ra _rb"), we need to do the replacement multiple times
rpn_check = rpn
while re.search(REGRESSION_PARAMETERS_REGEX, rpn_check):
rpn_check = re.sub(REGRESSION_PARAMETERS_REGEX, r'\g<1>0\g<3>', rpn_check)
if (error := get_rpn_errors(rpn_check, variable=parameters.variable)) is not None:
LOGGER.error("error in the regression expression '%s' : %s", parameters.regression_expression, error)
return None
return rpn
def _get_fit_rpn(rpn: str, parameters_names: list[str], parameters_values: list[float]) -> str:
"""Return the given RPN with the regression parameters replaced by their values, so that the regression can be normally computed later."""
for param_name, param_value in zip(parameters_names, parameters_values):
# abs function necessary to take care of -0.0
if param_value >= 0:
param_value_str = rf'\g<1>{abs(param_value)}\g<2>'
# if the value is negative, we need to add a unary subtraction in the rpn
else:
param_value_str = rf'\g<1>{abs(param_value)} -u\g<2>'
rpn = re.sub(rf'(^| ){param_name}( |$)', param_value_str, rpn)
return rpn
def _get_fit_expression(expression: str, parameters_names: list[str], parameters_values: list[float], brackets: bool = True) -> str:
"""Return the given expression with the regression parameters replaced by their values. Brackets are added to force correct parsing by others softwares."""
for param_name, param_value in zip(parameters_names, parameters_values):
param_value_str = f'({param_value})' if brackets else str(param_value)
expression = re.sub(rf'\b{param_name}\b', param_value_str, expression)
return expression
def _remove_nan(arr1: Union[list[float], np.ndarray], arr2: Union[list[float], np.ndarray]) -> tuple[np.ndarray, np.ndarray]:
"""Remove the values of both arrays where at least one of them is nan."""
arr1 = np.array(arr1)
arr2 = np.array(arr2)
not_nan_indices = ~(np.isnan(arr1) | np.isnan(arr2))
return (arr1[not_nan_indices], arr2[not_nan_indices])
# Ref : https://stackoverflow.com/questions/19189362/getting-the-r-squared-value-using-curve-fit
def _compute_r_squared_and_errors(ydata: np.ndarray, yfit: np.ndarray) -> tuple[float, float]:
residuals = ydata - yfit
sum_sq_res = np.sum(residuals ** 2)
sum_sq_tot = np.sum((ydata - np.mean(ydata)) ** 2)
max_error = np.max(np.abs(residuals))
non_zero_indices = ~np.isclose(ydata, 0)
max_rel_error = np.max(np.abs(residuals[non_zero_indices] / ydata[non_zero_indices]))
if sum_sq_tot == 0:
return (1, max_error, max_rel_error)
return (1 - sum_sq_res / sum_sq_tot, max_error, max_rel_error)
def compute_regressions(parameters: PlotParameters, graphs: list[Graph]) -> list[Graph]:
# Import in this function, so it is not imported if no regression is computed
from scipy.optimize import curve_fit, OptimizeWarning
if len(graphs) == 0:
return []
if (rpn := _check_regression_expression(parameters)) is None:
return []
file = sys.stdout
parameters_names = _get_unique_regression_parameters(rpn)
parameters_names_without_prefix = [param_name[2:] for param_name in parameters_names]
def _regression_function(xdata: np.ndarray, *regression_parameters: list[float]):
for (param_name, param_value) in zip(parameters_names, regression_parameters):
FUNCTIONS[param_name] = (0, param_value)
pbar.update(1)
return compute_rpn_list(rpn, xdata, parameters.variable, progress_bar=False)
regression_graphs: list[Graph] = []
file.write('\n===== REGRESSION COEFFICIENTS OF THE FUNCTIONS =====\n\n')
file.write(f'Regression function: reg(x) = {_get_fit_expression(parameters.regression_expression, parameters_names, parameters_names_without_prefix, brackets=False)}\n\n')
for graph in graphs:
# Remove all nan values for the curve_fit computation
inputs_without_nan, values_without_nan = _remove_nan(graph.inputs, graph.values)
if inputs_without_nan.size < len(parameters_names):
LOGGER.error(
"not enough non-nan input points on graph '%s' to compute specified regression ('%s' found, at least '%s' needed)",
graph.expression, inputs_without_nan.size, len(parameters_names)
)
continue
try:
# Default max number of iterations of curve_fit
pbar = tqdm(total=200 * (len(parameters_names) + 1), leave=False)
parameters_values, _ = curve_fit(
f=_regression_function,
xdata=inputs_without_nan,
ydata=values_without_nan,
p0=[1.0]*len(parameters_names)
)
except (OptimizeWarning, RuntimeError):
pbar.close()
LOGGER.error("error while computing regression of '%s', try reducing the number of parameters or simplifying the expression", graph.expression)
continue
if np.isnan(parameters_values).any():
pbar.close()
LOGGER.error("error while computing regression of '%s', try changing the number of parameters or simplifying the expression", graph.expression)
continue
pbar.close()
custom_inputs = np.linspace(graph.inputs.min(), graph.inputs.max(), parameters.n_points, endpoint=True)
r2, max_error, max_rel_error = _compute_r_squared_and_errors(*_remove_nan(values_without_nan, _regression_function(inputs_without_nan, *parameters_values)))
regression_graphs.append(Graph(
inputs=custom_inputs,
type=GraphType.REGRESSION,
expression=f'Regression [{graph.expression}]',
rpn=_get_fit_rpn(rpn, parameters_names, parameters_values),
values=_regression_function(custom_inputs, *parameters_values)
))
file.write(f'- Function f(x) = {graph.expression}\n')
file.write(' Coefficients:\n')
for param_name, param_value in zip(parameters_names, parameters_values):
file.write(f' {param_name[2:]} = {round(param_value, 5)} (exact {param_value})\n')
file.write(f'\n Accuracy on [{graph.inputs.min():.3f} ; {graph.inputs.max():.3f}]:\n')
file.write(f' R2 = {r2}\n')
file.write(f' |err| <= {max_error}\n')
file.write(f' |rel err| <= {max_rel_error}\n')
file.write(f'\n Copyable expression:\n f(x) = {_get_fit_expression(parameters.regression_expression, parameters_names, parameters_values)}\n\n\n')
return regression_graphs