1515import time
1616import datetime
1717import os
18+ from copy import deepcopy
1819from functools import partial
1920import pathos .multiprocessing as mp
2021from xgboost import XGBClassifier
@@ -284,7 +285,7 @@ def build_callbacks(conf):
284285 return cbks .CallbackList (callbacks )
285286
286287
287- def train (conf ,shot_list_train ,shot_list_validate ,loader ):
288+ def train (conf ,shot_list_train ,shot_list_validate ,loader , shot_list_test = None ):
288289
289290 np .random .seed (1 )
290291
@@ -367,6 +368,20 @@ def train(conf,shot_list_train,shot_list_validate,loader):
367368 Y_predv = model .predict (Xv )
368369 print ("Validate" )
369370 print (classification_report (Yv ,Y_predv ))
371+
372+
373+ if 'monitor_test' in conf ['callbacks' ].keys () and conf ['callbacks' ]['monitor_test' ]:
374+ times = conf ['callbacks' ]['monitor_times' ]
375+ roc_areas ,losses = make_predictions_and_evaluate_multiple_times (conf ,shot_list_validate ,loader ,times )
376+ for roc ,t in zip (roc_areas ,times ):
377+ print ('val_roc_{} = {}' .format (t ,roc ))
378+ if shot_list_test is not None :
379+ roc_areas ,losses = make_predictions_and_evaluate_multiple_times (conf ,shot_list_test ,loader ,times )
380+ for roc ,t in zip (roc_areas ,times ):
381+ print ('test_roc_{} = {}' .format (t ,roc ))
382+
383+
384+
370385 #print(confusion_matrix(Y,Y_pred))
371386 _ ,_ ,_ ,roc_area ,loss = make_predictions_and_evaluate_gpu (conf ,shot_list_validate ,loader )
372387 # _,_,_,roc_area_train,loss_train = make_predictions_and_evaluate_gpu(conf,shot_list_train,loader)
@@ -378,6 +393,8 @@ def train(conf,shot_list_train,shot_list_validate,loader):
378393 epoch_logs ['val_loss' ] = loss
379394 # epoch_logs['train_roc'] = roc_area_train
380395 # epoch_logs['train_loss'] = loss_train
396+
397+
381398 callbacks .on_epoch_end (0 , epoch_logs )
382399
383400
@@ -432,3 +449,20 @@ def make_predictions_and_evaluate_gpu(conf,shot_list,loader,custom_path = None):
432449 loss = get_loss_from_list (y_prime ,y_gold ,conf ['data' ]['target' ])
433450 return y_prime ,y_gold ,disruptive ,roc_area ,loss
434451
452+ def make_predictions_and_evaluate_multiple_times (conf ,shot_list ,loader ,times ,custom_path = None ):
453+ y_prime ,y_gold ,disruptive = make_predictions (conf ,shot_list ,loader ,custom_path )
454+ areas = []
455+ losses = []
456+ for T_min_curr in times :
457+ #if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
458+ conf_curr = deepcopy (conf )
459+ T_min_warn_orig = conf ['data' ]['T_min_warn' ]
460+ conf_curr ['data' ]['T_min_warn' ] = T_min_curr
461+ assert (conf ['data' ]['T_min_warn' ] == T_min_warn_orig )
462+ analyzer = PerformanceAnalyzer (conf = conf_curr )
463+ roc_area = analyzer .get_roc_area (y_prime ,y_gold ,disruptive )
464+ #shot_list.set_weights(analyzer.get_shot_difficulty(y_prime,y_gold,disruptive))
465+ loss = get_loss_from_list (y_prime ,y_gold ,conf ['data' ]['target' ])
466+ areas .append (roc_area )
467+ losses .append (loss )
468+ return areas ,losses
0 commit comments