本記事では、PRML第4章「線形識別モデル」の内容について全体像を整理し、Juliaを使った一部の図の実装を試みます。細かな式の導出よりも、流れやモチベーションを掴むことを第一に記述しています。最後に、Juliaコードと汎用性の高そうなポイントをまとめました。(プログラムやアルゴリズムの詳細な解説はしていませんが、変数や記法はPRMLと合わせているため、本と見比べれば何をやっているか分かり易いと思います。)
問題設定
全ての$D$次元ベクトルで表されたデータに対して、そのベクトルの値と、各データが所属するクラスが分かっていると仮定します。このとき、データ空間にクラスを分類するための境界を引くことが線形識別問題です。
より具体的に式で定式化します。ある関数$f$(これを活性化関数と呼びます)に対して
y(\boldsymbol{x}) = f(\boldsymbol{w}^T\boldsymbol{x}+w_0)
でモデルの「予測」を定義し、その出力$y$をクラス分類に対応づける「ルール」を定義します。これで入力$\boldsymbol{x}$にクラスを与えることができました。あとは分類の正答率が上がるように、パラメータ$w_0,\boldsymbol{w}$やモデル$f$をうまく調整すれば良いわけです。
「予測」や「ルール」とには、大きく分けて二種類あります。一つは「決定的モデル」と呼ばれる手法で、例えば$y(\boldsymbol{x})$が正ならばクラス1に、負ならばクラス2に属すると判断するというなどというルールが挙げられます。この方法はシンプルですが、「クラス1に入るか、入らないか」という二択の答えしか得られません。もう一つは「確率的モデル」であり、こちらは出力$y$をそのクラスに分類される確率として解釈できるという利点があります1。確率なので、ベイズの定理と非常に相性が良いです。
識別関数(決定的)によるアプローチ
多クラス分類への拡張
2クラス分類での手法は、問題設定のところで解説した通りです。これを他クラス分類の拡張することを考えてみます。$K$クラス分類では、$K$種類の決定関数から境界を決める方法がいくつかありますが、どれも問題点があります。
- 1対他分類器
どこにも属さない領域が存在してしまう - 1対1分類器(多数決法)
例えば3クラス分類などで、3すくみの関係になった領域では、クラス分類できない - $y_k$の値が最大になるクラス$k$を選ぶ方法
必ず境界線が$K$本生じてしまう
例えば下図のように、本来は境界線2本分類するのが適切だと思われるところを、無理やり3本で分類してしまっているため、特にクラス1(赤)の正答率が低くなってしまっています。
今度は、具体的にパラメータ$w_0,\boldsymbol{w}$を決める手順について考えてみましょう。
最小二乗法
y_k(\boldsymbol{x}) = \boldsymbol{w_k}^T\boldsymbol{x}+w_{k0}
で出力を定義し、$y_k$の値が最大になるクラス$k$を選ぶ方法を取るとします。このとき誤差関数を最小化するようにパラメータを選ぶと、パラメータは擬似逆行列を用いて簡単に記述できます。上のグラフも、このやり方を用いて境界を描いています。
フィッシャーの線形判別法
上で考えた方法は、もともと$D$次元ベクトルだったデータをあえて1次元に落として活用している、という見方をすることができます。この見方のもとでは、「各クラスごとの平均ベクトル$\boldsymbol{m_k}$は、射影した後にできるだけ離れた位置にあったほうがいい」という要請の元から、
\boldsymbol{w}^T(\boldsymbol{m_2-m_1})
を最小化することで$\boldsymbol{w}$を決めることが良さそうに思えます。しかし、この方法はいささか安直で、射影した後に2つのクラスの裾野の部分で重なりが生じてしまう可能性があります。具体的にみてみましょう。次の二次元データを一次元に射影することにより分類を試みます。
下図の左が、平均の差のみを使った2クラス分類の結果です。平均を遠くすることだけを考えて一次元に射影すると、判別不可能な重なり領域がだいぶ大きくなっていることがわかります。
そのため、この方法を改善して、「各クラス内の射影した後の分散はできるだけ小さいほうがいい」という要請のもと
J(\boldsymbol{w})=\frac{\boldsymbol{w}^T(\boldsymbol{m_2-m_1})}{s_1^2+s_2^2}
を最小化することで$\boldsymbol{w}$を決めることにします。この値を、フィッシャーの情報量基準と呼びます。実際にフィッシャー情報量基準を最小化するパラメータで同じデータを射影した結果が上手の右です。2クラスが完全に分離できていることがわかります。この量は突然出てきた感じもしますが、この決め方は、データのクラスの符号化$t_k$をうまく選んだ時の最小二乗法に根拠を見出すこともできます。
パーセプトロンアルゴリズム
今度は、出力をステップ関数で決める方針を考えます。つまり、$\theta$をヘビサイド関数として
y_k(\boldsymbol{x}) = \theta~ (\boldsymbol{w_k}^T\boldsymbol{x}+w_{k0})
として出力を定義します。同様にパラメータを誤差の最小化によって決定すれば良いのですが、自然に定義される誤差関数が「間違って分類されたデータ」からの和を取るため、$\boldsymbol{w}$について微分することがうまくできません。そのため、再急降下法で数値的に$\boldsymbol{w}$を決定することになります。
確率的アプローチ(最尤推定)
今までの決定的なやり方では、0か1かを答えることしかできず、データの位置からクラスに所属する確率を考えることができません。そこで確率的なアプローチの出番になります。手法は大きく分けて二つ挙げられます。
生成的アプローチ | 識別的アプローチ | |
---|---|---|
特徴 | クラスの事前分布とクラスの条件付き確率分布をモデル化し、そこから尤度関数を導出する | 尤度関数の関数形を直接モデル化する |
モデル | $p(\boldsymbol{x}\mid C_k)$と$p(C_k)$をモデル化 | $p(C_k\mid\boldsymbol{x})$をモデル化 |
メリット | $p(\boldsymbol{x})$が計算できるので、データを自分で生成できる | パラメータの数が少なく、最尤推定の計算が楽 |
まずはベイズの定理を適用せず、最尤推定によってパラメータを決定することを考えてみます。
生成的アプローチ
クラスに関する事前分布$p(C_k)$をベルヌーイ分布で設定し、クラスごとのデータの生成確率$p(\boldsymbol{x}\mid C_k)$を平均$~\boldsymbol{\mu_k}$, 分散$\Sigma$の正規分布でモデル化します。これらを使って尤度関数$p(D|\boldsymbol{\mu_1},\boldsymbol{\mu_2},\pi,\Sigma)$を計算し、微分をすることで4つのパラメータを導出します。やや技巧的な行列計算ののち、
\boldsymbol{\mu}_1=\boldsymbol{m}_1\\
\boldsymbol{\mu}_2=\boldsymbol{m}_2\\
\boldsymbol{\Sigma}=\boldsymbol{S}
という(至極当然な)結果が得られます。
より一般に、$p(\boldsymbol{x}\mid C_k)$を(特殊な形をした)指数型分布属でモデル化すると、尤度関数の関数形が線形関数のシグモイド関数
p(C_k\mid\boldsymbol{x})=\sigma(\boldsymbol{w}^T\boldsymbol{x})\\
で書けることが証明できます。尤度関数がこの関数形をしている時、この識別問題をロジスティック回帰と呼びます。(注意:回帰と言ってますが回帰ではなく識別問題です) 線形性がもたらす特徴として、決定境界が必ず直線になります。
識別的アプローチ
生成的アプローチでクラス事後確率を計算すると、多くの場合でロジスティック回帰に帰着されることがわかりました。それなら最初から尤度関数の関数形を線形関数のロジスティックだと仮定してパラメータを推定すれば話が早そうです。実際に尤度関数の関数形を
y(\boldsymbol{x})=p(C_k\mid\boldsymbol{x})=\sigma(\boldsymbol{w}^T\boldsymbol{x})\\
と仮定して、対数尤度にマイナスをつけたもの(交差エントロピー関数)を最小化する$w$を求めます。この計算は解析的には難しいので、IRLSというアルゴリズムを用います。
今回は、上の二つの識別問題をIRLSを使って解いてみます。左は直線で識別可能で、右は分離不能です2。結果は次のようになります。
上のグラフは、識別可能であるがゆえに、直線を挟んで確率が0か1かになっています。これは「分離可能な識別関数は過学習を引き起こす」ことを反映しています。一方分離不能な問題の方では、確率が滑らかに変化していることがわかります。これは確率的アプローチがうまくいっていることの表れです。
過学習の原因は、交差エントロピー関数が下のグラフのように、アルゴリズムの反復の中で0になってしまう(つまり、尤度関数が1になれてしまう)ためです。
モデルエビデンスの評価
モデルエビデンス$p(D \mid M)$は、そのモデルの「正当性」を表す量です。識別的アプローチの枠組みでこの量をラプラス近似を使って計算することができます。
p(D\mid M)= \int d\boldsymbol{\theta}~ p(D\mid \boldsymbol{\theta})~p(\boldsymbol{\theta})
のうち非積分関数$p(D\mid \boldsymbol{\theta})~p(\boldsymbol{\theta})$をラプラス近似すると、指数の肩を二次までで落としてガウス分布に近似されます。するとガウス分布の積分は容易であることからモデルエビデンスが計算できるという流れです。結果として、事後確率分布のモードでのパラメータの値についての尤度に加え、Occam係数と呼ばれる「複雑なモデルにペナルティーを課す量」が現れます。このOccam係数を特別な仮定のもとさらに近似したものがベイズ情報量基準になります。
(未解決の疑問)なぜ自然な流れでベイズ的に計算を進めていくと、複雑なモデルにペナルティーを課すようになるのでしょう?まるで自然が複雑なモデルを嫌っているかのようです。この辺りを、物理の自由エネルギーやエントロピーを関連付けて解釈できないか模索中です。
確率的アプローチ(ベイズ)
工事中
プログラムのポイント解説
- 正規分布に従う乱数生成
Distributionsパッケージで簡単に生成可能。 - 陰関数で与えられた領域を塗りつぶす
JuliaにはImplicitEquations.jlやInterpolations.jlなどの陰関数用のパッケージがあるが、解像度が悪くなるなどあまり使い勝手は良くなかった。特に不等式で指定した領域を塗りつぶすことは対応していない(?)ようである。そのため、pythonのブールインデックスに対応する操作をJuliaで利用する。具体的には、$f(x,y)>0$に色を塗りたければ、平面上に生成したグリッドの各点で$z=f(x,y)$を評価し、これが正になるところだけ抜き出せば良い。この際、グリッドを関数の入力にする場合、ドット演算になることに注意。 - 二つのヒストグラムを、重なりが見えるように表示する
関数histogramのパラメータalphaを調整すると、色が薄くなり、重なりが見えるようになる。 - ヒートマップを描きたい
ヒートマップは網目が荒く(?)、美しく表示されない場合があります。contour関数を使い、fill=trueとする方が手軽にきれいなグラフになります。(ただcontourが邪魔です)
記事中の図を作成するコード
最小二乗法による3クラス分類
まずデータを生成。
using Distributions, Plots, Random
Random.seed!(1234)
Σ₁ = [
1 2
2 5
]
Σ₂ = [
1 2
2 5
]
Σ₃ = [
1 2
2 5
]
μ₁ = [8, 20]
μ₂ = [7, 24]
μ₃ = [9, 10]
dist1 = MvNormal(μ₁, Σ₁)
dist2 = MvNormal(μ₂, Σ₂)
dist3 = MvNormal(μ₃, Σ₃)
x₁ = rand(dist1, 2^7)
x₂ = rand(dist2, 2^7)
x₃ = rand(dist3, 2^7)
scatter(x₁[1,:],x₁[2,:],label="class1",color="red")
scatter!(x₂[1,:],x₂[2,:],label="class2",color="green")
scatter!(x₃[1,:],x₃[2,:],label="class3",color="blue")
データを行列形式に整理し、公式に従って最適なパラメータを計算。
#データ集合の行列表現
N= 2^7*3
X̃ = zeros(N,3)
for i in 1:N
X̃[i,:] = hcat(1,hcat(x₁,x₂,x₃)[:,i]')
end
T = zeros(N,3)
for i in 1:N
if 1≤i≤N/3
T[i,:] = [1,0,0]
end
if N/3+1≤i≤2*N/3
T[i,:] = [0,1,0]
end
if 2*N/3+1≤i≤N
T[i,:] = [0,0,1]
end
end
W̃ = inv(X̃' * X̃) * X̃' * T
#linear discriminant
y₁(x,y) = (W̃' * [1,x,y])[1,:][1] #x,y is number
y₂(x,y) = (W̃' * [1,x,y])[2,:][1]
y₃(x,y) = (W̃' * [1,x,y])[3,:][1]
グラフで描画。グリッドの生成方法と、二項演算がドット演算になっていることに注意。
X = [n for n in range(4, 13, length = 700)]
Y = [n for n in range(4, 32, length = 700)]
x_grid, y_grid = [X for Y in Y, X in X], [Y for Y in Y, X in X]
Z₁ = y₁.(x_grid,y_grid)
Z₂ = y₂.(x_grid,y_grid)
Z₃ = y₃.(x_grid,y_grid)
plot(x_grid[(Z₁ .> Z₂) .&& (Z₁ .> Z₃)], y_grid[(Z₁ .> Z₂) .&& (Z₁ .> Z₃)], color="lightsalmon",xlims=[4,13],ylims=[4,32],label="")
plot!(x_grid[(Z₂ .> Z₁) .&& (Z₂ .> Z₃)], y_grid[(Z₂ .> Z₁) .&& (Z₂ .> Z₃)], color="darkseagreen1",label="")
plot!(x_grid[(Z₃ .> Z₂) .&& (Z₃ .> Z₁)], y_grid[(Z₃ .> Z₂) .&& (Z₃ .> Z₁)], color="lightskyblue1",label="")
scatter!(x₁[1,:],x₁[2,:],label="class1", color="red")
scatter!(x₂[1,:],x₂[2,:],label="class2", color="green")
scatter!(x₃[1,:],x₃[2,:],label="class2", color="blue")
フィッシャーの線形判別法
データを生成。
using Distributions, Plots, Random
Random.seed!(1234)
Σ₁ = [
1 2
2 5
]
Σ₂ = [
1 2
2 5
]
N₁ = 2^8
N₂ = 2^8
μ₁ = [8, 20]
μ₂ = [7, 24]
dist1 = MvNormal(μ₁, Σ₁)
dist2 = MvNormal(μ₂, Σ₂)
x₁ = rand(dist1, N₁)
x₂ = rand(dist2, N₂)
scatter(x₁[1,:],x₁[2,:],label="class1",color="red")
scatter!(x₂[1,:],x₂[2,:],label="class2",color="green")
公式に従ってパラメータを計算。正規化はnormalize()で計算できる。平均だけを考慮したナイーブな方法と、フィッシャーの線形判別法を比較する。
using LinearAlgebra
using StatsPlots
m₁ = [sum(x₁[1,:]), sum(x₁[2,:])] / N₁
m₂ = [sum(x₂[1,:]), sum(x₂[2,:])] / N₂
S_w = zeros(2,2)
for i in 1:N₁
S_w += (x₁[:,i] - m₁) * (x₁[:,i] - m₁)'
end
for i in 1:N₂
S_w += (x₂[:,i] - m₂) * (x₂[:,i] - m₂)'
end
w_pre = normalize(m₂-m₁)
w = normalize(inv(S_w) * (m₂-m₁))
y₁_pre = zeros(N₁)
for i in 1:N₁
y₁_pre[i] = w_pre' * x₁[:,i]
end
y₂_pre = zeros(N₂)
for i in 1:N₂
y₂_pre[i] = w_pre' * x₂[:,i]
end
#y₁ is a set of y in class 1
y₁ = zeros(N₁)
for i in 1:N₁
y₁[i] = w' * x₁[:,i]
end
y₂ = zeros(N₂)
for i in 1:N₂
y₂[i] = w' * x₂[:,i]
end
hist1 = histogram!(histogram(y₁_pre, color="red" ,bin=6:0.5:26,alpha=0.4,label="class1"),
y₂_pre,color="green", bin=6:0.5:30 ,alpha=0.4,label="class2")
hist2 = histogram!(histogram(y₁, color="red", bin=0:0.125:6,alpha=0.4,label="class1"),
y₂, color="green", bin=0:0.125:6,alpha=0.4,label="class2")
plot(hist1,hist2)
IRLSによるロジスティック回帰
まず例のごとく乱数生成。
using Distributions, Plots, Random, LinearAlgebra
Random.seed!(1234)
Σ₁ = [
1 2
2 5
]
Σ₂ = [
1 2
2 5
]
Σ₃ = [
1 0
0 4
]
Σ₄ = [
1 0
0 4
]
N₁ = 2^8
N₂ = 2^8
N₃ = 2^8
N₄ = 2^8
μ₁ = [7.5, 19]
μ₂ = [7, 24]
μ₃ = [7.5, 19]
μ₄ = [7, 24]
dist1 = MvNormal(μ₁, Σ₁)
dist2 = MvNormal(μ₂, Σ₂)
dist3 = MvNormal(μ₃, Σ₃)
dist4 = MvNormal(μ₄, Σ₄)
x₁ = rand(dist1, N₁)
x₂ = rand(dist2, N₂)
x₃ = rand(dist3, N₃)
x₄ = rand(dist4, N₄)
plot1 = scatter!(scatter(x₁[1,:],x₁[2,:],label="class1",color="red"),
x₂[1,:],x₂[2,:],label="class2",color="green")
plot2 = scatter!(scatter(x₃[1,:],x₃[2,:],label="class3",color="red"),
x₄[1,:],x₄[2,:],label="class4",color="green")
scatter(plot1,plot2)
データを行列に格納してアルゴリズムを適用。シグモイド関数の収束が早いので、すぐに情報落ちが生じてしまうことに注意。
#データ集合の行列表現
N= 2^8*2
#データの次元は M=2+1
Φ₁ = zeros(N,3) #入力データ
Φ₂ = zeros(N,3)
for i in 1:N
Φ₁[i,:] = hcat(1,hcat(x₁,x₂)[:,i]')
Φ₂[i,:] = hcat(1,hcat(x₃,x₄)[:,i]')
end
t = zeros(N) #出力データ
for i in 1:N
if 1≤i≤N/2
t[i] = 1
end
if N/2+1≤i≤N
t[i] = 0
end
end
w₁ = zeros(3) #初期条件
w₂ = zeros(3)
sigmoid(x) = 1 ./ (1 .+ exp.(-x))
R₁ = zeros(N,N)
R₂ = zeros(N,N)
E₁ = zeros(9)
E₂ = zeros(9)
for i in 1:9
#wを通じてyを更新
y₁ = (sigmoid(w₁' * Φ₁'))' #スカラーyₙたちを保存したベクトル
y₂ = (sigmoid(w₂' * Φ₂'))'
E₁[i] = sum([-t[n]*log(y₁[n]) - (1-t[n])*log(1-y₁[n]) for n in 1:N])
E₂[i] = sum([-t[n]*log(y₂[n]) - (1-t[n])*log(1-y₂[n]) for n in 1:N])
#yを通じて重み付け行列Rの更新
for j in 1:N
R₁[j,j] = y₁[j] * (1 - y₁[j])
R₂[j,j] = y₂[j] * (1 - y₂[j])
end
z₁ = Φ₁*w₁ - R₁\(y₁-t)
z₂ = Φ₂*w₂ - R₂\(y₂-t)
#wを更新
w₁ = (Φ₁'*R₁*Φ₁) \ Φ₁' * R₁ * z₁
w₂ = (Φ₂'*R₂*Φ₂) \ Φ₂' * R₂ * z₂
end
plot_E = plot!(plot(E₁), E₂)
f₁(x, y) = sigmoid(w₁' * [1 x y]')[1,1]
f₂(x, y) = sigmoid(w₂' * [1 x y]')[1,1]
plot(plot_E,xlabel="iteration",ylabel="E")
二つのヒートマップを描きます。
contour(4:0.01:11, #x
12:0.01:30, #y
f₁,
fill=(true ,cgrad([:darkseagreen1, :lightsalmon]))
)
scatter!(x₁[1,:],x₁[2,:],label="class1",color="red")
scatter!(x₂[1,:],x₂[2,:],label="class2",color="green")
contour(4:0.01:11, #x
12:0.01:30, #y
f₂,
fill=(true ,cgrad([:darkseagreen1, :lightsalmon]))
)
scatter!(x₃[1,:],x₃[2,:],label="class1",color="red")
scatter!(x₄[1,:],x₄[2,:],label="class2",color="green")