Skip to content

Commit ad5de2f

Browse files
author
Julian Kates-Harbeck
committed
added support for torch fully convolutional model. Temporal convolutions and spatial convolutions. TODO is to add multi GPU support and customizability via conf. Parameters currently hardcoded.
1 parent e97f251 commit ad5de2f

File tree

3 files changed

+120
-80
lines changed

3 files changed

+120
-80
lines changed

plasma/models/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ def extract_id_and_epoch_from_filename(self,filename):
273273
def get_all_saved_files(self):
274274
self.ensure_save_directory()
275275
unique_id = self.get_unique_id()
276-
filenames = os.listdir(self.conf['paths']['model_save_path'])
276+
path = self.conf['paths']['model_save_path']
277+
filenames = [name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))]
277278
epochs = []
278279
for file in filenames:
279280
curr_id,epoch = self.extract_id_and_epoch_from_filename(file)

plasma/models/loader.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,13 @@ def shift_buffer(self,buff,length):
123123
buff[:,:-length,:] = buff[:,length:,:]
124124

125125

126-
def resize_buffer(self,buff,new_length):
126+
def resize_buffer(self,buff,new_length,dtype=None):
127+
if dtype == None:
128+
dtype = self.conf['data']['floatx']
127129
old_length = buff.shape[1]
128130
batch_size = buff.shape[0]
129131
num_signals = buff.shape[2]
130-
new_buff = np.empty((batch_size,new_length,num_signals),dtype=self.conf['data']['floatx'])
132+
new_buff = np.zeros((batch_size,new_length,num_signals),dtype=dtype)
131133
new_buff[:,:old_length,:] = buff
132134
#print("Resizing buffer to new length {}".format(new_length))
133135
return new_buff
@@ -149,18 +151,20 @@ def inference_batch_generator_full_shot(self,shot_list):
149151
- reset_states_now: boolean flag indicating when to reset state during stateful RNN training
150152
- num_so_far,num_total: number of samples generated so far and the total dataset size as per shot_list
151153
"""
152-
batch_size = self.conf['training']['pred_batch_size']
154+
batch_size = self.conf['model']['pred_batch_size']
153155
sig,res = self.get_signal_result_from_shot(shot_list.shots[0])
154-
Xbuff = np.empty((batch_size,) + sig.shape,dtype=self.conf['data']['floatx'])
155-
Ybuff = np.empty((batch_size,) + res.shape,dtype=self.conf['data']['floatx'])
156-
Maskbuff = np.empty((batch_size,) + res.shape,dtype=self.conf['data']['floatx'])
157-
disr = np.empty(batch_size,dtype=bool)
156+
Xbuff = np.zeros((batch_size,) + sig.shape,dtype=self.conf['data']['floatx'])
157+
Ybuff = np.zeros((batch_size,) + res.shape,dtype=self.conf['data']['floatx'])
158+
Maskbuff = np.zeros((batch_size,) + res.shape,dtype=self.conf['data']['floatx'])
159+
disr = np.zeros(batch_size,dtype=bool)
160+
lengths = np.zeros(batch_size,dtype=int)
158161
# epoch = 0
159162
num_total = len(shot_list)
160163
num_so_far = 0
161164
returned = False
162165
num_steps = 0
163166
batch_idx = 0
167+
np.seterr(all='raise')
164168
# warmup_steps = self.conf['training']['batch_generator_warmup_steps']
165169
# is_warmup_period = num_steps < warmup_steps
166170
# is_first_fill = num_steps < batch_size
@@ -178,15 +182,29 @@ def inference_batch_generator_full_shot(self,shot_list):
178182
Maskbuff = self.resize_buffer(Maskbuff,sig_len)
179183
Maskbuff[:,old_len:,:] = 0.0
180184

181-
Xbuff[batch_idx,:,:] = sig
182-
Ybuff[batch_idx,:,:] = res
185+
Xbuff[batch_idx,:,:] = 0.0
186+
Ybuff[batch_idx,:,:] = 0.0
187+
Maskbuff[batch_idx,:,:] = 0.0
188+
Xbuff[batch_idx,:sig_len,:] = sig
189+
Ybuff[batch_idx,:sig_len,:] = res
183190
Maskbuff[batch_idx,:sig_len,:] = 1.0
184-
Maskbuff[batch_idx,sig_len:,:] = 0.0
185191
disr[batch_idx] = shot.is_disruptive_shot()
192+
lengths[batch_idx] = res.shape[0]
186193
batch_idx += 1
187194
if batch_idx == batch_size:
188195
num_so_far += batch_size
189-
yield 1.0*Xbuff,1.0*Ybuff,1.0*Maskbuff,disr & True,num_so_far,num_total
196+
x1 = 1.0*Xbuff
197+
try:
198+
x2 = 1.0*Ybuff
199+
except:
200+
print(Ybuff[:100])
201+
print(Ybuff[-100:])
202+
print(Ybuff)
203+
x3 = 1.0*Maskbuff
204+
x4 = disr & True
205+
x5 = 1*lengths
206+
207+
yield x1,x2,x3,x4,x5,num_so_far,num_total
190208
batch_idx = 0
191209

192210

@@ -236,10 +254,13 @@ def training_batch_generator_full_shot_partial_reset(self,shot_list):
236254
Maskbuff = self.resize_buffer(Maskbuff,sig_len)
237255
Maskbuff[:,old_len:,:] = 0.0
238256

239-
Xbuff[batch_idx,:,:] = sig
240-
Ybuff[batch_idx,:,:] = res
257+
Xbuff[batch_idx,:,:] = 0.0
258+
Ybuff[batch_idx,:,:] = 0.0
259+
Maskbuff[batch_idx,:,:] = 0.0
260+
261+
Xbuff[batch_idx,:sig_len,:] = sig
262+
Ybuff[batch_idx,:sig_len,:] = res
241263
Maskbuff[batch_idx,:sig_len,:] = 1.0
242-
Maskbuff[batch_idx,sig_len:,:] = 0.0
243264
batch_idx += 1
244265
if batch_idx == batch_size:
245266
num_so_far += batch_size
@@ -735,8 +756,10 @@ def __init__(self,generator):
735756

736757
def fill_batch_queue(self):
737758
print("Starting process to fetch data")
759+
count = 0
738760
while True:
739761
self.queue.put(next(self.generator),True)
762+
count += 1
740763

741764
def __next__(self):
742765
return self.queue.get(True)

plasma/models/torch_runner.py

Lines changed: 81 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self,n_scalars,n_profiles,profile_size,layer_sizes_spatial,
4545
num_channels_tcn,kernel_size_temporal,dropout=0.1):
4646
super(FTCN, self).__init__()
4747
self.lin = InputBlock(n_scalars, n_profiles,profile_size, layer_sizes_spatial, kernel_size_spatial, linear_size, dropout)
48-
self.input_layer = TimeDistributed(lin,batch_first=True)
48+
self.input_layer = TimeDistributed(self.lin,batch_first=True)
4949
self.tcn = TCN(linear_size, output_size, num_channels_tcn , kernel_size_temporal, dropout)
5050
self.model = nn.Sequential(self.input_layer,self.tcn)
5151

@@ -102,8 +102,8 @@ def forward(self, x):
102102
if self.n_scalars == 0:
103103
x_profiles = x
104104
else:
105-
x_scalars = x[:,:n_scalars]
106-
x_profiles = x[:,n_scalars:]
105+
x_scalars = x[:,:self.n_scalars]
106+
x_profiles = x[:,self.n_scalars:]
107107
x_profiles = x_profiles.contiguous().view(x.size(0),self.n_profiles,self.profile_size)
108108
profile_features = self.net(x_profiles).view(x.size(0),-1)
109109
if self.n_scalars == 0:
@@ -271,18 +271,18 @@ def build_torch_model(conf):
271271
# dim = 10
272272

273273
# lin = nn.Linear(input_size,intermediate_dim)
274-
n_scalars, n_profile, profile_size = get_signal_dimensions(conf)
274+
n_scalars, n_profiles, profile_size = get_signal_dimensions(conf)
275275
dim = n_scalars+n_profiles*profile_size
276276
input_size = dim
277277
output_size = 1
278278
# intermediate_dim = 15
279279

280-
layer_sizes_spatial = [40,20,20]
280+
layer_sizes_spatial = [6,3,3]#[40,20,20]
281281
kernel_size_spatial = 3
282-
linear_size = 10
282+
linear_size = 5
283283

284-
num_channels_tcn = [3]*5
285-
kernel_size_temporal = 3
284+
num_channels_tcn = [10,5,3,3]#[3]*5
285+
kernel_size_temporal = 3 #3
286286
model = FTCN(n_scalars,n_profiles,profile_size,layer_sizes_spatial,
287287
kernel_size_spatial,linear_size,output_size,num_channels_tcn,
288288
kernel_size_temporal,dropout)
@@ -300,16 +300,68 @@ def get_signal_dimensions(conf):
300300
num_channels = sig.num_channels
301301
if num_channels > 1:
302302
profile_size = num_channels
303-
num_1D += 1
303+
n_profiles += 1
304304
is_1D_region = True
305305
else:
306306
assert(not is_1D_region), "make sure all use_signals are ordered such that 1D signals come last!"
307307
assert(num_channels == 1)
308-
num_0D += 1
308+
n_scalars += 1
309309
is_1D_region = False
310310
return n_scalars,n_profiles,profile_size
311311

312-
def train_epoch(model,data_gen,loss_fn):
312+
def apply_model_to_np(model,x):
313+
# return model(Variable(torch.from_numpy(x).float()).unsqueeze(0)).squeeze(0).data.numpy()
314+
return model(Variable(torch.from_numpy(x).float())).data.numpy()
315+
316+
317+
318+
def make_predictions(conf,shot_list,loader,custom_path=None):
319+
generator = loader.inference_batch_generator_full_shot(shot_list)
320+
inference_model = build_torch_model(conf)
321+
322+
if custom_path == None:
323+
model_path = get_model_path(conf)
324+
else:
325+
model_path = custom_path
326+
inference_model.load_state_dict(torch.load(model_path))
327+
#shot_list = shot_list.random_sublist(10)
328+
329+
y_prime = []
330+
y_gold = []
331+
disruptive = []
332+
num_shots = len(shot_list)
333+
334+
pbar = Progbar(num_shots)
335+
while True:
336+
x,y,mask,disr,lengths,num_so_far,num_total = next(generator)
337+
#x, y, mask = Variable(torch.from_numpy(x_).float()), Variable(torch.from_numpy(y_).float()),Variable(torch.from_numpy(mask_).byte())
338+
output = apply_model_to_np(inference_model,x)
339+
for batch_idx in range(x.shape[0]):
340+
curr_length = lengths[batch_idx]
341+
y_prime += [output[batch_idx,:curr_length,0]]
342+
y_gold += [y[batch_idx,:curr_length,0]]
343+
disruptive += [disr[batch_idx]]
344+
pbar.add(1.0)
345+
if len(disruptive) >= num_shots:
346+
y_prime = y_prime[:num_shots]
347+
y_gold = y_gold[:num_shots]
348+
disruptive = disruptive[:num_shots]
349+
break
350+
return y_prime,y_gold,disruptive
351+
352+
def make_predictions_and_evaluate_gpu(conf,shot_list,loader,custom_path = None):
353+
y_prime,y_gold,disruptive = make_predictions(conf,shot_list,loader,custom_path)
354+
analyzer = PerformanceAnalyzer(conf=conf)
355+
roc_area = analyzer.get_roc_area(y_prime,y_gold,disruptive)
356+
loss = get_loss_from_list(y_prime,y_gold,conf['data']['target'])
357+
return y_prime,y_gold,disruptive,roc_area,loss
358+
359+
360+
def get_model_path(conf):
361+
return conf['paths']['model_save_path'] + 'torch/' + model_filename #save_prepath + model_filename
362+
363+
364+
def train_epoch(model,data_gen,optimizer,loss_fn):
313365
loss = 0
314366
total_loss = 0
315367
num_so_far = 0
@@ -335,17 +387,19 @@ def train_epoch(model,data_gen,loss_fn):
335387
loss.backward()
336388
optimizer.step()
337389
step += 1
390+
print("[{}] [{}/{}] loss: {:.3f}, ave_loss: {:.3f}".format(step,num_so_far-num_so_far_start,num_total,loss.data[0],total_loss/step))
338391
if num_so_far-num_so_far_start >= num_total:
339392
break
340-
x_,y_,mask_,num_so_far_start,num_total = next(data_gen)
341-
return step,loss,total_loss,num_so_far,1.0*num_so_far/num_total
393+
x_,y_,mask_,num_so_far,num_total = next(data_gen)
394+
return step,loss.data[0],total_loss,num_so_far,1.0*num_so_far/num_total
342395

343396

344397
def train(conf,shot_list_train,shot_list_validate,loader):
345398

346399
np.random.seed(1)
347400

348-
data_gen = ProcessGenerator(partial(loader.training_batch_generator_full_shot_partial_reset,shot_list=shot_list_train))
401+
#data_gen = ProcessGenerator(partial(loader.training_batch_generator_full_shot_partial_reset,shot_list=shot_list_train)())
402+
data_gen = partial(loader.training_batch_generator_full_shot_partial_reset,shot_list=shot_list_train)()
349403

350404
print('validate: {} shots, {} disruptive'.format(len(shot_list_validate),shot_list_validate.num_disruptive()))
351405
print('training: {} shots, {} disruptive'.format(len(shot_list_train),shot_list_train.num_disruptive()))
@@ -358,6 +412,7 @@ def train(conf,shot_list_train,shot_list_validate,loader):
358412
# e = specific_builder.load_model_weights(train_model)
359413

360414
num_epochs = conf['training']['num_epochs']
415+
patience = conf['callbacks']['patience']
361416
lr_decay = conf['model']['lr_decay']
362417
batch_size = conf['training']['batch_size']
363418
lr = conf['model']['lr']
@@ -385,23 +440,25 @@ def train(conf,shot_list_train,shot_list_validate,loader):
385440
else:
386441
best_so_far = np.inf
387442
cmp_fn = min
388-
optimizer = opt.Adam(model.parameters(),lr = lr)
389-
model.train()
443+
optimizer = opt.Adam(train_model.parameters(),lr = lr)
444+
scheduler = opt.lr_scheduler.ExponentialLR(optimizer,lr_decay)
445+
train_model.train()
390446
not_updated = 0
391447
total_loss = 0
392448
count = 0
393-
loss_fn = nn.MSELoss(size_average=False)
394-
model_path = conf['paths']['model_save_path'] + model_filename #save_prepath + model_filename
395-
makedirs_process_safe(conf['paths']['model_save_path'])
449+
loss_fn = nn.MSELoss(size_average=True)
450+
model_path = get_model_path(conf)
451+
makedirs_process_safe(os.path.dirname(model_path))
396452
while e < num_epochs-1:
397-
print_unique('\nEpoch {}/{}'.format(e,num_epochs))
398-
(step,ave_loss,curr_loss,num_so_far,effective_epochs) = train_epoch(model,data_gen,loss_fn)
453+
scheduler.step()
454+
print('\nEpoch {}/{}'.format(e,num_epochs))
455+
(step,ave_loss,curr_loss,num_so_far,effective_epochs) = train_epoch(train_model,data_gen,optimizer,loss_fn)
399456
e = effective_epochs
400457
loader.verbose=False #True during the first iteration
401458
# if task_index == 0:
402459
# specific_builder.save_model_weights(train_model,int(round(e)))
403-
model.save_state_dict(model_path)
404-
_,_,_,roc_area,loss = mpi_make_predictions_and_evaluate(conf,shot_list_validate,loader)
460+
torch.save(train_model.state_dict(),model_path)
461+
_,_,_,roc_area,loss = make_predictions_and_evaluate_gpu(conf,shot_list_validate,loader)
405462

406463
best_so_far = cmp_fn(roc_area,best_so_far)
407464

@@ -411,54 +468,13 @@ def train(conf,shot_list_train,shot_list_validate,loader):
411468
print('Validation Loss: {:.3e}'.format(loss))
412469
print('Validation ROC: {:.4f}'.format(roc_area))
413470

414-
if best_so_far != epoch_logs[conf['callbacks']['monitor']]: #only save model weights if quantity we are tracking is improving
471+
if best_so_far != roc_area: #only save model weights if quantity we are tracking is improving
415472
print("No improvement, still saving model")
416473
not_updated += 1
417474
else:
418475
print("Saving model")
419-
model.save_state_dict(model_path)
420476
# specific_builder.delete_model_weights(train_model,int(round(e)))
421477
if not_updated > patience:
422478
print("Stopping training due to early stopping")
423479
break
424480

425-
def make_predictions(conf,shot_list,loader,custom_path=None):
426-
generator = loader.inference_batch_generator_full_shot(shot_list)
427-
inference_model = build_torch_model(conf)
428-
429-
if custom_path == None:
430-
model_path = conf['paths']['model_save_path'] + model_filename#save_prepath + model_filename
431-
else:
432-
model_path = custom_path
433-
inference_model.load_state_dict(model_path)
434-
#shot_list = shot_list.random_sublist(10)
435-
436-
y_prime = []
437-
y_gold = []
438-
disruptive = []
439-
num_shots = len(shot_list)
440-
441-
pbar = Progbar(num_shots)
442-
while True:
443-
x_,y_,mask_,disr_,num_so_far,num_total = next(generator)
444-
x, y, mask = Variable(torch.from_numpy(x_).float()), Variable(torch.from_numpy(y_).float()),Variable(torch.from_numpy(mask_).byte())
445-
output = model(x)
446-
for batch_idx in range(x.shape[0])
447-
y_prime[batch_idx] += [output[batch_idx,:,:]]
448-
y_gold += [y_[batch_idx,:,:]]
449-
disruptive += [disr[batch_idx]]
450-
pbar.add(1.0)
451-
if len(disruptive) >= num_shots:
452-
y_prime = y_prime[:num_shots]
453-
y_gold = y_gold[:num_shots]
454-
disruptive = disruptive[:num_shots]
455-
break
456-
return y_prime,y_gold,disruptive
457-
458-
def make_predictions_and_evaluate_gpu(conf,shot_list,loader,custom_path = None):
459-
y_prime,y_gold,disruptive = make_predictions(conf,shot_list,loader,custom_path)
460-
analyzer = PerformanceAnalyzer(conf=conf)
461-
roc_area = analyzer.get_roc_area(y_prime,y_gold,disruptive)
462-
loss = get_loss_from_list(y_prime,y_gold,conf['data']['target'])
463-
return y_prime,y_gold,disruptive,roc_area,loss
464-

0 commit comments

Comments
 (0)