Skip to content

Commit df951d2

Browse files
author
xuming06
committed
update keras with text generation. xuming 20180226
1 parent 89e3138 commit df951d2

3 files changed

Lines changed: 10043 additions & 0 deletions

File tree

07keras/11lstm_text_generation.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: XuMing <[email protected]>
3+
# Brief: generate text from hand writings
4+
5+
import os
6+
import random
7+
import sys
8+
9+
import numpy as np
10+
from keras.callbacks import LambdaCallback
11+
from keras.layers import Dense, Activation
12+
from keras.layers import LSTM
13+
from keras.models import Sequential
14+
from keras.optimizers import RMSprop
15+
16+
SAVE_MODEL_PATH = 'text_generation_model.h5'
17+
pwd_path = os.path.abspath(os.path.dirname(__file__))
18+
print('pwd_path:', pwd_path)
19+
data_path = os.path.join(pwd_path, '../data/nietzsche.txt')
20+
print('data path:', data_path)
21+
22+
23+
def get_corpus(data_path):
24+
with open(data_path, 'r', encoding='utf-8') as f:
25+
text = f.read().lower()
26+
return text
27+
28+
29+
text = get_corpus(data_path)
30+
print('corpus length:', len(text))
31+
32+
chars = sorted(list(set(text)))
33+
print('total chars:', len(chars))
34+
35+
char_indices = dict((c, i) for i, c in enumerate(chars))
36+
indices_char = dict((i, c) for i, c in enumerate(chars))
37+
38+
# cut sequences of max len chars
39+
maxlen = 40
40+
step = 3
41+
sentences = []
42+
next_chars = []
43+
44+
for i in range(0, len(text) - maxlen, step):
45+
sentences.append(text[i:i + maxlen])
46+
next_chars.append(text[i + maxlen])
47+
print('num sentences:', len(sentences))
48+
49+
print('vector...')
50+
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
51+
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
52+
for i, sentence in enumerate(sentences):
53+
for t, char in enumerate(sentence):
54+
x[i, t, char_indices[char]] = 1
55+
y[i, char_indices[next_chars[i]]] = 1
56+
57+
# build LSTM model
58+
print('build model...')
59+
model = Sequential()
60+
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
61+
model.add(Dense(len(chars)))
62+
model.add(Activation('softmax'))
63+
64+
model.compile(optimizer=RMSprop(lr=0.01), loss='categorical_crossentropy')
65+
66+
67+
def sample(preds, temperature=1.0):
68+
preds = np.asarray(preds).astype('float64')
69+
preds = np.log(preds) / temperature
70+
exp_preds = np.exp(preds)
71+
preds = exp_preds / np.sum(exp_preds)
72+
probs = np.random.multinomial(1, preds, 1)
73+
return np.argmax(probs)
74+
75+
76+
def on_epoch_end(epoch):
77+
# print generated text
78+
print('\n--- Generating text each epoch: %d' % epoch)
79+
start_index = random.randint(0, len(text) - maxlen - 1)
80+
for diversity in [0.2, 0.5, 1.0, 1.2]:
81+
print('--- diversity:', diversity)
82+
generated = ''
83+
sentence = text[start_index:start_index + maxlen]
84+
generated += sentences
85+
print('--- generating with:', sentence)
86+
sys.stdout.write(generated)
87+
88+
for i in range(400):
89+
x_pred = np.zero((1, maxlen, len(chars)))
90+
for t, char in enumerate(sentence):
91+
x_pred[0, t, char_indices[char]] = 1.0
92+
preds = model.predict(x_pred, verbose=0)[0]
93+
next_index = sample(preds, diversity)
94+
next_char = indices_char[next_index]
95+
96+
generated += next_char
97+
sentence = sentence[1:] + next_char
98+
sys.stdout.write(next_char)
99+
sys.stdout.flush()
100+
print()
101+
102+
103+
print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
104+
model.fit(x, y,
105+
batch_size=128,
106+
epochs=60,
107+
callbacks=[print_callback])
108+
model.save(SAVE_MODEL_PATH)

0 commit comments

Comments
 (0)