PyTorchに深層学習を勉強するためのTutorialに例題が豊富にあって大変有り難いのですが、初っ端から何をやっているのか全然わからなかったので、解説を付け加えました。
今回解らなかったのは勾配降下法を説明したページです。
そもそも勾配降下法とは
深層学習や機械学習では、人工的な関数(モデル)を作って、それをできるだけ自然界に近似させることをしているのですが、コンピュータは現時点の関数が目標に近づいているのか遠のいているのかを知らないといけません。それを知る方法というのが勾配降下法です。目標までの距離を図るのが目的関数あるいは損失関数と呼び、損失が最小になるように、モデルを最適化します。
例えば下図で、青い線を損失関数とすると、その最小値に到達するようにモデルを最適化しますが、現時点から左右のどっちに行けば最小値に近づくかは、その場所での傾き(勾配)から分かります。例えば、勾配がプラスなら(オレンジの線)最小値の右側にいることがわかるので、これから左側に向かうよう最適化しなければいけません。逆に勾配がマイナスなら(緑の線)最小値の左側にいるので、これから右に行く必要があります。
また、勾配の絶対値の大きさは最小値から離れれば離れるほど大きくなり、逆に最小値の部分での傾きは0になります。なので、勾配の符号と反対側に向かって、勾配が0になるように最適化していくのが、勾配降下法です。ちなみに、勾配は損失関数を微分することで得られます。
numpyによる実装例
PyTorch Tutorialでは手始めにnumpyによる実装例をしてしています。
三次関数
$y_{pred} = a + bx + cx^2 + dx^3$
をsin関数
$y_{target} = \sin(x)$
に近似させるというのがタスクです。
三次関数は $a, b, c, d$ の4つのパラメータあるので、これらを上手く調節していくのがここでの仕事です。
勾配降下法を使って、$a, b, c, d$ を少しづつ変えて、三次関数をsin関数に近づけて行きます。完全に近似することはできないので、ある区間だけ近似することになります。
では実際のコードを見てみましょう。
import numpy as np
import math
# xとターゲットとなるsin関数を準備します。
x = np.linspace(-math.pi, math.pi, 2000)
y = np.sin(x)
# a, b, c, dに適当に初期値を割り当てます。
a = np.random.randn()
b = np.random.randn()
c = np.random.randn()
d = np.random.randn()
# ここから勾配降下法。
learning_rate = 1e-6
for t in range(2000):
# 初期値のy_predを計算します。
# y = a + b x + c x^2 + d x^3
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# 損失関数は差の2乗と定義します。
loss = np.square(y_pred - y).sum()
if t % 100 == 99:
print(t, loss)
#この部分が解らなかった。
# Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
# 勾配の大きさに応じて各パラメータを調整していきます。
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')
勾配降下法の中の、勾配を計算している部分がなぜそうなるのか解らなかったのでここで説明しておきます。
勾配は何を表しているの?
ここでしたいのは、$a, b, c, d$ を調節して損失関数の値を最小にすることです。ならば、損失関数を $a, b, c, d$ に対して微分して勾配を求め、それを小さくしていけばよいのです。
例えば、損失関数を(1)のように定義して、これを$a$で微分してみます。
$Loss = (y_{pred} - y_{target})^2$ (1)
$\frac{\partial Loss}{\partial a} =
\frac{\partial Loss}{\partial y_{pred}} \cdot \frac{\partial y_{pred}}{\partial a} =
\frac{\partial Loss}{\partial y_{pred}} \cdot \frac{\partial (a + bx + cx^2 +
dx^3)}{\partial a}$
ここで注目したいのは、変数は$a$なので、$a$以外は定数扱いです。いつも$x$が変数に使われているので、大変違和感があるかもしれませんが、あくまで$a$が変数です。
$\frac{\partial (a + bx + cx^2 + dx^3)}{\partial a} =
\frac{\partial a}{\partial a} + \frac{\partial (bx + cx^2 + dx^3)}{\partial a} =
1 + 0$
よって、損失関数の$a$による微分(変数$a$に対する勾配)はこのようになりました。
$\frac{\partial Loss}{\partial a} =
\frac{\partial Loss}{\partial y_{pred}} \cdot (1 + 0) =
\frac{\partial Loss}{\partial y_{pred}}$ (2)
同じような手法で残りの$b, c, d$も微分していきます。
$\frac{\partial Loss}{\partial b} =
\frac{\partial Loss}{\partial y_{pred}} \cdot \frac{\partial y_{pred}}{\partial b} =
\frac{\partial Loss}{\partial y_{pred}} \cdot \frac{\partial (a + bx + cx^2 +
dx^3)}{\partial b}$
$\frac{\partial Loss}{\partial b} =
\frac{\partial Loss}{\partial y_{pred}} \cdot (\frac{x \partial b}{\partial b} +
\frac{\partial (a + cx^2 + dx^3)}{\partial b}) =
\frac{\partial Loss}{\partial y_{pred}} \cdot (x + 0) =
\frac{\partial Loss}{\partial y_{pred}}x$ (3)
$\frac{\partial Loss}{\partial c} =
\frac{\partial Loss}{\partial y_{pred}} \cdot \frac{\partial y_{pred}}{\partial c} =
\frac{\partial Loss}{\partial y_{pred}} \cdot \frac{\partial (a + bx + cx^2 +
dx^3)}{\partial c}$
$\frac{\partial Loss}{\partial c} =
\frac{\partial Loss}{\partial y_{pred}} \cdot (\frac{x^2 \partial c}{\partial c} +
\frac{\partial (a + bx + dx^3)}{\partial c}) =
\frac{\partial Loss}{\partial y_{pred}} \cdot (x^2 + 0) = \frac{\partial
Loss}{\partial y_{pred}}x^2$ (4)
$\frac{\partial Loss}{\partial d} =
\frac{\partial Loss}{\partial y_{pred}} \cdot \frac{\partial y_{pred}}{\partial d} =
\frac{\partial Loss}{\partial y_{pred}} \cdot \frac{\partial (a + bx + cx^2 +
dx^3)}{\partial d}$
$\frac{\partial Loss}{\partial d} =
\frac{\partial Loss}{\partial y_{pred}} \cdot (\frac{x^3 \partial d}{\partial d} +
\frac{\partial (a + bx + cx^2)}{\partial d}) =
\frac{\partial Loss}{\partial y_{pred}} \cdot (x^3 + 0) = \frac{\partial
Loss}{\partial y_{pred}}x^3$ (5)
ところで、同じように損失関数を$y_{pred}$で微分すると
$\frac{\partial Loss}{\partial y_{pred}} = 2(y_{pred} - y_{target})$
が得られるので、これを(2)〜(5)に代入すると
$\frac{\partial Loss}{\partial a} = 2(y_{pred} - y_{target})$
$\frac{\partial Loss}{\partial b} = 2(y_{pred} - y_{target})x$
$\frac{\partial Loss}{\partial c} = 2(y_{pred} - y_{target})x^2$
$\frac{\partial Loss}{\partial d} = 2(y_{pred} - y_{target})x^3$
が得られ、コードの中の解らなかった部分と完全に一致しました。
最後に
勾配降下法によって三次関数がどんどんsin関数に近づいていく過程を可視化しました。
ソースコードはこちらに載せてあります。