Skip to content

Commit 0de301e

Browse files
author
Enrico Bothmann
committed
Port improvements for resolve_data_object from master
1 parent 8b52918 commit 0de301e

1 file changed

Lines changed: 36 additions & 20 deletions

File tree

heppyplotlib/yodaplot.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,16 @@ def data_object_names(filename):
201201

202202
def resolve_data_object(filename_or_data_object, name,
203203
divide_by=None,
204-
use_correlated_division=False,
204+
multiply_by=None,
205+
assume_correlated=False,
206+
use_correlated_division=None, # this is only for backwards-compatibility
205207
rebin_count=1,
206208
rebin_begin=0):
207209
"""Take passed data object or loads a data object from a YODA file,
208-
and return it after dividing by divide_by."""
210+
and return it after dividing (or multiplying) by divide_by (multiply_by)."""
211+
if use_correlated_division is not None:
212+
assume_correlated = use_correlated_division
213+
print("Heppyplotlib deprecation warning: Use assume_correlated instead of use_correlated_division")
209214
if isinstance(filename_or_data_object, basestring):
210215
data_object = yoda.readYODA(filename_or_data_object)[name]
211216
else:
@@ -241,39 +246,50 @@ def resolve_data_object(filename_or_data_object, name,
241246
new_points.append(yoda.Point2D(x=new_x, y=new_y, xerrs=new_xerrs, yerrs=0.0))
242247
i = i + 1
243248
new_points.extend(data_object.points[first_index+rebin_count:])
244-
print data_object.points
245-
print new_points
246-
print data_object.path
247-
print data_object.title
248249
data_object = yoda.Scatter2D(path=data_object.path, title=data_object.title)
249250
for point in new_points:
250251
data_object.addPoint(point)
251-
if divide_by is not None:
252+
if divide_by is not None or multiply_by is not None:
252253
data_object = yoda.mkScatter(data_object)
253-
if isinstance(divide_by, float):
254+
if isinstance(divide_by, float) or isinstance(multiply_by, float):
254255
for point in data_object.points:
255-
new_y = point.y / divide_by
256-
new_y_errs = [y_err / divide_by for y_err in point.yErrs]
256+
if divide_by is not None:
257+
new_y = point.y / divide_by
258+
new_y_errs = [y_err / divide_by for y_err in point.yErrs]
259+
else:
260+
new_y = point.y * multiply_by
261+
new_y_errs = [y_err * multiply_by for y_err in point.yErrs]
257262
point.y = new_y
258263
point.yErrs = new_y_errs
259264
else:
260-
divide_by = resolve_data_object(divide_by, name).mkScatter()
261-
for point, denominator_point in zip(data_object.points, divide_by.points):
262-
if denominator_point.y == 0.0:
263-
new_y = 1.0
265+
if divide_by is not None:
266+
operand = resolve_data_object(divide_by, name).mkScatter()
267+
else:
268+
operand = resolve_data_object(multiply_by, name).mkScatter()
269+
for point, operand_point in zip(data_object.points, operand.points):
270+
if operand_point.y == 0.0:
271+
if divide_by is not None:
272+
new_y = 1.0
273+
else:
274+
new_y = 0.0
264275
new_y_errs = [0.0, 0.0]
265276
else:
266-
new_y = point.y / denominator_point.y
267-
if use_correlated_division:
268-
new_y_errs = [y_err / denominator_point.y for y_err in point.yErrs]
277+
if divide_by is not None:
278+
new_y = point.y / operand_point.y
279+
if assume_correlated:
280+
new_y_errs = [y_err / operand_point.y for y_err in point.yErrs]
269281
else:
270-
# assume that we divide through an independent data set, use error propagation
282+
new_y = point.y * operand_point.y
283+
if assume_correlated:
284+
new_y_errs = [y_err * operand_point.y for y_err in point.yErrs]
285+
if not assume_correlated:
286+
# assume that we divide/multiply through an independent data set, use error propagation
271287
rel_y_errs = []
272-
for y_err, den_y_err in zip(point.yErrs, denominator_point.yErrs):
288+
for y_err, operand_y_err in zip(point.yErrs, operand_point.yErrs):
273289
err2 = 0.0
274290
if point.y != 0.0:
275291
err2 += (y_err / point.y)**2
276-
err2 += (den_y_err / denominator_point.y)**2
292+
err2 += (operand_y_err / operand_point.y)**2
277293
rel_y_errs.append(np.sqrt(err2))
278294
new_y_errs = [rel_y_err * new_y for rel_y_err in rel_y_errs]
279295
point.y = new_y

0 commit comments

Comments
 (0)