PRML(パターン認識と機械学習:Pattern Recognition and Machine Learning)は,機械学習の基礎知識を身につけるのに非常にオススメな書籍です.
本記事では,PRMLの第1章にて紹介されている**「正則化」について細部まで説明し,これをJulia**で実装してみます.
正則化
GitHubのリポジトリ→こちら
ソースコード(.jl,.ipynb)→こちら
概要
前回は,「多項式曲線フィッティング」について説明しました.
そこでは,フィッティングするn次関数のパラメータの数「n+1」が大きいほど,訓練データに対する二乗誤差は小さくなりやすいですが,ある一定のパラメータ数を超えると,テストデータに対する二乗誤差は逆に大きくなりやすい(過学習)ことがわかりました.
以下は,「y=sin(x)」にノイズを加えて生成したデータを用いて,近似した関数をグラフとしてプロットした例です.
9次関数は,観測データ点にかなり近い場所を通っていることがわかりますが,生成する際に用いた関数「y=sin(x)」からは遠い点を通っていることがわかります.
これは,新たに入力データxから観測データy(テストデータ)を得たときに,それに対する誤差は大きくなってしまうことを示しています.
以下の「train RMSE」は,もともと与えられた観測データ(訓練データ)であり,「test RMSE」は,0.0〜1.0の一様乱数により生成した入力データx(テストデータ)を10000個用意し,この出力yとn次の近似関数の平均二乗平方根誤差(RMSE)を示しています.
この結果から,テストデータに対してはn=5が最も良い結果が得られていることがわかります.
したがって,この観測データにおいては,9次関数ではなく5次関数による近似が最も有効であり,**パラメータ数が大きい場合に「モデル(近似関数)が訓練データのみにフィッティングしやすくなる」過学習を起こしてしまう傾向がある**ことがわかりました.
一方,1次関数と9次関数のときのグラフや二乗誤差の違いを見てもわかる通り,パラメータ数を増やすことによってモデルの柔軟性が増す場合があります.
ゆえに,モデルを過学習させないだけではなく,パラメータ数が十分にあることも必要になります.
また,そもそも近似関数を決めるための観測データyが多い場合には,近似関数により予測する点の補間の必要がなくなるため性能が向上することがわかっています.
しかし,訓練データの数やデータの種類の膨大な組み合わせを考えた場合,毎回適切なパラメータ数を決めるのは大変です.
そこで,少ないデータの数を保ったまま,あるいは膨大なパラメータの数を保ったまま,モデルを過学習させない方法があり,これを**正則化**といいます.
理論
理論に入る前に,以下の表を見てください.
以下は,概要にて紹介した観測データにおける,各n次関数のパラメータwです.
※小数第3位までの概数として表示しています.
n=1 | n=2 | n=3 | n=4 | n=5 | n=6 | n=7 | n=8 | n=9 | |
---|---|---|---|---|---|---|---|---|---|
w0 | 0.569 | 0.505 | -0.209 | -0.193 | -0.155 | -0.168 | -0.164 | -0.163 | -0.163 |
w1 | -1.183 | -0.748 | 11.573 | 10.914 | 6.602 | 11.703 | 3.481 | -35.207 | -20.313 |
w2 | -0.435 | -33.110 | -29.746 | 6.642 | -56.395 | 76.560 | 823.221 | 499.399 | |
w3 | 21.783 | 16.372 | -86.830 | 184.125 | -594.425 | -6069.006 | -3276.527 | ||
w4 | 2.706 | 121.120 | -402.065 | 1783.890 | 22075.154 | 9350.228 | |||
w5 | -47.366 | 417.884 | -2758.670 | -44564.133 | -10599.087 | ||||
w6 | -155.083 | 2147.088 | 50650.066 | -4253.415 | |||||
w7 | -657.763 | -30287.327 | 22582.880 | ||||||
w8 | 7407.391 | -20480.280 | |||||||
w9 | 6129.278 |
上記の表から,パラメータ数が多いほうがパラメータwの絶対値が大きい傾向にあることがわかります.
これは,モデルが訓練データにフィッティングしやすいように,柔軟にパラメータが決まってしまうからです.
正則化は,この性質を利用して過学習を抑えます.
つまり,**「モデルが過学習している=各パラメータの絶対値が大きいのでは?」と考えるのです.
したがって,各パラメータの絶対値を小さくすれば,モデルの過学習を抑えることができる**ということになります.
この思想に基づいて,以下の内容を説明していきます.
L2正則化(リッジ回帰)
正則化の一つとして,**「各パラメータの絶対値の二乗和」を罰則とするL2正則化(リッジ回帰)**があります.
これは,パラメータを二乗したことによって微分可能となり,簡単にパラメータを求めることができるため,よく使われます.
(1)式はノルムを一般化したLpノルム,(2)式はL2ノルム,(3)式は(2)式を用いて二乗和誤差にL2正則化項を追加した最終的な誤差関数です.
なお,係数「λ」は正則化パラメータといい,正則化項と二乗和誤差との相対的な重要度(λ≧0)を示しています.
||\boldsymbol{w}||_p = \sqrt[p]{|w_0|^p+|w_1|^p+|w_2|^p+\cdots+|w_n|^p} ...(1) \\
||\boldsymbol{w}||_2 = \sqrt{|w_0|^2+|w_1|^2+|w_2|^2+\cdots+|w_n|^2} ...(2) \\
E = \sum_{i=1}^{m}(y_i-f(x_i, \boldsymbol{w}))^2 + \lambda ||\boldsymbol{w}||_2^2 ...(3) \\
ただし,f(x_i, \boldsymbol{w}) = w_0 + w_1 x_i + w_2 x_i^2 + \cdots + w_n x_i^n = \sum_{j=0}^{n}w_j x_i^j
それでは,(3)式を最小とするwを求めていきます.
正則化項がない場合と同様に,任意のwを表す数l(0≦l≦n かつ 整数)に対してw_l^2の係数は正であるため,二乗和誤差Eが最小となる点は各パラメータを変数とする偏微分により求めることができます.
まず,(3)式を整理し(4)式にします.
\begin{eqnarray}
E &=& \sum_{i=1}^{m}(y_i-f(x_i, \boldsymbol{w}))^2 + \lambda ||\boldsymbol{w}||_2^2 \\
&=& \sum_{i=1}^{m}(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n)^2 + \lambda (w_0^2+w_1^2+w_2^2+\cdots+w_n^2) ...(4)
\end{eqnarray}
(4)式を各パラメータwに対して偏微分します.
\left\{\begin{array}{ll}
\frac{\partial E}{\partial w_0} &=& \sum_{i=1}^{m}-2(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n) &+& 2 \lambda w_0 &=& 0 \\
\frac{\partial E}{\partial w_1} &=& \sum_{i=1}^{m}-2(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n)x_i &+& 2 \lambda w_1 &=& 0 \\
\frac{\partial E}{\partial w_2} &=& \sum_{i=1}^{m}-2(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n)x_i^2 &+& 2 \lambda w_2 &=& 0 \\
&& \vdots \\
\frac{\partial E}{\partial w_n} &=& \sum_{i=1}^{m}-2(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n)x_i^n &+& 2 \lambda w_n &=& 0
\end{array}\right.
\begin{eqnarray}
\sum_{i=1}^{m}w_0 + \sum_{i=1}^{m}w_1x_i + \sum_{i=1}^{m}w_2x_i^2 + \cdots + \sum_{i=1}^{m}w_nx_i^n +\lambda w_0 &=& \sum_{i=1}^{m}y_i \\
\sum_{i=1}^{m}w_0x_i + \sum_{i=1}^{m}w_1x_i^2 + \sum_{i=1}^{m}w_2x_i^3 + \cdots + \sum_{i=1}^{m}w_nx_i^{n+1} +\lambda w_1 &=& \sum_{i=1}^{m}y_ix_i \\
\sum_{i=1}^{m}w_0x_i^2 + \sum_{i=1}^{m}w_1x_i^3 + \sum_{i=1}^{m}w_2x_i^4 + \cdots + \sum_{i=1}^{m}w_nx_i^{n+2} +\lambda w_2 &=& \sum_{i=1}^{m}y_ix_i^2 \\
\vdots \\
\sum_{i=1}^{m}w_0x_i^n + \sum_{i=1}^{m}w_1x_i^{n+1} + \sum_{i=1}^{m}w_2x_i^{n+2} + \cdots + \sum_{i=1}^{m}w_nx_i^{2n} +\lambda w_n &=& \sum_{i=1}^{m}y_ix_i^n
\end{eqnarray}
これを行列にすると,
\begin{eqnarray}
\begin{pmatrix}
m & \sum_{i=1}^{m}x_i & \sum_{i=1}^{m}x_i^2 & \cdots & \sum_{i=1}^{m}x_i^{n} \\
\sum_{i=1}^{m}x_i & \sum_{i=1}^{m}x_i^2 & \sum_{i=1}^{m}x_i^3 & \cdots & \sum_{i=1}^{m}x_i^{n+1} \\
\sum_{i=1}^{m}x_i^2 & \sum_{i=1}^{m}x_i^3 & \sum_{i=1}^{m}x_i^4 & \cdots & \sum_{i=1}^{m}x_i^{n+2} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
\sum_{i=1}^{m}x_i^{n} & \sum_{i=1}^{m}x_i^{n+1} & \sum_{i=1}^{m}x_i^{n+2} & \cdots & \sum_{i=1}^{m}x_i^{2n} \\
\end{pmatrix}
\begin{pmatrix}
w_0 \\
w_1 \\
w_2 \\
\vdots \\
w_n
\end{pmatrix}
+
\lambda
\begin{pmatrix}
1 & 0 & 0 & \cdots & 0 \\
0 & 1 & 0 & \cdots & 0 \\
0 & 0 & 1 & \cdots & 0 \\
\vdots & \vdots & \vdots & \ddots & 0 \\
0 & 0 & 0 & \cdots & 1
\end{pmatrix}
\begin{pmatrix}
w_0 \\
w_1 \\
w_2 \\
\vdots \\
w_n
\end{pmatrix}
&=&
\begin{pmatrix}
\sum_{i=1}^{m}y_i \\
\sum_{i=1}^{m}y_ix_i \\
\sum_{i=1}^{m}y_ix_i^2 \\
\vdots \\
\sum_{i=1}^{m}y_ix_i^n
\end{pmatrix} \\
\left\{
\begin{pmatrix}
1 & 1 & 1 & \cdots & 1 \\
x_1 & x_2 & x_3 & \cdots & x_m \\
x_1^2 & x_2^2 & x_3^2 & \cdots & x_m^2 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_1^n & x_2^n & x_3^n & \cdots & x_m^n
\end{pmatrix}
\begin{pmatrix}
1 & x_1 & x_1^2 & \cdots & x_1^n \\
1 & x_2 & x_2^2 & \cdots & x_2^n \\
1 & x_3 & x_3^2 & \cdots & x_3^n \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & x_m & x_m^2 & \cdots & x_m^n
\end{pmatrix}
+
\lambda
\begin{pmatrix}
1 & 0 & 0 & \cdots & 0 \\
0 & 1 & 0 & \cdots & 0 \\
0 & 0 & 1 & \cdots & 0 \\
\vdots & \vdots & \vdots & \ddots & 0 \\
0 & 0 & 0 & \cdots & 1
\end{pmatrix}
\right\}
\begin{pmatrix}
w_0 \\
w_1 \\
w_2 \\
\vdots \\
w_n
\end{pmatrix}
&=&
\begin{pmatrix}
1 & 1 & 1 & \cdots & 1 \\
x_1 & x_2 & x_3 & \cdots & x_m \\
x_1^2 & x_2^2 & x_3^2 & \cdots & x_m^2 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_1^n & x_2^n & x_3^n & \cdots & x_m^n
\end{pmatrix}
\begin{pmatrix}
y_1 \\
y_2 \\
y_3 \\
\vdots \\
y_m
\end{pmatrix}
\end{eqnarray}
ここで,
\boldsymbol{X} = \begin{pmatrix}
1 & x_1 & x_1^2 & \cdots & x_1^n \\
1 & x_2 & x_2^2 & \cdots & x_2^n \\
1 & x_3 & x_3^2 & \cdots & x_3^n \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & x_m & x_m^2 & \cdots & x_m^n
\end{pmatrix}
,
\boldsymbol{I} = \begin{pmatrix}
1 & 0 & 0 & \cdots & 0 \\
0 & 1 & 0 & \cdots & 0 \\
0 & 0 & 1 & \cdots & 0 \\
\vdots & \vdots & \vdots & \ddots & 0 \\
0 & 0 & 0 & \cdots & 1
\end{pmatrix}
,
\boldsymbol{w} = \begin{pmatrix}
w_0 \\
w_1 \\
w_2 \\
\vdots \\
w_n
\end{pmatrix}
,
\boldsymbol{y} = \begin{pmatrix}
y_1 \\
y_2 \\
y_3 \\
\vdots \\
y_m
\end{pmatrix}
とおくと,
\begin{eqnarray}
(\boldsymbol{X}^T \boldsymbol{X} + \lambda \boldsymbol{I}) \boldsymbol{w} &=& \boldsymbol{X}^T \boldsymbol{y} \\
\boldsymbol{w} &=& (\boldsymbol{X}^T \boldsymbol{X} + \lambda \boldsymbol{I})^{-1} \boldsymbol{X}^T \boldsymbol{y} ...(5)
\end{eqnarray}
(5)式は,L2正則化項を追加した最小二乗法の解(n次関数)の一般形(正規方程式)になります.
最適な正則化パラメータの決定
さて,(5)式に以下のデータを当てはめてみます.
- 表
1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|
x | 0.000 | 0.125 | 0.250 | 0.375 | 0.500 | 0.625 | 0.750 | 0.875 | 1.000 |
y | 0.141 | 0.721 | 1.227 | 0.822 | -0.027 | -0.906 | -0.840 | -0.606 | -0.023 |
このデータに対し,あらゆる正則化パラメータλの正則化項を持つ9次関数によってフィッティングさせると,以下のようになります.
λ=0.0のときは,正則化項を追加していない場合と等しいので過学習してしまい,元の関数「y=sin(x)」からは遠い点を取っています.
これに対し,正則化の影響を強くしたとき(λ=1.5)は,かなり平坦なグラフになっており,過学習を抑えることができています.
しかし,これも元の関数「y=sin(x)」からは遠い点を取っているため,最適なパラメータとは言えません.
λの値を変えてグラフをプロットしてみると,「λ=0.001」や「λ=1.0e-5」がかなり良いことがわかります.
では,具体的にどのλの値が最も良いのかを知るために,(5)式をもとに誤差が最も小さいλを選定します.
選定する際には,0.0〜1.0の一様乱数により生成した入力データxを10000個を用意し,(5)式により得た10000個のyを利用します.
これら10000組のデータ(x, y)を検証データとして,最も誤差が小さいλを調査します.
まずは,λが0.0〜1.5のときの誤差をプロットします.
赤丸の部分に最小値があるようなので,λのスケールを小さくして,今後はλが0.0〜10^-5のときの誤差をプロットしてみます.
さらに,λのスケールを小さくして,今後はλが1.210^-6〜2.210^-6のときの誤差をプロットしてみます.
誤差の最小値がかなり見やすくなってきました.
この赤丸の点が最小値0.142となっており,その時のλは1.75×10^-6でした.
それでは,λ=1.75×10^-6としてwを求め,f(x, w)をプロットしてみます.
正則化を加えていないときや,正則化の影響を強くしすぎたときと比べ,かなり「y=sin(x)」に近い点を通っていることがわかります.
このように,二乗和誤差に正則化項を追加することによって,**訓練データの数が少ない場合でも,近似関数のパラメータ数が多い場合でも,テストデータに対して高精度に推定することができる**ようになるのです.
L1正則化とL2正則化
これまでは,L2正則化について説明してきました.
正則化にはさまざまな種類があり,L2正則化(リッジ回帰)以外にはL1正則化(ラッソ回帰)もあります.
それぞれの数式的な違いは以下の通りです.
(6)式がL1正則化,(7)式がL2正則化の誤差関数を示しています.
E_{L1} = \sum_{i=1}^{m}(y_i-f(x_i, \boldsymbol{w}))^2 + \lambda ||\boldsymbol{w}||_1 ...(6) \\
ただし,||\boldsymbol{w}||_1 = |w_0|+|w_1|+|w_2|+\cdots+|w_n| \\
E_{L2} = \sum_{i=1}^{m}(y_i-f(x_i, \boldsymbol{w}))^2 + \lambda ||\boldsymbol{w}||_2^2 ...(7) \\
ただし,||\boldsymbol{w}||_2 = \sqrt{|w_0|^2+|w_1|^2+|w_2|^2+\cdots+|w_n|^2}
数式での違いは,それぞれのパラメータwを何乗するかになりますが,性質的な違いは以下の通りです.
L1正則化→スパース性を持つ
L2正則化→過学習を防ぐ
L2正則化の「過学習を防ぐ」については,これまで説明してきた通りで,L1正則化と異なり各パラメータを二乗するので,より大きい値が出ないように罰則を加えることになります.
これが,L1正則化よりも過学習を抑制する効果を発揮しています.
では,L1正則化の「スパース性を持つ」について説明します.
まず,**「スパース性」とは各パラメータの一部の値がゼロになりやすい性質**のことです.
極端な例でいうと,以下の表のような感じです.
w_1 | w_2 | w_3 | w_4 | |
---|---|---|---|---|
非スパース | 0.5 | 0.6 | 0.4 | 0.5 |
スパース | 1.0 | 0.1 | 0.0 | 0.1 |
非スパースがL2正則化に相当し,こちらは全てのパラメータが同程度の絶対値をとっていることがわかります.
しかし,スパースであるL1正則化は,w_1が1.0であるのに対し,他のw_2,w_3,w_4は0.0に近い値を取っていることがわかります.
ここから,w_1のみが誤差関数に強い影響を与えているということもわかります.
では,なぜL1正則化はスパース性を持つのか,説明していきます.
以下の図は,各パラメータw_1,w_2と誤差関数Eの関係性を示しており,誤差関数が二乗和誤差のみの場合と正則化項のみの場合の誤差の等高線を記載しています.
合計の誤差関数の最小値は,二乗和誤差と正則化項の等高線のいずれかの交点に存在します.
ここでは,☆マークを合計の誤差関数の最小値とします.
なお,二乗和誤差の最小化点(極小値),正則化項の最小化点(原点)は,ともに位置を変えていません.
ここで,等高線の交点に注目すると,L2正則化のw_1,w_2はともに程よく値をとっているのに対し,L1正則化のw_1は0になっていることがわかります.
このように,L1正則化はスパース性を持つので,**変数選択(不要なパラメータを削除し次元を圧縮すること)**ができるようになります.
実装
今回は以下のパッケージを使用します.
using Distributions
using LinearAlgebra
using Plots
まず,生成する訓練データの総数と近似する関数の次数を設定します.
M = 9 # 訓練データの総数
N = 9 # 近似する関数の次数
次に,以下の数式に対応する正則化パラメータλを設定します.
E = \sum_{i=1}^{m}(y_i-f(x_i, \boldsymbol{w}))^2 + \lambda ||\boldsymbol{w}||_2^2
lambda_ = 1.75e-6 # 正則化パラメータ
次に,データを生成する関数を設定します.
dist = Normal(0.0, 0.1) # 平均0.0,標準偏差0.1の正規分布
base(x) = sin.(2pi*x) # 周期1の正弦波
gen(x) = base(x) + rand(dist, length(x)) # データを生成する関数
次に,以下のn次多項式行列Xを作成する関数を設計します.
\boldsymbol{X} = \begin{pmatrix}
1 & x_1 & x_1^2 & \cdots & x_1^n \\
1 & x_2 & x_2^2 & \cdots & x_2^n \\
1 & x_3 & x_3^2 & \cdots & x_3^n \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & x_m & x_m^2 & \cdots & x_m^n
\end{pmatrix}
function polynomial_matrix(x, N)
m = length(x)
X = fill(1.0, m) # 要素が1の列ベクトルを生成
for j in 1:N
x_vec = x .^ j # xのj乗
X = hcat(X, x_vec) # 右側にベクトルを追加
end
return X
end
次に,以下のn次多項式の係数wを求める関数を設計します.
\boldsymbol{w} = (\boldsymbol{X}^T \boldsymbol{X} + \lambda \boldsymbol{I})^{-1} \boldsymbol{X}^T \boldsymbol{y}
function weight(x, y, N, lambda_)
X = polynomial_matrix(x, N)
w = (X' * X + lambda_ * Matrix{Float64}(I, N+1, N+1)) \ X' * y
return w
end
さらに,推定したn次多項式の係数wと,入力データxをもとにf(x)を出力する関数を設計します.
function f(x, w)
X = polynomial_matrix(x, length(w)-1)
y = X * w
return y
end
これで,関数の準備は整いました.
実際に,どのように近似されるか確認しましょう.
M個の入力データxを生成し,観測データyを得ます.
step = 1.0 / (M - 1)
x = collect(0.0:step:1.0)
y = gen(x)
println("x = ", x)
println("y = ", y)
次に,これをもとにwを推定します.
w = weight(x, y, N, lambda_)
println("w = ", w)
wを推定し,f(x)のパラメータを得ることができました.
それでは,グラフを描いてみます.
gr()
scatter(x, y, xlabel="x", ylabel="y", label="data")
graph_x = collect(0.0:0.01:1.0)
graph_y = f(graph_x, w)
plot!(graph_x, base(graph_x), label="base")
plot!(graph_x, graph_y, label=string("f(x, w)"))
以下は,グラフの出力結果の一例です.
以上で,正則化の理論と実装の説明を終わります.