-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_plotting.py
More file actions
497 lines (400 loc) · 19.9 KB
/
test_plotting.py
File metadata and controls
497 lines (400 loc) · 19.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
import os
import pickle
from math import ceil, sqrt
from unittest.mock import MagicMock, patch
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.collections import PolyCollection, QuadMesh
from matplotlib.patches import Rectangle
import ratapi.utils.plotting as RATplot
from ratapi.events import notify
from ratapi.rat_core import EventTypes, PlotEventData
TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data")
def data() -> PlotEventData:
"""Creates the data for the tests."""
data_path = os.path.join(TEST_DIR_PATH, "plotting_data.pickle")
with open(data_path, "rb") as f:
loaded_data = pickle.load(f)
data = PlotEventData()
data.modelType = loaded_data["modelType"]
data.dataPresent = loaded_data["dataPresent"]
data.subRoughs = loaded_data["subRoughs"]
data.resample = loaded_data["resample"]
data.resampledLayers = loaded_data["resampledLayers"]
data.reflectivity = loaded_data["reflectivity"]
data.shiftedData = loaded_data["shiftedData"]
data.sldProfiles = loaded_data["sldProfiles"]
data.contrastNames = ["D2O", "SMW", "H2O"]
return data
def domains_data() -> PlotEventData:
"""Creates the fake domains data for the tests."""
domains_data = data()
for sld_list in domains_data.sldProfiles:
sld_list.append(sld_list[0])
return domains_data
@pytest.fixture(params=[False])
def fig(request) -> plt.figure:
"""Creates the fixture for the tests."""
plt.close("all")
figure = plt.subplots(1, 2)[0]
RATplot.plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data())
return figure
@pytest.fixture
def bayes_fig(request) -> plt.figure:
plt.close("all")
figure = plt.subplots(1, 2)[0]
dat = data()
confidence_intervals = {
"reflectivity": [
(curve[:, 1] - curve[:, 1] * 0.5, curve[:, 1] + curve[:, 1] * 0.5) for curve in dat.reflectivity
],
"sld": [
[(curve[:, 1] - curve[:, 1] * 0.1, curve[:, 1] + curve[:, 1] * 0.1) for curve in sld]
for sld in dat.sldProfiles
],
}
RATplot.plot_ref_sld_helper(data=dat, fig=figure, confidence_intervals=confidence_intervals)
return figure
@pytest.mark.parametrize("fig", [False, True], indirect=True)
def test_figure_axis_formatting(fig: plt.figure) -> None:
"""Tests the axis formatting of the figure."""
ref_plot = fig.axes[0]
sld_plot = fig.axes[1]
assert fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2)
assert len(fig.axes) == 2
assert ref_plot.get_xlabel() == "$Q_{z} (\u00c5^{-1})$"
assert ref_plot.get_xscale() == "log"
assert ref_plot.get_ylabel() == "Reflectivity"
assert ref_plot.get_yscale() == "log"
assert [label._text for label in ref_plot.get_legend().texts] == ["D2O", "SMW", "H2O"]
assert sld_plot.get_xlabel() == "$Z (\u00c5)$"
assert sld_plot.get_xscale() == "linear"
assert sld_plot.get_ylabel() == "$SLD (\u00c5^{-2})$"
assert sld_plot.get_yscale() == "linear"
labels = [label._text for label in sld_plot.get_legend().texts]
if len(labels) == 3:
assert labels == ["D2O", "SMW", "H2O"]
else:
assert labels == [
"D2O Domain 1",
"D2O Domain 2",
"SMW Domain 1",
"SMW Domain 2",
"H2O Domain 1",
"H2O Domain 2",
]
def test_ref_sld_color_formatting(fig: plt.figure) -> None:
"""Tests the color formatting of the figure."""
ref_plot = fig.axes[0]
sld_plot = fig.axes[1]
assert len(ref_plot.get_lines()) == 6
assert len(sld_plot.get_lines()) == 6
for i in range(0, len(ref_plot.get_lines()), 2):
# Tests whether the color of the line and the errorbars match on the ref_plot
assert ref_plot.containers[i // 2][2][0]._original_edgecolor == ref_plot.get_lines()[i].get_color()
# Tests whether the color of the sld and resampled_sld match on the sld_plot
assert sld_plot.get_lines()[i].get_color() == sld_plot.get_lines()[i + 1].get_color()
def test_ref_sld_bayes(fig, bayes_fig):
"""Test that shading is correctly added to the figure when confidence intervals are supplied."""
# the shading is of type PolyCollection
for axes in fig.axes:
components = axes.get_children()
assert not any(isinstance(comp, PolyCollection) for comp in components)
for axes in bayes_fig.axes:
components = axes.get_children()
assert any(isinstance(comp, PolyCollection) for comp in components)
@patch("ratapi.utils.plotting.makeSLDProfile")
def test_sld_profile_function_call(mock: MagicMock) -> None:
"""Tests the makeSLDProfile function called with
correct args.
"""
RATplot.plot_ref_sld_helper(data(), plt.subplots(1, 2)[0])
assert mock.call_count == 3
assert mock.call_args_list[0].args[0] == 2.07e-06
assert mock.call_args_list[0].args[1] == 6.28e-06
assert mock.call_args_list[0].args[3] == 0.0
assert mock.call_args_list[0].args[4] == 1
assert mock.call_args_list[1].args[0] == 2.07e-06
assert mock.call_args_list[1].args[1] == 1.83e-06
assert mock.call_args_list[1].args[3] == 0.0
assert mock.call_args_list[1].args[4] == 1
assert mock.call_args_list[2].args[0] == 2.07e-06
assert mock.call_args_list[2].args[1] == -5.87e-07
assert mock.call_args_list[2].args[3] == 0.0
assert mock.call_args_list[2].args[4] == 1
@patch("ratapi.utils.plotting.makeSLDProfile")
def test_live_plot(mock: MagicMock) -> None:
plot_data = data()
with RATplot.LivePlot() as figure:
assert len(figure.axes) == 2
notify(EventTypes.Plot, plot_data)
plt.close(figure)
notify(EventTypes.Plot, plot_data)
assert mock.call_count == 3
assert mock.call_args_list[0].args[0] == 2.07e-06
assert mock.call_args_list[0].args[1] == 6.28e-06
assert mock.call_args_list[0].args[3] == 0.0
assert mock.call_args_list[0].args[4] == 1
assert mock.call_args_list[1].args[0] == 2.07e-06
assert mock.call_args_list[1].args[1] == 1.83e-06
assert mock.call_args_list[1].args[3] == 0.0
assert mock.call_args_list[1].args[4] == 1
assert mock.call_args_list[2].args[0] == 2.07e-06
assert mock.call_args_list[2].args[1] == -5.87e-07
assert mock.call_args_list[2].args[3] == 0.0
assert mock.call_args_list[2].args[4] == 1
@patch("ratapi.utils.plotting.plot_ref_sld_helper")
def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_results) -> None:
RATplot.plot_ref_sld(input_project, reflectivity_calculation_results)
mock.assert_called_once()
data = mock.call_args[0][0]
figure = mock.call_args[0][1]
assert figure.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2)
assert len(figure.axes) == 2
for reflectivity, reflectivity_results in zip(
data.reflectivity, reflectivity_calculation_results.reflectivity, strict=False
):
assert (reflectivity == reflectivity_results).all()
for sldProfile, result_sld_profile in zip(
data.sldProfiles, reflectivity_calculation_results.sldProfiles, strict=False
):
for sld, sld_results in zip(sldProfile, result_sld_profile, strict=False):
assert (sld == sld_results).all()
assert data.modelType == input_project.model
assert data.shiftedData == reflectivity_calculation_results.shiftedData
assert data.resampledLayers == reflectivity_calculation_results.resampledLayers
assert data.dataPresent.size == 0
assert (data.subRoughs == reflectivity_calculation_results.contrastParams.subRoughs).all()
assert data.resample.size == 0
assert len(data.contrastNames) == 0
def test_ref_sld_subplot_correction():
"""Test that if an incorrect number of subplots is corrected in the figure helper."""
fig = plt.subplots(1, 3)[0]
RATplot.plot_ref_sld_helper(data=data(), fig=fig)
assert fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2)
assert len(fig.axes) == 2
@patch("ratapi.utils.plotting.plot_ref_sld_helper")
def test_plot_ref_sld_bayes_validation(mock, input_project, reflectivity_calculation_results, dream_results):
"""Test that plot_ref_sld correctly throws errors for bad Bayesian input."""
RATplot.plot_ref_sld(input_project, dream_results)
RATplot.plot_ref_sld(input_project, dream_results, bayes=65)
RATplot.plot_ref_sld(input_project, dream_results, bayes=95)
with pytest.raises(ValueError):
RATplot.plot_ref_sld(input_project, reflectivity_calculation_results, bayes=65)
with pytest.raises(ValueError):
RATplot.plot_ref_sld(input_project, dream_results, bayes=15)
def test_assert_bayesian(dream_results, reflectivity_calculation_results):
"""Test that the `assert_bayesian` decorator validates correctly."""
@RATplot.assert_bayesian("test")
def test_plot(results):
pass
test_plot(dream_results)
with pytest.raises(
ValueError, match=r"test plots are only available for the results of Bayesian analysis \(NS or DREAM\)"
):
test_plot(reflectivity_calculation_results)
@pytest.mark.parametrize("indices", [[0, 1, 2, 3, 4], [2, 5, 11], [8]])
def test_panel_helper(indices):
"""Test that the panel plot helper creates a panel with the expected subplots."""
def plot_func(axes, k):
"""Plot k lines on an Axes."""
for i in range(0, k):
axes.plot([i], [i])
nplots = len(indices)
fig = RATplot.panel_plot_helper(plot_func, indices)
# ensure correct number of axes were created
expected_num_axes = ceil(sqrt(nplots)) * round(sqrt(nplots))
assert len(fig.axes) == expected_num_axes
# assert all required axes are visible and have the requested number of lines
for i, index in enumerate(indices):
assert len(fig.axes[i].get_lines()) == index
assert fig.axes[i].get_visible()
# assert remaining axes are not visible
for i in range(nplots, expected_num_axes):
assert fig.axes[i].get_visible() is False
plt.close(fig)
@pytest.mark.parametrize("param", ["CW Thickness", "D2O", 5])
@pytest.mark.parametrize("hist_settings", [{}, {"bins": 18}, {"density": False, "range": (0, 5)}])
@pytest.mark.parametrize("est_dens", [None, "normal", "lognor", "kernel"])
def test_hist(dream_results, param, hist_settings, est_dens):
"""Tests the formatting of the histogram plot."""
fig: plt.Figure = RATplot.plot_one_hist(
dream_results, param, estimated_density=est_dens, **hist_settings, return_fig=True
)
ax = fig.axes[0]
components = ax.get_children()
# assert expected number of bins including default
# this ensures default hist_settings are overwritten correctly
# +1 rectangle because the bounds of the plot is a rectangle
expected_bins = hist_settings.get("bins", 25) + 1
assert len([c for c in components if isinstance(c, Rectangle)]) == expected_bins
# assert line is only drawn if estimated density given
assert len(ax.get_lines()) == (0 if est_dens is None else 1)
# assert title is as expected
# also tests string to index conversion
assert ax.get_title(loc="left") == dream_results.fitNames[param] if isinstance(param, int) else param
# assert range is default, unless given
# this tests non-default hist_settings propagates correctly
try:
expected_range = hist_settings["range"]
except KeyError: # if no range given, compute the automatic range
param_index = dream_results.fitNames.index(param) if isinstance(param, str) else param
param_chain = dream_results.chain[:, param_index]
expected_range = (param_chain.min(), param_chain.max())
assert ax.get_xbound() == expected_range
plt.close(fig)
@pytest.mark.parametrize(["x_param", "y_param"], [["CW Thickness", "D2O"], ["Bilayer Heads Thickness", 5], [2, 7]])
@pytest.mark.parametrize("hist2d_settings", [{}, {"bins": 15}, {"range": [(-50.0, 50.0), (-50.0, 200.0)]}])
def test_contour(dream_results, x_param, y_param, hist2d_settings):
"""Test the formatting of the contour plot."""
fig: plt.Figure = RATplot.plot_contour(dream_results, x_param, y_param, return_fig=True, **hist2d_settings)
ax = fig.axes[0]
components = ax.get_children()
# assert expected number of bins including default
# this ensures default hist2d_settings are overwritten correctly
# +1 as we are counting edges in this case
expected_bins = hist2d_settings.get("bins", 25) + 1
quad_mesh = [c for c in components if isinstance(c, QuadMesh)][0]
assert quad_mesh._coordinates.shape == (expected_bins, expected_bins, 2)
# assert correct axis labels
# this ensures string to index conversion works
assert ax.get_xlabel() == (dream_results.fitNames[x_param] if isinstance(x_param, int) else x_param)
assert ax.get_ylabel() == (dream_results.fitNames[y_param] if isinstance(y_param, int) else y_param)
# assert range is default, unless given
# this tests non-default hist2d_settings propagates correctly
try:
x_expected_range, y_expected_range = hist2d_settings["range"]
except KeyError: # if no range given, compute the automatic range
x_param_index = dream_results.fitNames.index(x_param) if isinstance(x_param, str) else x_param
y_param_index = dream_results.fitNames.index(y_param) if isinstance(y_param, str) else y_param
x_param_chain = dream_results.chain[:, x_param_index]
y_param_chain = dream_results.chain[:, y_param_index]
x_expected_range = (x_param_chain.min(), x_param_chain.max())
y_expected_range = (y_param_chain.min(), y_param_chain.max())
assert ax.get_xbound() == x_expected_range
assert ax.get_ybound() == y_expected_range
# plt.close(fig)
@pytest.mark.parametrize(
"params",
[
None,
["Bilayer Heads Thickness", "Bilayer Heads Hydration", "D2O"],
["Bilayer Heads Thickness", 2, 3, "D2O", 5],
[1, 2, 3, 4, 5],
],
)
def test_corner(dream_results, params):
"""Test that corner plots are formatted correctly."""
# no use testing hist_settings and hist2d_settings here as they're tested above
fig: plt.Figure = RATplot.plot_corner(dream_results, params, return_fig=True)
axes = fig.axes
if params is None:
params = range(0, len(dream_results.fitNames))
assert len(axes) == len(params) ** 2
# annoyingly, fig.axes doesn't preserve the grid shape from plt.subplots... reconstruct grid
axes = np.array([axes[i : i + len(params)] for i in range(0, len(axes), len(params))])
for i in range(0, len(params)):
for j in range(0, len(params)):
current_axes = axes[i][j]
# ensure upper triangle is invisible
if i < j:
assert current_axes.get_visible() is False
elif i > j:
# check axes are the same along each row and column for contours
assert current_axes.get_ybound() == axes[i][0].get_ybound()
assert current_axes.get_xbound() == axes[-1][j].get_xbound()
elif i == j:
# check title is correct
assert (
current_axes.get_title(loc="left") == dream_results.fitNames[params[i]]
if isinstance(params[i], int)
else params[i]
)
plt.close(fig)
@pytest.mark.parametrize(
"params", [None, [2, 3], [1, 5, "D2O"], ["Bilayer Heads Thickness", "Bilayer Heads Hydration", "D2O"]]
)
@patch("ratapi.plotting.panel_plot_helper")
def test_hist_panel(mock_panel_helper: MagicMock, params, dream_results):
"""Test chain panel name-to-index (panel helper has already been tested)"""
fig = RATplot.plot_hists(dream_results, params, return_fig=True)
plt.close(fig)
if params is None:
params = range(0, len(dream_results.fitNames))
passed_params = mock_panel_helper.call_args.args[1]
assert len(passed_params) == len(params)
for param in passed_params:
assert param == (dream_results.fitNames.index(param) if isinstance(param, str) else param)
@pytest.mark.parametrize(
["input", "expected_dict"],
[
(None, "NONEDICT"),
({"D2O": "kernel"}, "D2O_DICT"),
({"default": "lognor"}, "DEFAULTDICT"), # workaround as we need to access the fixture attrs
("lognor", "DEFAULTDICT"),
({"default": "normal", 1: "kernel"}, "DEFAULT_WITH_1CHANGE_DICT"),
],
)
@patch("ratapi.plotting.plot_one_hist")
def test_standardise_est_dens(mock_plot_hist: MagicMock, input, expected_dict, dream_results):
"""Test estimated density is correctly standardised."""
_ = RATplot.plot_hists(dream_results, estimated_density=input, return_fig=True)
expected_dict = {
"NONEDICT": {i: None for i in range(0, len(dream_results.fitNames))},
"D2O_DICT": {**{i: None for i in range(0, len(dream_results.fitNames))}, **{16: "kernel"}},
"DEFAULTDICT": {i: "lognor" for i in range(0, len(dream_results.fitNames))},
"DEFAULT_WITH_1CHANGE_DICT": {**{i: "normal" for i in range(0, len(dream_results.fitNames))}, **{1: "kernel"}},
}[expected_dict]
call_args = mock_plot_hist.call_args_list
keys_called = [call[0][1] for call in call_args]
est_density = [call[1]["estimated_density"] for call in call_args]
est_dens_dict = {keys_called[i]: est_density[i] for i in range(0, len(keys_called))}
assert expected_dict == est_dens_dict
plt.close("all")
@pytest.mark.parametrize("input", [{250: "lognor"}, {"Oxide Quickness": "normal"}, {"D2O": "Rattian"}, {-5: "lognor"}])
def test_est_dens_error(dream_results, input):
"""Ensure a bad estimated density input raises an error."""
# the error message contains the phrase "Parameter {key}" or "Index {key}", so use that
# to ensure we're not getting some random ValueError
with pytest.raises((ValueError, IndexError), match=f"Parameter|Index {(list(input.keys())[0])}"):
RATplot.plot_hists(dream_results, estimated_density=input)
@pytest.mark.parametrize(
"params", [None, [2, 3], [1, 5, "D2O"], ["Bilayer Heads Thickness", "Bilayer Heads Hydration", "D2O"]]
)
@patch("ratapi.plotting.panel_plot_helper")
def test_chain_panel(mock_panel_helper: MagicMock, params, dream_results):
"""Test chain panel name-to-index (panel helper has already been tested)"""
# return fig just to avoid plt.show() being called
fig = RATplot.plot_chain(dream_results, params, return_fig=True)
plt.close(fig)
if params is None:
params = range(0, len(dream_results.fitNames))
for param in mock_panel_helper.call_args()[1]:
assert param == (dream_results.fitNames.index(param) if isinstance(param, str) else param)
@patch("ratapi.plotting.plot_ref_sld")
@patch("ratapi.plotting.plot_hists")
@patch("ratapi.plotting.plot_corner")
def test_bayes_calls(
mock_corner: MagicMock, mock_hists: MagicMock, mock_ref_sld: MagicMock, input_project, dream_results
):
"""Test that the Bayes plot calls the required plotting subroutines."""
RATplot.plot_bayes(input_project, dream_results)
assert mock_ref_sld.call_count == 2
mock_hists.assert_called_once()
mock_corner.assert_called_once()
def test_bayes_validation(input_project, reflectivity_calculation_results):
"""Ensure that plot_bayes fails if given regular Results."""
with pytest.raises(
ValueError, match=r"Bayes plots are only available for the results of Bayesian analysis \(NS or DREAM\)"
):
RATplot.plot_bayes(input_project, reflectivity_calculation_results)
@pytest.mark.parametrize("data", [data(), domains_data()])
def test_extract_plot_data(data) -> None:
plot_data = RATplot._extract_plot_data(data, False, True, 50)
assert len(plot_data["ref"]) == len(data.reflectivity)
assert len(plot_data["sld"]) == len(data.shiftedData)
with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 0 and 100"):
RATplot._extract_plot_data(data, False, True, -0.1)
with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 0 and 100"):
RATplot._extract_plot_data(data, False, True, 100.5)