|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# Author: XuMing <[email protected]> |
| 3 | +# Data: 17/8/10 |
| 4 | +# Brief: 平均感知机:词性标注测试 |
| 5 | + |
| 6 | +import os |
| 7 | +import random |
| 8 | +from collections import defaultdict |
| 9 | +import pickle |
| 10 | +import logging |
| 11 | + |
| 12 | +from AveragePerceptron import AveragePerceptron |
| 13 | + |
| 14 | +PICKLE = "../data/bp/trontagger-0.1.pkg" |
| 15 | +TRAIN_FILE_PATH = "../data/bp/train.txt" |
| 16 | +TEST_FILE_PATH = "../data/bp/test.txt" |
| 17 | + |
| 18 | + |
| 19 | +class PerceptronTagger(): |
| 20 | + START = ['-START-', '-START2-'] |
| 21 | + END = ['-END-', '-END2-'] |
| 22 | + AP_MODEL_LOC = os.path.join(os.path.dirname(__file__), PICKLE) |
| 23 | + |
| 24 | + def __init__(self, load=True): |
| 25 | + self.model = AveragePerceptron() |
| 26 | + self.tagdict = {} |
| 27 | + self.classes = set() |
| 28 | + if load: |
| 29 | + self.load(self.AP_MODEL_LOC) |
| 30 | + |
| 31 | + def tag(self, corpus): |
| 32 | + s_split = lambda t: t.split('\n') |
| 33 | + w_split = lambda s: s.split() |
| 34 | + |
| 35 | + def split_sents(corpus): |
| 36 | + for s in s_split(corpus): |
| 37 | + yield w_split(s) |
| 38 | + |
| 39 | + prev, prev2 = self.START |
| 40 | + tokens = [] |
| 41 | + for words in split_sents(corpus): |
| 42 | + context = self.START + [self._normalize(w) for w in words] + self.END |
| 43 | + for i, word in enumerate(words): |
| 44 | + tag = self.tagdict.get(word) |
| 45 | + if not tag: |
| 46 | + features = self._get_features(i, word, context, prev, prev2) |
| 47 | + tag = self.model.predict(features) |
| 48 | + tokens.append((word, tag)) |
| 49 | + prev2 = prev |
| 50 | + prev = tag |
| 51 | + return tokens |
| 52 | + |
| 53 | + def load(self, loc): |
| 54 | + try: |
| 55 | + w_td_c = pickle.load(open(loc, 'rb')) |
| 56 | + except IOError: |
| 57 | + raise IOError("Missing trontagger.pkg file.") |
| 58 | + self.model.weights, self.tagdict, self.classes = w_td_c |
| 59 | + self.model.classes = self.classes |
| 60 | + return None |
| 61 | + |
| 62 | + def _normalize(self, word): |
| 63 | + if '-' in word and word[0] != '-': |
| 64 | + return '!HYPHEN' |
| 65 | + elif word.isdigit() and len(word) == 4: |
| 66 | + return '!YEAR' |
| 67 | + elif word[0].isdigit(): |
| 68 | + return '!DIGITS' |
| 69 | + else: |
| 70 | + return word.lower() |
| 71 | + |
| 72 | + def _get_features(self, i, word, context, prev, prev2): |
| 73 | + i += len(self.START) |
| 74 | + features = defaultdict(int) |
| 75 | + |
| 76 | + def add(name, *args): |
| 77 | + features[' '.join((name,) + tuple(args))] += 1 |
| 78 | + |
| 79 | + # constant feature |
| 80 | + add('bias') |
| 81 | + add('i suffix', word[-3:]) |
| 82 | + add('i pref1', word[0]) |
| 83 | + add('i-1 tag', prev) |
| 84 | + add('i-2 tag', prev2) |
| 85 | + add('i tag+i-2 tag', prev, prev2) |
| 86 | + add('i word', context[i]) |
| 87 | + add('i-1 tag+i word', prev, context[i]) |
| 88 | + add('i-1 word', context[i - 1]) |
| 89 | + add('i-1 suffix', context[i - 1][-3:]) |
| 90 | + add('i-2 word', context[i - 2]) |
| 91 | + add('i+1 word', context[i + 1]) |
| 92 | + add('i+1 suffix', context[i + 1][-3:]) |
| 93 | + add('i+2 word', context[i + 2]) |
| 94 | + return features |
| 95 | + |
| 96 | + def _make_tagdict(self, sentences): |
| 97 | + counts = defaultdict(lambda: defaultdict(int)) |
| 98 | + for words, tags in sentences: |
| 99 | + for word, tag in zip(words, tags): |
| 100 | + counts[word][tag] += 1 |
| 101 | + self.classes.add(tag) |
| 102 | + freq_thresh = 20 |
| 103 | + ambiguity_thresh = 0.97 |
| 104 | + for word, tag_freqs in counts.items(): |
| 105 | + tag, mode = max(tag_freqs.items(), key=lambda item: item[1]) |
| 106 | + n = sum(tag_freqs.values()) |
| 107 | + if n >= freq_thresh and (float(mode) / n) >= ambiguity_thresh: |
| 108 | + self.tagdict[word] = tag |
| 109 | + |
| 110 | + def _pc(self, n, d): |
| 111 | + return (float(n) / d) * 100 |
| 112 | + |
| 113 | + def train(self, sentences, save_loc=None, nr_iter=5): |
| 114 | + self._make_tagdict(sentences) |
| 115 | + self.model.classes = self.classes |
| 116 | + for iter_ in range(nr_iter): |
| 117 | + c = 0 |
| 118 | + n = 0 |
| 119 | + for words, tags in sentences: |
| 120 | + prev, prev2 = self.START |
| 121 | + context = self.START + [self._normalize(w) for w in words] + self.END |
| 122 | + for i, word in enumerate(words): |
| 123 | + guess = self.tagdict.get(word) |
| 124 | + if not guess: |
| 125 | + feats = self._get_features(i, word, context, prev, prev2) |
| 126 | + guess = self.model.predict(feats) |
| 127 | + self.model.update(tags[i], guess, feats) |
| 128 | + prev2 = prev |
| 129 | + prev = guess |
| 130 | + c += guess == tags[i] |
| 131 | + n += 1 |
| 132 | + random.shuffle(sentences) |
| 133 | + logging.info("Iter {0}: {1}/{2}={3}".format(iter_, c, n, self._pc(c, n))) |
| 134 | + self.model.average_weights() |
| 135 | + if save_loc is not None: |
| 136 | + pickle.dump((self.model.weights, self.tagdict, self.classes), |
| 137 | + open(save_loc, 'wb'), -1) |
| 138 | + return None |
| 139 | + |
| 140 | + |
| 141 | +if __name__ == "__main__": |
| 142 | + logging.basicConfig(level=logging.INFO) |
| 143 | + tagger = PerceptronTagger(False) |
| 144 | + try: |
| 145 | + tagger.load(PICKLE) |
| 146 | + print(tagger.tag("how are you ?")) |
| 147 | + logging.info("Start testing...") |
| 148 | + right = 0.0 |
| 149 | + total = 0.0 |
| 150 | + sentence = ([], []) |
| 151 | + for line in open(TEST_FILE_PATH): |
| 152 | + params = line.split() |
| 153 | + if len(params) != 2: continue |
| 154 | + sentence[0].append(params[0]) |
| 155 | + sentence[1].append(params[1]) |
| 156 | + if params[0] == ".": |
| 157 | + text = "" |
| 158 | + words = sentence[0] |
| 159 | + tags = sentence[1] |
| 160 | + for i, word in enumerate(words): |
| 161 | + text += word |
| 162 | + if i < len(words): |
| 163 | + text += " " |
| 164 | + outputs = tagger.tag(text) |
| 165 | + assert len(tags) == len(outputs) |
| 166 | + total += len(tags) |
| 167 | + for o, t in zip(outputs, tags): |
| 168 | + if o[1].strip() == t: |
| 169 | + right += 1 |
| 170 | + sentence = ([], []) |
| 171 | + logging.info("Precision : %f", right / total) |
| 172 | + except IOError: |
| 173 | + logging.info("Reading corpus...") |
| 174 | + training_data = [] |
| 175 | + sentence = ([], []) |
| 176 | + for line in open(TRAIN_FILE_PATH): |
| 177 | + params = line.split('\t') |
| 178 | + sentence[0].append(params[0]) |
| 179 | + sentence[1].append(params[1]) |
| 180 | + if params[0] == ".": |
| 181 | + training_data.append(sentence) |
| 182 | + sentence = ([], []) |
| 183 | + logging.info("training corpus size: %d", len(training_data)) |
| 184 | + logging.info("Start training...") |
| 185 | + tagger.train(training_data, save_loc=PICKLE) |
| 186 | + logging.info("training end.") |
0 commit comments