Skip to content

Commit 0aa2171

Browse files
author
xuming06
committed
add fasttext.
1 parent 24f6fa8 commit 0aa2171

5 files changed

Lines changed: 3816 additions & 0 deletions

File tree

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@author:XuMing([email protected])
4+
@description:
5+
"""
6+
7+
from typing import Iterator, List, Dict
8+
import torch
9+
import torch.optim as optim
10+
import numpy as np
11+
from allennlp.data import Instance
12+
from allennlp.data.fields import TextField, SequenceLabelField
13+
from allennlp.data.dataset_readers import DatasetReader
14+
from allennlp.common.file_utils import cached_path
15+
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
16+
from allennlp.data.tokenizers import Token
17+
from allennlp.data.vocabulary import Vocabulary
18+
from allennlp.models import Model
19+
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
20+
from allennlp.modules.token_embedders import Embedding
21+
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
22+
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
23+
from allennlp.training.metrics import CategoricalAccuracy
24+
from allennlp.data.iterators import BucketIterator
25+
from allennlp.training.trainer import Trainer
26+
from allennlp.predictors import SentenceTaggerPredictor
27+
28+
torch.manual_seed(1)
29+
30+
31+
class PosDatasetReader(DatasetReader):
32+
"""
33+
DatasetReader for PoS tagging data, one sentence per line, like
34+
35+
The###DET dog###NN ate###V the###DET apple###NN
36+
"""
37+
38+
def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
39+
super().__init__(lazy=False)
40+
self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
41+
42+
def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
43+
sentence_field = TextField(tokens, self.token_indexers)
44+
fields = {"sentence": sentence_field}
45+
46+
if tags:
47+
label_field = SequenceLabelField(labels=tags, sequence_field=sentence_field)
48+
fields["labels"] = label_field
49+
50+
return Instance(fields)
51+
52+
def _read(self, file_path: str) -> Iterator[Instance]:
53+
with open(file_path) as f:
54+
for line in f:
55+
pairs = line.strip().split()
56+
sentence, tags = zip(*(pair.split("###") for pair in pairs))
57+
yield self.text_to_instance([Token(word) for word in sentence], tags)
58+
59+
60+
class LstmTagger(Model):
61+
def __init__(self,
62+
word_embeddings: TextFieldEmbedder,
63+
encoder: Seq2SeqEncoder,
64+
vocab: Vocabulary) -> None:
65+
super().__init__(vocab)
66+
self.word_embeddings = word_embeddings
67+
self.encoder = encoder
68+
self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
69+
out_features=vocab.get_vocab_size('labels'))
70+
self.accuracy = CategoricalAccuracy()
71+
72+
def forward(self,
73+
sentence: Dict[str, torch.Tensor],
74+
labels: torch.Tensor = None) -> torch.Tensor:
75+
mask = get_text_field_mask(sentence)
76+
embeddings = self.word_embeddings(sentence)
77+
encoder_out = self.encoder(embeddings, mask)
78+
tag_logits = self.hidden2tag(encoder_out)
79+
output = {"tag_logits": tag_logits}
80+
if labels is not None:
81+
self.accuracy(tag_logits, labels, mask)
82+
output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)
83+
84+
return output
85+
86+
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
87+
return {"accuracy": self.accuracy.get_metric(reset)}
88+
89+
90+
reader = PosDatasetReader()
91+
train_dataset = reader.read(cached_path(
92+
'https://raw.githubusercontent.com/allenai/allennlp'
93+
'/master/tutorials/tagger/training.txt'))
94+
validation_dataset = reader.read(cached_path(
95+
'https://raw.githubusercontent.com/allenai/allennlp'
96+
'/master/tutorials/tagger/validation.txt'))
97+
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)
98+
EMBEDDING_DIM = 6
99+
HIDDEN_DIM = 6
100+
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
101+
embedding_dim=EMBEDDING_DIM)
102+
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
103+
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
104+
model = LstmTagger(word_embeddings, lstm, vocab)
105+
optimizer = optim.SGD(model.parameters(), lr=0.1)
106+
iterator = BucketIterator(batch_size=2, sorting_keys=[("sentence", "num_tokens")])
107+
iterator.index_with(vocab)
108+
trainer = Trainer(model=model,
109+
optimizer=optimizer,
110+
iterator=iterator,
111+
train_dataset=train_dataset,
112+
validation_dataset=validation_dataset,
113+
patience=10,
114+
num_epochs=800)
115+
trainer.train()
116+
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)
117+
tag_logits = predictor.predict("The dog ate the apple")['tag_logits']
118+
tag_ids = np.argmax(tag_logits, axis=-1)
119+
print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])

31fasttext/classify.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@author:XuMing([email protected])
4+
@description:
5+
"""
6+
import fasttext
7+
8+
classifier = fasttext.supervised('train_sample.txt', 'classify_model', label_prefix='__label__')
9+
result = classifier.test('test_sample.txt')
10+
print('P@1:', result.precision)
11+
print('R@1:', result.recall)
12+
print('Number of examples:', result.nexamples)
13+
14+
texts = ['吃 什么 止泻 快 _ 宝宝 拉肚子 _ 酸味 重 _ 专题 解答 ', '增高 _ 正确 长高 方法 _ 刺激 骨骼 二次 生长发育 增高 精准 找到 长高 办法 , 有助 孩子 长高 的 方法 ,']
15+
labels = classifier.predict(texts)
16+
print(labels)
17+
18+
# Or with the probability
19+
labels = classifier.predict_proba(texts)
20+
print(labels)
21+
22+
labels = classifier.predict(texts, k=3)
23+
print(labels)
24+
25+
# Or with the probability
26+
labels = classifier.predict_proba(texts, k=3)
27+
print(labels)

0 commit comments

Comments
 (0)