1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

prml 10章 実装

Last updated at Posted at 2025-04-23

はじめに

prml(パターン認識と機械学習)の10章で紹介されている図をRd実装しましたので、こちらにその結果をまとめます。順次更新予定です。

図10.2の左 正規分布の平均場近似

library( mvnfast )

# 近似される正規分布の平均と精度
mu <- c( 1, 3 )
Sigma <- matrix( c( 4, 0.9*3*2, 0.9*3*2, 9 ), ncol = 2 )
Lambda <- solve( Sigma )

# サンプリングによる表示
# サンプリング数の指定とプロット
n <- 5000
result <- mvnfast::rmvn( n, mu, Sigma )
plot( result[,1], result[,2] )

####################################

# 初期値の設定
mu1 <- 0
mu2 <- 0
# mu1とmu2の変化を記録する用のベクトル
mu1_store <- mu1
mu2_store <- mu2

# 反復による推定
m <- 1 # カウンター
crite <- 1e-4 # 収束の判断基準
dif1 <- dif2 <- 1 # dif1,dif2は前の推定値との差分の絶対値を記録。初期値として1を代入
while ( ! ( dif1 < crite & dif2 < crite ) ) {
  # 値の更新
  mu1 <- mu[ 1 ] - Lambda[ 1, 1 ]^( -1 ) * Lambda[ 1, 2 ] * ( mu2 - mu[ 2 ] )
  mu2 <- mu[ 2 ] - Lambda[ 2, 2 ]^( -1 ) * Lambda[ 2, 1 ] * ( mu1 - mu[ 1 ] )

  # 更新値の記録
  mu1_store <- c( mu1_store, mu1 )
  mu2_store <- c( mu2_store, mu2 )
  
  # 差分の絶対値の計算
  dif1 <- abs( mu1_store[ m + 1 ] - mu1_store[ m ] )
  dif2 <- abs( mu2_store[ m + 1 ] - mu2_store[ m ] )
  
  # カウンターの更新
  m <- m + 1
}

# mu1, mu2の変化のプロット
plot( mu1_store, type="l" )
plot( mu2_store, type="l" )

# 近似された分布からのサンプリング
result1 <- mvnfast::rmvn( n, mu1[ length( mu1 ) ], Lambda[ 1, 1 ]^( -1 ) )
result2 <- mvnfast::rmvn( n, mu2[ length( mu2 ) ], Lambda[ 2, 2 ]^( -1 ) )

# 近似される分布のサンプリング
plot( result[ ,1 ], result[ , 2 ], col = 1, xlab = "", ylab = "" )
# 平均場近似によるサンプリング。赤で示している。
points( result1, result2, col = 2 )

mu1の変化過程
image.png

mu2の変化過程
image.png

近似前と近似後のサンプリング結果の違い(黒:近似前、赤:近似後)
image.png

図10.2の右 逆のKLダイバージェンスによる近似

library( mvnfast )

# 近似される正規分布の平均と精度
mu <- c( 1, 3 )
Sigma <- matrix( c( 4, 0.9*3*2, 0.9*3*2, 9 ), ncol = 2 )
Lambda <- solve( Sigma )

# サンプリングによる表示
# サンプリング数の指定とプロット
n <- 5000
result <- mvnfast::rmvn( n, mu, Sigma )
plot( result[ , 1 ], result[ , 2 ] )

#######################################

# 近似された分布からのサンプリング
result1 <- mvnfast::rmvn( n, mu[ 1 ], Sigma[ 1, 1 ] )
result2 <- mvnfast::rmvn( n, mu[ 2 ], Sigma[ 2, 2 ] )

# 近似される分布のサンプリング
plot( result[ , 1 ], result[ , 2 ], col = 1, xlab = "", ylab = "" )
# 平均場近似によるサンプリング。赤で示している。
points( result1, result2, col = 2 )

image.png

図10.4の(a)

特にパラメータは書かれていなかったので適当にきんじしt

n_samp <- 10000 # number of sampling

#####################################################
# True distribution

# definition of parameters
shape_t = 4
scale_t = 1/5 
mu0_t <- 0
lambda0_t <- 2

# sampling
a_t <-rgamma( n_samp, shape = shape_t, scale = scale_t ) 
b_t <- numeric( length( a_t ) )
for( i in 1:length( a_t ) ){
  tau_t <- a_t[ i ]
  var_t <- ( lambda0_t * tau_t )^( -1 )
  b_t[ i ] <- rnorm( 1, mu0_t, var_t )
}

#####################################################
# The approximated distribution with initial values

# definition of parameters
shape_init = 150
scale_init = 1/100
mu_init <- 1
var_init <- 0.2

# sampling independently
a_init <- rgamma( n_samp, shape = shape_init, scale = scale_init ) 
b_init <- rnorm( n_samp, mu_init, var_init )

#####################################################
# True distribution
plot( b_t, a_t, xlim = c( -2, 2 ), ylim = c( 0, 2 ), xlab = "mu", ylab = "tau" )
# The approximated distribution with initial values
points( b_init, a_init, col = 2 )

image.png

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?