Skip to content

Commit 1743838

Browse files
author
xuming06
committed
add rnn network to learn add and multiplication. xuming 20180227
1 parent 1093c07 commit 1743838

3 files changed

Lines changed: 312 additions & 2 deletions

File tree

07keras/12rnn_num_add.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: XuMing <[email protected]>
3+
# Brief: 深度网络学习加法运算
4+
5+
# input '100+100'
6+
# output '200'
7+
8+
import numpy as np
9+
from keras import layers
10+
from keras.models import Sequential
11+
from six.moves import range
12+
13+
14+
class CharTable(object):
15+
"""
16+
Give a set of chars:
17+
encode chars to a one hot integer representation
18+
decode the one hot integer representation to their char output
19+
decode a vector of probs to their char output
20+
"""
21+
22+
def __init__(self, chars):
23+
self.chars = sorted(set(chars))
24+
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
25+
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
26+
27+
def encode(self, C, num_rows):
28+
"""
29+
One hot encode given string C.
30+
:param C:
31+
:param num_rows: number of rows in the returned one hot encoding.
32+
:return:
33+
"""
34+
x = np.zeros((num_rows, len(self.chars)))
35+
for i, c in enumerate(C):
36+
x[i, self.char_indices[c]] = 1
37+
return x
38+
39+
def decode(self, x, calc_argmax=True):
40+
if calc_argmax:
41+
x = x.argmax(axis=-1)
42+
return ''.join(self.indices_char[x] for x in x)
43+
44+
45+
# parameters
46+
TRAINING_SIZE = 50000
47+
DIGITS = 3 # max output is 999+999=1998
48+
INVERT = True
49+
50+
MAXLEN = DIGITS + 1 + DIGITS
51+
chars = '0123456789+ '
52+
ctable = CharTable(chars)
53+
54+
questions = []
55+
expected = []
56+
seen = set()
57+
print('make data...')
58+
while len(questions) < TRAINING_SIZE:
59+
f = lambda: int(''.join(np.random.choice(list('0123456789'))
60+
for i in range(np.random.randint(1, DIGITS + 1))))
61+
a, b = f(), f()
62+
# skip questions have seen
63+
# skip any such as 'a+b=b+a'
64+
key = tuple(sorted((a, b)))
65+
if key in seen:
66+
continue
67+
seen.add(key)
68+
# pad the data with spaces such that it is always maxlen
69+
q = '{}+{}'.format(a, b)
70+
query = q + ' ' * (MAXLEN - len(q))
71+
ans = str(a + b)
72+
# answers: max size of digits+1
73+
ans += ' ' * (DIGITS + 1 - len(ans))
74+
if INVERT:
75+
query = query[::-1]
76+
questions.append(query)
77+
expected.append(ans)
78+
print('TOTAL questions:', len(questions))
79+
80+
print('vector...')
81+
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
82+
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
83+
for i, sentence in enumerate(questions):
84+
x[i] = ctable.encode(sentence, MAXLEN)
85+
for i, sentence in enumerate(expected):
86+
y[i] = ctable.encode(sentence, DIGITS + 1)
87+
88+
# shuffle (x,y)
89+
indices = np.arange(len(y))
90+
np.random.shuffle(indices)
91+
x = x[indices]
92+
y = y[indices]
93+
94+
# explicitly set apart 10% for valid
95+
split_at = len(x) - len(x) // 10
96+
(x_train, x_val) = x[:split_at], x[split_at:]
97+
(y_train, y_val) = y[:split_at], y[split_at:]
98+
99+
print('training data:')
100+
print(x_train.shape)
101+
print(y_train.shape)
102+
print('val data:')
103+
print(x_val.shape)
104+
print(y_val.shape)
105+
106+
# RNN, replace by GRU or SimpleRNN
107+
RNN = layers.LSTM
108+
HIDDEN_SIZE = 128
109+
BATCH_SIZE = 128
110+
LAYERS = 1
111+
112+
print('Build model...')
113+
model = Sequential()
114+
# encode the input sentence
115+
model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars))))
116+
# same as the decoder RNN input, repeatedly provide the last hidden state of RNN for each time step
117+
# repeat 'digits+1' times to the max length of output
118+
model.add(layers.RepeatVector(DIGITS + 1))
119+
for i in range(LAYERS):
120+
# return sequences of (num_samples, timesteps, output_dim)
121+
model.add(RNN(HIDDEN_SIZE, return_sequences=True))
122+
123+
# apply a dense layer to the every temporal slice of an input.
124+
model.add(layers.TimeDistributed(layers.Dense(len(chars))))
125+
model.add(layers.Activation('softmax'))
126+
model.compile(optimizer='adam',
127+
loss='categorical_crossentropy',
128+
metrics=['accuracy'])
129+
model.summary()
130+
131+
# train
132+
for iteration in range(1, 20):
133+
print()
134+
print('-')
135+
print('iteration', iteration)
136+
model.fit(x_train, y_train,
137+
batch_size=BATCH_SIZE,
138+
epochs=1,
139+
validation_data=(x_val, y_val))
140+
# select 10 sample from the validation set at random to visualize errors
141+
for i in range(10):
142+
ind = np.random.randint(0, len(x_val))
143+
rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
144+
preds = model.predict_classes(rowx, verbose=0)
145+
q = ctable.decode(rowx[0])
146+
correct = ctable.decode(rowy[0])
147+
guess = ctable.decode(preds[0], calc_argmax=False)
148+
print('Q', q[::-1] if INVERT else q, end=' ')
149+
print('T', correct, end=' ')
150+
if correct == guess:
151+
print('\033[92m' + '☑' + '\033[0m', end=' ')
152+
else:
153+
print('\033[91m' + '☒' + '\033[0m', end=' ')
154+
print(guess)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: XuMing <[email protected]>
3+
# Brief: 深度网络学习乘法运算
4+
# 本来觉得这个学习能力萌萌哒,想到深度网络优势是解决异或问题,呵呵哒,杀鸡用牛刀了
5+
6+
# input '100*100'
7+
# output '10000'
8+
9+
import numpy as np
10+
from keras import layers
11+
from keras.models import Sequential
12+
from six.moves import range
13+
14+
15+
class CharTable(object):
16+
"""
17+
Give a set of chars:
18+
encode chars to a one hot integer representation
19+
decode the one hot integer representation to their char output
20+
decode a vector of probs to their char output
21+
"""
22+
23+
def __init__(self, chars):
24+
self.chars = sorted(set(chars))
25+
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
26+
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
27+
28+
def encode(self, C, num_rows):
29+
"""
30+
One hot encode given string C.
31+
:param C:
32+
:param num_rows: number of rows in the returned one hot encoding.
33+
:return:
34+
"""
35+
x = np.zeros((num_rows, len(self.chars)))
36+
for i, c in enumerate(C):
37+
x[i, self.char_indices[c]] = 1
38+
return x
39+
40+
def decode(self, x, calc_argmax=True):
41+
if calc_argmax:
42+
x = x.argmax(axis=-1)
43+
return ''.join(self.indices_char[x] for x in x)
44+
45+
46+
# parameters
47+
TRAINING_SIZE = 50000
48+
DIGITS = 3 # max output is 999*999=998001
49+
INVERT = True
50+
51+
MAXLEN = DIGITS + 1 + DIGITS # len('999*999') = 7
52+
chars = '0123456789* '
53+
ctable = CharTable(chars)
54+
55+
questions = []
56+
expected = []
57+
seen = set()
58+
print('make data...')
59+
while len(questions) < TRAINING_SIZE:
60+
# num len is 3
61+
f = lambda: int(''.join(np.random.choice(list('0123456789'))
62+
for i in range(np.random.randint(1, DIGITS + 1))))
63+
a, b = f(), f()
64+
# skip questions have seen
65+
# skip any such as 'a*b=b*a'
66+
key = tuple(sorted((a, b)))
67+
if key in seen:
68+
continue
69+
seen.add(key)
70+
# pad the data with spaces such that it is always maxlen
71+
q = '{}*{}'.format(a, b)
72+
query = q + ' ' * (MAXLEN - len(q))
73+
ans = str(a * b)
74+
# answers: max size of digits+1
75+
ans += ' ' * (DIGITS * 2 - len(ans))
76+
if INVERT:
77+
query = query[::-1]
78+
questions.append(query)
79+
expected.append(ans)
80+
print('TOTAL questions:', len(questions))
81+
82+
print('vector...')
83+
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
84+
y = np.zeros((len(questions), DIGITS * 2, len(chars)), dtype=np.bool)
85+
for i, sentence in enumerate(questions):
86+
x[i] = ctable.encode(sentence, MAXLEN)
87+
for i, sentence in enumerate(expected):
88+
y[i] = ctable.encode(sentence, DIGITS * 2) # len('998001') = 6
89+
90+
# shuffle (x,y)
91+
indices = np.arange(len(y))
92+
np.random.shuffle(indices)
93+
x = x[indices]
94+
y = y[indices]
95+
96+
# explicitly set apart 10% for valid
97+
split_at = len(x) - len(x) // 10
98+
(x_train, x_val) = x[:split_at], x[split_at:]
99+
(y_train, y_val) = y[:split_at], y[split_at:]
100+
101+
print('training data:')
102+
print(x_train.shape)
103+
print(y_train.shape)
104+
print('val data:')
105+
print(x_val.shape)
106+
print(y_val.shape)
107+
108+
# RNN, replace by GRU or SimpleRNN
109+
RNN = layers.LSTM
110+
HIDDEN_SIZE = 128
111+
BATCH_SIZE = 128
112+
LAYERS = 1
113+
114+
print('Build model...')
115+
model = Sequential()
116+
# encode the input sentence
117+
model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars))))
118+
# same as the decoder RNN input, repeatedly provide the last hidden state of RNN for each time step
119+
# repeat 'digits*2' times to the max length of output
120+
model.add(layers.RepeatVector(DIGITS * 2))
121+
for i in range(LAYERS):
122+
# return sequences of (num_samples, timesteps, output_dim)
123+
model.add(RNN(HIDDEN_SIZE, return_sequences=True))
124+
125+
# apply a dense layer to the every temporal slice of an input.
126+
model.add(layers.TimeDistributed(layers.Dense(len(chars))))
127+
model.add(layers.Activation('softmax'))
128+
model.compile(optimizer='adam',
129+
loss='categorical_crossentropy',
130+
metrics=['accuracy'])
131+
model.summary()
132+
133+
# train
134+
for iteration in range(1, 20):
135+
print()
136+
print('-')
137+
print('iteration', iteration)
138+
model.fit(x_train, y_train,
139+
batch_size=BATCH_SIZE,
140+
epochs=1,
141+
validation_data=(x_val, y_val))
142+
# select 10 sample from the validation set at random to visualize errors
143+
for i in range(10):
144+
ind = np.random.randint(0, len(x_val))
145+
rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
146+
preds = model.predict_classes(rowx, verbose=0)
147+
q = ctable.decode(rowx[0])
148+
correct = ctable.decode(rowy[0])
149+
guess = ctable.decode(preds[0], calc_argmax=False)
150+
print('Question', q[::-1] if INVERT else q, end=' ')
151+
print('Answer', correct, end=' ')
152+
if correct == guess:
153+
print('\033[92m' + '☑' + '\033[0m', end=' ')
154+
else:
155+
print('\033[91m' + '☒' + '\033[0m', end=' ')
156+
print(guess)

17tensorflow/4_cnn_text_classification/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
# train
2020
batch_size = 64 # batch size (default: 64)
21-
num_epochs = 200 # number of training epochs (default: 200)
21+
num_epochs = 5 # number of training epochs (default: 5)
2222
evaluate_every = 100 # evaluate model on dev set after this many steps (default: 100)
23-
checkpoint_every = 100 # save model after this many epochs (default: 100)
23+
checkpoint_every = 100 # save model after this many steps (default: 100)
2424
num_checkpoints = 5 # number of checkpoints to store
2525

2626
# proto

0 commit comments

Comments
 (0)