Skip to content

Commit 6556b41

Browse files
author
Craig Michoski
committed
cerm
1 parent 1623c5d commit 6556b41

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

plasma/conf_parser.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import plasma.models.targets as t
21
from plasma.primitives.shots import ShotListFiles
32
from data.signals import *
43

@@ -10,6 +9,7 @@
109

1110
def parameters(input_file):
1211
"""Parse yaml file of configuration parameters."""
12+
from plasma.models.targets import HingeTarget, MaxHingeTarget, BinaryTarget, TTDTarget, TTDInvTarget, TTDLinearTarget
1313

1414
with open(input_file, 'r') as yaml_file:
1515
params = yaml.load(yaml_file)
@@ -42,18 +42,18 @@ def parameters(input_file):
4242

4343
#ensure shallow model has +1 -1 target.
4444
if params['model']['shallow'] or params['target'] == 'hinge':
45-
params['data']['target'] = t.HingeTarget
45+
params['data']['target'] = HingeTarget
4646
elif params['target'] == 'maxhinge':
47-
t.MaxHingeTarget.fac = params['data']['positive_example_penalty']
48-
params['data']['target'] = t.MaxHingeTarget
47+
MaxHingeTarget.fac = params['data']['positive_example_penalty']
48+
params['data']['target'] = MaxHingeTarget
4949
elif params['target'] == 'binary':
50-
params['data']['target'] = t.BinaryTarget
50+
params['data']['target'] = BinaryTarget
5151
elif params['target'] == 'ttd':
52-
params['data']['target'] = t.TTDTarget
52+
params['data']['target'] = TTDTarget
5353
elif params['target'] == 'ttdinv':
54-
params['data']['target'] = t.TTDInvTarget
54+
params['data']['target'] = TTDInvTarget
5555
elif params['target'] == 'ttdlinear':
56-
params['data']['target'] = t.TTDLinearTarget
56+
params['data']['target'] = TTDLinearTarget
5757
else:
5858
print('Unkown type of target. Exiting')
5959
exit(1)

plasma/models/mpi_runner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=No
185185
self.num_workers = comm.Get_size()
186186
self.task_index = comm.Get_rank()
187187
self.history = cbks.History()
188+
self.model.stop_training = False
188189
if num_replicas is None or num_replicas < 1 or num_replicas > self.num_workers:
189190
self.num_replicas = self.num_workers
190191
else:
@@ -736,7 +737,6 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
736737
epoch_logs['train_loss'] = ave_loss
737738
best_so_far = cmp_fn(epoch_logs[conf['callbacks']['monitor']],best_so_far)
738739

739-
stop_training = False
740740
if task_index == 0:
741741
print('=========Summary======== for epoch{}'.format(step))
742742
print('Training Loss numpy: {:.3e}'.format(ave_loss))
@@ -747,8 +747,6 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
747747
print('Training ROC: {:.4f}'.format(roc_area_train))
748748

749749
callbacks.on_epoch_end(int(round(e)), epoch_logs)
750-
if hasattr(mpi_model.model,'stop_training'):
751-
stop_training = mpi_model.model.stop_training
752750
if best_so_far != epoch_logs[conf['callbacks']['monitor']]: #only save model weights if quantity we are tracking is improving
753751
print("Not saving model weights")
754752
specific_builder.delete_model_weights(train_model,int(round(e)))
@@ -759,7 +757,7 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
759757
val_steps = 1
760758
tensorboard.on_epoch_end(val_generator,val_steps,int(round(e)),epoch_logs)
761759

762-
stop_training = comm.bcast(stop_training,root=0)
760+
stop_training = comm.bcast(mpi_model.model.stop_training,root=0)
763761
if stop_training:
764762
print("Stopping training due to early stopping")
765763
break

plasma/models/targets.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
from plasma.utils.evaluation import mae_np,mse_np,binary_crossentropy_np,hinge_np,squared_hinge_np
66
import keras.backend as K
77

8-
import plasma.conf
9-
108
#Requirement: larger value must mean disruption more likely.
119
class Target(object):
1210
activation = 'linear'
1311
loss = 'mse'
1412

1513
@abc.abstractmethod
1614
def loss_np(y_true,y_pred):
15+
from plasma.conf import conf
1716
return conf['model']['loss_scale_factor']*mse_np(y_true,y_pred)
1817

1918
@abc.abstractmethod
@@ -31,6 +30,7 @@ class BinaryTarget(Target):
3130

3231
@staticmethod
3332
def loss_np(y_true,y_pred):
33+
from plasma.conf import conf
3434
return conf['model']['loss_scale_factor']*binary_crossentropy_np(y_true,y_pred)
3535

3636
@staticmethod
@@ -53,6 +53,7 @@ class TTDTarget(Target):
5353

5454
@staticmethod
5555
def loss_np(y_true,y_pred):
56+
from plasma.conf import conf
5657
return conf['model']['loss_scale_factor']*mse_np(y_true,y_pred)
5758

5859
@staticmethod
@@ -94,6 +95,7 @@ class TTDLinearTarget(Target):
9495

9596
@staticmethod
9697
def loss_np(y_true,y_pred):
98+
from plasma.conf import conf
9799
return conf['model']['loss_scale_factor']*mse_np(y_true,y_pred)
98100

99101

@@ -118,6 +120,7 @@ class MaxHingeTarget(Target):
118120

119121
@staticmethod
120122
def loss(y_true, y_pred):
123+
from plasma.conf import conf
121124
fac = MaxHingeTarget.fac
122125
#overall_fac = np.prod(np.array(K.shape(y_pred)[1:]).astype(np.float32))
123126
overall_fac = K.prod(K.cast(K.shape(y_pred)[1:],K.floatx()))
@@ -133,6 +136,7 @@ def loss(y_true, y_pred):
133136

134137
@staticmethod
135138
def loss_np(y_true, y_pred):
139+
from plasma.conf import conf
136140
fac = MaxHingeTarget.fac
137141
#print(y_pred.shape)
138142
overall_fac = np.prod(np.array(y_pred.shape).astype(np.float32))
@@ -175,6 +179,7 @@ class HingeTarget(Target):
175179

176180
@staticmethod
177181
def loss_np(y_true, y_pred):
182+
from plasma.conf import conf
178183
return conf['model']['loss_scale_factor']*hinge_np(y_true,y_pred)
179184
#return squared_hinge_np(y_true,y_pred)
180185

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
download_url = "https://github.com/PPPLDeepLearning/plasma-python",
2626
#license = "Apache Software License v2",
2727
test_suite = "tests",
28-
install_requires = ['keras==2.0.6','pathos','matplotlib==2.0.2','hyperopt','mpi4py','xgboost'],
28+
install_requires = ['keras>2.0.8','pathos','matplotlib==2.0.2','hyperopt','mpi4py','xgboost'],
2929
tests_require = [],
3030
classifiers = ["Development Status :: 3 - Alpha",
3131
"Environment :: Console",

0 commit comments

Comments
 (0)