PRML(パターン認識と機械学習:Pattern Recognition and Machine Learning)は,機械学習の基礎知識を身につけるのに非常にオススメな書籍です.
本記事では,PRMLの第1章にて紹介されている**「多項式曲線フィッティング」について細部まで説明し,これをJulia**で実装してみます.
多項式曲線フィッティング
GitHubのリポジトリ→こちら
ソースコード(.jl,.ipynb)→こちら
概要
パターン認識のゴールは,**観測されたデータをもとに,今後観測されるであろうデータを予測することです.
その際に,観測されるデータがある関数によって生成されていると仮定し,その関数を多項式で近似することを考える**のが多項式曲線フィッティングです.
以下に,例を示します.
x=0のときに0.07,0.125のときに0.32,0.25のときに0.19...のようにデータが得られ,以下のようにデータ点がプロットできたとしましょう.
これらの点がある関数によって生成されていると仮定したとき,その関数の一つとして「f(x)=x」を予想することができます.
さて,実際に,関数「f(x)=x」がこれらのデータに対して近似できているかどうかを試すために,関数「f(x)=x」をプロットしてみます.
データ点の大半が**「f(x)=x」に近い値を取っていることがわかります.
ここで,データにはノイズが存在していることに注意しましょう.
観測対象が持っている真実の値は真値ですが,私達が観測して得た値の多くは何らかのノイズが入っており,真値+ノイズ**となっています.
したがって,ノイズが入っていることを考慮したとして,関数**「f(x)=x」**は大方予測できていると考えられます.
実際に,このデータは「f(x)=x」にノイズを加えて生成したものです.
つまり,データを生成する関数を正しく予測できたことがわかります.
本例では,1次関数による近似を試みましたが,実際に自然界に存在するデータはこれほどまでに直線的で単純なデータ点を取るものは非常に少ないため,多項式による適切な曲線フィッティングが求められます.
理論
概要では,生成したデータが**「f(x)=x」**で近似できるときの話をしました.
しかし,その関数は予想して選んだ関数です.
数学的な根拠に基づいて傾きと切片を選んだのではなく,予想して選んだその関数は,たまたま近似できていたに過ぎなく,別の観測データ群を近似する際にも適用できるとは限りません.
そこで,ある基準(客観的な指標)を設定し,その基準に基づき,数学的に関数を近似することを考えます.
例として,データ点と近似関数との**「二乗誤差の最小化」**がよく使われます.
データ点と近似関数の誤差(残差)が大きいほど二乗誤差が大きく,誤差(残差)が小さいほど二乗誤差が小さくなるので,二乗誤差の最小化は近似関数が正しいことを示す一つの指標として適切であることがわかります.
以下,二乗誤差の最小化をもとに,近似関数の設計方法を説明していきます.
1次関数による近似
さて,まずは一番単純な関数「1次関数」による近似を試みます.
以下のようなデータが与えられたとしましょう.
※小数第3位までの概数として表示しています.
1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|
x | 0.000 | 0.125 | 0.250 | 0.375 | 0.500 | 0.625 | 0.750 | 0.875 | 1.000 |
y | 0.099 | -0.077 | 0.260 | 0.446 | 0.416 | 0.616 | 0.890 | 1.035 | 1.050 |
グラフで表すと以下のようになります.
近似する1次関数を「f(x)=ax+b」,入力データをx,xに対する観測データをy,観測データ数をm,二乗和誤差をEとすると,二乗和誤差Eは以下の式で表すことができます.
E=\sum_{i=1}^{m}\{y_i-f(x_i)\}^2 ...(1)
最終的に求めたいのは,Eが最小となるf(x)のパラメータ(a, b)です.
したがって,まずはEを最小化することを考えます.
ここで,(1)式のf(x)に値を代入すると(2)式になります.
E=\sum_{i=1}^{m}(y_i-ax_i-b)^2 ...(2)
ただし,入力データxと観測データyは既知であり,定数aとbは未知であることを確認しておきましょう.
ゆえに,二乗和誤差Eは,aとbを変数とする2変数関数E(a, b)となります.
まず,問題を簡単化するために,変数bを固定(b=0)し,二乗和誤差Eが変数aの関数E(a)であると仮定します.
すると,以下のようなグラフを描くことができます.
グラフより,E(a)は**下に凸の2次関数となることがわかります.
つまり,グラフの頂点(極小値)がE(a)の最小値**となります.
E(a)が下に凸の2次関数であることは,数式で示すこともできます.
\begin{eqnarray}
E &=& \sum_{i=1}^{m}(y_i-ax_i-b)^2 \\
&=& \sum_{i=1}^{m}\{x_i^2a^2+2x_i(b-y_i)a+(b-y_i)^2\} \\
&=& (x_1^2+x_2^2+\cdots+x_m^2)a^2+\sum_{i=1}^{m}\{2x_i(b-y_i)a+(b-y_i)^2\} ...(3)
\end{eqnarray}
(3)式より,a^2の係数が実数の二乗和となることから,a^2の係数は正であることがわかります.
つまり,E(a)は下に凸の2次関数です.
二乗和誤差Eが変数aの関数E(a)である場合において,**グラフの頂点(極小値)がE(a)の最小値**であることはわかりました.
では,二乗和誤差Eが変数bの関数E(b)である場合はどうでしょうか.これも同様です.
\begin{eqnarray}
E &=& \sum_{i=1}^{m}(y_i-ax_i-b)^2 \\
&=& \sum_{i=1}^{m}\{b^2+2(x_ia-y_i)b+(x_ia-y_i)^2\} \\
&=& mb^2+\sum_{i=1}^{m}\{2(x_ia-y_i)b+(x_ia-y_i)^2\} ...(4)
\end{eqnarray}
(4)式より,b^2の係数はm(>0)であることから,E(b)も下に凸の2次関数であることがわかります.
したがって,二乗和誤差Eが変数bの関数E(b)である場合において,**グラフの頂点(極小値)がE(b)の最小値**であることがわかりました.
さて,本来の目的は,aとbを変数とする2変数関数E(a, b)の最小値を求めることでした.
E(a, b)は,a, bの各々について下に凸の2次関数であることから,以下のように,大域的に見てグラフの頂点(極小値)が存在し,それがE(a, b)の最小値になります.
つまり,**E(a, b)の最小値を求めること=E(a, b)の頂点(極小値)を求めること**です.
さて,ここでは,**E(a, b)の極小値を求めることに注目**します.
多変数関数の極小値は,各変数についての微分が0となる位置に存在します.
つまり,E(a, b)のaについての偏微分,bについての偏微分がともに0の点です.
これを計算式で表すと,(5), (6)式のようになります.
E = \sum_{i=1}^{m}(y_i-ax_i-b)^2 \\
\left\{\begin{array}{ll}
\frac{\partial E}{\partial a} &=& \sum_{i=1}^{m}-2(y_i-ax_i-b)x_i &=& 0 ...(5) \\
\frac{\partial E}{\partial b} &=& \sum_{i=1}^{m}-2(y_i-ax_i-b) &=& 0 ...(6)
\end{array}\right.
さらに,(5), (6)式を変形すると,
\left\{\begin{array}{ll}
(\sum_{i=1}^{m}x_i^2)a &+&(\sum_{i=1}^{m}x_i)b &=& \sum_{i=1}^{m}y_ix_i \\
(\sum_{i=1}^{m}x_i^2)a &+&mb &=& \sum_{i=1}^{m}y_i
\end{array}\right.
これを行列にすると,
\begin{eqnarray}
\begin{pmatrix}
\sum_{i=1}^{m}x_i^2 & \sum_{i=1}^{m}x_i \\
\sum_{i=1}^{m}x_i & m
\end{pmatrix}
\begin{pmatrix}
a \\
b
\end{pmatrix}
&=&
\begin{pmatrix}
\sum_{i=1}^{m}y_ix_i \\
\sum_{i=1}^{m}y_i
\end{pmatrix} \\
\begin{pmatrix}
x_1 & x_2 & \cdots & x_m \\
1 & 1 & \cdots & 1
\end{pmatrix}
\begin{pmatrix}
x_1 & 1 \\
x_2 & 1 \\
\vdots & \vdots \\
x_m & 1
\end{pmatrix}
\begin{pmatrix}
a \\
b
\end{pmatrix}
&=&
\begin{pmatrix}
x_1 & x_2 & \cdots & x_m \\
1 & 1 & \cdots & 1
\end{pmatrix}
\begin{pmatrix}
y_1 \\
y_2 \\
\vdots \\
y_m
\end{pmatrix}
\end{eqnarray}
ここで,
\boldsymbol{X} = \begin{pmatrix}
x_1 & 1 \\
x_2 & 1 \\
\vdots & \vdots \\
x_m & 1
\end{pmatrix}
,
\boldsymbol{w} = \begin{pmatrix}
a \\
b
\end{pmatrix}
,
\boldsymbol{y} = \begin{pmatrix}
y_1 \\
y_2 \\
\vdots \\
y_m
\end{pmatrix}
とおくと,
\boldsymbol{X}^T \boldsymbol{X} \boldsymbol{w} = \boldsymbol{X}^T \boldsymbol{y} ...(7)
となります.
さらに,X^T×Xは,入力データ数に問わず正方行列(正則行列であると仮定)であるため,逆行列(X^T×X)^-1を持ちます.
つまり,(7)式の両辺に,左から(X^T×X)^-1をかけることによって,(8)式に変形させることができます.
\boldsymbol{w} = (\boldsymbol{X}^T \boldsymbol{X})^{-1} \boldsymbol{X}^T \boldsymbol{y} ...(8)
あとは右辺の式に値を代入するだけで,E(a, b)の最小値をとる一次関数のパラメータa, bを求めることができます.
このように,残差平方和を利用して近似関数を求める方法を**最小二乗法**と言います.
さて,(8)式に先程のデータに当てはめてみましょう.
すると,以下のような解が得られます.
\begin{eqnarray}
\begin{pmatrix}
a \\
b
\end{pmatrix}
&=&
\begin{pmatrix}
\begin{pmatrix}
0.000 & 0.125 & \cdots & 1.000 \\
1 & 1 & \cdots & 1
\end{pmatrix}
\begin{pmatrix}
0.000 & 1 \\
0.125 & 1 \\
\vdots & 1 \\
1.000 & 1
\end{pmatrix}
\end{pmatrix}
^{-1}
\begin{pmatrix}
0.000 & 0.125 & \cdots & 1.000 \\
1 & 1 & \cdots & 1
\end{pmatrix}
\begin{pmatrix}
0.099 \\
-0.077 \\
\vdots \\
1.050
\end{pmatrix} \\
&=&
\begin{pmatrix}
1.143 \\
-0.045
\end{pmatrix}
\end{eqnarray}
ゆえに,近似関数は以下のように定まります.
f(x) = 1.143x - 0.045
これを観測データと比較してみましょう.
ほとんど誤差がないことがわかります.
1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|
x | 0.000 | 0.125 | 0.250 | 0.375 | 0.500 | 0.625 | 0.750 | 0.875 | 1.000 |
y | 0.099 | -0.077 | 0.260 | 0.446 | 0.416 | 0.616 | 0.890 | 1.035 | 1.050 |
f(x) | -0.045 | 0.097 | 0.240 | 0.383 | 0.526 | 0.669 | 0.812 | 0.955 | 1.098 |
次にグラフです.
なんとなく,バランス良くデータの中央を通っていることがわかります.
もし,新たに入力データxが与えられた場合,xをf(x)に入力することによりその予測値yを得ることができるようになりました.
n次関数による近似
さて,1次関数による近似を一般化し,今度はn次関数による近似を試みます.
f(x) = w_0 + w_1x + w_2x^2 + \cdots + w_nx^n = \sum_{j=0}^{n}w_jx^j
観測データyとf(x)の二乗和誤差Eは,以下のように表すことができます.
\begin{eqnarray}
E &=& \sum_{i=1}^{m}(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n)^2 \\
&=& \sum_{i=1}^{m}(x_i^{2l} w_l^2 + Bw_l + C) \\
&=& (x_1^{2l} + x_2^{2l} + \cdots + x_m^{2l})w_l^2 + \sum_{i=1}^{m}(Bw_l + C) \\
\end{eqnarray}
ただし,l(0≦l≦n かつ 整数)は任意のwを表す数,BまたはCはw_lを含まない定数とします.
この式からわかるように,x^2lは正になるため,それらの和であるw_l^2の係数は正になります.
このことから,二乗和誤差Eが変数w_lの関数E(w_l)である場合において,グラフの頂点(極小値)がE(w_l)の最小値であることがわかり,これが全てのwに適用できることがわかります.
したがって,大域的に見てグラフの頂点(極小値)が存在し,それがE(w)の最小値になります.
これで,最小二乗法により近似関数を求めることができます.
1次関数の場合と同様に,極小値が0となる点を求め,wを決定します.
\left\{\begin{array}{ll}
\frac{\partial E}{\partial w_0} &=& \sum_{i=1}^{m}-2(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n) &=& 0 \\
\frac{\partial E}{\partial w_1} &=& \sum_{i=1}^{m}-2(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n)x_i &=& 0 \\
\frac{\partial E}{\partial w_2} &=& \sum_{i=1}^{m}-2(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n)x_i^2 &=& 0 \\
&& \vdots \\
\frac{\partial E}{\partial w_n} &=& \sum_{i=1}^{m}-2(y_i-w_0-w_1x_i-w_2x_i^2-\cdots-w_nx_i^n)x_i^n &=& 0
\end{array}\right.
\begin{eqnarray}
\sum_{i=1}^{m}w_0 + \sum_{i=1}^{m}w_1x_i + \sum_{i=1}^{m}w_2x_i^2 + \cdots + \sum_{i=1}^{m}w_nx_i^n &=& \sum_{i=1}^{m}y_i \\
\sum_{i=1}^{m}w_0x_i + \sum_{i=1}^{m}w_1x_i^2 + \sum_{i=1}^{m}w_2x_i^3 + \cdots + \sum_{i=1}^{m}w_nx_i^{n+1} &=& \sum_{i=1}^{m}y_ix_i \\
\sum_{i=1}^{m}w_0x_i^2 + \sum_{i=1}^{m}w_1x_i^3 + \sum_{i=1}^{m}w_2x_i^4 + \cdots + \sum_{i=1}^{m}w_nx_i^{n+2} &=& \sum_{i=1}^{m}y_ix_i^2 \\
\vdots \\
\sum_{i=1}^{m}w_0x_i^n + \sum_{i=1}^{m}w_1x_i^{n+1} + \sum_{i=1}^{m}w_2x_i^{n+2} + \cdots + \sum_{i=1}^{m}w_nx_i^{2n} &=& \sum_{i=1}^{m}y_ix_i^n
\end{eqnarray}
これを行列にすると,
\begin{eqnarray}
\begin{pmatrix}
m & \sum_{i=1}^{m}x_i & \sum_{i=1}^{m}x_i^2 & \cdots & \sum_{i=1}^{m}x_i^{n} \\
\sum_{i=1}^{m}x_i & \sum_{i=1}^{m}x_i^2 & \sum_{i=1}^{m}x_i^3 & \cdots & \sum_{i=1}^{m}x_i^{n+1} \\
\sum_{i=1}^{m}x_i^2 & \sum_{i=1}^{m}x_i^3 & \sum_{i=1}^{m}x_i^4 & \cdots & \sum_{i=1}^{m}x_i^{n+2} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
\sum_{i=1}^{m}x_i^{n} & \sum_{i=1}^{m}x_i^{n+1} & \sum_{i=1}^{m}x_i^{n+2} & \cdots & \sum_{i=1}^{m}x_i^{2n} \\
\end{pmatrix}
\begin{pmatrix}
w_0 \\
w_1 \\
w_2 \\
\vdots \\
w_n
\end{pmatrix}
&=&
\begin{pmatrix}
\sum_{i=1}^{m}y_i \\
\sum_{i=1}^{m}y_ix_i \\
\sum_{i=1}^{m}y_ix_i^2 \\
\vdots \\
\sum_{i=1}^{m}y_ix_i^n
\end{pmatrix} \\
\begin{pmatrix}
1 & 1 & 1 & \cdots & 1 \\
x_1 & x_2 & x_3 & \cdots & x_m \\
x_1^2 & x_2^2 & x_3^2 & \cdots & x_m^2 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_1^n & x_2^n & x_3^n & \cdots & x_m^n
\end{pmatrix}
\begin{pmatrix}
1 & x_1 & x_1^2 & \cdots & x_1^n \\
1 & x_2 & x_2^2 & \cdots & x_2^n \\
1 & x_3 & x_3^2 & \cdots & x_3^n \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & x_m & x_m^2 & \cdots & x_m^n
\end{pmatrix}
\begin{pmatrix}
w_0 \\
w_1 \\
w_2 \\
\vdots \\
w_n
\end{pmatrix}
&=&
\begin{pmatrix}
1 & 1 & 1 & \cdots & 1 \\
x_1 & x_2 & x_3 & \cdots & x_m \\
x_1^2 & x_2^2 & x_3^2 & \cdots & x_m^2 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_1^n & x_2^n & x_3^n & \cdots & x_m^n
\end{pmatrix}
\begin{pmatrix}
y_1 \\
y_2 \\
y_3 \\
\vdots \\
y_m
\end{pmatrix}
\end{eqnarray}
ここで,
\boldsymbol{X} = \begin{pmatrix}
1 & x_1 & x_1^2 & \cdots & x_1^n \\
1 & x_2 & x_2^2 & \cdots & x_2^n \\
1 & x_3 & x_3^2 & \cdots & x_3^n \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & x_m & x_m^2 & \cdots & x_m^n
\end{pmatrix}
,
\boldsymbol{w} = \begin{pmatrix}
w_0 \\
w_1 \\
w_2 \\
\vdots \\
w_n
\end{pmatrix}
,
\boldsymbol{y} = \begin{pmatrix}
y_1 \\
y_2 \\
y_3 \\
\vdots \\
y_m
\end{pmatrix}
とおくと,
\begin{eqnarray}
\boldsymbol{X}^T \boldsymbol{X} \boldsymbol{w} &=& \boldsymbol{X}^T \boldsymbol{y} \\
\boldsymbol{w} &=& (\boldsymbol{X}^T \boldsymbol{X})^{-1} \boldsymbol{X}^T \boldsymbol{y} ...(9)
\end{eqnarray}
(9)式は,最小二乗法の解(n次関数)の一般形(正規方程式)になります.
最適なn次関数の決定
さて,(9)式に以下のデータを当てはめてみます.
- 表
1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|
x | 0.000 | 0.125 | 0.250 | 0.375 | 0.500 | 0.625 | 0.750 | 0.875 | 1.000 |
y | 0.063 | 0.792 | 1.053 | 0.580 | 0.051 | -0.711 | -1.129 | -0.654 | -0.155 |
これをn次関数(1≦n≦9)によってフィッティングさせると,以下のようになります.
1次関数の場合,x=0.50のときに観測データに近い値をとっているだけで,他はほとんど近い値をとっていません.
しかしながら,9次関数の場合,全ての観測データにおいて近い値をとっていることがわかります.
このことから,大きいnを用いると,形状がより複雑なデータの特徴を捉えることができることがわかります.
(ただし,必ずしもnが大きい方が良いというわけではありません.)
以下は,n次関数のnと平均二乗平方根誤差(RMSE)の関係を示しています.
このグラフから,nが大きい方が誤差が小さい傾向にあることがわかります.
しかしながら,本当にnが大きい方が絶対に良いのでしょうか?
いいえ,違います.
以下のグラフを見てください.
実は,先程のデータは,y=sin(x)にノイズを加えて生成したものです.
観測データに対してフィッティングがうまく行ったとしても,今後入力データxに対して出力データyを推定することを考えた場合,9次関数は適切な近似関数とは言えません.
以下の「test RMSE」は,0.0〜1.0の一様乱数により生成した入力データxを10000個用意し,この出力yとn次の近似関数の平均二乗平方根誤差(RMSE)を示しています.
この結果から,テストデータに対してはn=5が最も良い結果が得られていることがわかります.
したがって,この観測データにおいては,9次関数ではなく5次関数による近似が最も有効であることがわかりました.
実装
今回は以下のパッケージを使用します.
using Distributions
using Plots
まず,生成する訓練データの総数と近似する関数の次数を設定します.
M = 9 # 訓練データの総数
N = 9 # 近似する関数の次数
次に,データを生成する関数を設定します.
dist = Normal(0.0, 0.1) # 平均0.0,標準偏差0.1の正規分布
base(x) = sin.(2pi*x) # 周期1の正弦波
gen(x) = base(x) + rand(dist, length(x)) # データを生成する関数
次に,以下のn次多項式行列Xを作成する関数を設計します.
\boldsymbol{X} = \begin{pmatrix}
1 & x_1 & x_1^2 & \cdots & x_1^n \\
1 & x_2 & x_2^2 & \cdots & x_2^n \\
1 & x_3 & x_3^2 & \cdots & x_3^n \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & x_m & x_m^2 & \cdots & x_m^n
\end{pmatrix}
function polynomial_matrix(x, N)
m = length(x)
X = fill(1.0, m) # 要素が1の列ベクトルを生成
for j in 1:N
x_vec = x .^ j # xのj乗
X = hcat(X, x_vec) # 右側にベクトルを追加
end
return X
end
次に,以下のn次多項式の係数wを求める関数を設計します.
\boldsymbol{w} = (\boldsymbol{X}^T \boldsymbol{X})^{-1} \boldsymbol{X}^T \boldsymbol{y}
function weight(x, y, N)
X = polynomial_matrix(x, N)
w = (X' * X) \ X' * y
return w
end
さらに,推定したn次多項式の係数wと,入力データxをもとにf(x)を出力する関数を設計します.
function f(x, w)
X = polynomial_matrix(x, length(w)-1)
y = X * w
return y
end
これで,関数の準備は整いました.
実際に,どのように近似されるか確認しましょう.
M個の入力データxを生成し,観測データyを得ます.
step = 1.0 / (M - 1)
x = collect(0.0:step:1.0)
y = gen(x)
println("x = ", x)
println("y = ", y)
次に,これをもとにwを推定します.
w = weight(x, y, N)
println("w = ", w)
wを推定し,f(x)のパラメータを得ることができました.
それでは,グラフを描いてみます.
gr()
scatter(x, y, xlabel="x", ylabel="y", label="data")
graph_x = collect(0.0:0.01:1.0)
graph_y = f(graph_x, w)
plot!(graph_x, base(graph_x), label="base")
plot!(graph_x, graph_y, label=string(N, "th degree"))
以下は,グラフの出力結果の一例です.
以上で,多項式曲線フィッティングの理論と実装の説明を終わります.