-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathpytorch06_grad_descent_optimizer.py
More file actions
39 lines (34 loc) · 1.35 KB
/
pytorch06_grad_descent_optimizer.py
File metadata and controls
39 lines (34 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import matplotlib.pyplot as plt
f = lambda x: 0.1*x**3 - 0.8*x**2 - 1.5*x + 5.4
viz_range = torch.FloatTensor([-6, 12])
learn_rate = 0.1
max_iter = 100
min_tol = 1e-6
x_init = 12.
# Prepare visualization
xs = torch.linspace(*viz_range, 100)
plt.plot(xs, f(xs), 'r-', label='f(x)', linewidth=2)
plt.plot(x_init, f(x_init), 'b.', label='Each step', markersize=12)
plt.axis((*viz_range, *f(viz_range)))
plt.legend()
x = torch.tensor(x_init, requires_grad=True)
optimizer = torch.optim.SGD([x], lr=learn_rate)
for i in range(max_iter):
# Run the gradient descent with the optimizer
optimizer.zero_grad() # Reset gradient tracking
y = f(x) # Calculate the function (forward)
y.backward() # Calculate the gradient (backward)
xp = x.clone().detach() # Note) xp = x
optimizer.step() # Update 'x'
# Update visualization for each iteration
print(f'Iter: {i}, x = {xp:.3f} to {x:.3f}, f(x) = {f(xp):.3f} to {f(x):.3f} (f\'(x) = {x.grad:.3f})')
lcolor = torch.rand(3).tolist()
approx = x.grad*(xs-xp) + f(xp)
plt.plot(xs, approx, '-', linewidth=1, color=lcolor, alpha=0.5)
xc = x.clone().detach() # Copy 'x' for plotting
plt.plot(xc, f(xc), '.', color=lcolor, markersize=12)
# Check the terminal condition
if abs(x - xp) < min_tol:
break;
plt.show()