LoginSignup
0
0

More than 1 year has passed since last update.

多変量正規分布を二つのブロックに分けてgibbs sampling(ギブスサンプリング)する。

Last updated at Posted at 2022-05-24

以下の多変量正規分布を考えます。

\begin{pmatrix}
x_1 \\ 
x_2
\end{pmatrix} \sim N\left( \mu, \Sigma   \right) \\

\mu = \begin{pmatrix} \mu_1 \\ \mu_2 \end{pmatrix}、
\Sigma = \begin{pmatrix} \Sigma_{11} & \Sigma_{12} \\ 
\Sigma_{21} & \Sigma_{22}  \end{pmatrix}\\

このとき以下の条件付き確率がなりたちます( 以下に詳しい導出があります。https://qiita.com/m1t0/items/12b58f3583eab6e1e05c )。

p(x_1 | x_2 ) = N( \mu_{1|2}, \Sigma_{1|2} ) \\
\mu_{1|2} = \mu_1 + \Sigma_{12} \Sigma_{22}^{-1} (x_2 - \mu_2) \\
\Sigma_{1|2} = \Sigma_{11}- \Sigma_{12}  \Sigma_{22}^{-1}\Sigma_{21} \\

以上の条件付き確率を用いて、以下のgibbs sampling(ギブスサンプリング)を行えます。

x_1 \sim p(x_1 | x_2 ) \\
x_2 \sim p(x_2 | x_1 )

matlabでコードしたものを以下に示します。

%
% 正規分布を二つのブロックに分割し、gibbs samplingします。
%
rng(4);
close all; clear all;

N = 10^5;


% 共分散行列を乱数で生成します。
X = randn( 4, 4 );
X = X'*X;

% 共分散行列をブロックに分割します。       
x11 = X(1:2,1:2);
x12 = X(1:2,3:4);
x22 = X(3:4,3:4);
x21 = x12';

% 散布図をプロットしてみます。
test1( X, x11, x12, x21, x22, N )

% gibbsサンプリングを行います。
xx = zeros( 1, 4 ); % 初期値
for i = 1:N
b = xx(end,3:4);
a = cd(  b, x11, x12, x21, x22 ); % 上半分をサンプリング
b = cd(  a, x22, x21, x12, x11 ); % 下半分をサンプリング
xx(end+1,:) = [ a b ];
end

subplot( 3, 3, 7 );
plot( xx(:,1), xx(:,2), '.' );
subplot( 3, 3, 8 );
plot( xx(:,3), xx(:,4), '.' );
subplot( 3, 3, 9 );
plot( xx(:,2), xx(:,3), '.' );

muhatx = mean( xx );
covhatx = cov( xx );

%
% 条件付き確率の計算
%
function a = cd( b, x11, x12, x21, x22 )
mu = x12 /( x22 )*(b');
cv = x11 - x12 /( x22 ) * x21;
a = mvnrnd( mu, cv );
end

%
% 多変量正規分布の散布図を部分的にグラフ表示
%
function test1(X, x11, x12, x21, x22, N)

mu = zeros( 4, 1 );
xs = [];
xs = mvnrnd( mu, X, N );
figure( 1 );
subplot( 3, 3, 1 );
plot( xs(:,1), xs(:,2), '.' );
subplot( 3, 3, 2 );
plot( xs(:,3), xs(:,4), '.' );
subplot( 3, 3, 3 );
plot( xs(:,2), xs(:,3), '.' );


muhat = mean( xs );
covhat = cov( xs );

%
% ためしに上半分と下半分を入れ替えてグラフ表示
%
X2 = [ x22 x21; x12 x11];
xt = mvnrnd( mu, X2, N );
muhat2 = mean( xt );
covhat2 = cov( xt );

subplot( 3, 3, 4 );
plot( xt(:,1), xt(:,2), '.' );
subplot( 3, 3, 5 );
plot( xt(:,3), xt(:,4), '.' );
subplot( 3, 3, 6 );
plot( xt(:,2), xt(:,3), '.' );

end

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0