-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathnetworks.py
More file actions
70 lines (58 loc) · 2.16 KB
/
networks.py
File metadata and controls
70 lines (58 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import unicode_literals, print_function, division
import torch
import torch.nn as nn
import torch.nn.functional as nnf
class LSTM(nn.Module):
"""LSTM class"""
def __init__(self, input_size, hidden_size, output_size, batch_size=1, num_layers=1, device="cpu"):
"""
:param input_size: number of input coming in
:param hidden_size: number of he hidden units
:param output_size: size of the output
"""
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.num_layers = num_layers
self.batch_size = batch_size
self.device = device
# LSTM
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers).to(device)
self.fc = nn.Linear(hidden_size, output_size).to(device)
self.hidden = self.init_hidden()
def forward(self, x):
out, self.hidden = self.lstm(x, self.hidden)
output = self.fc(out[-1]) # many to one
output = nnf.log_softmax(output, dim=1)
return output
def init_hidden(self):
return (torch.zeros(self.num_layers, self.batch_size, self.hidden_size).to(self.device),
torch.zeros(self.num_layers, self.batch_size, self.hidden_size).to(self.device))
def get_wl0(self):
wi, wf, wg, wo = self.lstm.weight_ih_l0.reshape(4, self.hidden_size, -1)
wl0 = {"WI": wi,
"WF": wf,
"WG": wg,
"WO": wo}
return wl0
def get_wl1(self):
wi, wf, wg, wo = self.lstm.weight_ih_l1.reshape(4, self.hidden_size, -1)
wl1 = {"WI": wi,
"WF": wf,
"WG": wg,
"WO": wo}
return wl1
def get_bl0(self):
bi, bf, bg, bo = self.lstm.bias_ih_l0.reshape(4, self.hidden_size, -1)
bl0 = {"BI": bi,
"BF": bf,
"BG": bg,
"BO": bo}
return bl0
def get_bl1(self):
bi, bf, bg, bo = self.lstm.bias_ih_l1.reshape(4, self.hidden_size, -1)
bl1 = {"BI": bi,
"BF": bf,
"BG": bg,
"BO": bo}
return bl1