はじめに
スタンフォード大学 Andrew Ng氏による機械学習の講義動画を各週社内メンバーで持ち回りで受講し、学んだことを発表、Qiitaで共有しています。
第一週は以下の内容でした。
- Introduction
- Linear Regression with One Variable
- Linear Algebra Review
この中の Linear Regression with One Variable (線形単回帰)について共有します。
目標
以下の数式を通して、線形単回帰について学んでいきます。
仮定関数
h_{\theta}(x) = \theta_0 + \theta_1x
誤差関数
J(\theta_0, \theta_1) = \frac{1}{2m}\sum_{i=1}^{m} (h_{\theta(x_i)}-y_i)^2
最急降下法
minimize\ J(\theta_0,\theta_1)
repeat\ until\ convergence\{\\
\theta_j := \theta_j-\alpha\frac{\partial}{\partial\theta_j}J(\theta_0,\theta_1)\\
(for\ j=0\ and\ j=1)\}
いい感じの直線を引こう! (仮定関数について)
まずは1つ目の数式です。
例えば、このように説明変数 $x$ と 目的変数 $y$ のペアデータがグラフ上にプロットされているとします。
この説明変数 $x$ と 目的変数 $y$ には、直線的な関係があるように見えます。
実際、このような直線を引くことでこれらデータに適した関係を表せそうな気がします。
試しにグラフ上に直線 $y=x+1$ を引いてみました。
確かにうまくデータの関係を表せてそうな直線ですね!しかし、この直線はテキトーに引いた直線 $y=x+1$ です。もしかすると他の人が引いたオレンジの直線 $y=1.4x+1.3$ の方がデータに適した関係をうまく表せているかもしれません。テキトーに線を引くだけでは意見が割れてしまいます。
どのような直線が「いい感じの直線」なのでしょう?
「直線」を表すための関数が、仮定関数 です。
h_{\theta}(x) = \theta_0 + \theta_1x
この関数 $h_{\theta}(x)$ は2次元のグラフの直線の方程式であり、$\theta_1$ が直線の傾き、$\theta_0$ が直線のy切片を表しています。この2つの $\theta$ によって直線がどのように引かれるかが決まります。
これでまず、「いい感じの直線」が引ける、$(\theta_0,\theta_1)$ の値を見つけに行こう!という方針ができました。
それでは次に、「いい感じ」とは何なのか!?について考えていきます。
補足
中学高校数学では直線の方程式は $\theta$ を使わず、以下のように表していたかと思います。
y = ax + b\\
しかし、今後複数のパラメータを使っていく可能性があることを考慮すると、 $a, b, c \cdots$ では不便です。以下のように、$z$まで使い果たすと表現がややこしくなります。
y = ax_1 + bx_2 + c\\
y = ax_1 + bx_2 + cx_3 + d\\
y = ax_1 + bx_2 + cx_3 + dx_4 + \cdots +z\\
そのため、$\theta$ の右下に数字を添えることでパラメータを扱いやすくしています。(本記事は線形単回帰についてなので、パラメータは最大2つです)
「いい感じ」具合を測ろう!(誤差関数について)
最適なパラメータ $\theta$ を探索するためには基準となる指標が必要です。機械学習ではこの指標のことを誤差関数または損失関数と呼びます。誤差関数は仮定関数の性能の悪さを示す指標です。
下のグラフのように、仮定関数 $h_{\theta}(x_i)$ と本当のデータ $y_i$ には誤差があります。
誤差は $h_{\theta}(x_i) -y_i $ で測ることができます。
各データの誤差を二乗した総和 $J(\theta_0, \theta_1)$ によって性能の悪さを測ることができます。この誤差関数は2乗和誤差関数と呼ばれます。
J(\theta_0, \theta_1) = \frac{1}{2m}\sum_{i=1}^{m} (h_{\theta(x_i)}-y_i)^2
仮定関数の性能の良さを最大化するということはつまり、2乗和誤差関数 $J(\theta_0, \theta_1)$ を最小化する = 性能の悪さを最小化するということと同じです。
これで「いい感じ」具合を測定可能になりました!最後にやるべきことは、誤差関数が最小になるパラメータ $\theta$ を探すことです!
補足
なぜ誤差を2乗するのか?なぜ 1/2 するのか?
のちの計算を楽にするためです。
こちらの記事にて、詳しく説明してくださっています。
その他誤差関数
誤差関数には色々な種類があり、場合に応じて使い分ける必要があります。
こちらの記事にて、誤差関数について詳しく説明してくださっています。
- [機械学習で抑えておくべき損失関数(回帰編)]
(https://www.hellocybernetics.tech/entry/2017/06/19/084210#%E8%89%B2%E3%80%85%E3%81%AA%E6%90%8D%E5%A4%B1%E9%96%A2%E6%95%B0)
「いい感じの点」を探そう!(最急降下法について)
簡単のため、誤差関数を $\theta_1$ のみの一変数の関数にします。
J(\theta_1) =\frac{1}{2m}\sum_{i=1}^{m} (\theta_1x_i-y_i)^2
グラフは 下に凸な $\theta_1$ の二次関数になります。最小点は微分して $0$ になるところですね。簡単に求まりそうです。
しかし、一般に誤差関数はもっと複雑です。今回の2次関数のように最小点が初めから検討がつくということは稀であり、どこに最小点を取るのか検討がつきません。最小と思える箇所が複数あったりします。
そこで、任意のパラメータからスタートし、関数の値を減らす方向 = 勾配を利用して、できるだけ最小でありそうな値を探そう!という方針を取る必要があります。
話を一変数の二次関数に戻しましょう。
一変数の場合、勾配とはグラフの接線の傾きそのものになります。よって関数 $J(\theta_1)$ を微分した値
\frac{d}{d\theta_1}J(\theta_1)
が勾配です。最小点よりも右側の $\theta_1$ の勾配は正の値に、左側は負の値になります。任意に指定したパラメータが最小点を示していない場合、その点の勾配分最小点の方向に降りていくアルゴリズムを作ります。
\theta_1 := \theta_1-\alpha\frac{d}{d\theta_j}J(\theta_1)
$:=$ は代入演算子と呼ばれ、右辺を左辺に代入し、値を更新します。
$\alpha$ は学習率と呼ばれ、一回の学習でどれだけパラメータを更新するか、を制御します。$\alpha$ の値はあらかじめ0.01や0.001など、前もって決めておきます。大きすぎても小さすぎても、学習がうまくいかないので、正しく学習できているかを確かめながら学習率を調整する必要があります。
誤差関数を最小にするという制約のもと、この計算を更新がされなくなるまで、つまり収束(convergence)するまで繰り返します。
minimize\ J(\theta_1)
repeat\ until\ convergence\{\\
\theta_1 := \theta_1-\alpha\frac{d}{d\theta_1}J(\theta_1)\}
こうして得られたパラメータ $\theta_1$ こそが、誤差関数を最小化するものだということがわかりました。
そして、パラメータ2つ以上の場合の勾配は偏微分をしてやれば良いので、変数が増えても同じような式によって更新されます。
minimize\ J(\theta_0,\theta_1)
repeat\ until\ convergence\{\\
\theta_j := \theta_j-\alpha\frac{\partial}{\partial\theta_j}J(\theta_0,\theta_1)\\
(for\ j=0\ and\ j=1)\}
これで最もシンプルな機械学習アルゴリズムである最急降下法の説明が終わりました!
補足
ちなみに、誤差関数の偏微分を解くことで以下の数式を得ることができます。
\frac{\partial}{\partial\theta_j}J(\theta_0,\theta_1) = 0
\theta_1 = \frac{\sigma_{xy}}{\sigma^2_x}
\theta_0 = \overline{y}-\theta_1\overline{x}
- 共分散を変数 $x$ の分散で割って回帰直線の傾きを得る
- 2つの変数の平均値と傾きから、回帰直線の $y$ 切片を得る
統計学の教科書にあるこの公式を丸暗記してもよく理解できなかったが、このような過程で導出されたと知り、納得ができました。
こちらの記事にて、詳しく説明してくださっています。
まとめ
Coursera の Andrew Ng 先生の講義を元に、最もシンプルな線形単回帰の学習アルゴリズムについて学びました。
- データに当てはまりそうな「直線」の数式 (仮定関数)
- 直線の性能の悪さを測定する数式 (誤差関数)
- 性能の悪さを最小化する数式 (最急降下法)
次回、Coursera Machine Learning Week2 線形重回帰について学んでいきます。