Skip to content

Commit 374e069

Browse files
Julian Kates-HarbeckJulian Kates-Harbeck
authored andcommitted
not saving from within preprocessor class
2 parents 480e303 + 8ce98dc commit 374e069

File tree

4 files changed

+89
-76
lines changed

4 files changed

+89
-76
lines changed

examples/mpi_learn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@
8686
loader = Loader(conf,normalizer)
8787
print("...done")
8888

89+
#ensure training has a separate random seed for every worker
90+
np.random.seed(task_index)
91+
random.seed(task_index)
8992
if not only_predict:
9093
mpi_train(conf,shot_list_train,shot_list_validate,loader)
9194

plasma/models/mpi_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import time
2121
import datetime
2222
import numpy as np
23+
import random
2324

2425
from functools import partial
2526
import socket
@@ -170,7 +171,8 @@ def get_val(self):
170171

171172
class MPIModel():
172173
def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=None,warmup_steps=1000,lr=0.01,num_batches_minimum=100):
173-
# random.seed(task_index)
174+
random.seed(task_index)
175+
np.random.seed(task_index)
174176
self.epoch = 0
175177
self.num_so_far = 0
176178
self.num_so_far_accum = 0
@@ -640,6 +642,7 @@ def mpi_make_predictions_and_evaluate(conf,shot_list,loader,custom_path=None):
640642

641643

642644
def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=None):
645+
643646
loader.set_inference_mode(False)
644647
conf['num_workers'] = comm.Get_size()
645648

plasma/preprocessor/normalize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def get_individual_shot_file(prepath,shot_num,ext='.txt'):
418418

419419
def apply_positivity(shot):
420420
for (i,sig) in enumerate(shot.signals):
421-
if sig.is_strictly_positive:
422-
#print ('Applying positivity constraint to {} signal'.format(sig.description))
423-
shot.signals_dict[sig]=np.clip(shot.signals_dict[sig],0,np.inf)
421+
if hasattr(sig,"is_strictly_positive"): #backwards compatibility when this attribute didn't exist
422+
if sig.is_strictly_positive:
423+
#print ('Applying positivity constraint to {} signal'.format(sig.description))
424+
shot.signals_dict[sig]=np.clip(shot.signals_dict[sig],0,np.inf)

plasma/utils/performance.py

Lines changed: 78 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from plasma.preprocessor.normalize import VarNormalizer as Normalizer
1414
from plasma.conf import conf
15-
from plasma.primitives.shots import Shot
15+
from plasma.primitives.shots import Shot,ShotList
1616

1717
class PerformanceAnalyzer():
1818
def __init__(self,results_dir=None,shots_dir=None,i = 0,T_min_warn = None,T_max_warn = None, verbose = False,pred_ttd=False,conf=None):
@@ -342,8 +342,8 @@ def load_ith_file(self):
342342
self.pred_test = dat['y_prime_test']
343343
self.truth_test = dat['y_gold_test']
344344
self.disruptive_test = dat['disruptive_test']
345-
self.shot_list_test = dat['shot_list_test'][()]
346-
self.shot_list_train = dat['shot_list_train'][()]
345+
self.shot_list_test = ShotList(dat['shot_list_test'][()])
346+
self.shot_list_train = ShotList(dat['shot_list_train'][()])
347347
self.saved_conf = dat['conf'][()]
348348
self.conf['data']['T_warning'] = self.saved_conf['data']['T_warning'] #all files must agree on T_warning due to output of truth vs. normalized shot ttd.
349349
for mode in ['test','train']:
@@ -656,78 +656,84 @@ def plot_shot(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,
656656

657657
if(shot.previously_saved(self.shots_dir)):
658658
shot.restore(self.shots_dir)
659-
t_disrupt = shot.t_disrupt
660-
is_disruptive = shot.is_disruptive
661-
if normalize:
662-
self.normalizer.apply(shot)
659+
if shot.signals_dict is not None: #make sure shot was saved with data
660+
t_disrupt = shot.t_disrupt
661+
is_disruptive = shot.is_disruptive
662+
if normalize:
663+
self.normalizer.apply(shot)
663664

664-
use_signals = self.saved_conf['paths']['use_signals']
665-
fontsize= 15
666-
lower_lim = 0 #len(pred)
667-
plt.close()
668-
colors = ["b","k"]
669-
lss = ["-","--"]
670-
f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(10,15))#, squeeze=False)
671-
plt.title(prediction_type)
672-
assert(np.all(shot.ttd.flatten() == truth.flatten()))
673-
xx = range(len(prediction)) #list(reversed(range(len(pred))))
674-
for i,sig in enumerate(use_signals):
675-
ax = axarr[i]
676-
num_channels = sig.num_channels
677-
sig_arr = shot.signals_dict[sig]
678-
if num_channels == 1:
679-
# if j == 0:
680-
ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j])
681-
# else:
682-
# ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig])
683-
ax.plot([],linestyle="none",label = sig.description)#labels[sig])
684-
if np.min(sig_arr[:,0]) < 0:
685-
ax.set_ylim([-6,6])
686-
ax.set_yticks([-5,0,5])
665+
use_signals = self.saved_conf['paths']['use_signals']
666+
fontsize= 15
667+
lower_lim = 0 #len(pred)
668+
plt.close()
669+
colors = ["b","k"]
670+
lss = ["-","--"]
671+
f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(10,15))#, squeeze=False)
672+
plt.title(prediction_type)
673+
assert(np.all(shot.ttd.flatten() == truth.flatten()))
674+
xx = range(len(prediction)) #list(reversed(range(len(pred))))
675+
for i,sig in enumerate(use_signals):
676+
ax = axarr[i]
677+
num_channels = sig.num_channels
678+
sig_arr = shot.signals_dict[sig]
679+
if num_channels == 1:
680+
# if j == 0:
681+
ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j])
682+
# else:
683+
# ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig])
684+
ax.plot([],linestyle="none",label = sig.description)#labels[sig])
685+
if np.min(sig_arr[:,0]) < 0:
686+
ax.set_ylim([-6,6])
687+
ax.set_yticks([-5,0,5])
688+
# ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig])
689+
ax.plot([],linestyle="none",label = sig.description)#labels[sig])
690+
if np.min(sig_arr[:,0]) < 0:
691+
ax.set_ylim([-6,6])
692+
ax.set_yticks([-5,0,5])
693+
else:
694+
ax.set_ylim([0,8])
695+
ax.set_yticks([0,5])
696+
# ax.set_ylabel(labels[sig],size=fontsize)
687697
else:
688-
ax.set_ylim([0,8])
689-
ax.set_yticks([0,5])
690-
# ax.set_ylabel(labels[sig],size=fontsize)
691-
else:
692-
ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description,cmap="inferno" )
693-
ax.set_ylim([0,num_channels])
694-
ax.text(lower_lim+200, 45, sig.description, bbox={'facecolor': 'white', 'pad': 10},fontsize=fontsize-5)
695-
ax.set_yticks([0,num_channels/2])
696-
ax.set_yticklabels(["0","0.5"])
697-
ax.set_ylabel("$\\rho$",size=fontsize)
698-
ax.legend(loc="best",labelspacing=0.1,fontsize=fontsize,frameon=False)
699-
ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)
700-
plt.setp(ax.get_xticklabels(),visible=False)
698+
ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description,cmap="inferno" )
699+
ax.set_ylim([0,num_channels])
700+
ax.text(lower_lim+200, 45, sig.description, bbox={'facecolor': 'white', 'pad': 10},fontsize=fontsize-5)
701+
ax.set_yticks([0,num_channels/2])
702+
ax.set_yticklabels(["0","0.5"])
703+
ax.set_ylabel("$\\rho$",size=fontsize)
704+
ax.legend(loc="best",labelspacing=0.1,fontsize=fontsize,frameon=False)
705+
ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)
706+
plt.setp(ax.get_xticklabels(),visible=False)
707+
plt.setp(ax.get_yticklabels(),fontsize=fontsize)
708+
f.subplots_adjust(hspace=0)
709+
#print(sig)
710+
#print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
711+
ax = axarr[-1]
712+
# ax.semilogy((-truth+0.0001),label='ground truth')
713+
# ax.plot(-prediction+0.0001,'g',label='neural net prediction')
714+
# ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
715+
# nn = np.min(pred)
716+
ax.plot(xx,truth,'g',label='target',linewidth=2)
717+
# ax.axhline(0.4,linestyle="--",color='k',label='threshold')
718+
ax.plot(xx,prediction,'b',label='RNN output',linewidth=2)
719+
ax.axhline(P_thresh_opt,linestyle="--",color='k',label='threshold')
720+
ax.set_ylim([-2,2])
721+
ax.set_yticks([-1,0,1])
722+
# if len(truth)-T_max_warn >= 0:
723+
# ax.axvline(len(truth)-T_max_warn,color='r')#,label='max warning time')
724+
ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)#,label='min warning time')
725+
ax.set_xlabel('T [ms]',size=fontsize)
726+
# ax.axvline(2400)
727+
ax.legend(loc = (0.5,0.7),fontsize=fontsize-5,labelspacing=0.1,frameon=False)
701728
plt.setp(ax.get_yticklabels(),fontsize=fontsize)
702-
f.subplots_adjust(hspace=0)
703-
#print(sig)
704-
#print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
705-
ax = axarr[-1]
706-
# ax.semilogy((-truth+0.0001),label='ground truth')
707-
# ax.plot(-prediction+0.0001,'g',label='neural net prediction')
708-
# ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
709-
# nn = np.min(pred)
710-
ax.plot(xx,truth,'g',label='target',linewidth=2)
711-
# ax.axhline(0.4,linestyle="--",color='k',label='threshold')
712-
ax.plot(xx,prediction,'b',label='RNN output',linewidth=2)
713-
ax.axhline(P_thresh_opt,linestyle="--",color='k',label='threshold')
714-
ax.set_ylim([-2,2])
715-
ax.set_yticks([-1,0,1])
716-
# if len(truth)-T_max_warn >= 0:
717-
# ax.axvline(len(truth)-T_max_warn,color='r')#,label='max warning time')
718-
ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)#,label='min warning time')
719-
ax.set_xlabel('T [ms]',size=fontsize)
720-
# ax.axvline(2400)
721-
ax.legend(loc = (0.5,0.7),fontsize=fontsize-5,labelspacing=0.1,frameon=False)
722-
plt.setp(ax.get_yticklabels(),fontsize=fontsize)
723-
plt.setp(ax.get_xticklabels(),fontsize=fontsize)
724-
# plt.xlim(0,200)
725-
plt.xlim([lower_lim,len(truth)])
726-
# plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
727-
if save_fig:
728-
plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
729-
np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
730-
#plt.show()
729+
plt.setp(ax.get_xticklabels(),fontsize=fontsize)
730+
# plt.xlim(0,200)
731+
plt.xlim([lower_lim,len(truth)])
732+
# plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
733+
if save_fig:
734+
plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
735+
np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
736+
#plt.show()
731737
else:
732738
print("Shot hasn't been processed")
733739

0 commit comments

Comments
 (0)