|
12 | 12 |
|
13 | 13 | from plasma.preprocessor.normalize import VarNormalizer as Normalizer |
14 | 14 | from plasma.conf import conf |
15 | | -from plasma.primitives.shots import Shot |
| 15 | +from plasma.primitives.shots import Shot,ShotList |
16 | 16 |
|
17 | 17 | class PerformanceAnalyzer(): |
18 | 18 | def __init__(self,results_dir=None,shots_dir=None,i = 0,T_min_warn = None,T_max_warn = None, verbose = False,pred_ttd=False,conf=None): |
@@ -342,8 +342,8 @@ def load_ith_file(self): |
342 | 342 | self.pred_test = dat['y_prime_test'] |
343 | 343 | self.truth_test = dat['y_gold_test'] |
344 | 344 | self.disruptive_test = dat['disruptive_test'] |
345 | | - self.shot_list_test = dat['shot_list_test'][()] |
346 | | - self.shot_list_train = dat['shot_list_train'][()] |
| 345 | + self.shot_list_test = ShotList(dat['shot_list_test'][()]) |
| 346 | + self.shot_list_train = ShotList(dat['shot_list_train'][()]) |
347 | 347 | self.saved_conf = dat['conf'][()] |
348 | 348 | self.conf['data']['T_warning'] = self.saved_conf['data']['T_warning'] #all files must agree on T_warning due to output of truth vs. normalized shot ttd. |
349 | 349 | for mode in ['test','train']: |
@@ -656,78 +656,84 @@ def plot_shot(self,shot,save_fig=True,normalize=True,truth=None,prediction=None, |
656 | 656 |
|
657 | 657 | if(shot.previously_saved(self.shots_dir)): |
658 | 658 | shot.restore(self.shots_dir) |
659 | | - t_disrupt = shot.t_disrupt |
660 | | - is_disruptive = shot.is_disruptive |
661 | | - if normalize: |
662 | | - self.normalizer.apply(shot) |
| 659 | + if shot.signals_dict is not None: #make sure shot was saved with data |
| 660 | + t_disrupt = shot.t_disrupt |
| 661 | + is_disruptive = shot.is_disruptive |
| 662 | + if normalize: |
| 663 | + self.normalizer.apply(shot) |
663 | 664 |
|
664 | | - use_signals = self.saved_conf['paths']['use_signals'] |
665 | | - fontsize= 15 |
666 | | - lower_lim = 0 #len(pred) |
667 | | - plt.close() |
668 | | - colors = ["b","k"] |
669 | | - lss = ["-","--"] |
670 | | - f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(10,15))#, squeeze=False) |
671 | | - plt.title(prediction_type) |
672 | | - assert(np.all(shot.ttd.flatten() == truth.flatten())) |
673 | | - xx = range(len(prediction)) #list(reversed(range(len(pred)))) |
674 | | - for i,sig in enumerate(use_signals): |
675 | | - ax = axarr[i] |
676 | | - num_channels = sig.num_channels |
677 | | - sig_arr = shot.signals_dict[sig] |
678 | | - if num_channels == 1: |
679 | | - # if j == 0: |
680 | | - ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j]) |
681 | | - # else: |
682 | | - # ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig]) |
683 | | - ax.plot([],linestyle="none",label = sig.description)#labels[sig]) |
684 | | - if np.min(sig_arr[:,0]) < 0: |
685 | | - ax.set_ylim([-6,6]) |
686 | | - ax.set_yticks([-5,0,5]) |
| 665 | + use_signals = self.saved_conf['paths']['use_signals'] |
| 666 | + fontsize= 15 |
| 667 | + lower_lim = 0 #len(pred) |
| 668 | + plt.close() |
| 669 | + colors = ["b","k"] |
| 670 | + lss = ["-","--"] |
| 671 | + f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(10,15))#, squeeze=False) |
| 672 | + plt.title(prediction_type) |
| 673 | + assert(np.all(shot.ttd.flatten() == truth.flatten())) |
| 674 | + xx = range(len(prediction)) #list(reversed(range(len(pred)))) |
| 675 | + for i,sig in enumerate(use_signals): |
| 676 | + ax = axarr[i] |
| 677 | + num_channels = sig.num_channels |
| 678 | + sig_arr = shot.signals_dict[sig] |
| 679 | + if num_channels == 1: |
| 680 | + # if j == 0: |
| 681 | + ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j]) |
| 682 | + # else: |
| 683 | + # ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig]) |
| 684 | + ax.plot([],linestyle="none",label = sig.description)#labels[sig]) |
| 685 | + if np.min(sig_arr[:,0]) < 0: |
| 686 | + ax.set_ylim([-6,6]) |
| 687 | + ax.set_yticks([-5,0,5]) |
| 688 | + # ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig]) |
| 689 | + ax.plot([],linestyle="none",label = sig.description)#labels[sig]) |
| 690 | + if np.min(sig_arr[:,0]) < 0: |
| 691 | + ax.set_ylim([-6,6]) |
| 692 | + ax.set_yticks([-5,0,5]) |
| 693 | + else: |
| 694 | + ax.set_ylim([0,8]) |
| 695 | + ax.set_yticks([0,5]) |
| 696 | + # ax.set_ylabel(labels[sig],size=fontsize) |
687 | 697 | else: |
688 | | - ax.set_ylim([0,8]) |
689 | | - ax.set_yticks([0,5]) |
690 | | - # ax.set_ylabel(labels[sig],size=fontsize) |
691 | | - else: |
692 | | - ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description,cmap="inferno" ) |
693 | | - ax.set_ylim([0,num_channels]) |
694 | | - ax.text(lower_lim+200, 45, sig.description, bbox={'facecolor': 'white', 'pad': 10},fontsize=fontsize-5) |
695 | | - ax.set_yticks([0,num_channels/2]) |
696 | | - ax.set_yticklabels(["0","0.5"]) |
697 | | - ax.set_ylabel("$\\rho$",size=fontsize) |
698 | | - ax.legend(loc="best",labelspacing=0.1,fontsize=fontsize,frameon=False) |
699 | | - ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5) |
700 | | - plt.setp(ax.get_xticklabels(),visible=False) |
| 698 | + ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description,cmap="inferno" ) |
| 699 | + ax.set_ylim([0,num_channels]) |
| 700 | + ax.text(lower_lim+200, 45, sig.description, bbox={'facecolor': 'white', 'pad': 10},fontsize=fontsize-5) |
| 701 | + ax.set_yticks([0,num_channels/2]) |
| 702 | + ax.set_yticklabels(["0","0.5"]) |
| 703 | + ax.set_ylabel("$\\rho$",size=fontsize) |
| 704 | + ax.legend(loc="best",labelspacing=0.1,fontsize=fontsize,frameon=False) |
| 705 | + ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5) |
| 706 | + plt.setp(ax.get_xticklabels(),visible=False) |
| 707 | + plt.setp(ax.get_yticklabels(),fontsize=fontsize) |
| 708 | + f.subplots_adjust(hspace=0) |
| 709 | + #print(sig) |
| 710 | + #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr))) |
| 711 | + ax = axarr[-1] |
| 712 | + # ax.semilogy((-truth+0.0001),label='ground truth') |
| 713 | + # ax.plot(-prediction+0.0001,'g',label='neural net prediction') |
| 714 | + # ax.axhline(-P_thresh_opt,color='k',label='trigger threshold') |
| 715 | + # nn = np.min(pred) |
| 716 | + ax.plot(xx,truth,'g',label='target',linewidth=2) |
| 717 | + # ax.axhline(0.4,linestyle="--",color='k',label='threshold') |
| 718 | + ax.plot(xx,prediction,'b',label='RNN output',linewidth=2) |
| 719 | + ax.axhline(P_thresh_opt,linestyle="--",color='k',label='threshold') |
| 720 | + ax.set_ylim([-2,2]) |
| 721 | + ax.set_yticks([-1,0,1]) |
| 722 | + # if len(truth)-T_max_warn >= 0: |
| 723 | + # ax.axvline(len(truth)-T_max_warn,color='r')#,label='max warning time') |
| 724 | + ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)#,label='min warning time') |
| 725 | + ax.set_xlabel('T [ms]',size=fontsize) |
| 726 | + # ax.axvline(2400) |
| 727 | + ax.legend(loc = (0.5,0.7),fontsize=fontsize-5,labelspacing=0.1,frameon=False) |
701 | 728 | plt.setp(ax.get_yticklabels(),fontsize=fontsize) |
702 | | - f.subplots_adjust(hspace=0) |
703 | | - #print(sig) |
704 | | - #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr))) |
705 | | - ax = axarr[-1] |
706 | | - # ax.semilogy((-truth+0.0001),label='ground truth') |
707 | | - # ax.plot(-prediction+0.0001,'g',label='neural net prediction') |
708 | | - # ax.axhline(-P_thresh_opt,color='k',label='trigger threshold') |
709 | | - # nn = np.min(pred) |
710 | | - ax.plot(xx,truth,'g',label='target',linewidth=2) |
711 | | - # ax.axhline(0.4,linestyle="--",color='k',label='threshold') |
712 | | - ax.plot(xx,prediction,'b',label='RNN output',linewidth=2) |
713 | | - ax.axhline(P_thresh_opt,linestyle="--",color='k',label='threshold') |
714 | | - ax.set_ylim([-2,2]) |
715 | | - ax.set_yticks([-1,0,1]) |
716 | | - # if len(truth)-T_max_warn >= 0: |
717 | | - # ax.axvline(len(truth)-T_max_warn,color='r')#,label='max warning time') |
718 | | - ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)#,label='min warning time') |
719 | | - ax.set_xlabel('T [ms]',size=fontsize) |
720 | | - # ax.axvline(2400) |
721 | | - ax.legend(loc = (0.5,0.7),fontsize=fontsize-5,labelspacing=0.1,frameon=False) |
722 | | - plt.setp(ax.get_yticklabels(),fontsize=fontsize) |
723 | | - plt.setp(ax.get_xticklabels(),fontsize=fontsize) |
724 | | - # plt.xlim(0,200) |
725 | | - plt.xlim([lower_lim,len(truth)]) |
726 | | - # plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight") |
727 | | - if save_fig: |
728 | | - plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight') |
729 | | - np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt) |
730 | | - #plt.show() |
| 729 | + plt.setp(ax.get_xticklabels(),fontsize=fontsize) |
| 730 | + # plt.xlim(0,200) |
| 731 | + plt.xlim([lower_lim,len(truth)]) |
| 732 | + # plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight") |
| 733 | + if save_fig: |
| 734 | + plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight') |
| 735 | + np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt) |
| 736 | + #plt.show() |
731 | 737 | else: |
732 | 738 | print("Shot hasn't been processed") |
733 | 739 |
|
|
0 commit comments