勾配降下法についての個人的まとめです。
なるべく丁寧にまとめたつもりですが、お気付きの点があったら是非ご指摘ください。
勾配降下法とは
ネットに転がっていた、様々な説明を抜き出してみます。
・最適化探索の方法の一つ。コスト関数の勾配と逆方向に重みを調整することで、コスト関数を最小化させる方法。
・誤差関数の勾配の下に降下していく方法。
・目的関数を最小化させるための方法で、目的関数の勾配の逆方向にパラメータを更新していくことで実現される。
パッと見た感じ、あれですね。難しい言葉が使われている上に用語が統一されてませんね。
ただ一貫して言われることを抜き出せば、こんなところでしょうか。
「ある関数について、その値が最小となるようなパラメータ(係数)を見つけ出す方法。関数が為す谷(勾配)の谷底方向に進むようにパラメータを更新していくことで目的パラメータを発見する。」
ある関数の最小値を知りたいときに使用する概念のようですね。
勾配降下法が必要なときとは
この勾配降下法、いつ使うんでしょうね。
「この関数の最小値が知りたくてたまらないー!」
なんて局所的な欲求があれば別ですけど、そうでもなければいまいち使いどころがわからない気もします。
調べてみました。例えばこんな時に使うようです。
例えば回帰直線を引きたいとき
ある分散したデータが存在するとして、その傾向から回帰直線を引いて未知のデータについて予測をしたい、なんてことがあったとします。(あるとしますよ?)
この時どうやって回帰直線を引くかというと、方法の一つとして誤差の最小を求めるという方法があります。
すなわち、
直線 y=ax+b と各分散値(x, y)との距離が最小となるような y=ax+b の直線
を生成することが出来れば、求めているモデルが得られるじゃないか
という考え方です。
上で述べたような、ある関数と入力値の差を「誤差」といい、その誤差の和を求める式は以下のようになります。
E = \sum_{i=1}^n(y^{(i)} - f(x^{(i)}))^2
目的の直線を表す関数を f とした時、入力値(x, y)における y の値から、目的の直線である f に x を引数として渡した時の返り値を引くことで誤差がわかる訳です。
前提として入力値は複数ある訳ですから誤差も複数存在する。そこで全ての誤差を足し合わせるためにΣを使用する。
こうして、ここで表現される E という数値は、
「分散した値と関数fの誤差の和」
ということができる訳です。
関数の最小を求めたいという動機まとめ
ということで、最初に考えていた「関数の最小を求めたい」という動機が生まれましたね。整理しましょう。
-1.分散したデータから回帰直線を作りたい!
-2.分散したデータとの距離(誤差)という概念で、直線の生成にアプローチできる!
-3.それぞれのデータとの誤差の和が最小になる関数を見つけられたら、それが求めている直線になるらしい!
-4.じゃあ関数の最小値を知りたい!
では、どうやって関数の最小値を求めるのでしょうか。
そこで出てくるのが、最初にあった「関数が為す谷(勾配)の谷底方向に進むようにパラメータを更新していく」という考え方です。
プログラムによる実装
誤差関数のプロットから回帰直線を求める
勾配降下法を用いる動機は先に述べたように、ある関数の最小値を求めるということです。
今回の例のように回帰直線を生成したい場合、目的関数 f(x) と誤差関数 E という概念が必要になります。
目的関数は一般的な一次関数とするとして、後者の誤差関数についてもう一度おさらいしておきましょう。
E = \sum_{i=1}^n(y^{(i)} - f(x^{(i)}))^2
xとyの値は分散データの値ですから固定値であるとすると、ここで誤差関数の値の増減に関係してくるのは f(x) の中身ということになります。
求める関数を一次関数としたとき f(x) = a + bx と置くことができますね。
これらのことを元にして、段階的にプログラムを組んでいきましょう。
なお今回は理解を簡単にするため、 a=0 、つまり切片は存在しないものとしてアルゴリズムを組んでいきます。
まずは今回の分散データの可視化から。
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
# 分散データを用意
data = pd.read_csv('0304.csv')
data_x = data.iloc[:, 0]
data_y = data.iloc[:, 1]
# 分散データを図示
plt.plot(data_x, data_y , 'o', label='data')
plt.legend(loc='lower right')
plt.show()
仮に パラメータ a = 0 ,b = 3 とすると、このような直線になりますね。
一見してわかるように、これは分散に対する適切な回帰直線とは言えません。
ここでは何が間違っているかというと、パラメータして設定した a=0 b=3 が誤っているわけです。
では何が適切なパラメータなのか。それを探すために誤差関数を作っていきます。
上でも述べましたが、今回、目的関数 f(x) には切片は存在しないものとします。
# 目的関数 f(x)=ax を定義
def f(a, x):
return a * x
# 誤差関数の定義
def E(a, data_x, data_y):
sum = 0
for x, y in zip (data_x, data_y):
tmp = (y - f(a, x)) ** 2
sum = sum + tmp
return sum
# パラメータの取りうる値
a_array = np.arange(-10,10)
# パラメータを変化させながら誤差を算出する関数
def gradiate_error(data_x, data_y):
E_array = np.array([])
for a in a_array:
result = E(a, data_x, data_y)
E_array = np.append(E_array, result)
return E_array
ここでは3つの関数を定義しています。
1つは関数f(x)です。これは現在のパラメータを受け取り、xとかけ合わせた値を返します。
1つは誤差関数です。パラメータに基づき算出した値とyとの誤差の和を返します。この値を最小にすることが目的でした。
1つはパラメータを変更しながら誤差関数を実行する関数です。何が正解のパラメータかが分からないので、様々なパラメータを設定して誤差関数を実行する必要があります。今回はパラメータの範囲として -10 ~ 10 を設定しています。
この結果を可視化すると、このようになります。
plt.plot(a_array, gradiate_error(data_x, data_y))
plt.xlabel('a value')
plt.ylabel('E value')
plt.legend(loc='lower right')
plt.show()
x軸がパラメータの値、y軸がその時の誤差の和となります。
グラフをみてみると、パラメータが1付近の時に誤差和が最小になっていることがわかりますね。
ではパラメータに1を設定して、もう一度回帰直線を引いてみましょう。
これは分散との相関が大きい、適切な回帰直線が引けたと言えそうですね。
このように、誤差関数E の和が最小になるようにパラメータを設定すると、分散に対して適切な回帰直線を引くことができます。
勾配降下法
あれ、勾配降下法、使ってませんね。
先の例では、誤差関数のプロットをみて、「パラメータが1くらいの時に最小になっていそうだ」という判断に基づいてパラメータを設定しました。
しかし、厳密に言えば誤差関数が最小になるのはパラメータが1、なんていう綺麗な数字のときではないかもしれません。
誤差関数のプロットから、厳密に最小のパラメータを見つけ出すアルゴリズム、それが勾配降下法です。
ここで「誤差関数の値が最小となるパラメータ」を発見するために見ていくのは、「それぞれのパラメータにおけるグラフの傾き」です。傾きが最も小さい接線を引けるポイント、そこにおけるパラメータが、厳密の意味での目的パラメータとなります。パッと見で決めてはいけないんですね。
例えば、このパラメータが 5 の時の接線のグラフはこのようになります。
def make_tangent(a,data_x,data_y):
slope = get_slope(a, data_x, data_y)
segment = get_segment(a, data_x, data_y)
result = slope * a_array + segment
return result
def get_slope(a, data_x, data_y):
h = 0.000001
sumA = E(a + h, data_x, data_y)
sumB = E(a, data_x, data_y)
return (sumA - sumB) / h
def get_segment(a, data_x, data_y):
E_value = E(a, data_x, data_y)
slope = get_slope(a, data_x, data_y)
segment = E_value - (slope * a)
return segment
plt.plot(a_array, gradiate_error(data_x, data_y))
plt.plot(a_array, make_tangent(5,data_x,data_y))
plt.xlabel('a value')
plt.ylabel('E value')
plt.legend(loc='lower right')
plt.show()
print('パラメータ5の時の接線の傾き : {}'.format(get_slope(5, data_x, data_y)))
この時の接線の傾きは、 302920.0385790318 となります。よく分からない数字ですね。
何はともあれ、これが最も小さい数になれば良いわけです。
そして、この「勾配(傾き)が0に近づくように(降下するように)パラメータを更新していくというアルゴリズムが、「勾配降下法」になります。ここまで長かったですね。
勾配降下法の実装
ここまでの内容を整理して見ましょう。
- 誤差関数 E の最小の値を求めたい
- そのため、それを実現するような目的関数 f(x) のパラメータを求めたい
- パラメータを変化させてそれぞれに対応した誤差和をプロットした時、そのグラフにおける各パラメータについての接線の傾きが最小となる時のパラメータが、求めているパラメータである
接線の傾きが最小となるパラメータの探し方ですが、考え方としては、誤差関数グラフを形成している谷を下へ下へと下っていくイメージになります。求めているのは結局、谷のようなグラフの谷底地点におけるパラメータですからね。
では、実装します。
# パラメータの更新方向を定義する関数
def get_direct(param, data_x, data_y):
paramA = param
paramB = param + 0.1
resultA = get_slope(paramA, data_x, data_y)
resultB = get_slope(paramB, data_x, data_y)
result = resultA**2 - resultB**2
if result < 0:
return 1
else :
return 0
# 更新回数を定義する
count = 1000
# パラメータの更新率
step = 0.01
# 勾配が最小になるパラメータを探索する
def search_min_param(data_x, data_y):
# 探索開始点をランダムに決定する
param = np.random.randint(-10,10)
# パラメータの更新方向を取得
direction = get_direct(param, data_x, data_y)
# 指定回数、パラメータの更新を行う
for _ in range(count):
# 勾配の向きに応じてパラメータの更新方向を変更する
if get_direct(param, data_x, data_y) == 1:
param = param - step
else :
param = param + step
# 発見したパラメータを返却する
return param
print('パラメータ探索1回目 param : {}'.format(search_min_param(data_x, data_y)))
print('パラメータ探索2回目 param : {}'.format(search_min_param(data_x, data_y)))
print('パラメータ探索3回目 param : {}'.format(search_min_param(data_x, data_y)))
#出力結果
パラメータ探索1回目 param : 1.020000000000127
パラメータ探索2回目 param : 1.0199999999998743
パラメータ探索3回目 param : 1.0199999999999596
各関数の意味はコード内のコメントを参照してください。
ざっくり言うと、
- 1.パラメータの更新方向を求める(勾配は右肩上がりの時と右肩下がりの時があるため)
- 2.パラメータの更新回数を決める(回数を指定しないと無限に更新してしまうため)
- 3.パラメータの更新率を決める(一更新あたりにどれくらいパラメータをいじるのか)
- 4.パラメータを更新する
という流れになります。
パラメータはランダムな値で初期化しており、そこから指定回数の更新をしているので、複数回実行すると結果は微妙に変わります。ただ、回数を十分に行なっていれば、だいたい同じような結果に収束するはずです。
以上より、今回の分散データにおける回帰直線の最適パラメータは 約1.02 ということがわかりました。
最後に、適切な回帰直線を可視化して終わりましょう。
plt.plot(data_x, data_y , 'o', label='data')
plt.plot(data_x, f(param,data_x),linestyle='solid', label='f(x)')
plt.legend(loc='lower right')
plt.show()
終わりに
以上が勾配降下法のまとめでした。
関数などは独自に作っているので最適化されていないものがほとんどだと思いますが、理解の一助になれば幸いです。
以上。