@@ -120,7 +120,6 @@ def decoding_layer(target_char_indices, decoding_embedding_size, num_layers, rnn
120120 :param decoder_input:
121121 :return:
122122 """
123- print ('build model...' )
124123 # embedding
125124 target_vocab_size = len (target_char_indices )
126125 decoder_embeddings = tf .Variable (tf .random_uniform ([target_vocab_size , decoding_embedding_size ]))
@@ -185,6 +184,7 @@ def seq2seq(input_data, targets, lr, target_sequence_len,
185184 :param num_layers:
186185 :return:
187186 """
187+ print ('build model...' )
188188 # get state output of encoder
189189 _ , encoder_state = get_encoder_layer (input_data ,
190190 rnn_size , num_layers ,
@@ -282,13 +282,14 @@ def train():
282282 train_graph = tf .Graph ()
283283 with train_graph .as_default ():
284284 # get inputs
285- input_data , targets , lr , target_sequence_len , target_sequence_maxlen , source_sequence_len = get_input ()
286- training_decoder_output , predicting_decoder_output = seq2seq (input_data , targets , lr , target_sequence_len ,
285+ input_data , targets , learning_rate , target_sequence_len , target_sequence_maxlen , source_sequence_len = get_input ()
286+ training_decoder_output , predicting_decoder_output = seq2seq (input_data , targets ,
287+ learning_rate , target_sequence_len ,
287288 target_sequence_maxlen , source_sequence_len ,
288289 len (source_char_indices ), len (target_char_indices ),
289290 encoding_embedding_size , decoding_embedding_size ,
290- rnn_size , num_layers , target_char_indices ,
291- batch_size )
291+ rnn_size , num_layers ,
292+ target_char_indices , batch_size )
292293 training_logits = tf .identity (training_decoder_output .rnn_output , 'logits' )
293294 predicting_logits = tf .identity (predicting_decoder_output .sample_id , name = 'predictions' )
294295
@@ -297,7 +298,7 @@ def train():
297298 # loss
298299 cost = tf .contrib .seq2seq .sequence_loss (training_logits , targets , masks )
299300 # optimizer
300- optimizer = tf .train .AdamOptimizer (lr )
301+ optimizer = tf .train .AdamOptimizer (learning_rate )
301302 # gradient clipping
302303 gradients = optimizer .compute_gradients (cost )
303304 capped_gradients = [(tf .clip_by_value (grad , - 5. , 5. ), var ) for grad , var in gradients if grad is not None ]
@@ -312,14 +313,14 @@ def train():
312313 _ , loss = sess .run ([train_op , cost ],
313314 {input_data : sources_batch ,
314315 targets : targets_batch ,
315- lr : learning_rate ,
316+ learning_rate : learning_rate ,
316317 target_sequence_len : targets_length ,
317318 source_sequence_len : sources_lengths })
318319 if batch_i % display_step == 0 :
319320 validation_loss = sess .run ([cost ],
320321 {input_data : valid_sources_batch ,
321322 targets : valid_targets_batch ,
322- lr : learning_rate ,
323+ learning_rate : learning_rate ,
323324 target_sequence_len : valid_targets_lengths ,
324325 source_sequence_len : valid_sources_lengths })
325326 print ('Epoch {:>3}/{} Batch {:>4}/{} - Training Loss: {:>6.3f} - Validation Loss: {:>6.3f}' .format (
@@ -356,8 +357,8 @@ def infer():
356357 source_indices_char , source_char_indices = extract_char_vocab (source_data )
357358 target_indices_char , target_char_indices = extract_char_vocab (target_data )
358359
359- input_word = 'common '
360- text = source_2_seq (input_word )
360+ input_word = 'hello '
361+ text = source_2_seq (input_word , source_char_indices )
361362 loaded_graph = tf .Graph ()
362363 with tf .Session (graph = loaded_graph ) as sess :
363364 loader = tf .train .import_meta_graph (checkpoint + '.meta' )
@@ -382,5 +383,5 @@ def infer():
382383
383384
384385if __name__ == '__main__' :
385- train ()
386- # infer()
386+ # train()
387+ infer ()
0 commit comments