forked from vpj/python_autocomplete
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimple_model.py
More file actions
25 lines (20 loc) · 835 Bytes
/
simple_model.py
File metadata and controls
25 lines (20 loc) · 835 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.nn
class SimpleLstmModel(torch.nn.Module):
def __init__(self, *,
encoding_size,
embedding_size,
lstm_size,
lstm_layers):
super().__init__()
self.embedding = torch.nn.Embedding(encoding_size, embedding_size)
self.lstm = torch.nn.LSTM(input_size=embedding_size,
hidden_size=lstm_size,
num_layers=lstm_layers)
self.fc = torch.nn.Linear(lstm_size, encoding_size)
self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, x, h0, c0):
# shape of x is [seq, batch, feat]
x = self.embedding(x)
out, (hn, cn) = self.lstm(x, (h0, c0))
logits = self.fc(out)
return self.softmax(logits), logits, (hn, cn)