Skip to content

Commit 286d4e7

Browse files
author
xuming06
committed
add cnn text classification. xuming 20171017
1 parent ff29516 commit 286d4e7

7 files changed

Lines changed: 11087 additions & 0 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: XuMing <[email protected]>
3+
# Data: 17/10/17
4+
# Brief: 配置
5+
6+
config = {
7+
# data
8+
"dev_sample_percentage": 0.1, # percentage of the training data for validation
9+
"positive_data_file": "./data/en_polarity/pos.txt", # positive data
10+
"negative_data_file": "./data/en_polarity/neg.txt", # negative data
11+
12+
# model
13+
"embedding_dim": 128, # dimensionality of character embedding (default: 128)
14+
"filter_sizes": "3,4,5", # comma-separated filter size (default: "3,4,5")
15+
"num_filters": 128, # number of filters per filter size
16+
"dropout_keep_prob": 0.5, # dropout keep probability
17+
"l2_reg_lambda": 0.0, # l2 regulaization lambda
18+
19+
# train
20+
"batch_size": 64, # batch size (default: 64)
21+
"num_epochs": 200, # number of training epochs (default: 200)
22+
"evaluate_every": 100, # evaluate model on dev set after this many steps (default: 100)
23+
"checkpoint_every": 100, # save model after this many steps (default: 100)
24+
"num_checkpoints": 5, # number of checkpoints to store
25+
26+
# proto
27+
"allow_soft_placement": True, # allow device soft device placement
28+
"log_device_placement": False, # log placement of ops on devices
29+
}
30+
31+
evaluate = {
32+
"checkpoint_dir": "", # checkpoint directory from training run
33+
"eval_all_train_data": False, # evaluate on all training data
34+
}

17tensorflow/4_cnn_text_classification/data/en_polarity/neg.txt

Lines changed: 5331 additions & 0 deletions
Large diffs are not rendered by default.

17tensorflow/4_cnn_text_classification/data/en_polarity/pos.txt

Lines changed: 5330 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: XuMing <[email protected]>
3+
# Data: 17/10/16
4+
# Brief: 数据处理
5+
6+
import numpy as np
7+
import re
8+
import itertools
9+
from collections import Counter
10+
11+
12+
def clean_str(string):
13+
"""
14+
Tokenization cleaning for dataset
15+
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
16+
:param string:
17+
:return:
18+
"""
19+
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
20+
string = re.sub(r"\'s", " \'s", string)
21+
string = re.sub(r"\'ve", " \'ve", string)
22+
string = re.sub(r"n\'t", " n\'t", string)
23+
string = re.sub(r"\'re", " \'re", string)
24+
string = re.sub(r"\'d", " \'d", string)
25+
string = re.sub(r"\'ll", " \'ll", string)
26+
string = re.sub(r",", " , ", string)
27+
string = re.sub(r"!", " ! ", string)
28+
string = re.sub(r"\(", " \( ", string)
29+
string = re.sub(r"\)", " \) ", string)
30+
string = re.sub(r"\?", " \? ", string)
31+
string = re.sub(r"\s{2,}", " ", string)
32+
return string.strip().lower()
33+
34+
35+
def load_data_labels(positive_data_file, negative_data_file):
36+
"""
37+
Loads polarity data from files, splits data to words and labels
38+
:param positive_data_file:
39+
:param negative_data_file:
40+
:return: split sentence and labels
41+
"""
42+
positive_data = list(open(positive_data_file, "r", encoding="utf-8").readlines())
43+
positive_data = [s.strip() for s in positive_data]
44+
negative_data = list(open(negative_data_file, "r", encoding="utf-8").readlines())
45+
negative_data = [s.strip() for s in negative_data]
46+
# split by words
47+
x_text = positive_data + negative_data
48+
x_text = [clean_str(sent) for sent in x_text]
49+
# generate labels
50+
positive_labels = [[0, 1] for i in positive_data]
51+
negative_labels = [[1, 0] for i in negative_data]
52+
y = np.concatenate([positive_labels, negative_labels], 0)
53+
return [x_text, y]
54+
55+
56+
def batch_iter(data, batch_size, num_epochs, shuffle=True):
57+
"""
58+
Generate a batch iterator for dataset
59+
:param data:
60+
:param batch_size:
61+
:param num_epochs:
62+
:param shuffle:
63+
:return: batch iterator
64+
"""
65+
data = np.array(data)
66+
data_size = len(data)
67+
num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1
68+
for epoch in range(num_epochs):
69+
# shuffle the data at each epoch
70+
if shuffle:
71+
shuffle_indices = np.random.permutation(np.arange(data_size))
72+
shuffle_data = data[shuffle_indices]
73+
else:
74+
shuffle_data = data
75+
for batch_num in range(num_batches_per_epoch):
76+
start_index = batch_num * batch_size
77+
end_index = min((batch_num + 1) * batch_size, data_size)
78+
yield shuffle_data[start_index:end_index]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: XuMing <[email protected]>
3+
# Data: 17/10/16
4+
# Brief: cnn网络结构
5+
6+
import tensorflow as tf
7+
import numpy as np
8+
import os
9+
import time
10+
import datetime
11+
import data_helpers
12+
import csv
13+
from tensorflow.contrib import learn
14+
from text_cnn import TextCNN
15+
import config
16+
17+
# params
18+
print("\nparameters evaluate:")
19+
for k, v in config.evaluate.items():
20+
print("{}={}".format(k, v))
21+
22+
if config.evaluate["eval_all_train_data"]:
23+
x_raw, y_test = data_helpers.load_data_labels(config.config["positive_data_file"],
24+
config.config["negative_data_file"])
25+
y_test = np.argmax(y_test, axis=1)
26+
else:
27+
x_raw = ["In my opinion, this is Rembrandt's greatest work", "everything is off."]
28+
y_test = [1, 0]
29+
30+
# map data into vocabulary
31+
checkpoint_dir = config.evaluate["checkpoint_dir"]
32+
vocab_path = os.path.join(checkpoint_dir, "..", "vocab")
33+
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path)
34+
x_test = np.array(list(vocab_processor.transform(x_raw)))
35+
36+
print("\nEvluating...\n")
37+
38+
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
39+
graph = tf.Graph()
40+
with graph.as_default():
41+
session_conf = tf.ConfigProto(allow_soft_placement=config.config["allow_soft_placement"],
42+
log_device_placement=config.config["log_device_placement"])
43+
sess = tf.Session(config=session_conf)
44+
with sess.as_default():
45+
# load the saved meta graph and restore variables
46+
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
47+
saver.restore(sess, checkpoint_file)
48+
49+
# get the placeholders
50+
input_x = graph.get_operation_by_name("input_x").outputs[0]
51+
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
52+
53+
# evaluate
54+
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
55+
56+
# generate batches for one epoch
57+
batches = data_helpers.batch_iter(list(x_test), config.config["batch_size"], 1, shuffle=False)
58+
59+
# collect the predictions
60+
all_predictions = []
61+
for x_test_batch in batches:
62+
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
63+
all_predictions = np.concatenate([all_predictions, batch_predictions])
64+
65+
# print accuracy if y_test is defined
66+
if y_test is not None:
67+
correct_predictions = float(sum(all_predictions == y_test))
68+
print("Total number of test examples: {}".format(len(y_test)))
69+
print("Accuracy: {:g}".format(correct_predictions / float(len(y_test))))
70+
71+
# save the evaluation to csv
72+
predictions_human_readable = np.column_stack((np.array(x_raw), all_predictions))
73+
out_path = os.path.join(checkpoint_dir, "..", "prediction.csv")
74+
print("Saveing evaluation to {0}".format(out_path))
75+
with open(out_path, "w")as f:
76+
csv.writer(f).writerows(predictions_human_readable)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: XuMing <[email protected]>
3+
# Data: 17/10/16
4+
# Brief:
5+
import tensorflow as tf
6+
import numpy as np
7+
8+
9+
class TextCNN:
10+
"""
11+
CNN for text classification, sentiment analysis
12+
"""
13+
14+
def __init__(self, sequence_length, num_classes, vocab_size,
15+
embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0):
16+
self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
17+
self.input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y")
18+
self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
19+
20+
# l2 regularization loss
21+
l2_loss = tf.constant(0.0)
22+
23+
# embedding layer
24+
with tf.device("/cpu:0"), tf.name_scope("embedding"):
25+
self.W = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), name='W')
26+
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
27+
self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)
28+
29+
# create a convolution + maxpool layer for each filter size
30+
pooled_outputs = []
31+
for i, filter_size in enumerate(filter_sizes):
32+
with tf.name_scope("conv-maxpool-%s" % filter_size):
33+
# convolution layer
34+
filter_shape = [filter_size, embedding_size, 1, num_filters]
35+
W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
36+
b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b")
37+
conv = tf.nn.conv2d(self.embedded_chars_expanded, W, strides=[1, 1, 1, 1], padding="VALID", name="conv")
38+
39+
# apply nonlinear
40+
h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
41+
# maxpool
42+
pooled = tf.nn.max_pool(h, ksize=[1, sequence_length - filter_size + 1, 1, 1], strides=[1, 1, 1, 1],
43+
padding="VALID", name="pool")
44+
pooled_outputs.append(pooled)
45+
46+
# combine all pooled feature
47+
num_filters_total = num_filters * len(filter_sizes)
48+
self.h_pool = tf.concat(pooled_outputs, 3)
49+
self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])
50+
51+
# add dropout
52+
with tf.name_scope("dropout"):
53+
self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)
54+
55+
# final scores and predictions
56+
with tf.name_scope("output"):
57+
W = tf.get_variable("W", shape=[num_filters_total, num_classes],
58+
initializer=tf.contrib.layers.xavier_initializer())
59+
b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
60+
l2_loss += tf.nn.l2_loss(W)
61+
l2_loss += tf.nn.l2_loss(b)
62+
self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores")
63+
self.predictions = tf.argmax(self.scores, 1, name="predictions")
64+
65+
# calculate mean cross-entropy loss
66+
with tf.name_scope("loss"):
67+
losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y)
68+
self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
69+
70+
# accuracy
71+
with tf.name_scope("accuracy"):
72+
correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
73+
self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")

0 commit comments

Comments
 (0)