Skip to content

Commit 775df49

Browse files
author
xuming06
committed
add pytorch cifar.
1 parent 93985c4 commit 775df49

24 files changed

Lines changed: 156945 additions & 13 deletions

20pytorch/03.neural_network.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Author: XuMing <[email protected]>
33
# Brief: http://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#define-the-network
44

5-
# 1. define network
5+
# define network
66
import torch
77
from torch.autograd import Variable
88
import torch.nn as nn
@@ -56,3 +56,31 @@ def num_flat_features(self, x):
5656
criterion = nn.MSELoss()
5757
loss = criterion(output, target)
5858
print('loss:', loss)
59+
60+
print(loss.grad_fn)
61+
print(loss.grad_fn.next_functions[0][0])
62+
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])
63+
64+
network.zero_grad()
65+
print(network.conv1.bias.grad)
66+
67+
loss.backward()
68+
print(network.conv1.bias.grad)
69+
70+
# update weights
71+
learning_rate = 0.1
72+
for i in network.parameters():
73+
i.data.sub_(i.grad.data * learning_rate)
74+
75+
print(i.data)
76+
77+
import torch.optim as optim
78+
79+
optimizer = optim.SGD(network.parameters(), lr=0.01)
80+
optimizer.zero_grad()
81+
output = network(input)
82+
loss = criterion(output, target)
83+
loss.backward()
84+
optimizer.step() # does the update
85+
86+
print(loss)

20pytorch/04.cifar10.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: XuMing <[email protected]>
3+
# Brief: http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
4+
import torch
5+
import torchvision
6+
import torchvision.transforms as transforms
7+
import os
8+
9+
data_dir = './data'
10+
if not os.path.exists(data_dir):
11+
os.makedirs(data_dir)
12+
transform = transforms.Compose([transforms.ToTensor(),
13+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
14+
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
15+
download=True, transform=transform)
16+
train_loader = torch.utils.data.DataLoader(train_set, batch_size=4,
17+
shuffle=True, num_workers=2)
18+
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
19+
download=True, transform=transform)
20+
test_loader = torch.utils.data.DataLoader(test_set, batch_size=4,
21+
shuffle=False, num_workers=2)
22+
classes = ('plane', 'car', 'bird' 'cat',
23+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
24+
25+
import matplotlib.pyplot as plt
26+
import numpy as np
27+
28+
29+
def imshow(img):
30+
img = img / 2 + 0.5
31+
npimg = img.numpy()
32+
plt.imshow(np.transpose(npimg, (1, 2, 0)))
33+
plt.show()
34+
35+
36+
data_iter = iter(train_loader)
37+
images, labels = data_iter.next()
38+
39+
# show images
40+
imshow(torchvision.utils.make_grid(images))
41+
print(' '.join('%s' % classes[labels[i]] for i in range(4)))
42+
43+
# CNN
44+
from torch.autograd import Variable
45+
import torch.nn as nn
46+
import torch.nn.functional as F
47+
import torch.optim as optim
48+
49+
50+
class Network(nn.Module):
51+
def __init__(self):
52+
super(Network, self).__init__()
53+
self.conv1 = nn.Conv2d(3, 6, 5)
54+
self.pool = nn.MaxPool2d(2, 2)
55+
self.conv2 = nn.Conv2d(6, 16, 5)
56+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
57+
self.fc2 = nn.Linear(120, 84)
58+
self.fc3 = nn.Linear(84, 10)
59+
60+
def forward(self, x):
61+
x = self.pool(F.relu(self.conv1(x)))
62+
x = self.pool(F.relu(self.conv2(x)))
63+
x = x.view(-1, 16 * 5 * 5)
64+
x = F.relu(self.fc1(x))
65+
x = F.relu(self.fc2(x))
66+
x = self.fc3(x)
67+
return x
68+
69+
70+
criterion = nn.CrossEntropyLoss()
71+
network = Network()
72+
optimizer = optim.SGD(network.parameters(), lr=0.001, momentum=0.9)
73+
# train
74+
for epoch in range(1):
75+
running_loss = 0.0
76+
for i, data in enumerate(train_loader, 0):
77+
inputs, labels = data
78+
inputs, labels = Variable(inputs), Variable(labels)
79+
optimizer.zero_grad()
80+
outputs = network(inputs)
81+
loss = criterion(outputs, labels)
82+
loss.backward()
83+
optimizer.step()
84+
running_loss += loss.data[0]
85+
if i % 2000 == 1999:
86+
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
87+
running_loss = 0.0
88+
print('training done.')
89+
90+
data_iter = iter(test_loader)
91+
images, labels = data_iter.next()
92+
imshow(torchvision.utils.make_grid(images))
93+
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
94+
95+
# predict
96+
outputs = network(Variable(images))
97+
_, predicted = torch.max(outputs.data, 1)
98+
print('predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
99+
100+
correct = 0
101+
total = 0
102+
for data in test_loader:
103+
images, labels = data
104+
outputs = network(Variable(images))
105+
_, predicted = torch.max(outputs.data, 1)
106+
total += labels.size(0)
107+
correct += (predicted == labels).sum()
108+
print('acc of 10000 test set: %f ' % (correct / total))

0 commit comments

Comments
 (0)