27
29

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

emceeでマルコフ連鎖モンテカルロ法(MCMC)によるサンプリング

Last updated at Posted at 2018-05-12

本稿ではemceeという、マルコフ連鎖モンテカルロ法(MCMC)によるサンプリングを行うモジュールの使い方を紹介します。

マルコフ連鎖モンテカルロ法(MCMC)について

すでに色んな記事に説明があるのでここでは説明せず、これを参考に
https://qiita.com/shogiai/items/bab2b915df2b8dd6f6f2
https://qiita.com/AIKI_SHI/items/0d2cd63d89a6433646f4
https://qiita.com/kaityo256/items/f05f9914eb0ad16afe05

numpyでの実装は
https://qiita.com/kenmatsu4/items/55e78cc7a5ae2756f9da
https://qiita.com/yadoyado128/items/9ff7fd6dad5c10259763

emceeとは

emceeはMCMCを簡単に実装できるpythonモジュールです。
公式サイト http://dfm.io/emcee/current

他にもpymc3やpystanがありますが、emceeは比較的にわかりやすいし、全部純pythonなのでインストールしやすいです。pymc3はtheanoが必要なのでインストールできないことが多いようです。pystanはstanのスクリップを書かなければならないので、pythonを書くだけでは使えないのは欠点。

emceeはpymc3のようにたくさん分布関数を準備しておいてあるのではなく、自分で分布関数を定義しなければならないが、色々手動で書くので自由度が高いです。機能はpymc3より少ないが、単なるサンプリングをしたいだけなら十分。むしろ機能が少ないから勉強しやすいし、余計なことを覚えなくても使えます。

emceeを作った人は天文学者なので天文学の研究によく貢献していますが、他の分野でも勿論よく使えます。

emceeのインストール

言うまでもないかもしれませんが、簡単にpipでインストールできます。

pip install emcee

emceeでサンプリングする

まずは例としてこのような分布関数を使います。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def fn(xy):
    x,y = xy
    return np.maximum(0,(10-np.abs(2+x)-np.abs(y-1)))**2

mx,my = np.meshgrid(np.linspace(-10,10,41),np.linspace(-10,10,41))
mz = fn([mx,my])
ax = plt.figure(figsize=[8,8]).add_axes([0,0,1,1],projection='3d')
ax.plot_surface(mx,my,mz,rstride=1,cstride=1,alpha=0.2,edgecolor='k',cmap='rainbow')
plt.show()

emceeを使ってこのような分布でサンプリングするにはこう書いたらいいです。

import emcee

def lnfn(xy):
    x,y = xy
    prop = np.maximum(1e-10,(10-np.abs(2+x)-np.abs(y-1)))**2
    return np.log(prop)

ndim = 2 # 鎖の数
nwalker = 6 # 次元の数
nstep = 4000 # 鎖の長さ
xy0 = np.random.uniform(-4,4,[nwalker,ndim]) # 初期のxとy
sampler = emcee.EnsembleSampler(nwalker,ndim,lnfn) # サンプラーを作る
sampler.run_mcmc(xy0,nstep) # サンプリング開始

xy = sampler.flatchain # サンプリングでできた結果を取得
x,y = xy[:,0],xy[:,1]

# 分布を描く関数
def bunpuplot(x,y):
    plt.figure(figsize=[8,8])
    plt.subplot(221,aspect=1)
    plt.scatter(x,y,alpha=0.002,marker='.')
    plt.subplot(222)
    plt.hist(y,bins=100,orientation='horizontal')
    plt.subplot(223)
    plt.hist(x,bins=100)
    plt.subplot(224,aspect=1)
    plt.hist2d(x,y,bins=50,cmap='rainbow')
    plt.colorbar()
    plt.tight_layout()
    plt.show()

bunpuplot(x,y)

結果はこうなります。

emceeを使う方法はまずemcee.EnsembleSamplerのオブジェクトを作って.run_mcmc()メソッドを使ってサンプリングを開始すのです。

sampler = emcee.EnsembleSampler(鎖の数,次元,分布の対数関数)
sampler.run_mcmc(初期値,鎖の長さ)

使う関数のは分布関数そのままではなくその分布の対数を使うので注意

鎖の数は次元の2倍以上でないとこういうこっぴどいエラーが出ます

AssertionError: The number of walkers needs to be more than twice the dimension of your parameter space...
unless you're crazy!

そしてサンプリング終了した後、結果は全部.chainと.flatchainに置かれています。.chainは鎖毎にわけられてサイズは(鎖の数,鎖の長さ,次元の数)ますが、.flatchainでは全ての鎖は合併してサイズは(鎖の数×鎖の長さ,次元の数)です。

print(sampler.chain.shape) # (6, 4000, 2)
print(sampler.flatchain.shape) # (24000, 2)

最初の部分を捨てる(バーンイン)

普通はサンプリングを始めたばかりの時に初期値によるバイアスがあるため、最初のいくつかのサンプリングは捨てられることが多い。

例として、各段階でのxの平均値の変化を描いてみます。

ndim = 2
nwalker = 100
nstep = 400
xy0 = np.random.uniform(-0.1,0.1,[nwalker,ndim])
sampler = emcee.EnsembleSampler(nwalker,ndim,lnfn)
sampler.run_mcmc(xy0,nstep)
xy = sampler.flatchain
x,y = xy[:,0],xy[:,1]

plt.figure()
plt.plot(sampler.chain[:,:,0].mean(0).T)
plt.show()

xの最大分布は-2ですが、初期値は0に集まったので、-2に収束するまでは時間がかかります。収束するまでの部分は分布にバイアスを残します。

最初の部分を捨てるにはこう書きます。例えば100捨てたい場合

sampler = emcee.EnsembleSampler(nwalker,ndim,lnfn)
xy0,_,_ = sampler.run_mcmc(xy0,100)
sampler.reset()
sampler.run_mcmc(xy0,nstep)

.run_mcmc()で実行した時に3つの値が返されます。それはすべての鎖の最後の

  • 位置
  • 確率
  • 状態
    ここでは位置だけを使います。そして.reset()でここまでの鎖を消して、その位置を初期値としてもう一度.run_mcmc()を。

関数がパラメータを含む場合

関数に他のパラメータが必要な場合、パラメータのリストをargsというキーワードに渡します。

例えば

def lnfn(xy,*p):
    return -((p[0]/2)**2-(xy**2).sum())**2/p[1]+xy[0]/p[2]

ndim = 2
nwalker = 10
nstep = 6000
p = 10.,200.,20.
xy0 = np.random.uniform(-5,5,[nwalker,ndim])
sampler = emcee.EnsembleSampler(nwalker,ndim,lnfn,args=p)
sampler.run_mcmc(xy0,nstep)
xy = sampler.flatchain
x,y = xy[:,0],xy[:,1]
bunpuplot(x,y)

制限値を指定する

サンプリングする途中、もし間違って極端な値まで走ったら戻れなくなることもあります。その場合は制限値を指定したらいいです。残念ながらemceeでは直接に範囲を指定することができませんが、関数で制限値を設定することは簡単です。

例えばこんな関数はどこでも0にはならない。このままサンプリングしたら遠くに行ってしまう可能性があります

def fn(xy):
    x,y = xy
    xy2 = x**2+y**2
    return np.sin(xy2)**2/xy2**0.5

こんな風に制限を指定できます

b = np.array([[-4,-2],[3,5]])
def lnfn(xy,b):
    if(np.any(xy<b[0])|np.any(xy>b[1])):
        return -np.inf
    x,y = xy
    xy2 = x**2+y**2
    return np.log(np.sin(xy2/2)**2/xy2**0.5)

ndim = 2
nwalker = 20
nstep = 10000
xy0 = np.random.uniform(-1,1,[nwalker,ndim])
sampler = emcee.EnsembleSampler(nwalker,ndim,lnfn,args=[b])
sampler.run_mcmc(xy0,nstep)
xy = sampler.flatchain
x,y = xy[:,0],xy[:,1]
bunpuplot(x,y)

qi6.png

制限値を超えたら値が-∞になるので、そんなところに行く可能性は完全に0になります。

emceeで関数の最適値を求める

MCMCはある関数の最適値を求めることによく使われます。ある関数のパラメータをランダムして関数を最大値にするということです。もし無闇に完全ランダムをしたら時間がかかりすぎますが、確率の高いところをよくランダムされるようにしたら早くなります。だからMCMCでパラメータをサンプリングすることで最大値を見つけやすいです。

最適値を求める方法は他にも色々あります。例えば、機械学習でよく使われているのは勾配降下法です。勾配降下法は局所的最大値に陥りやすいという欠点がありますが、MCMCなら最大値に歩いてもまだ周りに移る可能性があるので、局所的最大値から抜けやすい。

例として局所的最大値を複数持っているこの関数を使います。

def fn(xy):
    x,y = xy
    return np.exp(-(5**2-(x**2+y**2))**2/200 + xy[1]/20) * (6./5+np.sin(6*np.arctan2(x,y)))

その関数をそのまま分布関数に使って、サンプリングして最大値という峰を目指します。

def lnfn(xy):
    return np.log(max(1e-10,fn(xy)))

ndim = 2
nwalker = 6
xy0 = np.random.uniform(-5,5,[nwalker,ndim])
sampler = emcee.EnsembleSampler(nwalker,ndim,lnfn)
sampler.run_mcmc(xy0,4000)

xy = sampler.flatchain
x,y = xy[:,0],xy[:,1]
x_max,y_max = sampler.flatchain[sampler.flatlnprobability.argmax()]
plt.figure(figsize=[7,6])
plt.gca(aspect=1)
plt.scatter(x,y,alpha=0.1,c=sampler.flatlnprobability,marker='.',cmap='rainbow')
plt.colorbar()
plt.scatter(x_max,y_max,c='k') # 最大値の位置を描く
plt.show()

qi8.jpg

ここで.flatlnprobabilityはここまでサンプリングした各位置の確率値。鎖毎にflatしたくない場合は.lnprobability。そして.argmaxで最大値を与える位置を選んでその値を使います。最大値は黒い点で表示されます。

長くサンプリングするほど最大の値に近づけます。サンプリングでできた最大値の進歩を見たら、どれくらいサンプリングしたら充分か見込むことができます。

saidaichi = np.empty(4000)
lnprobmax = -np.inf
for i,a in enumerate(sampler.lnprobability.max(0)):
    if(a>lnprobmax):
        lnprobmax = a
    saidaichi[i] = lnprobmax

plt.ylabel(u'最大値',fontname='AppleGothic',size=16)
plt.xlabel(u'鎖の長さ',fontname='AppleGothic',size=16)
plt.plot(np.exp(saidaichi))
plt.loglog()
plt.show()

高次元の例

今までの例は2次元ばかりですが、何次元でも同様に使えます。

適当に4次元の関数を使った例を挙げます。

def lnfn(x):
    if(np.any(x**2>1)):
        return -np.inf
    p = (x[0]*np.sin(x[0]*3))**2*(1-x[1]**2)*(1-np.abs(x[2])-np.abs(x[3])/2-x[3]**2/2)
    return np.log(max(1e-10,p))

ndim = 4
nwalker = 20
nstep = 6000
xy0 = np.random.uniform(-0.5,0.5,[nwalker,ndim])
sampler = emcee.EnsembleSampler(nwalker,ndim,lnfn)
sampler.run_mcmc(xy0,nstep)
xy = sampler.flatchain
x,y = xy[:,0],xy[:,1]

plt.figure(figsize=[15,15])
for i in range(ndim):
    for j in range(i+1):
        plt.subplot(ndim,ndim,1+i*ndim+j)
        if(i==j):
            plt.hist(sampler.flatchain[:,i],50,color='#BB3300')
        else:
            plt.hist2d(sampler.flatchain[:,j],sampler.flatchain[:,i],bins=50,cmap='coolwarm')
plt.show()

qi10.png

終わりに

以上emceeの使い方です。随分便利だと思います。最近MCMCを使う必要があって勉強したので、ここでメモしました。他の人にも役に立てたらと思います。

--2018年7月2日に更新--
ガウス過程でemceeを使う例はこの記事で
https://qiita.com/phyblas/items/d756803ec932ab621c56

27
29
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
27
29

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?