Skip to content

Commit 0c2571a

Browse files
author
Julian Kates-Harbeck
committed
Merge branch 'jdev' of https://github.com/PPPLDeepLearning/plasma-python into jdev
2 parents 638306b + a083b94 commit 0c2571a

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

examples/learn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
#####################################################
9494
#train(conf,shot_list_train,loader)
9595
if not only_predict:
96-
p = old_mp.Process(target = train,args=(conf,shot_list_train,shot_list_validate,loader))
96+
p = old_mp.Process(target = train,args=(conf,shot_list_train,shot_list_validate,loader,shot_list_test))
9797
p.start()
9898
p.join()
9999

plasma/models/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
backend = conf['model']['backend']
2626

27-
def train(conf,shot_list_train,shot_list_validate,loader):
27+
def train(conf,shot_list_train,shot_list_validate,loader,shot_list_test=None):
2828
loader.set_inference_mode(False)
2929
np.random.seed(1)
3030

plasma/models/shallow_runner.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import time
1616
import datetime
1717
import os
18+
from copy import deepcopy
1819
from functools import partial
1920
import pathos.multiprocessing as mp
2021
from xgboost import XGBClassifier
@@ -284,7 +285,7 @@ def build_callbacks(conf):
284285
return cbks.CallbackList(callbacks)
285286

286287

287-
def train(conf,shot_list_train,shot_list_validate,loader):
288+
def train(conf,shot_list_train,shot_list_validate,loader,shot_list_test=None):
288289

289290
np.random.seed(1)
290291

@@ -367,6 +368,20 @@ def train(conf,shot_list_train,shot_list_validate,loader):
367368
Y_predv = model.predict(Xv)
368369
print("Validate")
369370
print(classification_report(Yv,Y_predv))
371+
372+
373+
if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
374+
times = conf['callbacks']['monitor_times']
375+
roc_areas,losses = make_predictions_and_evaluate_multiple_times(conf,shot_list_validate,loader,times)
376+
for roc,t in zip(roc_areas,times):
377+
print('val_roc_{} = {}'.format(t,roc))
378+
if shot_list_test is not None:
379+
roc_areas,losses = make_predictions_and_evaluate_multiple_times(conf,shot_list_test,loader,times)
380+
for roc,t in zip(roc_areas,times):
381+
print('test_roc_{} = {}'.format(t,roc))
382+
383+
384+
370385
#print(confusion_matrix(Y,Y_pred))
371386
_,_,_,roc_area,loss = make_predictions_and_evaluate_gpu(conf,shot_list_validate,loader)
372387
# _,_,_,roc_area_train,loss_train = make_predictions_and_evaluate_gpu(conf,shot_list_train,loader)
@@ -378,6 +393,8 @@ def train(conf,shot_list_train,shot_list_validate,loader):
378393
epoch_logs['val_loss'] = loss
379394
# epoch_logs['train_roc'] = roc_area_train
380395
# epoch_logs['train_loss'] = loss_train
396+
397+
381398
callbacks.on_epoch_end(0, epoch_logs)
382399

383400

@@ -432,3 +449,20 @@ def make_predictions_and_evaluate_gpu(conf,shot_list,loader,custom_path = None):
432449
loss = get_loss_from_list(y_prime,y_gold,conf['data']['target'])
433450
return y_prime,y_gold,disruptive,roc_area,loss
434451

452+
def make_predictions_and_evaluate_multiple_times(conf,shot_list,loader,times,custom_path=None):
453+
y_prime,y_gold,disruptive = make_predictions(conf,shot_list,loader,custom_path)
454+
areas = []
455+
losses = []
456+
for T_min_curr in times:
457+
#if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
458+
conf_curr = deepcopy(conf)
459+
T_min_warn_orig = conf['data']['T_min_warn']
460+
conf_curr['data']['T_min_warn'] = T_min_curr
461+
assert(conf['data']['T_min_warn'] == T_min_warn_orig)
462+
analyzer = PerformanceAnalyzer(conf=conf_curr)
463+
roc_area = analyzer.get_roc_area(y_prime,y_gold,disruptive)
464+
#shot_list.set_weights(analyzer.get_shot_difficulty(y_prime,y_gold,disruptive))
465+
loss = get_loss_from_list(y_prime,y_gold,conf['data']['target'])
466+
areas.append(roc_area)
467+
losses.append(loss)
468+
return areas,losses

0 commit comments

Comments
 (0)