Skip to content

Commit 4fcc039

Browse files
author
xuming06
committed
update keras with fasttext network for multi classification. xuming 20180226
1 parent aeffa20 commit 4fcc039

3 files changed

Lines changed: 44 additions & 29 deletions

File tree

07keras/09fasttext_multi_classification.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# Bi-gram : 0.9056 test accuracy after 5 epochs.
55
import os
66

7+
import keras
78
import numpy as np
89
from keras.layers import Dense
910
from keras.layers import Embedding
1011
from keras.layers import GlobalAveragePooling1D
1112
from keras.models import Sequential
1213
from keras.preprocessing import sequence
13-
from keras.preprocessing.sequence import pad_sequences
1414

1515

1616
def get_corpus(data_dir):
@@ -24,21 +24,22 @@ def get_corpus(data_dir):
2424
for file_name in os.listdir(data_dir):
2525
with open(os.path.join(data_dir, file_name), mode='r', encoding='utf-8') as f:
2626
for line in f:
27-
parts = line.strip().split(',')
27+
parts = line.rstrip().split(',')
2828
if parts and len(parts) > 1:
29-
lbl = parts[0]
29+
# keras categorical label start with 0
30+
lbl = int(parts[0]) - 1
3031
sent = parts[1]
3132
sent_split = sent.split()
3233
words.append(sent_split)
3334
labels.append(lbl)
3435
return words, labels
3536

3637

37-
def vectorize_words(words, word_idx, maxlen):
38+
def vectorize_words(words, word_idx):
3839
inputs = []
3940
for word in words:
4041
inputs.append([word_idx[w] for w in word])
41-
return pad_sequences(inputs, maxlen=maxlen)
42+
return inputs
4243

4344

4445
def create_ngram_set(input_list, ngram_value=2):
@@ -58,6 +59,11 @@ def add_ngram(sequences, token_indice, ngram_range=2):
5859
:param token_indice:
5960
:param ngram_range:
6061
:return:
62+
Example: adding bi-gram
63+
>>> sequences = [[1, 3, 4, 5], [1, 3, 7, 9, 2]]
64+
>>> token_indice = {(1, 3): 1337, (9, 2): 42, (4, 5): 2017}
65+
>>> add_ngram(sequences, token_indice, ngram_range=2)
66+
[[1, 3, 4, 5, 1337, 2017], [1, 3, 7, 9, 2, 1337, 42]]
6167
"""
6268
new_seq = []
6369
for input in sequences:
@@ -72,11 +78,12 @@ def add_ngram(sequences, token_indice, ngram_range=2):
7278

7379

7480
ngram_range = 2
81+
num_classes = 3
7582
max_features = 20000
7683
max_len = 400
7784
batch_size = 32
78-
embedding_dims = 50
79-
epochs = 5
85+
embedding_dims = 200
86+
epochs = 10
8087
SAVE_MODEL_PATH = 'fasttext_multi_classification_model.h5'
8188
pwd_path = os.path.abspath(os.path.dirname(__file__))
8289
print('pwd_path:', pwd_path)
@@ -87,11 +94,10 @@ def add_ngram(sequences, token_indice, ngram_range=2):
8794
print('loading data...')
8895
x_train, y_train = get_corpus(train_data_dir)
8996
x_test, y_test = get_corpus(test_data_dir)
90-
91-
# Reserve 0 for masking via pad_sequences
97+
y_train = keras.utils.to_categorical(y_train)
98+
y_test = keras.utils.to_categorical(y_test)
9299

93100
sent_maxlen = max(map(len, (x for x in x_train + x_test)))
94-
95101
print('-')
96102
print('Sentence max length:', sent_maxlen, 'words')
97103
print('Number of training data:', len(x_train))
@@ -102,11 +108,21 @@ def add_ngram(sequences, token_indice, ngram_range=2):
102108
print('-')
103109
print('Vectorizing the word sequences...')
104110

105-
print(len(x_train), 'train seq')
106-
print(len(x_test), 'test seq')
107111
print('Average train sequence length: {}'.format(np.mean(list(map(len, x_train)), dtype=int)))
108112
print('Average test sequence length: {}'.format(np.mean(list(map(len, x_test)), dtype=int)))
109113

114+
vocab = set()
115+
for w in x_train + x_test:
116+
vocab |= set(w)
117+
vocab = sorted(vocab)
118+
vocab_size = len(vocab) + 1
119+
print('Vocab size:', vocab_size, 'unique words')
120+
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
121+
ids_2_word = dict((value, key) for key, value in word_idx.items())
122+
123+
x_train = vectorize_words(x_train, word_idx)
124+
x_test = vectorize_words(x_test, word_idx)
125+
110126
if ngram_range > 1:
111127
print('Adding {}-gram features'.format(ngram_range))
112128
# n-gram set from train data
@@ -130,22 +146,9 @@ def add_ngram(sequences, token_indice, ngram_range=2):
130146
print('Average train sequence length: {}'.format(train_mean_len))
131147
print('Average test sequence length: {}'.format(test_mean_len))
132148

133-
vocab = set()
134-
for w in x_train + x_test + y_test:
135-
vocab |= set(w)
136-
vocab = sorted(vocab)
137-
vocab_size = len(vocab) + 1
138-
print('Vocab size:', vocab_size, 'unique words')
139-
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
140-
ids_2_word = dict((value, key) for key, value in word_idx.items())
141-
142149
print('pad sequences (samples x time)')
143-
# x_train = sequence.pad_sequences(x_train, maxlen=max_len)
144-
# x_test = sequence.pad_sequences(x_test, maxlen=max_len)
145-
x_train = vectorize_words(x_train, word_idx, max_len)
146-
x_test = vectorize_words(x_test, word_idx, max_len)
147-
print('x_train shape:', x_train.shape)
148-
print('x_test shape:', x_test.shape)
150+
x_train = sequence.pad_sequences(x_train, maxlen=max_len)
151+
x_test = sequence.pad_sequences(x_test, maxlen=max_len)
149152

150153
print('build model...')
151154
model = Sequential()
@@ -166,5 +169,5 @@ def add_ngram(sequences, token_indice, ngram_range=2):
166169
print('save model:', SAVE_MODEL_PATH)
167170
probs = model.predict(x_test, batch_size=batch_size)
168171
assert len(probs) == len(y_test)
169-
for answer, prob in zip(y_test, probs):
170-
print('answer_test_index:%s\tprob_index:%s\tprob:%s' % (answer, prob.argmax(), prob.max()))
172+
for label, prob in zip(y_test, probs):
173+
print('label_test_index:%s\tprob_index:%s\tprob:%s' % (label.argmax(), prob.argmax(), prob.max()))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1,本报记者 发自 上海 国外 媒体 昨日 报道 澳大利亚 银行 acq arie 预计 推出 中国 人民币 10 亿元 商业 住房 抵押 贷款 资产 证券化 计划 有关部门 批准 将是 海外 资金 首次 此项 计划 市场分析 人士 计划 预计 中国 监管部门 阻力 考虑到 交易 相关 高昂 固定成本 人民币 10 亿元 可能是 最低 金额 银行 原本 计划 2006 年初 中国 推出 macquarie anda 房地产 投资信托 计划 香港特区 证监会 否决 该银行 中国 房地产投资 基金 首席 投资 执行官 此前 开发商 行列 竟是 金融机构 项目 投融资 资本运作 才是 特长
2+
2,复旦 新浪 本报记者 杨国强 1984年 相貌端正 复旦大学 新闻系 大学 同学 回忆说 内向 做事 很有 生活 学习 很有 计划性 大学毕业 上海 电视台 当了 两年 记者 赴美 求学 先在 奥克拉荷 大学 拿了 新闻学 硕士 再到 德州 奥斯汀 大学 拿了 财务 专业 硕士 转入 企业界 早就 美国 会计师 协会 美国 注册会计师 1993 1999 普华永道 工作 负责 硅谷 地区 高科技公司 提供 审计 服务 商业 咨询 在此期间 参与 多家 高科技公司 上市 1999 2000 财务 副总裁 身份 加盟 新浪 运作 新浪 美国 上市 参与 设计 中国 互联网 公司 海外 上市 结构 新浪 余家 中国概念股 上市 提供 借鉴 2001年 担任 新浪 cfo 2000 2001 推动 新浪 变了 照搬 美国 网络广告 销售 方式 改为 符合 中国 广告主 需求 时段 流量 模式 广告 主和 客户 肯定 这一 举措 新浪 互联网 广告 市场 领先地位 奠定 基础 2003年 主持 谈判 两次 并购 新浪 无线 市场 后来居上 稳定的 利润 2004年 6月 兼任 新浪 联席 营长 负责 网站 运营 广告 销售 市场 广告 销售 部门 重组 进了 系统化 销售 管理体系 新浪 2005年 广告 销售 业绩 增长率 年来 首次 超过 竞争对手 推动 博客 发展计划 赢得了 新浪博客 成功 2005 年度 中国 杰出 cfo 2005 年度 中国 广告 影响力 人物 荣誉 2005年 9月 升任 新浪 裁并 兼任 首席 财务 2006年 5月 10日 担任 新浪 ceo
3+
2,美国 太空 网站 4月 27日 报道 5月 12日 14日 之间 73p 瓦斯 3号 彗星 30 碎片 史无前例 地球 对此 美国 宇航局 科学家 反驳 碎片 撞击 地球 更不 会引起 大规模 海啸 生物 灭绝 灾难 美国 宇航局 科学家 5月 12日 5月 28日 之间 即便是 73p 瓦斯 3号 彗星 最接近 地球 轨道 距离 地球 碎片 地球 月球 距离 20 多倍 不会有 危险 科学家 提醒 利用 会对 彗星 观察 科学家 预计 碎片 中最 明亮 碎片 双筒望远镜 肉眼 观察到 n101
4+
3,化妆品 改善 皮肤 状况 表皮 角质化 过程 所需 时间 化妆品 三个月 会把 理想 皮肤 安全地 改善 预期 短期 化妆品 都是 加了 违禁 原料 皮肤 虽然在 天内 改善 很可能 导致 皮肤病 2. 植物 绿色 化妆品 作成 形态 装在 瓶子 出售 化妆品 不可能 不含 防腐剂 化学成分 迷信 化妆品 植物 纯天然 宣传 3. 化妆品 质量 越好 对照 成分 也许 发现 的产品 便宜 的产品 相差无几 的产品 配方 便宜 选购 化妆品 简单 办法 尝试 检测 合格 品牌 选择 不良反应 感觉 最舒服 那一

0 commit comments

Comments
 (0)