forked from shibing624/python-tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAveragePerceptron.py
More file actions
81 lines (69 loc) · 2.61 KB
/
AveragePerceptron.py
File metadata and controls
81 lines (69 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-
# Author: XuMing <[email protected]>
# Data: 17/8/10
# Brief: 平均感知机
from collections import defaultdict
import pickle
import random
class AveragePerceptron:
def __init__(self):
self.weights = {}
self.classes = set()
self._totals = defaultdict(int)
self._tstamps = defaultdict(int)
self.i = 0
def predict(self, features):
scores = defaultdict(float)
for feat, value in features.items():
if feat not in self.weights or value == 0:
continue
weights = self.weights[feat]
for label, weight in weights.items():
scores[label] += value * weight
return max(self.classes, key=lambda label: (scores[label], label))
def update(self, truth, guess, features):
"""Update the feature weights"""
def update_feat(c, f, w, v):
param = (f, c)
self._totals[param] += (self.i - self._tstamps[param]) * w
self._tstamps[param] = self.i
self.weights[f][c] = w + v
self.i += 1
if truth == guess:
return None
for f in features:
weights = self.weights.setdefault(f, {})
update_feat(truth, f, weights.get(truth, 0.0), 1.0)
update_feat(guess, f, weights.get(guess, 0.0), -1.0)
return None
def average_weights(self):
"""Average weights from all iterator"""
for feat, weights in self.weights.items():
new_feat_weights = {}
for clas, weight in weights.items():
new_feat_weights = {}
for clas, weight in weights.items():
param = (feat, clas)
total = self._totals[param]
total += (self.i - self._tstamps[param]) * weight
averaged = round(total / float(self.i), 3)
if averaged:
new_feat_weights[clas] = averaged
self.weights[feat] = new_feat_weights
return None
def save(self, path):
return pickle.dump(dict(self.weights), open(path, 'w'))
def load(self, path):
self.weights = pickle.load(open(path))
return None
def train(nr_iter, examples):
model = AveragePerceptron()
for i in range(nr_iter):
random.shuffle(examples)
for features, clazz in examples:
scores = model.predict(features)
guess, score = max(scores.items(), key=lambda i: i[1])
if guess != clazz:
model.update(clazz, guess, features)
model.average_weights()
return model