2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

関西学院大学Advent Calendar 2017

Day 22

多クラスロジスティック回帰のIRLSバッチ処理アルゴリズム

Last updated at Posted at 2017-12-22

はじめに

この記事ではPRML4.3.4に記載されている、多クラスロジスティック回帰のIRLSバッチ処理アルゴリズムを紹介します。PRML本文には*Bishop and Nabney(2008)*にアルゴリズムの詳細があると書いてありますが、どうも見当たらないので実際に導出・実装してみました。

準備

まずは数学的な準備から始めます。詳細はPRML4章を参照してください!
今回はクラス数K、データ数N、入力ベクトル$\boldsymbol{x}$の次元数Mで考えていきます。
つまり、

  • 入力 $(M, 1)$: $\boldsymbol{x}=\{x_1, ...,x_M\}^T$
  • パラメータ $(M, K)$: $\boldsymbol{W}=\{\boldsymbol{w}_1, ..., \boldsymbol{w}_K\}$
  • 予測値 $(N, K)$: $\boldsymbol{Y}=\{\boldsymbol{y}_1, ...,\boldsymbol{y}_N\}^T \text{where}\ \boldsymbol{y_n}=\{y_{n1}, ...,y_{nK}\}$
  • 目的値 $(N, K)$: $\boldsymbol{T}=\{\boldsymbol{t}_1, ...,\boldsymbol{t}_N\}^T \text{where}\ \boldsymbol{t_n}=\{t_{n1}, ...,t_{nK}\}$

であり、$\boldsymbol{t} $はここでは1-of-K符号化法を使っています。
(※括弧はその行列(ベクトル)のサイズを示しています)

また各クラスの予測確率はSoftmax関数で定義します。
つまり$y_n$は
$$ y_n=p\left(C_k|\ a_n\right)=\frac{\exp{a_k}}{\sum_{n=1}^{N}\exp{a_n}} $$
ただし$ a_n=\boldsymbol{w}_k^T\boldsymbol{\phi\left(\boldsymbol{x_n}\right)} $、ここで$ \boldsymbol{\phi}()$は基底関数ベクトルです。
交差エントロピー誤差関数は$ E\left(\boldsymbol{w}_1, ..., \boldsymbol{w}_K\right)=-\sum_{n=1}^{N}\sum_{k=1}^{K}{t_{nk}\ln{y_{nk}}} $であり、この誤差関数を最小にするような$ \boldsymbol{W} $を求めるのが目標です。
ニュートンラフソン法を使うので勾配とヘッシアンも計算して

  • 勾配 $(M, 1)$: $ \nabla_{w_j}E\left(\boldsymbol{w}_1, ..., \boldsymbol{w}_K\right)=\sum_{n=1}^{N}\left(y_{nj}-t_{nj}\right)\boldsymbol{\phi}_n $
  • ヘッシアン $(M, M)$: $\boldsymbol{H}_{jk}=\nabla_{w_j}\nabla_{w_k}E\left(\boldsymbol{w}_1, ..., \boldsymbol{w}_K\right)=\sum_{n=1}^{N}y_{nk}\left(\delta_{kj}-y_{nj}\right)\boldsymbol{\phi}_n\boldsymbol{\phi}_n^T$

この形状のままIRLSアルゴリズムを使うことで最適な$\boldsymbol{w}_j$をそれぞれ求めることができます。しかし、ここで知りたいのは全ての$\boldsymbol{w}_k,\ k=(1, ..., K)$を一括で求めるバッチアルゴリズムなので、次のように行列を定義します。

IRLSバッチ処理アルゴリズムの構成

ここでは、先ほど記載した各パラメータ行列の構成が変わるので新しくそれらを定義します.

  • 入力 $(M, 1)$: $\boldsymbol{x}=\{x_1, ...,x_M\}^T$
  • パラメータ $(MK, 1)$: $\hat{\boldsymbol{w}}=\{\boldsymbol{w}_1^T, ..., \boldsymbol{w}_K^T\}^T$
  • 予測値 $(NK, 1)$: $\hat{\boldsymbol{Y}}=\{\boldsymbol{y}_{1}, ...,\boldsymbol{y}_{K}\}^T,\text{where}\ \boldsymbol{y_{k}}=\{y_{1k}, ...,y_{Nk}\}$
  • 目的値 $(NK, 1)$: $\hat{\boldsymbol{T}}=\{\boldsymbol{t}_1, ...,\boldsymbol{t}_N\}^T,\text{where}\ \boldsymbol{t_k}=\{t_{nk}, ...,t_{nk}\}$

このように構成することにより勾配とヘッシアンは

  • 勾配 $(MK, 1)$
    $$\nabla_{\hat{w}}E\left(\hat{\boldsymbol{w}}\right)=\hat{\boldsymbol{\Phi}}^T\left(\hat{\boldsymbol{Y}}-\hat{\boldsymbol{T}}\right) $$
  • ヘッシアン $(MK, MK)$
\begin{align}
\hat{\boldsymbol{H}}&=\nabla_{\hat{\boldsymbol{w}}}\nabla_{\hat{\boldsymbol{w}}}E\left(\boldsymbol{\hat{w}}\right)\\
&=\begin{pmatrix}
\boldsymbol{H}_{11} &... & \boldsymbol{H}_{1K} \\
... &...  &...\\
\boldsymbol{H}_{K1} &... & \boldsymbol{H}_{KK} 
\end{pmatrix}\\
&=\begin{pmatrix}
\boldsymbol{\Phi R_{11}\Phi} &... & \boldsymbol{\Phi R_{1K}\Phi} \\
... &...  &...\\
\boldsymbol{\Phi R_{K1}\Phi} &... & \boldsymbol{\Phi R_{KK}\Phi} 
\end{pmatrix}\\
&=\hat{\boldsymbol{\Phi}}^\top \hat{\boldsymbol{R}}\hat{\boldsymbol{\Phi}}
\end{align}

ただし、$\hat{\Phi}$は$(NK, MK)$行列で
$$ \hat{\boldsymbol{\Phi}}=diag(\boldsymbol{\Phi}) $$
$ \boldsymbol{\Phi} $は計画行列でサイズは$(N, M)$です。
また$\boldsymbol{\hat{R}}$は$(NK, NK)$行列で
$$\boldsymbol{\hat{R}}=\left(\begin{array}
\boldsymbol{R}_{11} &... & \boldsymbol{R}_{1K}\
... &... &...\
\boldsymbol{R}_{K1} &... & \boldsymbol{R}_{KK}
\end{array}\right) $$
$$\boldsymbol{R}_{jk}=diag(y_{nk}\left(\delta_{jk}-y_{nj}\right)) $$
$\boldsymbol{R}_{jk}$は$(N, N)$行列です。

IRLSアルゴリズムの式にこれらを代入すると以下の式となります。

\begin{align}
\boldsymbol{\hat{w}}^{(new)}&=\boldsymbol{\hat{w}}^{(old)}-\boldsymbol{\hat{H^{-1}\Phi^T\left(\boldsymbol{\hat{Y}}-\hat{T}\right)}}\\
&=\boldsymbol{\hat{w}}^{(old)}-\boldsymbol{\hat{\boldsymbol{\left(\hat{\Phi}^T\hat{R}\hat{\Phi}\right)}^{-1}\Phi^T\left(\boldsymbol{\hat{Y}}-\hat{T}\right)}}\\
\end{align}

 この式をpythonを使って実装した$※_1$結果がこちらです。
result

しっかりと、バッチ学習で3クラス分類できています。

おわりに

ブログでもその他、ちょっとした内容を公開しています。
よければご覧になってください。

備考

$※_1$: 人工知能に関する断創録で通常のIRLSアルゴリズムを実装されていたので、参考にしました。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?