※以下の企画です
前回に続いてゼロつくやっていきます。
今回も4章 -ニューラルネットワークの学習-に入ります。
いよいよ機械学習らしくなってくる章です。たのしみ〜
それでは頑張っていきます〜
4章 ニューラルネットワークの学習
損失関数
そもそもニューラルネットワークにおける「学習」とは、パラメータ(重み$w$や、バイアス$b$など)を教師データを基に自分で算出することを言う。
その際に、正解データと出力結果を比較してどれだけ間違っているかを表す指標として「損失関数」というものを用いる。
今回は損失関数の種類について学んでいく〜
二乗和誤差
二乗和誤差は、ニューラルネットワークの出力結果と正解データとの誤差を二乗して、それらをすべて足し合わせたもの。
以下の式で表される。
$$
E = \frac{1}{2} \sum_{k=1}^{K} (y_k - t_k)^2
$$
- $( y_{k} )$: ニューラルネットワークの出力(予測値)
- $( t_k )$: 正解データ
- $( K )$: データの次元数
先述の通り、損失関数はそのニューラルネットワークの精度の悪さを表すので、予測値と正解データの差分が大きいほど二乗和誤差の答えも大きくなる。
一応コードにも起こしておく。
def sum_square_error(y, t):
return 0.5 * np.sum((y-t)**2)
予測が合っているときと外しているときで、ちゃんと誤差が大きくなるかをテストしてみる。
データとしては、t
が正解データでなり、y
がニューラルネットワークからの出力だと仮定する。
やりたいことは、t
の配列のうち、どれが"1"になっているかを当てる。y
の出力結果は、前回学んだ"ソフトマックス関数"によって得られた確率値である。
y = [0.1, 0.05, 0.6, 0.05, 0.1, 0.1]
t = [0, 0, 1, 0, 0, 0]
sum_square_error(np.array(y), np.array(t))
0.09750000000000003
y = [0.1, 0.05, 0.1, 0.05, 0.1, 0.6]
t = [0, 0, 1, 0, 0, 0]
sum_square_error(np.array(y), np.array(t))
0.5974999999999999
ちゃんと外れているときの出力値の方が大きくなっていることがわかる。
交差エントロピー誤差
交差エントロピー誤差は、主に分類問題で使われる損失関数。出力結果が確率として解釈される場合に有効っぽい。
以下の式で表される。
$$
E = - \sum_{k=1}^{K} t_k \log y_k
$$
- $( t_k )$: 正解ラベル(通常「1(正解)」または「0(不正解)」)
- $( y_k )$: ニューラルネットワークの出力(ソフトマックス関数を適用した確率値)
実質的には「出力値の自然対数を計算するだけ」の関数。
出力が1になれば$log1 = 0$になって、損失関数が0になるという仕組み。
def cross_entropy_error(y, t):
delta = 1e-7
return -np.sum(t * np.log(y + delta))
コードもほぼそのまんまだが、一点だけ注意点があった。
それはdelta = 1e-7
の部分で、めちゃくちゃ小さい数字をy
に加えている。
理由としては、y=0
になったときno.log(0) = -inf
という爆発を起こしてしまうため、その防止策として加える必要があるみたい。なるほど。
せっかくなので二乗和誤差と同様のテストを行ってみる。
y = [0.1, 0.05, 0.6, 0.05, 0.1, 0.1]
t = [0, 0, 1, 0, 0, 0]
cross_entropy_error(np.array(y), np.array(t))
0.09750000000000003
y = [0.1, 0.05, 0.1, 0.05, 0.1, 0.6]
t = [0, 0, 1, 0, 0, 0]
cross_entropy_error(np.array(y), np.array(t))
2.302584092994546
こちらもちゃんと想定通りの結果になった。
ミニバッチ学習
ミニバッチ学習とは?
機械学習では通常、大量のデータを使ってモデルを学習させる。しかし、すべてのデータを一度に計算するのは計算コストが高くなる。そのため、データをいくつかの「小さなグループ(ミニバッチ)」に分けて順番に学習する方法を「ミニバッチ学習」と呼ぶとのこと。
ミニバッチ学習の流れ
- 学習データをランダムにシャッフルする
- 一定サイズのバッチに分割する(例: バッチサイズ = 32とか)
- 各バッチで損失関数を計算し、パラメータを更新する
これは実践的に学んだほうがガッテン行きそうなので一旦頭にいれておく・・・
まとめ
今回は損失関数の代表的な2つを学んだ。
活性化関数とか損失関数とか、なぜ存在していて、どんな種類があるのか とかはちゃんと説明できるように復習していきたい。