1212
Merge branch 'jdev' of https://github.com/PPPLDeepLearning/plasma-python · codeaudit/plasma-python@69063c0 · GitHub
Skip to content

Commit 69063c0

Browse files
Julian Kates-HarbeckJulian Kates-Harbeck
authored andcommitted
2 parents 8f4c7d1 + e22b32e commit 69063c0

File tree

7 files changed

+77
-26
lines changed

7 files changed

+77
-26
lines changed

data/signals.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,15 @@ def fetch_nstx_data(signal_path,shot_num,c):
237237

238238
fully_defined_signals = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if sig.is_defined_on_machines(all_machines)}
239239
fully_defined_signals_0D = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if ( sig.is_defined_on_machines(all_machines) and sig.num_channels == 1) }
240+
fully_defined_signals_1D = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if ( sig.is_defined_on_machines(all_machines) and sig.num_channels > 1) }
241+
240242
d3d_signals = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if sig.is_defined_on_machine(d3d)}
243+
d3d_signals_0D = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if (sig.is_defined_on_machine(d3d) and sig.num_channels == 1)}
244+
d3d_signals_1D = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if (sig.is_defined_on_machine(d3d) and sig.num_channels > 1)}
245+
241246
jet_signals = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if sig.is_defined_on_machine(jet)}
242247
jet_signals_0D = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if (sig.is_defined_on_machine(jet) and sig.num_channels == 1)}
243-
248+
jet_signals_1D = {sig_name: sig for (sig_name, sig) in all_signals_restricted.items() if (sig.is_defined_on_machine(jet) and sig.num_channels > 1)}
244249

245250
#['pcechpwrf'] #Total ECH Power Not always on!
246251
### 0D EFIT signals ###

examples/conf.yaml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ paths:
1010
signal_prepath: '/signal_data/' #/signal_data/jet/
1111
shot_list_dir: '/shot_lists/'
1212
tensorboard_save_path: '/Graph/'
13-
data: jet_data #'d3d_to_jet_data' #'d3d_to_jet_data' # 'jet_to_d3d_data' #jet_data
13+
data: jet_to_d3d_data #'d3d_to_jet_data' #'d3d_to_jet_data' # 'jet_to_d3d_data' #jet_data
1414
specific_signals: [] #['q95','li','ip','betan','energy','lm','pradcore','pradedge','pradtot','pin','torquein','tmamp1','tmamp2','tmfreq1','tmfreq2','pechin','energydt','ipdirect','etemp_profile','edens_profile'] #if left empty will use all valid signals defined on a machine. Only use if need a custom set
1515
executable: "mpi_learn.py"
1616
shallow_executable: "learn.py"
1717

1818
data:
19-
bleed_in: 0 #how many shots from the test sit to use in training?
19+
bleed_in: 5 #how many shots from the test sit to use in training?
20+
bleed_in_repeat_fac: 10
2021
bleed_in_remove_from_test: True
21-
bleed_in_equalize_sets: True
22+
bleed_in_equalize_sets: False
2223
signal_to_augment: None #'plasma current' #or None
2324
augmentation_mode: 'none'
2425
augment_during_training: False
@@ -52,10 +53,10 @@ data:
5253
floatx: 'float32'
5354

5455
model:
55-
shallow: False
56+
shallow: True
5657
shallow_model:
5758
num_samples: 1000000 #1000000 #the number of samples to use for training
58-
type: "mlp" #"xgboost" #"xgboost" #"random_forest" "xgboost"
59+
type: "xgboost" #"xgboost" #"xgboost" #"random_forest" "xgboost"
5960
n_estimators: 100 #for random forest
6061
max_depth: 3 #for random forest and xgboost (def = 3)
6162
C: 1.0 #for svm
@@ -89,8 +90,8 @@ model:
8990
#have not found a difference yet
9091
optimizer: 'adam'
9192
clipnorm: 10.0
92-
regularization: 0.0
93-
dense_regularization: 0.01
93+
regularization: 0.001
94+
dense_regularization: 0.001
9495
#1e-4 is too high, 5e-7 is too low. 5e-5 seems best at 256 batch size, full dataset and ~10 epochs, and lr decay of 0.90. 1e-4 also works well if we decay a lot (i.e ~0.7 or more)
9596
lr: 0.00002 #0.00001 #0.0005 #for adam plots 0.0000001 #0.00005 #0.00005 #0.00005
9697
lr_decay: 0.97 #0.98 #0.9

examples/tune_hyperparams.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
tunables = []
99
shallow = False
10-
num_nodes = 2
11-
num_trials = 50
10+
num_nodes = 1
11+
num_trials = 20
1212

1313
t_warn = CategoricalHyperparam(['data','T_warning'],[0.256,1.024,10.024])
1414
cut_ends = CategoricalHyperparam(['data','cut_shot_ends'],[False,True])
@@ -34,14 +34,20 @@
3434
lr_decay = CategoricalHyperparam(['model','lr_decay'],[0.97,0.985,1.0])
3535
fac = CategoricalHyperparam(['data','positive_example_penalty'],[1.0,4.0,16.0])
3636
target = CategoricalHyperparam(['target'],['maxhinge','hinge','ttdinv','ttd'])
37-
batch_size = CategoricalHyperparam(['training','batch_size'],[64,256,1024])
38-
dropout_prob = CategoricalHyperparam(['model','dropout_prob'],[0.1,0.3,0.5])
39-
conv_filters = CategoricalHyperparam(['model','num_conv_filters'],[5,10])
37+
#target = CategoricalHyperparam(['target'],['hinge','ttdinv','ttd'])
38+
batch_size = CategoricalHyperparam(['training','batch_size'],[128,256])
39+
dropout_prob = CategoricalHyperparam(['model','dropout_prob'],[0.01,0.05,0.1])
40+
conv_filters = CategoricalHyperparam(['model','num_conv_filters'],[128,256])
4041
conv_layers = IntegerHyperparam(['model','num_conv_layers'],2,4)
41-
rnn_layers = IntegerHyperparam(['model','rnn_layers'],1,4)
42-
rnn_size = CategoricalHyperparam(['model','rnn_size'],[100,200,300])
43-
tunables = [lr,lr_decay,fac,target,batch_size,dropout_prob]
44-
tunables += [conv_filters,conv_layers,rnn_layers,rnn_size]
42+
rnn_layers = IntegerHyperparam(['model','rnn_layers'],1,3)
43+
rnn_size = CategoricalHyperparam(['model','rnn_size'],[128,256])
44+
dense_size = CategoricalHyperparam(['model','dense_size'],[128,256])
45+
extra_dense_input = CategoricalHyperparam(['model','extra_dense_input'],[False,True])
46+
equalize_classes = CategoricalHyperparam(['data','equalize_classes'],[False,True])
47+
#rnn_length = CategoricalHyperparam(['model','length'],[32,128])
48+
#tunables = [lr,lr_decay,fac,target,batch_size,dropout_prob]
49+
tunables = [lr,lr_decay,fac,target,batch_size,equalize_classes,dropout_prob]
50+
tunables += [conv_filters,conv_layers,rnn_layers,rnn_size,dense_size,extra_dense_input]
4551
tunables += [cut_ends,t_warn]
4652

4753

plasma/conf_parser.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def parameters(input_file):
9393
params['paths']['shot_files'] = [jet_carbon_wall]
9494
params['paths']['shot_files_test'] = [jet_iterlike_wall]
9595
params['paths']['use_signals_dict'] = jet_signals_0D
96+
elif params['paths']['data'] == 'jet_data_1D':
97+
params['paths']['shot_files'] = [jet_carbon_wall]
98+
params['paths']['shot_files_test'] = [jet_iterlike_wall]
99+
params['paths']['use_signals_dict'] = jet_signals_1D
96100
elif params['paths']['data'] == 'jet_carbon_data':
97101
params['paths']['shot_files'] = [jet_carbon_wall]
98102
params['paths']['shot_files_test'] = []
@@ -105,6 +109,17 @@ def parameters(input_file):
105109
params['paths']['shot_files'] = [jenkins_jet_carbon_wall]
106110
params['paths']['shot_files_test'] = [jenkins_jet_iterlike_wall]
107111
params['paths']['use_signals_dict'] = jet_signals
112+
elif params['paths']['data'] == 'jet_data_fully_defined': #jet data but with fully defined signals
113+
params['paths']['shot_files'] = [jet_carbon_wall]
114+
params['paths']['shot_files_test'] = [jet_iterlike_wall]
115+
params['paths']['use_signals_dict'] = fully_defined_signals
116+
elif params['paths']['data'] == 'jet_data_fully_defined_0D': #jet data but with fully defined signals
117+
params['paths']['shot_files'] = [jet_carbon_wall]
118+
params['paths']['shot_files_test'] = [jet_iterlike_wall]
119+
params['paths']['use_signals_dict'] = fully_defined_signals_0D
120+
121+
122+
108123
elif params['paths']['data'] == 'd3d_data':
109124
params['paths']['shot_files'] = [d3d_full]
110125
params['paths']['shot_files_test'] = []
@@ -131,25 +146,40 @@ def parameters(input_file):
131146
params['paths']['shot_files_test'] = []
132147
params['paths']['use_signals_dict'] = {'q95':q95,'li':li,'ip':ip,'lm':lm,'betan':betan,'energy':energy,'dens':dens,'pradcore':pradcore,'pradedge':pradedge,'pin':pin,'torquein':torquein,'ipdirect':ipdirect,'iptarget':iptarget,'iperr':iperr,
133148
'etemp_profile':etemp_profile ,'edens_profile':edens_profile}
134-
149+
elif params['paths']['data'] == 'd3d_data_fully_defined': #jet data but with fully defined signals
150+
params['paths']['shot_files'] = [d3d_full]
151+
params['paths']['shot_files_test'] = []
152+
params['paths']['use_signals_dict'] = fully_defined_signals
153+
elif params['paths']['data'] == 'd3d_data_fully_defined_0D': #jet data but with fully defined signals
154+
params['paths']['shot_files'] = [d3d_full]
155+
params['paths']['shot_files_test'] = []
156+
params['paths']['use_signals_dict'] = fully_defined_signals_0D
135157

136158
#cross-machine
137159
elif params['paths']['data'] == 'jet_to_d3d_data':
138-
params['paths']['shot_files'] = [jet_carbon_wall]
160+
params['paths']['shot_files'] = [jet_full]
139161
params['paths']['shot_files_test'] = [d3d_full]
140162
params['paths']['use_signals_dict'] = fully_defined_signals
141163
elif params['paths']['data'] == 'd3d_to_jet_data':
142164
params['paths']['shot_files'] = [d3d_full]
143165
params['paths']['shot_files_test'] = [jet_iterlike_wall]
144166
params['paths']['use_signals_dict'] = fully_defined_signals
145167
elif params['paths']['data'] == 'jet_to_d3d_data_0D':
146-
params['paths']['shot_files'] = [jet_carbon_wall]
168+
params['paths']['shot_files'] = [jet_full]
147169
params['paths']['shot_files_test'] = [d3d_full]
148170
params['paths']['use_signals_dict'] = fully_defined_signals_0D
149171
elif params['paths']['data'] == 'd3d_to_jet_data_0D':
150172
params['paths']['shot_files'] = [d3d_full]
151173
params['paths']['shot_files_test'] = [jet_iterlike_wall]
152174
params['paths']['use_signals_dict'] = fully_defined_signals_0D
175+
elif params['paths']['data'] == 'jet_to_d3d_data_1D':
176+
params['paths']['shot_files'] = [jet_full]
177+
params['paths']['shot_files_test'] = [d3d_full]
178+
params['paths']['use_signals_dict'] = fully_defined_signals_1D
179+
elif params['paths']['data'] == 'd3d_to_jet_data_1D':
180+
params['paths']['shot_files'] = [d3d_full]
181+
params['paths']['shot_files_test'] = [jet_iterlike_wall]
182+
params['paths']['use_signals_dict'] = fully_defined_signals_1D
153183

154184

155185

plasma/preprocessor/preprocess.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def save_shotlists(self,shot_list_train,shot_list_validate,shot_list_test):
146146

147147

148148
def apply_bleed_in(conf,shot_list_train,shot_list_validate,shot_list_test):
149-
np.random.seed(1)
149+
np.random.seed(2)
150150
num = conf['data']['bleed_in']
151151
new_shots = []
152152
if num > 0:
@@ -168,19 +168,28 @@ def apply_bleed_in(conf,shot_list_train,shot_list_validate,shot_list_test):
168168
else:
169169
num_sampled_nd += 1
170170
print("Sampled {} shots, {} disruptive, {} nondisruptive".format(num_sampled_nd+num_sampled_d,num_sampled_d,num_sampled_nd))
171-
print("Before adding: training shots: {} validation shots: {}".format(len(shot_list_train,shot_list_validate)))
171+
print("Before adding: training shots: {} validation shots: {}".format(len(shot_list_train),len(shot_list_validate)))
172172
assert(num_sampled_d == num)
173-
num_to_sample = len(shot_list_bleed)
174173
if conf['data']['bleed_in_equalize_sets']:#add bleed-in shots to training and validation set repeatedly
174+
print("Applying equalized bleed in")
175175
for shot_list_curr in [shot_list_train,shot_list_validate]:
176176
for i in range(len(shot_list_curr)):
177177
s = shot_list_bleed.sample_shot()
178178
shot_list_curr.append(s)
179+
elif conf['data']['bleed_in_repeat_fac'] > 1:
180+
repeat_fac = conf['data']['bleed_in_repeat_fac']
181+
print("Applying bleed in with repeat factor {}".format(repeat_fac))
182+
num_to_sample = int(round(repeat_fac*len(shot_list_bleed)))
183+
for i in range(num_to_sample):
184+
s = shot_list_bleed.sample_shot()
185+
shot_list_train.append(s)
186+
shot_list_validate.append(s)
179187
else: #add each shot only once
188+
print("Applying bleed in without repetition")
180189
for s in shot_list_bleed:
181190
shot_list_train.append(s)
182191
shot_list_validate.append(s)
183-
print("After adding: training shots: {} validation shots: {}".format(len(shot_list_train,shot_list_validate)))
192+
print("After adding: training shots: {} validation shots: {}".format(len(shot_list_train),len(shot_list_validate)))
184193
print("Added bleed in shots to training and validation sets")
185194
# if num_d > 0:
186195
# for i in range(num):

plasma/primitives/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def load_data(self,prepath,shot,dtype='float32'):
227227
for i in range(timesteps):
228228
_,order = np.unique(mapping[i,:],return_index=True) #make sure the mapping is ordered and unique
229229
if sig[i,order].shape[0] > 2:
230-
f = UnivariateSpline(mapping[i,order],sig[i,order],s=0,k=1,ext=0)
230+
f = UnivariateSpline(mapping[i,order],sig[i,order],s=0,k=1,ext=3) #ext = 0 is extrapolation, ext = 3 is boundary value.
231231
sig_interp[i,:] = f(remapping)
232232
else:
233233
print('Signal {}, shot {} has not enough points for linear interpolation. dfitpack.error: (m>k) failed for hidden m: fpcurf0:m=1'.format(self.description,shot.number))

plasma/utils/batch_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def create_slurm_header(num_nodes,use_mpi,idx):
112112
assert(num_nodes == 1)
113113
lines = []
114114
lines.append('#!/bin/bash\n')
115-
lines.append('#SBATCH -t 06:00:00\n')
115+
lines.append('#SBATCH -t 20:00:00\n')
116116
lines.append('#SBATCH -N '+str(num_nodes)+'\n')
117117
if use_mpi:
118118
lines.append('#SBATCH --ntasks-per-node=4\n')

0 commit comments

Comments
 (0)