-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathyodaplot.py
More file actions
363 lines (347 loc) · 16.4 KB
/
yodaplot.py
File metadata and controls
363 lines (347 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
"""Functions for plotting data objects within YODA files."""
import matplotlib.pyplot as plt
import numpy as np
import yoda
def plot(filename_or_data_object, data_object_name,
errors_enabled=True, rebin_count=1, visible=True,
**kwargs):
"""Plots a data object, potentially from a YODA file."""
data_object = resolve_data_object(filename_or_data_object, data_object_name,
rebin_count=rebin_count)
return plot_data_object(data_object, errors_enabled, visible, **kwargs)
def plot_data_object(data_object,
errors_enabled=True, visible=True,
**kwargs):
"""Plots a YODA data object."""
plotfunctions = {yoda.Scatter2D: plot_scatter2d, yoda.Histo1D: plot_histo1d}
for classinfo, plotfunction in plotfunctions.items():
if isinstance(data_object, classinfo):
return plotfunction(data_object, errors_enabled, visible, **kwargs)
raise Exception('Unknown type of YODA data object: ', data_object)
def get_y_coords(yoda_data_object):
"""Return y coordinates for a YODA data object of an unknown type."""
getter_functions = {yoda.Scatter2D: get_scatter2d_y_coords, yoda.Histo1D: get_histo1d_y_coords}
for classinfo, getter_function in getter_functions.items():
if isinstance(yoda_data_object, classinfo):
return getter_function(yoda_data_object)
def plot_scatter2d(scatter, errors_enabled=True, visible=True, **kwargs):
"""Plots a YODA Scatter2D object."""
x_coords = [point.x() for point in scatter.points()]
y_coords = get_scatter2d_y_coords(scatter)
if visible:
x_errs = []
x_errs.append([point.xErrs()[0] for point in scatter.points()])
x_errs.append([point.xErrs()[1] for point in scatter.points()])
bins_are_adjacent = are_points_with_errors_adjacent(x_coords, x_errs)
if errors_enabled:
y_errs = []
y_errs.append([point.yErrs()[0] for point in scatter.points()])
y_errs.append([point.yErrs()[1] for point in scatter.points()])
else:
y_errs = None
if "xmin" in kwargs:
for i, x in enumerate(x_coords):
if x >= kwargs["xmin"]:
break
x_coords = x_coords[i:]
y_coords = y_coords[i:]
x_errs[0] = x_errs[0][i:]
x_errs[1] = x_errs[1][i:]
if y_errs is not None:
y_errs[0] = y_errs[0][i:]
y_errs[1] = y_errs[1][i:]
del kwargs["xmin"]
else:
bins_are_adjacent = False
x_errs = None
y_errs = None
if not bins_are_adjacent:
return plt.errorbar(x_coords, y_coords,
fmt='o', xerr=x_errs, yerr=y_errs, visible=visible, **kwargs)
else:
return step_with_errorbar_using_points(x_coords, x_errs, y_coords, y_errs,
errors_enabled=errors_enabled, visible=visible, **kwargs)
def get_scatter2d_y_coords(scatter):
"""Return y coordinates for a Scatter2D object."""
return [point.y() for point in scatter.points()]
def plot_histo1d(histo, errors_enabled=True, visible=True, **kwargs):
"""Plots a YODA Histo1D object."""
return plot_histo1d_bins(histo.bins(), errors_enabled, visible, **kwargs)
def plot_histo1d_bins(bins, errors_enabled=True, visible=True, **kwargs):
"""Plots YODA Histo 1D bins."""
x_lefts = [histo_bin.xEdges()[0] for histo_bin in bins]
widths = [histo_bin.xEdges()[1] - histo_bin.xEdges()[0] for histo_bin in bins]
bins_are_adjacent = are_bins_adjacent(x_lefts, widths)
y_coord = get_histo1d_y_coords(bins)
y_errs = [histo_bin.heightErr() for histo_bin in bins]
if "xmin" in kwargs:
for i, x_left in enumerate(x_lefts):
if x_left >= kwargs["xmin"]:
break
x_lefts = x_lefts[i:]
widths = widths[i:]
y_coord = y_coord[i:]
y_errs = y_errs[i:]
del kwargs["xmin"]
if not bins_are_adjacent:
result = plt.bar(x_lefts, y_coord, width=widths, yerr=y_errs, visible=visible, **kwargs)
else:
result = plot_step_with_errorbar(x_lefts, widths, y_coord, y_errs,
errors_enabled=errors_enabled, visible=visible, **kwargs)
return result
def get_histo1d_y_coords(histo_or_bins):
"""Return y coordinates for a Histo1D object."""
if isinstance(histo_or_bins, yoda.Histo1D):
bins = histo_or_bins.bins()
else:
bins = histo_or_bins
return [histo_bin.height() for histo_bin in bins]
def are_points_with_errors_adjacent(points, errs):
"""Returns whether a given set of points are adjacent when taking their errors into account."""
for i in range(len(points) - 1):
point = points[i]
err_right = errs[0][i]
next_point = points[i + 1]
next_err_left = errs[1][i + 1]
right_edge = point + err_right
left_edge = next_point - next_err_left
if abs(left_edge - right_edge) > (err_right + next_err_left) / 100.0:
return False
return True
def are_bins_adjacent(lefts, widths):
"""Returns whether a given set of bins are adjacent."""
for left, width, next_left in zip(lefts[:-1], widths[:-1], lefts[1:]):
if next_left == 0:
if abs(left + width) > 1e-4:
return False
elif abs((left + width - next_left)/next_left) > 1e-4:
return False
return True
def step_with_errorbar_using_points(x_coords, x_errs, y_coords, y_errs,
errors_enabled=True, **kwargs):
"""Makes a step plot with error bars from points."""
left = [coord - err_left for coord, err_left in zip(x_coords, x_errs[0])]
widths = [err_left + err_right for err_left, err_right in zip(x_errs[0], x_errs[1])]
return plot_step_with_errorbar(left, widths, y_coords, y_errs, errors_enabled, **kwargs)
def plot_step_with_errorbar(lefts, widths, y_coords, y_errs,
errors_enabled=True, use_errorrects_for_legend=False, **kwargs):
"""Makes a step plot with error bars."""
lefts.append(lefts[-1] + widths[-1])
y_coords.append(y_coords[-1])
# prevent that we have labels for the step and the errorbar,
# otherwise we have two legend entries per data set
step_kwargs = dict(kwargs)
rect_kwargs = dict(kwargs)
if errors_enabled and "label" in kwargs:
if use_errorrects_for_legend:
del step_kwargs["label"]
else:
del rect_kwargs["label"]
# delete kw args that are not defined for plt.step
try:
del step_kwargs["hatch"]
except KeyError:
pass
step_result = plt.step(lefts, y_coords, where='post', **step_kwargs)
if errors_enabled:
try:
ecolor = rect_kwargs["color"]
del rect_kwargs["color"]
except KeyError:
ecolor = plt.gca().lines[-1].get_color() # do not use the next color from the color cycle
try:
del rect_kwargs["marker"]
except KeyError:
pass
try:
del rect_kwargs["zorder"]
except KeyError:
pass
zorder = plt.gca().lines[-1].get_zorder() - 1 # make sure it's drawn below
errorrects_result = plot_errorrects(lefts, y_coords, y_errs, ecolor, zorder, **rect_kwargs)
# x_mids = [left + width / 2.0 for left, width in zip(lefts[:-1], widths)]
# plt.errorbar(x_mids, y_coords[:-1], fmt='none', yerr=y_errs, ecolor=ecolor)
else:
errorrects_result = None
return step_result, errorrects_result
def plot_errorrects(lefts, y_coords, y_errs, color, zorder=1, **kwargs):
"""Draws the y errors as an envelope for a step plot."""
try:
if not len(y_errs) == len(lefts) - 1:
y_errs = zip(*y_errs) # try transposing
if not len(y_errs) == len(lefts) - 1:
raise Exception("There are less y errors than points.")
except TypeError:
pass
lefts = np.ravel(list(zip(lefts[:-1], lefts[1:])))
try:
coords_and_errs = list(zip(y_coords, y_errs))
y_down = np.ravel([[y - y_err[1]] * 2 for y, y_err in coords_and_errs])
y_up = np.ravel([[y + y_err[0]] * 2 for y, y_err in coords_and_errs])
except TypeError:
y_down = np.ravel([[y - y_err] * 2 for y, y_err in zip(y_coords, y_errs)])
y_up = np.ravel([[y + y_err] * 2 for y, y_err in zip(y_coords, y_errs)])
if 'hatch' in kwargs:
return plt.fill_between(lefts, y_up, y_down,
color='none',
edgecolor=color,
alpha=1.0,
zorder=zorder, **kwargs)
else:
if 'linewidth' in kwargs:
up = plt.plot(lefts, y_up,
color=color,
zorder=zorder, **kwargs)
down = plt.plot(lefts, y_down,
color=color,
zorder=zorder, **kwargs)
return (up, down)
else:
if not 'alpha' in kwargs:
kwargs['alpha'] = 0.3
return plt.fill_between(list(lefts), list(y_up), list(y_down),
color=[color],
linewidth=0.0,
zorder=int(zorder), **kwargs)
def data_object_names(filename):
"""Retrieves all data object names from a YODA file."""
data_objects = yoda.readYODA(filename)
return [key for key in data_objects.keys()
if not data_objects[key].type in ('Counter', 'Scatter1D')]
def resolve_data_object(filename_or_data_object, name,
divide_by=None,
multiply_by=None,
subtract_by=None,
deviate_from=None,
assume_correlated=False,
rebin_count=1,
rebin_begin=0):
"""Take passed data object or loads a data object from a YODA file,
and return it after dividing (or multiplying) by divide_by (multiply_by)."""
if isinstance(filename_or_data_object, str):
data_object = yoda.readYODA(filename_or_data_object)[name]
else:
data_object = filename_or_data_object.clone()
if not rebin_count == 1:
if data_object.type == "Histo1D":
data_object.rebin(rebin_count, begin=rebin_begin)
else:
print("WARNING: Will assume statistical errors for rebinning a scatter plot")
x_coords = [point.x() for point in data_object.points()]
y_coords = get_scatter2d_y_coords(data_object)
x_errs = []
x_errs.append([point.xErrs()[0] for point in data_object.points()])
x_errs.append([point.xErrs()[1] for point in data_object.points()])
if not are_points_with_errors_adjacent(x_coords, x_errs):
raise Exception("Points must be adjacent for interpreting the scatter plots as a histogram")
new_points = data_object.points()[0:rebin_begin]
i = 0
while rebin_begin + i * rebin_count < len(data_object.points()) - 1:
first_index = rebin_begin + i * rebin_count
last_index = min(first_index + rebin_count, len(data_object.points()))
points = data_object.points()[first_index:last_index]
left_edge = points[0].x() - points[0].xErrs()[0]
right_edge = points[-1].x() + points[-1].xErrs()[1]
length = right_edge - left_edge
new_x = left_edge + length / 2.0
new_xerrs = length / 2.0
new_y = 0.0
new_yerrs = np.array([0.0, 0.0])
for point in points:
left_edge = point.x() - point.xErrs()[0]
right_edge = point.x() + point.xErrs()[1]
new_y += (right_edge - left_edge) * point.y()
new_yerrs += ((right_edge - left_edge) * np.array(point.yErrs()))**2
new_y /= length
new_yerrs = np.sqrt(new_yerrs) / length
new_points.append(yoda.Point2D(x=new_x, y=new_y, xerrs=new_xerrs, yerrs=new_yerrs))
i = i + 1
data_object = yoda.Scatter2D(path=data_object.path, title=data_object.title)
for point in new_points:
data_object.addPoint(point)
if subtract_by is not None:
data_object = yoda.mkScatter(data_object)
operand = resolve_data_object(subtract_by, name).mkScatter()
for point, operand_point in zip(data_object.points(), operand.points()):
new_y = point.y() - operand_point.y()
if assume_correlated:
new_y_errs = [y_err - operand_point.y() for y_err in point.yErrs()]
if not assume_correlated:
# assume that we subtract an independent data set, use error propagation
new_y_errs = []
for y_err, operand_y_err in zip(point.yErrs(), operand_point.yErrs()):
err2 = 0.0
if point.y() != 0.0:
err2 += (y_err)**2
err2 += (operand_y_err)**2
new_y_errs.append(np.sqrt(err2))
point.setY(new_y)
point.setYErrs(new_y_errs)
if divide_by is not None or multiply_by is not None:
data_object = yoda.mkScatter(data_object)
if isinstance(divide_by, float) or isinstance(multiply_by, float):
for point in data_object.points():
if divide_by is not None:
new_y = point.y() / divide_by
new_y_errs = [y_err / divide_by for y_err in point.yErrs()]
else:
new_y = point.y() * multiply_by
new_y_errs = [y_err * multiply_by for y_err in point.yErrs()]
point.setY(new_y)
point.setYErrs(new_y_errs)
else:
if divide_by is not None:
operand = resolve_data_object(divide_by, name).mkScatter()
else:
operand = resolve_data_object(multiply_by, name).mkScatter()
for point, operand_point in zip(data_object.points(), operand.points()):
if operand_point.y() == 0.0:
if divide_by is not None:
new_y = 1.0
else:
new_y = 0.0
new_y_errs = [0.0, 0.0]
else:
if divide_by is not None:
new_y = point.y() / operand_point.y()
if assume_correlated:
new_y_errs = [y_err / operand_point.y() for y_err in point.yErrs()]
else:
new_y = point.y() * operand_point.y()
if assume_correlated:
new_y_errs = [y_err * operand_point.y() for y_err in point.yErrs()]
if not assume_correlated:
# assume that we divide/multiply through an independent data set, use error propagation
rel_y_errs = []
for y_err, operand_y_err in zip(point.yErrs(), operand_point.yErrs()):
err2 = 0.0
if point.y() != 0.0:
err2 += (y_err / point.y())**2
err2 += (operand_y_err / operand_point.y())**2
rel_y_errs.append(np.sqrt(err2))
new_y_errs = [rel_y_err * new_y for rel_y_err in rel_y_errs]
point.setY(new_y)
point.setYErrs(new_y_errs)
if deviate_from is not None:
if assume_correlated:
raise Exception("You can not use assume_correlated and deviate_from at the same time.")
data_object = yoda.mkScatter(data_object)
operand = resolve_data_object(deviate_from, name).mkScatter()
is_equal = True
for point, operand_point in zip(data_object.points(), operand.points()):
new_y = point.y() - operand_point.y()
if new_y != 0.0:
is_equal = False
new_y_errs = []
for y_err, operand_y_err in zip(point.yErrs(), operand_point.yErrs()):
err2 = 0.0
err2 += (y_err)**2
err2 += (operand_y_err)**2
new_y_errs.append(np.sqrt(err2))
new_y /= new_y_errs[0]
point.setY(new_y)
point.setYErrs(0)
if is_equal:
for point in data_object.points():
point.setYErrs(1)
return data_object