|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# Author: XuMing <[email protected]> |
| 3 | +# Brief: This example demonstrates the use of fasttext for text classification |
| 4 | +# Bi-gram : 0.9056 test accuracy after 5 epochs. |
| 5 | +import os |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from keras.layers import Dense |
| 9 | +from keras.layers import Embedding |
| 10 | +from keras.layers import GlobalAveragePooling1D |
| 11 | +from keras.models import Sequential |
| 12 | +from keras.preprocessing import sequence |
| 13 | +from keras.preprocessing.sequence import pad_sequences |
| 14 | + |
| 15 | + |
| 16 | +def get_corpus(data_dir): |
| 17 | + """ |
| 18 | + Get the corpus data with retrieve |
| 19 | + :param data_dir: |
| 20 | + :return: |
| 21 | + """ |
| 22 | + words = [] |
| 23 | + labels = [] |
| 24 | + for file_name in os.listdir(data_dir): |
| 25 | + with open(os.path.join(data_dir, file_name), mode='r', encoding='utf-8') as f: |
| 26 | + for line in f: |
| 27 | + parts = line.strip().split(',') |
| 28 | + if parts and len(parts) > 1: |
| 29 | + lbl = parts[0] |
| 30 | + sent = parts[1] |
| 31 | + sent_split = sent.split() |
| 32 | + words.append(sent_split) |
| 33 | + labels.append(lbl) |
| 34 | + return words, labels |
| 35 | + |
| 36 | + |
| 37 | +def vectorize_words(words, word_idx, maxlen): |
| 38 | + inputs = [] |
| 39 | + for word in words: |
| 40 | + inputs.append([word_idx[w] for w in word]) |
| 41 | + return pad_sequences(inputs, maxlen=maxlen) |
| 42 | + |
| 43 | + |
| 44 | +def create_ngram_set(input_list, ngram_value=2): |
| 45 | + """ |
| 46 | + Create a set of n-grams |
| 47 | + :param input_list: [1, 2, 3, 4, 9] |
| 48 | + :param ngram_value: 2 |
| 49 | + :return: {(1, 2),(2, 3),(3, 4),(4, 9)} |
| 50 | + """ |
| 51 | + return set(zip(*[input_list[i:] for i in range(ngram_value)])) |
| 52 | + |
| 53 | + |
| 54 | +def add_ngram(sequences, token_indice, ngram_range=2): |
| 55 | + """ |
| 56 | + Augment the input list by appending n-grams values |
| 57 | + :param sequences: |
| 58 | + :param token_indice: |
| 59 | + :param ngram_range: |
| 60 | + :return: |
| 61 | + """ |
| 62 | + new_seq = [] |
| 63 | + for input in sequences: |
| 64 | + new_list = input[:] |
| 65 | + for i in range(len(new_list) - ngram_range + 1): |
| 66 | + for ngram_value in range(2, ngram_range + 1): |
| 67 | + ngram = tuple(new_list[i:i + ngram_value]) |
| 68 | + if ngram in token_indice: |
| 69 | + new_list.append(token_indice[ngram]) |
| 70 | + new_seq.append(new_list) |
| 71 | + return new_seq |
| 72 | + |
| 73 | + |
| 74 | +ngram_range = 2 |
| 75 | +max_features = 20000 |
| 76 | +max_len = 400 |
| 77 | +batch_size = 32 |
| 78 | +embedding_dims = 50 |
| 79 | +epochs = 5 |
| 80 | +SAVE_MODEL_PATH = 'fasttext_multi_classification_model.h5' |
| 81 | +pwd_path = os.path.abspath(os.path.dirname(__file__)) |
| 82 | +print('pwd_path:', pwd_path) |
| 83 | +train_data_dir = os.path.join(pwd_path, '../data/sogou_classifier_data/train') |
| 84 | +test_data_dir = os.path.join(pwd_path, '../data/sogou_classifier_data/test') |
| 85 | +print('data_dir path:', train_data_dir) |
| 86 | + |
| 87 | +print('loading data...') |
| 88 | +x_train, y_train = get_corpus(train_data_dir) |
| 89 | +x_test, y_test = get_corpus(test_data_dir) |
| 90 | + |
| 91 | +# Reserve 0 for masking via pad_sequences |
| 92 | + |
| 93 | +sent_maxlen = max(map(len, (x for x in x_train + x_test))) |
| 94 | + |
| 95 | +print('-') |
| 96 | +print('Sentence max length:', sent_maxlen, 'words') |
| 97 | +print('Number of training data:', len(x_train)) |
| 98 | +print('Number of test data:', len(x_test)) |
| 99 | +print('-') |
| 100 | +print('Here\'s what a "sentence" tuple looks like (label, sentence):') |
| 101 | +print(y_train[0], x_train[0]) |
| 102 | +print('-') |
| 103 | +print('Vectorizing the word sequences...') |
| 104 | + |
| 105 | +print(len(x_train), 'train seq') |
| 106 | +print(len(x_test), 'test seq') |
| 107 | +print('Average train sequence length: {}'.format(np.mean(list(map(len, x_train)), dtype=int))) |
| 108 | +print('Average test sequence length: {}'.format(np.mean(list(map(len, x_test)), dtype=int))) |
| 109 | + |
| 110 | +if ngram_range > 1: |
| 111 | + print('Adding {}-gram features'.format(ngram_range)) |
| 112 | + # n-gram set from train data |
| 113 | + ngram_set = set() |
| 114 | + for input_list in x_train: |
| 115 | + for i in range(2, ngram_range + 1): |
| 116 | + ng_set = create_ngram_set(input_list, ngram_value=i) |
| 117 | + ngram_set.update(ng_set) |
| 118 | + # add to n-gram |
| 119 | + start_index = max_features + 1 |
| 120 | + token_indice = {v: k + start_index for k, v in enumerate(ngram_set)} |
| 121 | + indice_token = {token_indice[k]: k for k in token_indice} |
| 122 | + |
| 123 | + max_features = np.max(list(indice_token.keys())) + 1 |
| 124 | + # augment x_train and x_test with n-grams features |
| 125 | + x_train = add_ngram(x_train, token_indice, ngram_range) |
| 126 | + x_test = add_ngram(x_test, token_indice, ngram_range) |
| 127 | + |
| 128 | + train_mean_len = np.mean(list(map(len, x_train)), dtype=int) |
| 129 | + test_mean_len = np.mean(list(map(len, x_test)), dtype=int) |
| 130 | + print('Average train sequence length: {}'.format(train_mean_len)) |
| 131 | + print('Average test sequence length: {}'.format(test_mean_len)) |
| 132 | + |
| 133 | +vocab = set() |
| 134 | +for w in x_train + x_test + y_test: |
| 135 | + vocab |= set(w) |
| 136 | +vocab = sorted(vocab) |
| 137 | +vocab_size = len(vocab) + 1 |
| 138 | +print('Vocab size:', vocab_size, 'unique words') |
| 139 | +word_idx = dict((c, i + 1) for i, c in enumerate(vocab)) |
| 140 | +ids_2_word = dict((value, key) for key, value in word_idx.items()) |
| 141 | + |
| 142 | +print('pad sequences (samples x time)') |
| 143 | +# x_train = sequence.pad_sequences(x_train, maxlen=max_len) |
| 144 | +# x_test = sequence.pad_sequences(x_test, maxlen=max_len) |
| 145 | +x_train = vectorize_words(x_train, word_idx, max_len) |
| 146 | +x_test = vectorize_words(x_test, word_idx, max_len) |
| 147 | +print('x_train shape:', x_train.shape) |
| 148 | +print('x_test shape:', x_test.shape) |
| 149 | + |
| 150 | +print('build model...') |
| 151 | +model = Sequential() |
| 152 | + |
| 153 | +# embed layer by maps vocab index into emb dimensions |
| 154 | +model.add(Embedding(max_features, embedding_dims, input_length=max_len)) |
| 155 | +# pooling the embedding |
| 156 | +model.add(GlobalAveragePooling1D()) |
| 157 | +# output |
| 158 | +model.add(Dense(3, activation='softmax')) |
| 159 | + |
| 160 | +model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) |
| 161 | +model.fit(x_train, y_train, |
| 162 | + batch_size=batch_size, |
| 163 | + epochs=epochs, |
| 164 | + validation_data=(x_test, y_test)) |
| 165 | +model.save(SAVE_MODEL_PATH) |
| 166 | +print('save model:', SAVE_MODEL_PATH) |
| 167 | +probs = model.predict(x_test, batch_size=batch_size) |
| 168 | +assert len(probs) == len(y_test) |
| 169 | +for answer, prob in zip(y_test, probs): |
| 170 | + print('answer_test_index:%s\tprob_index:%s\tprob:%s' % (answer, prob.argmax(), prob.max())) |
0 commit comments