forked from PPPLDeepLearning/plasma-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompare_performance.py
More file actions
62 lines (50 loc) · 2.25 KB
/
compare_performance.py
File metadata and controls
62 lines (50 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os,sys
import numpy as np
from plasma.utils.performance import *
from plasma.conf import conf
#mode = 'test'
file_num = 0
save_figure = True
pred_ttd = False
# cut_shot_ends = conf['data']['cut_shot_ends']
# dt = conf['data']['dt']
# T_max_warn = int(round(conf['data']['T_warning']/dt))
# T_min_warn = conf['data']['T_min_warn']#int(round(conf['data']['T_min_warn']/dt))
# if cut_shot_ends:
# T_max_warn = T_max_warn-T_min_warn
# T_min_warn = 0
T_min_warn = 30 #None #take value from conf #30
verbose=False
assert(sys.argv > 1)
results_dirs = sys.argv[1:]
shots_dir = conf['paths']['processed_prepath']
analyzers = [PerformanceAnalyzer(conf=conf,results_dir=results_dir,shots_dir=shots_dir,i = file_num,T_min_warn = T_min_warn, verbose = verbose, pred_ttd=pred_ttd) for results_dir in results_dirs]
for analyzer in analyzers:
analyzer.load_ith_file()
analyzer.verbose = False
P_threshs = [analyzer.compute_tradeoffs_and_print_from_training() for analyzer in analyzers]
print('Test ROC:')
for analyzer in analyzers:
print(analyzer.get_roc_area_by_mode('test'))
#P_thresh_opt = 0.566#0.566#0.92# analyzer.compute_tradeoffs_and_print_from_training()
linestyle="-"
#analyzer.compute_tradeoffs_and_plot('test',save_figure=save_figure,plot_string='_test',linestyle=linestyle)
#analyzer.compute_tradeoffs_and_plot('train',save_figure=save_figure,plot_string='_train',linestyle=linestyle)
#analyzer.summarize_shot_prediction_stats_by_mode(P_thresh_opt,'test')
shots = analyzers[0].shot_list_test
for shot in shots:
if all([(shot in analyzer.shot_list_test or shot in analyzer.shot_list_train) for analyzer in analyzers]):
types = [analyzers[i].get_prediction_type_for_individual_shot(P_threshs[i],shot,mode='test') for i in range(len(analyzers))]
#if len(set(types)) > 1:
if types == ['TP','late']:
if shot in analyzers[1].shot_list_test:
print("TEST")
else:
print("TRAIN")
print(shot.number)
print(types)
for i,analyzer in enumerate(analyzers):
analyzer.save_shot(shot,P_thresh_opt=P_threshs[i],extra_filename=['1D','0D'][i])
else:
pass
#print("shot {} not in train or test shot list (must be in validation)".format(shot))