-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathworker.py
More file actions
31 lines (26 loc) · 795 Bytes
/
worker.py
File metadata and controls
31 lines (26 loc) · 795 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
#import torch.optim as optim
from my_net import *
class Worker(nn.Module):
def __init__(self):
super().__init__()
self.model = MyNet()
if torch.cuda.is_available():
self.input_device = torch.device("cuda:0")
else:
self.input_device = torch.device("cpu")
def pull_weights(self, model_params):
self.model.load_state_dict(model_params)
def push_gradients(self, batch_idx, data, target):
data, target = data.to(self.input_device), target.to(self.input_device)
output = self.model(data)
data.requires_grad = True
loss = F.nll_loss(output, target)
loss.backward()
grads = []
for layer in self.parameters():
grad = layer.grad
grads.append(grad)
print(f"batch {batch_idx} training :: loss {loss.item()}")
return grads