LoginSignup
2
2

More than 1 year has passed since last update.

混合ベルヌーイモデルで手書き数字解析

Last updated at Posted at 2023-01-10

問題設定

手書きの数字データ(MNIST)一つ一つは、巨大なベクトルであるとみなすことができる。そのベクトルの各成分が、10個のベルヌーイ分布の混合分布から生成していると仮定する。つまり、画像ベクトル$ \boldsymbol{x}$の生成分布が

\begin{align}

p(\boldsymbol{x}\mid\boldsymbol{\mu},\boldsymbol{\pi})&=\sum_{k=1}^{10} \pi_k ~p( \boldsymbol{x} \mid \boldsymbol{\mu}_k) \\
p( \boldsymbol{x} \mid \boldsymbol{\mu}_k)~&〜~ Bernoulli(\boldsymbol{\mu}_k)
\end{align}

に従うとする。このとき、画像ベクトルのみからベクトル$\boldsymbol{\mu}_k$と$\boldsymbol{\pi}$を推定することを考える。もしこれができれば、画像から数字の特徴を学習して適切に再現することが可能になる。

本稿では最尤推定をEMアルゴリズムを用いて行う。詳しい理論はPRML9.3章を参照のこと。本稿ではPRML図9.10を再現することを目標とする。

EMアルゴリズムの実装

準備

juliaでMNISTを使用するには、MLDatasets.jlというパッケージから拝借する。画像データが行列として入っているので、それを画像化するためにImages.jlを使用する。

実際にいくつか画像データを表示してみる。

using MLDatasets: MNIST
using Images
using Distributions
using Random
using Plots
dataset = MNIST(:train)

# 表示するときはデータセットを転置する必要があることに注意
for i in 1:5
    img_gray_copy = Gray.(dataset[i].features)' 
    display(img_gray_copy)
end

ベルヌーイ分布は二値変数なので、まずは画像ベクトルを二値ベクトルに変換する。

D = 28*28
N = 1000
X = zeros(N, D)
K = 10
seed = 123

for i in 1:N
    X[i,:] = vec(dataset[i].features)
end

#データベクトルの二値変数化
for i in 1:N
    for j in 1:D
        if X[i,j]>0.5
            X[i,j] = 1 
        else
            X[i,j] = 0
        end
    end
end


function visualize_vector(x)
    A = zeros(28,28)
    for i in 1:28
        for j in 1:28
            A[i,j] = x[28*(j-1)+i]
        end
    end
    return Gray.(A)' 
end

visualize_vector(X[1,:])

上の画像のようにグレーの部分はなくなり黒か白かで表示されるようになる。

最後に、多次元ベルヌーイ分布の対数確率密度関数を計算する関数を実装しておく。

#多次元ベルヌーイ分布のpdfを計算
function log_multi_Bernoulli(x, μ)
    D = length(x)
    if length(x) != length(μ)
        error("dimension mismatching")
    end
    return sum([logpdf(Bernoulli(μ[i]),x[i]) for i in 1:D ])
end

パラメータの初期化

推定すべきパラメータを適当に初期化する。

function parameter_initialize(K, D)
    Random.seed!(seed)
    Π = fill(1/K, K)
    μ = rand(K, D)
    return Π , μ
end

Π , μ = parameter_initialize(10, D)

Eステップ

Mステップでのパラメータ更新に用いるために、次の負担率

\gamma_{nk}=\dfrac{\pi_k p\left(  \boldsymbol{x}_{n}\mid  \boldsymbol{\mu}_k\right) }{\sum ^{K}_{j=1}\pi _{j}p\left(  \boldsymbol{x}_{n}\mid \boldsymbol{\mu}_{j}\right) }

を行列の形で計算する。

#負担率を計算
function E_step(X, Π, μ)
    γ = zeros(N,K)
    #分母と分子が計算できるほど大きい数になるようにoffsetで調整
    offset = maximum([log_multi_Bernoulli(X[n,:], μ[k,:])  for n in 1:N , k in 1:K])
    for n in 1:N
        for k in 1:K
            log_γ = log(Π[k]) + log_multi_Bernoulli(X[n,:], μ[k,:]) - offset - log(sum([ Π[j]*exp(log_multi_Bernoulli(X[n,:], μ[j,:]) - offset)  for j in 1:K ] ))
            γ[n, k] = exp(log_γ)
        end
    end
    return γ
end

行列$\gamma$の代わりにその対数を計算しようとすることは、情報落ちを抑えるための基本的な技術である。その上で計算上特に注意すべきは、対数負担率の分母が「対数の中に和」が入っていることである。今回は扱っているデータ次元が非常に大きいため、pdfは数値計算上扱いにくく、log-pdfをできるだけ使用したい。しかし対数の中に和が入っているとlog-pdfのまま計算を進めることができない。これを解決するために、分母と分子を適当な数で割り算し、扱いやすい領域にスケール変換することを試みた。

Mステップ

次の更新式に従って、パラメータを更新する。

N_k = \sum_{n=1}^N \gamma_{nk}\\
\boldsymbol{\mu}_k = \frac{1}{N_k} \sum_{n=1}^N \gamma_{nk} \boldsymbol{x_n}\\
\pi_k =\frac{N_k}{N}

負担率の総和(つまり、$N_k$の総和)は、理論的には$N$になっているはずであることに注意する。(数値計算の都合上、厳密に$N$にはならない。)さらに対数尤度

 \sum_{n=1}^N \sum_{k=1}^K \gamma_{nk} \left( \ln{\pi_k} + \sum_{i=1}^D (x_{ni}\ln{\mu_{ki}} + (1-x_{ni})\ln{(1-\mu_{ki})}\right)

を計算し、収束していることを確認する。

function M_step(X, Π, μ)
    Nk = zeros(K)
    γ = E_step(X, Π, μ)
    for k in 1:K
        Nk[k] = sum(γ[:,k])
        Π[k] =  Nk[k]/N
        for i in 1:D
            μ[k,i] = sum(γ[:,k] .* X[:,i]) / Nk[k]
        end
    end
    #println(sum(Nk)) 総和はNになっていればOK
    return μ, Π
end

function log_liklihood(X, Π, μ)
    γ = E_step(X, Π, μ)
    log_like = 0.
    for n in 1:N
        for k in 1:K
            log_like += (log(Π[k]+10^(-200)) + sum([X[n,i]*log(μ[k,i] + 10^(-200))+(1-X[n,i])*log(1-μ[k,i] + 10^(-200)) for i in 1:D])) * γ[n, k]
        end
    end
    return log_like
end
ite = 50
liklihood_history = zeros(ite)
for i in 1:ite
    μ, Π = M_step(X, Π, μ)
    liklihood_history[i] = log_liklihood(X, Π, μ)
end

plot(liklihood_history, xlabel="iteration", ylabel="log-liklihood", label="",title="")

スクリーンショット 2023-01-10 20.33.35.png

この結果を見るに、対数尤度は十分収束して何処かの極大解に落ち着いたことがわかる。

最後に、得られた各クラスの平均ベクトル$\boldsymbol{\mu}_k$を画像化してみる。

for i in 1:10
    display(visualize_vector(μ[i,:]))
end

1.png
2.png
3.png
4.png
5.png
6.png
7.png
8.png
9.png
10.png

大量にある極大解の何処かに収束しているため、やや怪しい(「3」らしいものが2個ある、上から2つ目の数字が謎)ところもあるが、非常に簡単なモデルからは想像できないほど概ね数字が再現されているといえよう。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2