Skip to content

Commit 1093c07

Browse files
author
xuming06
committed
update tensorflow to multi classification. xuming 20180226
1 parent df951d2 commit 1093c07

12 files changed

Lines changed: 35292 additions & 35257 deletions

File tree

07keras/09fasttext_multi_classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def get_corpus(data_dir):
2424
for file_name in os.listdir(data_dir):
2525
with open(os.path.join(data_dir, file_name), mode='r', encoding='utf-8') as f:
2626
for line in f:
27-
parts = line.rstrip().split(',')
27+
# label in first sep
28+
parts = line.rstrip().split(',', 1)
2829
if parts and len(parts) > 1:
2930
# keras categorical label start with 0
3031
lbl = int(parts[0]) - 1

17tensorflow/4_cnn_text_classification/config.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,36 @@
22
# Author: XuMing <[email protected]>
33
# Data: 17/10/17
44
# Brief: 配置
5+
import os
56

6-
config = {
7-
# data
8-
"dev_sample_percentage": 0.1, # percentage of the training data for validation
9-
"positive_data_file": "./data/zh_polarity/pos.txt", # positive data
10-
"negative_data_file": "./data/zh_polarity/neg.txt", # negative data
117

12-
# model
13-
"embedding_dim": 128, # dimensionality of character embedding (default: 128)
14-
"filter_sizes": "3,4,5", # comma-separated filter size (default: "3,4,5")
15-
"num_filters": 128, # number of filters per filter size
16-
"dropout_keep_prob": 0.5, # dropout keep probability
17-
"l2_reg_lambda": 0.0, # l2 regulaization lambda
8+
# data
9+
dev_sample_percentage = 0.1 # percentage of the training data for validation
10+
data_dir = "./data/zh_polarity" # data file path
1811

19-
# train
20-
"batch_size": 64, # batch size (default: 64)
21-
"num_epochs": 200, # number of training epochs (default: 200)
22-
"evaluate_every": 100, # evaluate model on dev set after this many steps (default: 100)
23-
"checkpoint_every": 100, # save model after this many steps (default: 100)
24-
"num_checkpoints": 5, # number of checkpoints to store
12+
# model
13+
embedding_dim = 128 # dimensionality of character embedding (default: 128)
14+
filter_sizes = "3,4,5" # comma-separated filter size (default: "3,4,5")
15+
num_filters = 128 # number of filters per filter size
16+
dropout_keep_prob = 0.5 # dropout keep probability
17+
l2_reg_lambda = 0.0 # l2 regulaization lambda
2518

26-
# proto
27-
"allow_soft_placement": True, # allow device soft device placement
28-
"log_device_placement": False, # log placement of ops on devices
29-
}
19+
# train
20+
batch_size = 64 # batch size (default: 64)
21+
num_epochs = 200 # number of training epochs (default: 200)
22+
evaluate_every = 100 # evaluate model on dev set after this many steps (default: 100)
23+
checkpoint_every = 100 # save model after this many epochs (default: 100)
24+
num_checkpoints = 5 # number of checkpoints to store
3025

31-
evaluate = {
32-
"infer_data": "./data/input_data.txt", # infer data
33-
"checkpoint_dir": "runs/20171020-1508503142/checkpoints", # checkpoint directory from training run
34-
"eval_all_train_data": False, # evaluate on all training data
35-
}
26+
# proto
27+
allow_soft_placement = True # allow device soft device placement
28+
log_device_placement = False # log placement of ops on devices
29+
30+
infer_data_path = "./data/input_data.txt" # infer data
31+
checkpoint_dir = "./models/checkpoints" # checkpoint directory from training run
32+
eval_all_train_data = False # evaluate on all training data
33+
34+
# directory to save the trained model
35+
# create a new directory if the dir does not exist
36+
if not os.path.exists(checkpoint_dir):
37+
os.mkdir(checkpoint_dir)

17tensorflow/4_cnn_text_classification/data/zh_polarity/neg.txt

Lines changed: 18576 additions & 18576 deletions
Large diffs are not rendered by default.

17tensorflow/4_cnn_text_classification/data/zh_polarity/neg_sample.txt

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)