Skip to content

Commit d521e93

Browse files
author
xuming06
committed
update seq2seq model. xuming 20180301
1 parent c2d9099 commit d521e93

2 files changed

Lines changed: 15 additions & 13 deletions

File tree

07keras/10seq2seq_trans.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
decoder_input_data = np.zeros((len(input_texts), max_decoder_seq_len, num_decoder_tokens), dtype='float32')
5252
decoder_target_data = np.zeros((len(input_texts), max_decoder_seq_len, num_decoder_tokens), dtype='float32')
5353

54+
# one hot representation
5455
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
5556
for t, char in enumerate(input_text):
5657
encoder_input_data[i, t, input_token_index[char]] = 1.0
@@ -138,7 +139,7 @@ def decode_sequence(input_seq):
138139
return decoded_sentence
139140

140141

141-
for seq_index in range(100):
142+
for seq_index in range(10):
142143
# take one sequence (part of the training set) for decoding.
143144
input_seq = encoder_input_data[seq_index:seq_index + 1]
144145
decoded_sentence = decode_sequence(input_seq)

17tensorflow/2_network/basic_seq2seq.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def decoding_layer(target_char_indices, decoding_embedding_size, num_layers, rnn
120120
:param decoder_input:
121121
:return:
122122
"""
123-
print('build model...')
124123
# embedding
125124
target_vocab_size = len(target_char_indices)
126125
decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))
@@ -185,6 +184,7 @@ def seq2seq(input_data, targets, lr, target_sequence_len,
185184
:param num_layers:
186185
:return:
187186
"""
187+
print('build model...')
188188
# get state output of encoder
189189
_, encoder_state = get_encoder_layer(input_data,
190190
rnn_size, num_layers,
@@ -282,13 +282,14 @@ def train():
282282
train_graph = tf.Graph()
283283
with train_graph.as_default():
284284
# get inputs
285-
input_data, targets, lr, target_sequence_len, target_sequence_maxlen, source_sequence_len = get_input()
286-
training_decoder_output, predicting_decoder_output = seq2seq(input_data, targets, lr, target_sequence_len,
285+
input_data, targets, learning_rate, target_sequence_len, target_sequence_maxlen, source_sequence_len = get_input()
286+
training_decoder_output, predicting_decoder_output = seq2seq(input_data, targets,
287+
learning_rate, target_sequence_len,
287288
target_sequence_maxlen, source_sequence_len,
288289
len(source_char_indices), len(target_char_indices),
289290
encoding_embedding_size, decoding_embedding_size,
290-
rnn_size, num_layers, target_char_indices,
291-
batch_size)
291+
rnn_size, num_layers,
292+
target_char_indices, batch_size)
292293
training_logits = tf.identity(training_decoder_output.rnn_output, 'logits')
293294
predicting_logits = tf.identity(predicting_decoder_output.sample_id, name='predictions')
294295

@@ -297,7 +298,7 @@ def train():
297298
# loss
298299
cost = tf.contrib.seq2seq.sequence_loss(training_logits, targets, masks)
299300
# optimizer
300-
optimizer = tf.train.AdamOptimizer(lr)
301+
optimizer = tf.train.AdamOptimizer(learning_rate)
301302
# gradient clipping
302303
gradients = optimizer.compute_gradients(cost)
303304
capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]
@@ -312,14 +313,14 @@ def train():
312313
_, loss = sess.run([train_op, cost],
313314
{input_data: sources_batch,
314315
targets: targets_batch,
315-
lr: learning_rate,
316+
learning_rate: learning_rate,
316317
target_sequence_len: targets_length,
317318
source_sequence_len: sources_lengths})
318319
if batch_i % display_step == 0:
319320
validation_loss = sess.run([cost],
320321
{input_data: valid_sources_batch,
321322
targets: valid_targets_batch,
322-
lr: learning_rate,
323+
learning_rate: learning_rate,
323324
target_sequence_len: valid_targets_lengths,
324325
source_sequence_len: valid_sources_lengths})
325326
print('Epoch {:>3}/{} Batch {:>4}/{} - Training Loss: {:>6.3f} - Validation Loss: {:>6.3f}'.format(
@@ -356,8 +357,8 @@ def infer():
356357
source_indices_char, source_char_indices = extract_char_vocab(source_data)
357358
target_indices_char, target_char_indices = extract_char_vocab(target_data)
358359

359-
input_word = 'common'
360-
text = source_2_seq(input_word)
360+
input_word = 'hello'
361+
text = source_2_seq(input_word, source_char_indices)
361362
loaded_graph = tf.Graph()
362363
with tf.Session(graph=loaded_graph) as sess:
363364
loader = tf.train.import_meta_graph(checkpoint + '.meta')
@@ -382,5 +383,5 @@ def infer():
382383

383384

384385
if __name__ == '__main__':
385-
train()
386-
# infer()
386+
# train()
387+
infer()

0 commit comments

Comments
 (0)