11# -*- coding: utf-8 -*-
22# Author: XuMing <[email protected] > 33# Brief:
4+
5+ import os
6+ import re
7+ import tarfile
8+ from functools import reduce
9+
10+ import keras
411import numpy as np
12+ from keras .models import Model
13+ from keras .preprocessing .sequence import pad_sequences
14+
15+
16+ def tokenize (sentence ):
17+ """
18+ English segment
19+ :param sentence:
20+ :return:
21+ """
22+ return [x .strip () for x in re .split ('(\W+)?' , sentence ) if x .strip ()]
23+
24+
25+ def parse_stroes (lines , only_supporting = False ):
26+ """
27+ Parse stories by bAbi task format
28+ :param lines:
29+ :param only_supporting:
30+ :return:
31+ """
32+ data = []
33+ story = []
34+ for line in lines :
35+ line = line .decode ('utf-8' ).strip ()
36+ id , line = line .split (' ' , 1 )
37+ id = int (id )
38+ if id == 1 :
39+ story = []
40+ if '\t ' in line :
41+ q , a , support = line .split ('\t ' )
42+ q = tokenize (q )
43+ substory = None
44+ if only_supporting :
45+ # only select the related substory
46+ support = map (int , support .split (' ' ))
47+ substory = [story [i - 1 ] for i in support ]
48+ else :
49+ # get all the substory
50+ substory = [x for x in story if x ]
51+ data .append ((substory , q , a ))
52+ story .append ('' )
53+ else :
54+ sent = tokenize (line )
55+ story .append (sent )
56+ return data
57+
58+
59+ def read_lines (path ):
60+ lines = []
61+ with open (path , mode = 'r' , encoding = 'utf-8' ) as f :
62+ for line in f :
63+ line = line .rstrip ()
64+ if line :
65+ lines .append (line )
66+ return lines
67+
68+
69+ def get_stories (f , only_supporting = False , max_len = None ):
70+ """
71+ Get the stories with retrieve, and convert the sentences into a single story
72+ :param f:
73+ :param only_supporting:
74+ :param max_len:
75+ :return:
76+ """
77+ data = parse_stroes (f .readlines (), only_supporting = only_supporting )
78+ flatten = lambda data : reduce (lambda x , y : x + y , data )
79+ data = [(flatten (story ), q , a ) for story , q , a in data if not max_len or len (flatten (story )) < max_len ]
80+ return data
81+
82+
83+ def vectorize_stories (data , word_idx , story_maxlen , query_maxlen ):
84+ idx_story = []
85+ idx_query = []
86+ idx_answer = []
87+ for story , query , answer in data :
88+ s = [word_idx [w ] for w in story ]
89+ q = [word_idx [w ] for w in query ]
90+ a = np .zeros (len (word_idx ) + 1 )
91+ a [word_idx [answer ]] = 1
92+ idx_story .append (s )
93+ idx_query .append (q )
94+ idx_answer .append (a )
95+ return pad_sequences (idx_story , maxlen = story_maxlen ), pad_sequences (idx_query , maxlen = query_maxlen ), np .array (idx_answer )
96+
97+
98+ RNN = keras .layers .recurrent .LSTM
99+ EMBED_HIDDEN_SIZE = 50
100+ SENT_HIDDEN_SIZE = 100
101+ QUERY_HIDDEN_SIZE = 100
102+ BATCH_SIZE = 32
103+ EPOCH = 40
104+ print ("RNN,Embed,Sent,Query={},{},{},{}" .format (RNN , EMBED_HIDDEN_SIZE , SENT_HIDDEN_SIZE , QUERY_HIDDEN_SIZE ))
105+
106+ challenge = 'tasks_1-20_v1-2/en/qa2_two-supporting-facts_{}.txt'
107+ pwd_path = os .path .abspath (os .path .dirname (__file__ ))
108+ print ('pwd_path:' , pwd_path )
109+ path = os .path .join (pwd_path , '../data/babi_tasks_1-20_v1-2.tar.gz' )
110+ print ('path:' , path )
111+ with tarfile .open (path ) as tar :
112+ train = get_stories (tar .extractfile (challenge .format ('train' )))
113+ test = get_stories (tar .extractfile (challenge .format ('test' )))
114+
115+ vocab = set ()
116+ for story , q , a in train + test :
117+ vocab |= set (story + q + [a ])
118+ vocab = sorted (vocab )
119+
120+ vocab_size = len (vocab ) + 1
121+ word_idx = dict ((c , i + 1 ) for i , c in enumerate (vocab ))
122+ story_maxlen = max (map (len , (x for x , _ , _ in train + test )))
123+ query_maxlen = max (map (len , (x for _ , x , _ in train + test )))
124+
125+ idx_story , idx_query , idx_answer = vectorize_stories (train , word_idx , story_maxlen , query_maxlen )
126+ test_idx_story , test_idx_query , test_idx_answer = vectorize_stories (test , word_idx , story_maxlen , query_maxlen )
127+ print ('vocab:' , vocab )
128+ print ('idx_story.shape:' , idx_story .shape )
129+ print ('idx_query.shape:' , idx_query .shape )
130+ print ('idx_answer.shape:' , idx_answer .shape )
131+ print ('story max len:' , story_maxlen )
132+ print ('query max len:' , query_maxlen )
133+
134+ print ('build model...' )
135+
136+ sentence = keras .layers .Input (shape = (story_maxlen ,), dtype = 'int32' )
137+ encoded_sentence = keras .layers .Embedding (vocab_size , EMBED_HIDDEN_SIZE )(sentence )
138+ encoded_sentence = keras .layers .Dropout (0.3 )(encoded_sentence )
139+
140+ question = keras .layers .Input (shape = (query_maxlen ,), dtype = 'int32' )
141+ encoded_question = keras .layers .Embedding (vocab_size , EMBED_HIDDEN_SIZE )(question )
142+ encoded_question = keras .layers .Dropout (0.3 )(encoded_question )
143+ encoded_question = RNN (EMBED_HIDDEN_SIZE )(encoded_question )
144+ encoded_question = keras .layers .RepeatVector (story_maxlen )(encoded_question )
145+
146+ merged = keras .layers .add ([encoded_sentence , encoded_question ])
147+ merged = RNN (EMBED_HIDDEN_SIZE )(merged )
148+ merged = keras .layers .Dropout (0.3 )(merged )
149+ preds = keras .layers .Dense (vocab_size , activation = 'softmax' )(merged )
5150
6- from keras .utils .data_utils import get_file
151+ model = Model ([sentence , question ], preds )
152+ model .compile (optimizer = 'adam' , loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
7153
8- try :
9- path = get_file ('babi-tasks-v1-2.tar.gz' ,
10- origin = 'https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz' )
11- except :
12- print ('Error downloading dataset, please download it manually:\n '
13- '$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz\n '
14- '$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz' )
15- raise
154+ print ('training' )
155+ model .fit ([idx_story , idx_query ], idx_answer , batch_size = BATCH_SIZE , epochs = EPOCH , validation_split = 0.05 )
156+ loss , acc = model .evaluate ([test_idx_story , test_idx_query ], test_idx_answer , batch_size = BATCH_SIZE )
157+ print ('Test loss / test accuracy= {:.4f} / {:.4f}' .format (loss , acc ))
158+ # loss: 1.6114 - acc: 0.3758 - val_loss: 1.6661 - val_acc: 0.3800
159+ # Test loss / test accuracy= 1.6762 / 0.3050
0 commit comments