LoginSignup
4
5

More than 1 year has passed since last update.

PyTorchで勾配降下法

Posted at

PyTorchで勾配降下法をするコードを書いてみました。

最適化したい関数

def rosenbrock(x0, x1):
    y = 10 * (x1 - x0 ** 2) ** 2 + (x0 - 1) ** 2
    return y

関数を可視化する

import numpy as np

h = 0.01
x_min = -2
y_min = -3
x_max = 2
y_max = 5

X = np.arange(x_min, x_max, h)
Y = np.arange(y_min, y_max, h)

xx, yy = np.meshgrid(X, Y)

最小値はこのへんですね

matrix = rosenbrock(xx, yy)
minimum = None
min_x = None
min_y = None
for i in range(matrix.shape[0]):
    for j in range(matrix.shape[1]):
        if minimum is None or minimum > matrix[i][j]:
            minimum = matrix[i][j]
            min_y = Y[i]
            min_x = X[j]

print(min_x, min_y, minimum)
1.0000000000000027 0.9999999999999147 8.208018832734106e-26

ドットが描かれた場所が最小値になります。

import matplotlib.pyplot as plt
plt.contourf(xx, yy, np.sqrt(rosenbrock(xx, yy)), alpha=0.5)
plt.scatter(min_x, min_y, c="k")
plt.colorbar()
plt.grid()
plt.show()

PyTorchで勾配降下法_7_0.png

勾配降下法

import numpy as np
import torch

x0 = torch.tensor(0.0, requires_grad=True)
x1 = torch.tensor(4.0, requires_grad=True)

lr = 0.001
iters = 10000

history = []
for i in range(iters):
    history.append(np.array([np.array(x0.data), np.array(x1.data)]).flatten())
    y = rosenbrock(x0, x1)
    y.backward()

    with torch.no_grad():
        x0.data -= lr * x0.grad
        x1.data -= lr * x1.grad

        x0.grad.zero_()
        x1.grad.zero_()

結果表示

import matplotlib.pyplot as plt
plt.contourf(xx, yy, np.sqrt(rosenbrock(xx, yy)), alpha=0.5)
plt.scatter([p[0] for p in history], [p[1] for p in history])
plt.scatter(min_x, min_y, c="k")
plt.colorbar()
plt.grid()
plt.show()

PyTorchで勾配降下法_11_0.png

複数箇所からスタート

import numpy as np
import torch

x0 = torch.tensor([0.0, -0.5, -1.5, 1.5], requires_grad=True)
x1 = torch.tensor([4.0, 4.0, -1.0, -2], requires_grad=True)

lr = 0.001
iters = 10000

history = []
for i in range(iters):
    history.append([x0.detach().clone(), x1.detach().clone()])
    y = rosenbrock(x0, x1)
    #y.backward()
    s = torch.sum(y)
    s.backward()

    with torch.no_grad():
        x0.data -= lr * x0.grad
        x1.data -= lr * x1.grad

        x0.grad.zero_()
        x1.grad.zero_()

結果表示

import matplotlib.pyplot as plt
plt.contourf(xx, yy, np.sqrt(rosenbrock(xx, yy)), alpha=0.5)
for i in range(4):
    plt.scatter([p[0][i] for p in history], [p[1][i] for p in history])
plt.colorbar()
plt.scatter(min_x, min_y, c="k")
plt.grid()
plt.show()

PyTorchで勾配降下法_15_0.png

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5