LoginSignup
0
1

More than 3 years have passed since last update.

多次元ガウス分布 : 多次元ガウス分布 (ベイズ推定)

Last updated at Posted at 2020-11-05

記事の目的

多次元ガウス分布と、平均が未知のときの共役事前分布である多次元ガウス分布を使用し、Rを使ってベイズ推定します。
商品Aの購買個数と、商品Bの購買個数の平均を事後分布で推定します。
参考:ベイズ推論による機械学習入門

目次

0. モデルの説明
1. ライブラリ
2. 推定する分布
3. 事前分布
4. 事後分布
5. 予測分布

0. モデルの説明

IMG_0186.jpeg

1. ライブラリ

library(dplyr)
library(mvtnorm)
library(ggplot2)
library(gganimate)
set.seed(100)

2. 推定する分布

商品Aの購買個数と、商品Bの購買個数は、それぞれ平均が50, 170で、標準偏差が10, 相関係数が0.7の多次元ガウス分布に従います(適当)。
この、真の分布の平均の50, 170を事後分布で推定します。

sigma.true <- matrix(c(10^2, 0.7*10^2, 0.7*10^2, 10^2), ncol=2)
X.true <- rmvnorm(1000, mean = c(50, 170), sigma=sigma.true)
NULL %>% ggplot(aes(x=X.true[,1], y=X.true[,2])) + geom_point() +
  labs(x="商品Aの購買個数", y="商品Bの購買個数", title="推定する分布")

image.png

3. 事前分布

事前分布として、多次元ガウス分布の平均が未知の時の共役事前分布である多次元ガウス分布を指定します。

m0 <- c(0, 0)
lambda.mu0 <- diag(2)/100^2
X.pre <- rmvnorm(1000, mean = m0, sigma=solve(lambda.mu0))
NULL %>% ggplot(aes(x=X.pre[,1], y=X.pre[,2])) + geom_point() +
  labs(x="商品Aの平均購買個数", y="商品Bの平均購買個数", title="推定する分布")

image.png

4. 事後分布

真の分布から1000個のサンプルを取って事後分布を推定します。
事後分布から1000個サンプルを取って、学習の過程を可視化しました。
下図から、うまく平均の50, 170を推定できていることが分かります。

lambda <- solve(sigma.true)
Data <- rmvnorm(1000, mean = m0, sigma=solve(lambda.mu0))
Data <- data.frame(Data, iter=rep(0, 1000))
for(i in 1:3){
  X.sample <- rmvnorm(10^i, mean = c(50, 170), sigma=diag(2)*10)
  lambda.mu <- nrow(X.sample)*lambda + lambda.mu0
  m <- solve(lambda.mu) %*% (lambda %*% apply(X.sample, 2, sum)+lambda.mu0 %*% m0)
  X.post <- rmvnorm(1000, mean = m, sigma=solve(lambda.mu))
  X.post <- data.frame(X.post, iter=rep(i, 1000))
  Data <- rbind(Data, X.post)
}
Data %>% ggplot(aes(x=X1, y=X2)) +
  geom_point(col="green") +
  labs(x="商品Aの平均購買個数", y="商品Bの平均購買個数", title="推定する分布") +
  transition_states(iter, transition_length = 2, state_length = 1)

b1.gif

5. 予測分布

予測分布から1000個サンプルを取って、学習の仮定を可視化しました。
下図から、相関関係もうまくを推定できていることが分かります。

m.predict <- m0
lambda.predict <- solve(solve(lambda)+solve(lambda.mu0))
Data.predict <- rmvnorm(1000, mean = m0, sigma=solve(lambda.mu0))
Data.predict <- data.frame(Data.predict, iter=rep(0, 1000))
for(i in 1:3){
  X.sample <- rmvnorm(10^i, mean = c(50, 170), sigma=diag(2)*10)
  lambda.mu <- nrow(X.sample)*lambda + lambda.mu0
  lambda.predict <- solve(solve(lambda)+solve(lambda.mu))
  m <- solve(lambda.mu) %*% (lambda %*% apply(X.sample, 2, sum)+lambda.mu0 %*% m0)
  m.predict <- m
  X.predict <- rmvnorm(1000, mean = m.predict, sigma=solve(lambda.predict))
  X.predict <- data.frame(X.predict, iter=rep(i, 1000))
  Data.predict <- rbind(Data.predict, X.predict)
}
Data.predict %>% ggplot(aes(x=X1, y=X2))+
  geom_point(col="blue")+
  labs(x="商品Aの購買個数", y="商品Bの購買個数", title="推定する分布")+
  transition_states(iter, transition_length = 2, state_length = 1)

b2.gif

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1