Skip to content

Commit 2c7d423

Browse files
author
xuming06
committed
add keras babi demo. xuming 20180223
1 parent e5ffca8 commit 2c7d423

2 files changed

Lines changed: 153 additions & 9 deletions

File tree

07keras/babi_rnn.py

Lines changed: 153 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,159 @@
11
# -*- coding: utf-8 -*-
22
# Author: XuMing <[email protected]>
33
# Brief:
4+
5+
import os
6+
import re
7+
import tarfile
8+
from functools import reduce
9+
10+
import keras
411
import numpy as np
12+
from keras.models import Model
13+
from keras.preprocessing.sequence import pad_sequences
14+
15+
16+
def tokenize(sentence):
17+
"""
18+
English segment
19+
:param sentence:
20+
:return:
21+
"""
22+
return [x.strip() for x in re.split('(\W+)?', sentence) if x.strip()]
23+
24+
25+
def parse_stroes(lines, only_supporting=False):
26+
"""
27+
Parse stories by bAbi task format
28+
:param lines:
29+
:param only_supporting:
30+
:return:
31+
"""
32+
data = []
33+
story = []
34+
for line in lines:
35+
line = line.decode('utf-8').strip()
36+
id, line = line.split(' ', 1)
37+
id = int(id)
38+
if id == 1:
39+
story = []
40+
if '\t' in line:
41+
q, a, support = line.split('\t')
42+
q = tokenize(q)
43+
substory = None
44+
if only_supporting:
45+
# only select the related substory
46+
support = map(int, support.split(' '))
47+
substory = [story[i - 1] for i in support]
48+
else:
49+
# get all the substory
50+
substory = [x for x in story if x]
51+
data.append((substory, q, a))
52+
story.append('')
53+
else:
54+
sent = tokenize(line)
55+
story.append(sent)
56+
return data
57+
58+
59+
def read_lines(path):
60+
lines = []
61+
with open(path, mode='r', encoding='utf-8') as f:
62+
for line in f:
63+
line = line.rstrip()
64+
if line:
65+
lines.append(line)
66+
return lines
67+
68+
69+
def get_stories(f, only_supporting=False, max_len=None):
70+
"""
71+
Get the stories with retrieve, and convert the sentences into a single story
72+
:param f:
73+
:param only_supporting:
74+
:param max_len:
75+
:return:
76+
"""
77+
data = parse_stroes(f.readlines(), only_supporting=only_supporting)
78+
flatten = lambda data: reduce(lambda x, y: x + y, data)
79+
data = [(flatten(story), q, a) for story, q, a in data if not max_len or len(flatten(story)) < max_len]
80+
return data
81+
82+
83+
def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
84+
idx_story = []
85+
idx_query = []
86+
idx_answer = []
87+
for story, query, answer in data:
88+
s = [word_idx[w] for w in story]
89+
q = [word_idx[w] for w in query]
90+
a = np.zeros(len(word_idx) + 1)
91+
a[word_idx[answer]] = 1
92+
idx_story.append(s)
93+
idx_query.append(q)
94+
idx_answer.append(a)
95+
return pad_sequences(idx_story, maxlen=story_maxlen), pad_sequences(idx_query, maxlen=query_maxlen), np.array(idx_answer)
96+
97+
98+
RNN = keras.layers.recurrent.LSTM
99+
EMBED_HIDDEN_SIZE = 50
100+
SENT_HIDDEN_SIZE = 100
101+
QUERY_HIDDEN_SIZE = 100
102+
BATCH_SIZE = 32
103+
EPOCH = 40
104+
print("RNN,Embed,Sent,Query={},{},{},{}".format(RNN, EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, QUERY_HIDDEN_SIZE))
105+
106+
challenge = 'tasks_1-20_v1-2/en/qa2_two-supporting-facts_{}.txt'
107+
pwd_path = os.path.abspath(os.path.dirname(__file__))
108+
print('pwd_path:', pwd_path)
109+
path = os.path.join(pwd_path, '../data/babi_tasks_1-20_v1-2.tar.gz')
110+
print('path:', path)
111+
with tarfile.open(path) as tar:
112+
train = get_stories(tar.extractfile(challenge.format('train')))
113+
test = get_stories(tar.extractfile(challenge.format('test')))
114+
115+
vocab = set()
116+
for story, q, a in train + test:
117+
vocab |= set(story + q + [a])
118+
vocab = sorted(vocab)
119+
120+
vocab_size = len(vocab) + 1
121+
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
122+
story_maxlen = max(map(len, (x for x, _, _ in train + test)))
123+
query_maxlen = max(map(len, (x for _, x, _ in train + test)))
124+
125+
idx_story, idx_query, idx_answer = vectorize_stories(train, word_idx, story_maxlen, query_maxlen)
126+
test_idx_story, test_idx_query, test_idx_answer = vectorize_stories(test, word_idx, story_maxlen, query_maxlen)
127+
print('vocab:', vocab)
128+
print('idx_story.shape:', idx_story.shape)
129+
print('idx_query.shape:', idx_query.shape)
130+
print('idx_answer.shape:', idx_answer.shape)
131+
print('story max len:', story_maxlen)
132+
print('query max len:', query_maxlen)
133+
134+
print('build model...')
135+
136+
sentence = keras.layers.Input(shape=(story_maxlen,), dtype='int32')
137+
encoded_sentence = keras.layers.Embedding(vocab_size, EMBED_HIDDEN_SIZE)(sentence)
138+
encoded_sentence = keras.layers.Dropout(0.3)(encoded_sentence)
139+
140+
question = keras.layers.Input(shape=(query_maxlen,), dtype='int32')
141+
encoded_question = keras.layers.Embedding(vocab_size, EMBED_HIDDEN_SIZE)(question)
142+
encoded_question = keras.layers.Dropout(0.3)(encoded_question)
143+
encoded_question = RNN(EMBED_HIDDEN_SIZE)(encoded_question)
144+
encoded_question = keras.layers.RepeatVector(story_maxlen)(encoded_question)
145+
146+
merged = keras.layers.add([encoded_sentence, encoded_question])
147+
merged = RNN(EMBED_HIDDEN_SIZE)(merged)
148+
merged = keras.layers.Dropout(0.3)(merged)
149+
preds = keras.layers.Dense(vocab_size, activation='softmax')(merged)
5150

6-
from keras.utils.data_utils import get_file
151+
model = Model([sentence, question], preds)
152+
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
7153

8-
try:
9-
path = get_file('babi-tasks-v1-2.tar.gz',
10-
origin='https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz')
11-
except:
12-
print('Error downloading dataset, please download it manually:\n'
13-
'$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz\n'
14-
'$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz')
15-
raise
154+
print('training')
155+
model.fit([idx_story, idx_query], idx_answer, batch_size=BATCH_SIZE, epochs=EPOCH, validation_split=0.05)
156+
loss, acc = model.evaluate([test_idx_story, test_idx_query], test_idx_answer, batch_size=BATCH_SIZE)
157+
print('Test loss / test accuracy= {:.4f} / {:.4f}'.format(loss, acc))
158+
# loss: 1.6114 - acc: 0.3758 - val_loss: 1.6661 - val_acc: 0.3800
159+
# Test loss / test accuracy= 1.6762 / 0.3050

data/babi_tasks_1-20_v1-2.tar.gz

11.2 MB
Binary file not shown.

0 commit comments

Comments
 (0)