1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

混合正規分布におけるWAIC,DIC1,DIC2の比較

Posted at

概要

  • 「渡辺澄夫 WAIC Experiments」に掲載されているMATLABプログラムをRに移植。
  • 混合正規分布についてWAIC,DIC1,DIC2を算出。
  • 汎化損失ではなく汎化誤差を評価している。

WAIC,DIC1,DIC2の計算式

\newcommand\df{\mathrm{d}}
\begin{align}
p(x|a,b)&=a(2\pi)^{-\frac{g}{2}}|\Sigma|^{-\frac{1}{2}}\exp\left(-\frac{1}{2}(x-b)^t\Sigma^{-1}(x-b)\right)\notag\\
&+(1-a)(2\pi)^{-\frac{g}{2}}|\Sigma|^{-\frac{1}{2}}\exp\left(-\frac{1}{2}x^t\Sigma^{-1}x\right),
\Sigma=\mathrm{diag}(\sigma^2),\sigma=1\quad モデル\\
\phi(a)&=U(0,1)\quad 事前分布\\
\phi(b)&=(2\pi)^{-\frac{g}{2}}|\Sigma_b|^{-\frac{1}{2}}\exp\left(-\frac{1}{2}b^t\Sigma_b^{-1}b\right),
\Sigma_b=\mathrm{diag}(\sigma_b^2),\sigma_b=10\quad 事前分布\\
p(a,b|D)&\propto \phi(a)\phi(b)\prod p(X_i|a,b)\quad 事後分布\\
p(x)&=E_{ab}[p(x|a,b)]\quad 予測分布\\
q(x)&=p(x|a,b)\quad 真のモデル\\
&a=0.5,b=(2,2,2)\quad (RegularCase)\\
&a=0.5,b=(0,0,0)\quad (SingularCase)\\
&a=0.5,b=(0.3,0.3,0.3)\quad (DelicateCase)\\
G_n&=E_X\left[\log \frac{q(x)}{p(x)}\right]\quad 汎化誤差\\
T_n&=\frac{1}{n}\sum\log \frac{q(X_i)}{p(X_i)}\quad 学習誤差\\
\mathrm{WAIC}&=T_n+\frac{1}{n}\sum\left(E_{ab}[(\log p(X_i|a,b))^2]-E_{ab}[(\log p(X_i|a,b))]^2\right)\\
\mathrm{DIC1}&=T_n+\frac{2}{n}\sum\left(-E_{ab}[\log p(X_i|a,b)]+\log p(X_i|E_{ab}[a,b])\right)\\
\mathrm{DIC2}&=T_n+\frac{2}{n}\left(E_{ab}\left[\left(\sum\log p(X_i|a,b)\right)^2\right]-E_{ab}\left[\sum\log p(X_i|a,b)\right]^2\right)
\end{align}

Rスクリプト

#######################################################################
TRIAL_N=10 ### Independent Trial Number
NNN=200    ### Number of training samples ### NNN=20, 100, 200, ... 
TTT=5000   ### Number of testing samples
#######################################################################
Case=3
######## Regular Case
if(Case==1){
    ah=0.5   ### True mixture ratio
    bh=c(2,2,2)
    sprintf('Regular case. Sample=%g. Theoretial Gen Err=%f',NNN,(3+1)/(2*NNN))
}
######## Singular Case
if(Case==2){
    ah=0.5   ### True mixture ratio
    bh=c(0,0,0)
    sprintf('Singular case. Sample=%g. Theoretical Gen Err=%f',NNN,1/(2*NNN))
}
######## Delicate Case
if(Case==3){
    ah=0.5   ### True mixture ratio
    bh=c(0.3,0.3,0.3)
    sprintf('Delicate case. Sample=%g. Theoretical Gen Err = %f ~ %f',NNN,1/(2*NNN),2*(3+1)/(2*NNN))
}
#######################################################################
randn = function(n,m){array(rnorm(n*m),dim=c(n,m))}
zeros = function(n,m){array(0,dim=c(n,m))}
mod = function(a,b){a%%b}
#######################################################################
STD=1           ### standard deviation of each normal distribution
STDB=10         ### STD of prior of (b1,b2,b3)
BURNIN=20000    ### Burn-in Number in MCMC
SIZEM=500       ### Number of MCMC parameter samples
INTER=100       ### MCMC sampling interval
######## Note: Prior of ah is the uniform distribution on [0,1]
MCMC_A=0.3*sqrt(2/NNN)     ### Markov Chain Step for ah. 
MCMC_B=0.5*sqrt(2/NNN)     ### Markov Chain Step for (b1,b2,b3).
### If the number of samples is changed, then MCMC_A and MCMC_B 
### are set appropriately. 
######## Testing Samples generated ####################################
set.seed(1)
XT=STD*randn(3,TTT)
for(i in 1:TTT){
    if(runif(1)<ah){
        XT[,i]=XT[,i]+bh
    }
}
######## statistical model : normal mixture ###########################
##### pmodel1: (a,b,c,d) is a vector, (x,y,z) is a scalar.
# pmodel1=@(x,y,z,a,b,c,d)(ttt*((1-a)*exp(-sss*(x^2+y^2+z^2))...
# +a.*exp(-sss*((x-b).^2+(y-c).^2+(z-d).^2))))
pmodel1 = function(x,A,B){
    z=sqrt(2*pi*STD*STD)^3
    ((1-A)*exp(-sum(x*x)/2/STD^2)+A*exp(-colSums((x-B)*(x-B))/2/STD^2))/z
}
##### pmodel2: (x,y,z) is a vector, (a,b,c,d) is a scalar.
# pmodel2=@(x,y,z,a,b,c,d)(ttt*((1-a)*exp(-sss*(x.^2+y.^2+z.^2))...
# +a*exp(-sss*((x-b).^2+(y-c).^2+(z-d).^2))))
pmodel2 = function(X,a,b){
    z=sqrt(2*pi*STD^2)^3
    ((1-a)*exp(-colSums(X*X)/2/STD^2)+a*exp(-colSums((X-b)*(X-b))/2/STD^2))/z
}
######## - log likelihood - log prior #################################
# HHH=@(a,b,c,d,X,Y,Z)(PRIOR*(b^2+c^2+d^2)-sum(log(pmodel2(X,Y,Z,a,b,c,d))))
HHH = function(a,b,X){
    log(sqrt(2*pi*STDB^2))*3+sum(b*b)/2/STDB^2-sum(log(pmodel2(X,a,b)))
}
#######################################################################
sprintf('WAIC begin. Independent %g trials.',TRIAL_N)
######## Trial begins #################################################
t0 <- proc.time()
ret = t(sapply(1:TRIAL_N,function(trial){
    ######## training samples generated #######
    set.seed(trial)
    XN=STD*randn(3,NNN)
    for(i in 1:NNN){
        if(runif(1)<ah){
            XN[,i]=XN[,i]+bh
        }
    }
    ######## MCMC process begins (Metropolis Method) ##################
    ######## Initial parameter is set by the true parameter ###########
    ######## Note: In practical applications, such method can not be used. 
    a0=ah
    b0=bh
    h0=HHH(a0,b0,XN)
    k=0
    acceptance=0
    WA=zeros(1,SIZEM)
    WB=zeros(3,SIZEM)
    MCMC=BURNIN+SIZEM*INTER   ### Total MCMC trial number
    for(t in 1:MCMC){
        a1=a0+MCMC_A*rnorm(1) ### MCMC step
        b1=b0+MCMC_B*rnorm(3) ### MCMC step
        ### a should be in [0,1] 
        if(a1<0){
            a1=a1+1
        }
        if(a1>1){
            a1=a1-1
        }
        h1=HHH(a1,b1,XN)
        delta=h1-h0
        ### Metropolis probability
        if(exp(-delta)>runif(1)){
            a0=a1
            b0=b1
            h0=h1
            if(t>BURNIN){
                acceptance=acceptance+1
            }
        }
        ### MCMC parameters sampled
        if(mod(t,INTER)==0&t>BURNIN){
            k=k+1
            WA[1,k]=a0
            WB[,k]=b0
        }
    }
    ### acceptance ratio
    accept=acceptance/(MCMC-BURNIN)
    ######## If exchange rate is not appropriate, then MC step should be improved. 
    ######## recommended echange rate :  0.05 < echange_rate < 0.6
    ######## It seems that there is no theory about the best exchange rate for MCMC. 
    ######## Calculation of Training Error ############################
    px=zeros(1,NNN)
    for(i in 1:NNN){
        px[i]=mean(pmodel1(XN[,i],WA,WB))
    }
    qx=pmodel2(XN,ah,bh)
    te=mean(log(qx/px))
    ######## WAIC  Functional Variance 
    power1=zeros(1,NNN)
    power2=zeros(1,NNN)
    for(i in 1:NNN){
        logpr=log(pmodel1(XN[,i],WA,WB))
        power1[i]=mean(logpr)
        power2[i]=mean(logpr*logpr)
    }
    vn=sum(power2-power1*power1)
    ######## likelihoof of mean parameter
    avwa=mean(WA)     ### DIC1 average parameter
    avwb=rowMeans(WB) ### DIC1 average parameter
    dic0=log(pmodel2(XN,avwa,avwb))
    ######## DIC1 effective number of parameters
    eff_num=2*sum(-power1+dic0)
    ### vn is the effective number of parameters in WAIC.
    ### This is not equal to the real log canonical threshold.
    ### eff_num is the effective number of parameters defined in DIC1.
    ######## Training error + effective number of paratemers / training sample number
    dic1=te+eff_num/NNN
    ######## Training error + functional variance / training sample number
    waic=te+vn/NNN
    sumlog=zeros(1,SIZEM)
    for(k in 1:SIZEM){
        sumlog[k]=sum(log(pmodel2(XN,WA[k],WB[,k])))
    }
    power1=mean(sumlog)
    power2=mean(sumlog*sumlog)
    ### Training error + effective number / training sample number
    dic2=te+2*(power2-power1*power1)/NNN
    ######## Calculation of Generalization Error ######################
    px=zeros(1,TTT)
    for(i in 1:TTT){
        px[i]=mean(pmodel1(XT[,i],WA,WB))
    }
    qx=pmodel2(XT,ah,bh)
    ### Kullback Leibler = int qlog(q/p) = int q(log(q/p)+p/q-1)
    ge=mean(log(qx/px)+px/qx-1)
    c(TRIAL=trial,GE=ge,DIC1=dic1,DIC2=dic2,WAIC=waic,TE=te,ACCEPT=accept)
}))
proc.time() - t0

計算結果

Case=1 : Regular

     TRIAL          GE         DIC1         DIC2         WAIC           TE  ACCEPT
[1,]     1 0.004664188  0.014072870  0.012834114  0.014376591 -0.005096468 0.60146
[2,]     2 0.021629980 -0.001532818 -0.004340376 -0.001298821 -0.020923822 0.59596
[3,]     3 0.002679341  0.016576226  0.017991100  0.017883690 -0.002983547 0.60818
[4,]     4 0.007078676  0.014318154  0.011508826  0.013093491 -0.005254242 0.61414
[5,]     5 0.001315138  0.018015944  0.018630800  0.018681797 -0.001299642 0.61294
[6,]     6 0.020836343 -0.005283996 -0.007258766 -0.006429104 -0.024061321 0.59686

Case=2 : Singular

     TRIAL           GE          DIC1        DIC2        WAIC            TE  ACCEPT
[1,]     1 2.867942e-05  4.980253e-03 0.015693788 0.005055827  0.0050276830 0.13008
[2,]     2 2.466726e-03 -8.687872e-03 0.011318619 0.003293516 -0.0083514531 0.59642
[3,]     3 2.596135e-03 -1.124335e-02 0.013061751 0.003539433 -0.0107937262 0.53184
[4,]     4 8.194074e-04  5.524719e-03 0.008170343 0.004297762 -0.0002539411 0.37508
[5,]     5 1.738788e-03  1.229865e-05 0.004110315 0.003277511 -0.0081537885 0.47482
[6,]     6 2.399593e-03 -1.332184e-01 0.033614341 0.003009369 -0.0088403528 0.42338

Case=3 : Delicate

     TRIAL          GE         DIC1         DIC2         WAIC            TE  ACCEPT
[1,]     1 0.005676867  0.016453584  0.023234471  0.013994557 -0.0033313337 0.76964
[2,]     2 0.017098995  0.000662655 -0.002119495 -0.000657838 -0.0194740196 0.73148
[3,]     3 0.011137859  0.010451749  0.020259483  0.026125482 -0.0032883180 0.74270
[4,]     4 0.016149661  0.015894959  0.026178762  0.015647987  0.0010076113 0.69700
[5,]     5 0.003495657  0.018997899  0.025670537  0.020188652  0.0002758186 0.77104
[6,]     6 0.020805610 -0.001175075  0.021498243 -0.004007560 -0.0210332154 0.73556
1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?