Skip to content

Commit bbfa51b

Browse files
author
Julian Kates-Harbeck
committed
fix bugs related to T_min warn becoming too large. performance computation was wrong when T_max_warn was smaller. Shots can't be cut if they are shorter than T_min_warn
1 parent ad5de2f commit bbfa51b

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

examples/tune_hyperparams.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
tunables = []
99
shallow = False
1010
num_nodes = 1
11-
num_trials = 20
11+
num_trials = 30
1212

13-
t_warn = CategoricalHyperparam(['data','T_warning'],[0.256,1.024,10.024])
13+
t_warn = CategoricalHyperparam(['data','T_warning'],[0.256,1.024,4.096,10.024])
1414
cut_ends = CategoricalHyperparam(['data','cut_shot_ends'],[False,True])
1515
#for shallow
1616
if shallow:
@@ -35,19 +35,21 @@
3535
fac = CategoricalHyperparam(['data','positive_example_penalty'],[1.0,4.0,16.0])
3636
target = CategoricalHyperparam(['target'],['maxhinge','hinge','ttdinv','ttd'])
3737
#target = CategoricalHyperparam(['target'],['hinge','ttdinv','ttd'])
38-
batch_size = CategoricalHyperparam(['training','batch_size'],[128,256])
38+
batch_size = CategoricalHyperparam(['training','batch_size'],[64,128])
3939
dropout_prob = CategoricalHyperparam(['model','dropout_prob'],[0.01,0.05,0.1])
40-
conv_filters = CategoricalHyperparam(['model','num_conv_filters'],[128,256])
40+
conv_filters = CategoricalHyperparam(['model','num_conv_filters'],[64,128,256])
4141
conv_layers = IntegerHyperparam(['model','num_conv_layers'],2,4)
4242
rnn_layers = IntegerHyperparam(['model','rnn_layers'],1,3)
4343
rnn_size = CategoricalHyperparam(['model','rnn_size'],[128,256])
4444
dense_size = CategoricalHyperparam(['model','dense_size'],[128,256])
4545
extra_dense_input = CategoricalHyperparam(['model','extra_dense_input'],[False,True])
4646
equalize_classes = CategoricalHyperparam(['data','equalize_classes'],[False,True])
47+
t_min_warn = CategoricalHyperparam(['data','T_min_warn'],[30,70,200,500,1000])
4748
#rnn_length = CategoricalHyperparam(['model','length'],[32,128])
4849
#tunables = [lr,lr_decay,fac,target,batch_size,dropout_prob]
4950
tunables = [lr,lr_decay,fac,target,batch_size,equalize_classes,dropout_prob]
5051
tunables += [conv_filters,conv_layers,rnn_layers,rnn_size,dense_size,extra_dense_input]
52+
tunables += [t_min_warn]
5153
tunables += [cut_ends,t_warn]
5254

5355

plasma/preprocessor/normalize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def train_on_files(self,shot_files,use_shots,all_machines):
141141
def cut_end_of_shot(self,shot):
142142
cut_shot_ends = self.conf['data']['cut_shot_ends']
143143
if not self.inference_mode and cut_shot_ends: #only cut shots during training
144+
if shot.ttd.shape[0] <= T_min_warn:
145+
print("not cutting shot since T_min_warn is larger than length of shot")
146+
return
144147
T_min_warn = self.conf['data']['T_min_warn']
145148
for key in shot.signals_dict:
146149
shot.signals_dict[key] = shot.signals_dict[key][:-T_min_warn,:]

plasma/utils/performance.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(self,results_dir=None,shots_dir=None,i = 0,T_min_warn = None,T_max_
2525
self.T_min_warn = T_min_warn_def
2626
if T_max_warn == None:
2727
self.T_max_warn = T_max_warn_def
28+
if self.T_max_warn < self.T_min_warn:
29+
print("T max warn is too small: need to increase artificially.") #computation of statistics is only correct if T_max_warn is larger than T_min_warn
30+
self.T_max_warn = self.T_min_warn + 1
2831
self.verbose = verbose
2932
self.results_dir = results_dir
3033
self.shots_dir = shots_dir
@@ -293,6 +296,7 @@ def create_acceptable_region(self,truth,mode):
293296
else:
294297
print('Error Invalid Mode for acceptable region')
295298
exit(1)
299+
assert(self.T_max_warn > self.T_min_warn)
296300

297301
acceptable = np.zeros_like(truth,dtype=bool)
298302
if acceptable_timesteps > 0:

0 commit comments

Comments
 (0)