LoginSignup
0
1

More than 5 years have passed since last update.

ベイズ学習を書いてみた④

Last updated at Posted at 2018-06-14

はじめに

前回は、多次元ガウス分布のベイズ学習について勉強しました。

今回は、多項分布のベイズ学習について勉強します。
今までは、「ベイズ推論による機械学習」の内容に基づき勉強しましたが、今回はこれが「どういう風に使えるのか」を考えて、サンプルデータを使って勉強します。

問題設定

あるマーク式回答試験を考えます。
回答は全てマーク式で「A・B・C」の何れかです。

一郎君は明日にその試験を控えています。

試験で良い点を取りたいのですが、勉強が苦手でどうすべきか悩んでいました。
そんな時、ネットの情報を調べるとマーク式試験に関するこんな情報を見つけました。

3択式の回答の一般的な割合は「2:5:3」である

一郎君は良いことを知ったと思いながらも本当にそうなのかと悩みました。
そんなとき、彼はあることを思いつきました。
その先生の過去問を見ればマークの傾向が分かるのではと考えました。
実際に5年分の過去問を調べると以下のようなパターンでした。

A B C
3 4 1
3 1 2
2 0 3
3 2 2
4 1 1

計算すると、合計数は「15:8:9」だと分かりました。
これらの情報に基づき、一郎君は明日の試験の対策を考えました。
さて、彼はどのように回答すればよい成績を取れる確率が高いでしょうか。

  • ランダムにA・B・Cをマークするのが良いでしょうか。
  • ネットの情報を信じて、Bを多めにマークすべきでしょうか。
  • 過去問に基づき、Aを多めにマークすべきでしょうか。

ベイズ学習

問題設定に少々無理がありますが、一郎君の取るべき戦略をベイズ学習で考えます。

まず、これは3つの値を取る多項分布で考えることができます。
数式は以下の通りです。

\rm{Mult(m|\pi,M)}=M!\prod\frac{\pi_k^{m_k}}{m_k!}

今回の場合、$M$ は問題数です。
求めたいのは $\sum_{k=1}^K\pi_k=1$ のパラメータです。
これがA・B・Cの出やすさを制御しています。
これが分かれば $m_k=[0,1,2]$ (A・B・Cの出現回数)が分かります。

(1) 事前分布

パラメータ $\pi$ の事前分布について考えます。
無情報を仮定しても良いのですが、ここでは「ネットの情報」が利用します。
多項分布では共役事前分布として「ディレクレ分布」が利用可能です。

p(\pi)=\rm{Dir(\pi|\alpha)}

$\alpha$ には正の実数値を設定します。
今回は割合を整数値として、$\alpha=[2,5,3]$ と設定しておきます。

(2) 尤度関数

パラメータ $\pi$ が与えれてた条件におけるデータの尤もらしさを計算します。
これは、先述しましたが、以下のような多項分布とするのが良いです。

p(m|\pi)=\rm{Mult(m|\pi)}

これがデータ数分($\rm{M}$)あるので、以下のようになります。

p(\rm{M}|\pi)=\prod_{n=1}^N\rm{Mult(\rm{M}|\pi)}

対数尤度は以下の通りです。

\log p(\rm{M}|\pi)=\sum_{n=1}^N\rm{\log Mult(\rm{M}|\pi)}

今回は、過去問のデータを用います。

(3) 事後分布

事前分布と尤度関数が仮定できると、ベイズの定理より事後分布は下記のように計算できます。

\begin{eqnarray}
p(\pi|\rm{M})&\propto& p(\rm{M}|\pi)p(\pi)\\
\end{eqnarray}

計算のために対数を取ります。

\begin{eqnarray}
\log p(\pi|\rm{M})&=&\sum_{n=1}^N\rm{\log Mult(\rm{M}|\pi)}+\log\rm{Dir(\pi|\alpha)}+\rm{const.}\\
\end{eqnarray}

これを計算していくと、事後分布は以下のようなディレクレ分布で記述することが出来ます。

p(\pi|\rm{M})=\rm{Dir(\pi|\hat\alpha)}

ただし、

\hat\alpha=\sum_{n=1}^Nm_{n,k}+\alpha_k

予測分布はパラメータ $\pi$ を周辺化消去することで求めることが出来ます。
(詳細は「ベイズ推論による機械学習」に載っています。)

\begin{eqnarray}
p(m_*)&=&\int p(m_*|\pi)p(\pi)\rm{d}\pi\\
&=&\rm{Cat}\left(m_*|(\frac{\alpha_k}{\sum_{i=1}^K\alpha_i})\right)
\end{eqnarray}

Rで計算する

前回まではpythonでやっていましたが、気分を変えてRで計算します。

①ネットの情報

(alpha <- c(2,5,3))
[1] 2 5 3

②過去問データ

(m <- rbind(c(3,4,1), c(3,1,2), c(2,0,3), c(3,2,2), c(4,1,1)))
     [,1] [,2] [,3]
[1,]    3    4    1
[2,]    3    1    2
[3,]    2    0    3
[4,]    3    2    2
[5,]    4    1    1

③事後分布のパラメータの計算

(alpha_hat <- colSums(m) + alpha)
[1] 17 13 12

計算はこれだけです。

④予測分布
全10問をどのようにマークすると確率が最も高いか計算しました。

戦略 A B C
ランダム 3 4 3
ネット情報 2 5 3
過去問 5 2 3
dmultinom(c(3,4,3), prob = alpha_hat)
[1] 0.05962389
dmultinom(c(2,5,3), prob = alpha_hat)
[1] 0.02735684
dmultinom(c(5,2,3), prob = alpha_hat)
[1] 0.06117623

過去問の情報を頼りにマークするのが最も確率が高そうです。
では、ネットの情報を全く無視すべきなのでしょうか。
そんなことはありません。
それぞれの割合を「5:3:2」と変更した方が確率は上がりました。
これは過去問とネットの情報の両方を考慮した結果です。

dmultinom(c(5,3,2), prob = alpha_hat)
[1] 0.06627425

一郎君は「5:3:2」でマークすると良い点が取れる確率が最も高いということが分かりました。

さいごに

簡単な例でベイズ学習を確認しました。
あと、当たり前ですが一郎君はベイズ学習に頼る前に勉強を頑張るのが最も良い得点を取る方法です(笑)。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1