LoginSignup
3
2

More than 5 years have passed since last update.

Pythonで多項式フィッティング

Posted at

多項式フィッティング

重回帰分析する前にこっちやるべきかなと思いましたまる
PRMLを理解するためにも。。

  • 偏微分して最小値を求めるフィッティング
  • 最尤推定によるフィッティング
  • ベイズ推定によるフィッテイング

を行なっていこうかと。今回は偏微分編ということで

偏微分によるフィッティング

y(x,\boldsymbol{w})=w_0+w_1x+w_2x^2+...+w_Mx^M=\sum_{j=0}^{M}w_jx^j -①\\
E(\boldsymbol{w})=\frac{1}{2}\sum_{n=1}^{N}\Bigl(y(x_n,\boldsymbol{w})-t_n\Bigr)^2 -②\\

N個の観測点xを並べた$$\boldsymbol{x}=(x_1,x_2,...,x_N)^T$$とそれに対応する観測値tをN個並べた$$\boldsymbol{t}=(t_1,t_2,...,t_N)^T$$
が与えられた時、多項式の係数をまとめたベクトル$$\boldsymbol{w}=(w_0,w_1,...w_M)^T$$を用いて多項式①の線形モデルを使ってフィッティングを行う。このフィッティングはもっとも単純なモデルである最小二乗法を用いて行う。すなわち誤差関数E(w)を最小化するwを求める。

誤差関数はただのwの二次関数なので、その解は閉じた形で求められる。
数式弱者なので書き下していきます。。。

\begin{align}
\frac{\partial E(\boldsymbol{w})}{\partial w_i}&=0 -③を求めれば良い。
①を②に代入して\\
E(\boldsymbol{w})&=\frac{1}{2}\sum_{n=1}^{N}\bigl(\sum_{j=0}^{M}w_jx_n^j-t_n\bigr)^2\\
&=\frac{1}{2}\sum_{n=1}^{N}\bigl(w_0+w_1x_n+w_2x_2^2+...+w_Mx_n^M-t_n\bigr)^2\\
\therefore \frac{\partial E(\boldsymbol{w})}{\partial w_i}&=\frac{1}{2}\times 2\sum_{n=1}^{N}\bigl(w_0+w_1x_n+w_2x_2^2+...+w_Mx_n^M-t_n\bigr)x_n^i\\
&=\sum_{n=1}^{N}\bigl(\sum_{j=0}^{M}w_jx_n^j-t_n\bigr)x_n^i\\
③より\\
0&=\sum_{n=1}^{N}\bigl(\sum_{j=0}^{M}w_jx_n^j-t_n\bigr)x_n^i\\
\sum_{n=1}^{N}t_nx_n^i&=\sum_{n=1}^{N}\bigl(\sum_{j=0}^{M}w_jx_n^j\bigr)x_n^i\\ 
&=\sum_{n=1}^{N}x_n^i\bigl(w_0x_n^0+w_1x_n^1+w_2x_n^2+...+w_Mx_n^M\bigr)\\
&=\sum_{n=1}^{N}\bigl(w_0x_n^{0+i}+w_1x_n^{1+i}+w_2x_n^{2+i}+...+w_Mx_n^{M+i}\bigr)\\
&=\quad w_0x_1^{0+i}+w_1x_1^{1+i}+w_2x_1^{2+i}+...+w_Mx_1^{M+i}\\
&\quad +w_0x_2^{0+i}+w_1x_2^{1+i}+w_2x_2^{2+i}+...+w_Mx_2^{M+i}\\
&\quad \vdots\\
&\quad +w_0x_N^{0+i}+w_1x_N^{1+i}+w_2x_N^{2+i}+...+w_Mx_N^{M+i}\\
&=\quad w_0\bigl(x_1^{0+i}+x_2^{0+i}+...+x_N^{0+i}\bigr)\\
&\quad +w_1\bigl(x_1^{1+i}+x_2^{1+i}+...+x_N^{1+i}\bigr)\\
&\quad \vdots\\
&\quad +w_M\bigl(x_1^{M+i}+x_2^{M+i}+...+x_N^{M+i}\bigr)\\
&=\sum_{n=1}^{N}w_0x_n^{0+i}+\sum_{n=1}^{N}w_1x_n^{1+i}+...+\sum_{n=1}^{N}w_Mx_n^{M+i}\\
&=\sum_{j=0}^{M}\sum_{n=1}^{N}w_jx_n^{i+j}\\
\therefore \sum_{n=1}^{N}t_nx_n^i&=\sum_{j=0}^{M}\sum_{n=1}^{N}w_jx_n^{i+j}\\
\sum_{n=1}^{N}t_nx_n^i&=T_i \quad \sum_{n=1}^{N}x_n^{i+j}=a_{ij} とすると\\
T_i &= \sum_{j=0}^{M}a_{ij}w_j\\
&Aをa_{ij}を成分として持つM \times M行列とし\\
A&=\left(
\begin{array}{ccccc}
a_{11} & \cdots & a_{1j} & \cdots & a_{1M}\\
\vdots & \ddots & & & \vdots \\
a_{i1} & & a_{ij} & & a_{iM} \\
\vdots & & & \ddots & \vdots \\
a_{M1} & \cdots & a_{Mj} & \cdots & a_{MM}
\end{array}
\right)\\
\\
\boldsymbol{T}&=(T_0,T_1,...,T_M)^Tとすると\\
\\
\boldsymbol{w}&=A^{-1}\boldsymbol{T}\\
&と置くことができる。
\end{align}

ということでAの逆行列を求めればwを求めることができます。長かったけど多分ベクトルとか行列の微分上手く使えばもっと簡単にできる気がします。あと

f_i(x_n)=x_n^i\\
f_j(x_n)=x_n^j

で置き換えられるので、多項式じゃなくても他の累乗とか確率密度関数とかでも近似できる気がします。logとかとらないとダメそうでめんどいそうですが。

あとは過学習が起きるので以下のように罰則項を設けて正則化します。

\begin{align}
E(\boldsymbol{w})&=\frac{1}{2}\sum_{n=1}^{N}\bigl(\sum_{j=0}^{M}w_jx_n^j-t_n\bigr)^2+\frac{\lambda}{2}\|w\|^2 \\
\frac{\partial \|w\|^2 }{\partial w_i}&=2\times \sum_{j=0}^{M}w_i -④\\
\end{align}

結局④をAに足したものになるので、それは

\begin{align}
\boldsymbol{T} &= A\boldsymbol{w}+\lambda\sum_{j=0}^{M}w_i\\
\boldsymbol{T} &= \bigl(A+diag(\lambda)\bigr)\boldsymbol{w}
\end{align}

とかける。つまりAにλを成分とするM×Mの対角行列を足したものをフィッティングの際に使うことで過学習を抑えられる。たしかリッジ回帰ってやつですね。

実装は次でやります。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2