Skip to content

Commit 735e1c2

Browse files
author
Julian Kates-Harbeck
committed
remove bidirectional RNN since it violates causality. Add resilience for when batch_normalization is not defined in the conf
1 parent ac7df4e commit 735e1c2

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

plasma/models/builder.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,13 @@ def get_0D_1D_indices(self):
7272
def build_model(self,predict,custom_batch_size=None):
7373
conf = self.conf
7474
model_conf = conf['model']
75-
use_bidirectional = model_conf['use_bidirectional']
7675
rnn_size = model_conf['rnn_size']
7776
rnn_type = model_conf['rnn_type']
7877
regularization = model_conf['regularization']
7978
dense_regularization = model_conf['dense_regularization']
80-
use_batch_norm = model_conf['use_batch_norm']
79+
use_batch_norm = False
80+
if 'use_batch_norm' in model_conf:
81+
use_batch_norm = model_conf['use_batch_norm']
8182

8283
dropout_prob = model_conf['dropout_prob']
8384
length = model_conf['length']
@@ -187,12 +188,6 @@ def slicer_output_shape(input_shape,indices):
187188
x_input = Input(batch_shape = batch_input_shape)
188189
x_in = TimeDistributed(pre_rnn_model) (x_input)
189190

190-
if use_bidirectional:
191-
for _ in range(model_conf['rnn_layers']):
192-
x_in = Bidirectional(rnn_model(rnn_size, return_sequences=return_sequences,
193-
stateful=stateful,kernel_regularizer=l2(regularization),recurrent_regularizer=l2(regularization),
194-
bias_regularizer=l2(regularization),dropout=dropout_prob,recurrent_dropout=dropout_prob)) (x_in)
195-
x_in = Dropout(dropout_prob) (x_in)
196191
else:
197192
for _ in range(model_conf['rnn_layers']):
198193
x_in = rnn_model(rnn_size, return_sequences=return_sequences,#batch_input_shape=batch_input_shape,

0 commit comments

Comments
 (0)