@@ -94,18 +94,20 @@ def parameters(input_file):
9494 elif params ['paths' ]['data' ] == 'd3d_data' :
9595 params ['paths' ]['shot_files' ] = [d3d_full ]
9696 params ['paths' ]['shot_files_test' ] = []
97- params ['paths' ]['use_signals_dict' ] = {'q95' :q95 ,'li' :li ,'ip' :ip ,'lm' :lm ,'betan' :betan ,'energy' :energy ,'dens' :dens ,'pradcore' :pradcore ,
98- 'pradedge' :pradedge ,'pin' :pin ,'torquein' :torquein ,'ipdirect' :ipdirect ,'iptarget' :iptarget ,'iperr' :iperr ,
99- 'etemp_profile' :etemp_profile ,'edens_profile' :edens_profile }
97+ params ['paths' ]['use_signals_dict' ] = {'q95' :q95 ,'li' :li ,'ip' :ip ,'lm' :lm ,'betan' :betan ,'energy' :energy ,'dens' :dens ,'pradcore' :pradcore ,'pradedge' :pradedge ,'pin' :pin ,'torquein' :torquein ,'ipdirect' :ipdirect ,'iptarget' :iptarget ,'iperr' :iperr ,
98+ 'etemp_profile' :etemp_profile ,'edens_profile' :edens_profile }
10099 elif params ['paths' ]['data' ] == 'd3d_data_1D' :
101100 params ['paths' ]['shot_files' ] = [d3d_full ]
102101 params ['paths' ]['shot_files_test' ] = []
103102 params ['paths' ]['use_signals_dict' ] = {'ipdirect' :ipdirect ,'etemp_profile' :etemp_profile ,'edens_profile' :edens_profile }
103+ elif params ['paths' ]['data' ] == 'd3d_data_all_profiles' :
104+ params ['paths' ]['shot_files' ] = [d3d_full ]
105+ params ['paths' ]['shot_files_test' ] = []
106+ params ['paths' ]['use_signals_dict' ] = {'ipdirect' :ipdirect ,'etemp_profile' :etemp_profile ,'edens_profile' :edens_profile ,'itemp_profile' :itemp_profile ,'zdens_profile' :zdens_profile ,'trot_profile' :trot_profile ,'pthm_profile' :pthm_profile ,'neut_profile' :neut_profile ,'q_profile' :q_profile ,'bootstrap_current_profile' :bootstrap_current_profile ,'q_psi_profile' :q_psi_profile }
104107 elif params ['paths' ]['data' ] == 'd3d_data_0D' :
105108 params ['paths' ]['shot_files' ] = [d3d_full ]
106109 params ['paths' ]['shot_files_test' ] = []
107- params ['paths' ]['use_signals_dict' ] = {'q95' :q95 ,'li' :li ,'ip' :ip ,'lm' :lm ,'betan' :betan ,'energy' :energy ,'dens' :dens ,'pradcore' :pradcore ,
108- 'pradedge' :pradedge ,'pin' :pin ,'torquein' :torquein ,'ipdirect' :ipdirect ,'iptarget' :iptarget ,'iperr' :iperr }
110+ params ['paths' ]['use_signals_dict' ] = {'q95' :q95 ,'li' :li ,'ip' :ip ,'lm' :lm ,'betan' :betan ,'energy' :energy ,'dens' :dens ,'pradcore' :pradcore ,'pradedge' :pradedge ,'pin' :pin ,'torquein' :torquein ,'ipdirect' :ipdirect ,'iptarget' :iptarget ,'iperr' :iperr }
109111 elif params ['paths' ]['data' ] == 'd3d_data_all' :
110112 params ['paths' ]['shot_files' ] = [d3d_full ]
111113 params ['paths' ]['shot_files_test' ] = []
@@ -130,14 +132,13 @@ def parameters(input_file):
130132 print ("Signal {} is not fully defined for {} machine. Skipping..." .format (sig ,params ['paths' ]['data' ].split ("_" )[0 ]))
131133 params ['paths' ]['specific_signals' ] = list (filter (lambda x : x in params ['paths' ]['use_signals_dict' ].keys (), params ['paths' ]['specific_signals' ]))
132134 selected_signals = {k : params ['paths' ]['use_signals_dict' ][k ] for k in params ['paths' ]['specific_signals' ]}
133- params ['paths' ]['use_signals' ] = list (selected_signals .values ())
135+ params ['paths' ]['use_signals' ] = sort_by_channels ( list (selected_signals .values () ))
134136
135- selected_signals = {k : fully_defined_signals [k ] for k in params ['paths' ]['specific_signals' ]}
136- params ['paths' ]['all_signals' ] = list (selected_signals .values ())
137137 else :
138138 #default case
139- params ['paths' ]['use_signals' ] = list (params ['paths' ]['use_signals_dict' ].values ())
140- params ['paths' ]['all_signals' ] = list (fully_defined_signals .values ())
139+ params ['paths' ]['use_signals' ] = sort_by_channels (list (params ['paths' ]['use_signals_dict' ].values ()))
140+
141+ params ['paths' ]['all_signals' ] = sort_by_channels (list (params ['paths' ]['all_signals_dict' ].values ()))
141142
142143 print ("Selected signals (determines which signals training is run on):\n {}" .format (params ['paths' ]['use_signals' ]))
143144
@@ -153,3 +154,7 @@ def parameters(input_file):
153154def get_unique_signal_hash (signals ):
154155 return hash (tuple (sorted (signals )))
155156
157+ #make sure 1D signals come last! This is necessary for model builder.
158+ def sort_by_channels (list_of_signals ):
159+ return sorted (list_of_signals ,key = lambda x : x .num_channels )
160+
0 commit comments