LoginSignup
17
16

More than 5 years have passed since last update.

ギブスサンプリング実装とJITコンパイラによる高速化

Posted at

はじめに

マルコフ連鎖モンテカルロ法の勉強でギブスサンプリングを実装していたのですが、偶々同じ時期に、pythonコードをJITコンパイラのライブラリ"Numba"で高速化する記事を見つけたので組み合わせてみました。

ギブスサンプリングとは

マルコフ連鎖モンテカルロ法(以下 MCMC: Markov Chain Monte Carlo)はパラメータ事後分布から効率的にサンプルを生成する方法です。ギブスサンプリングはMCMCの1つであり、多変量の事後分布に対して、各々の確率変数の条件付き分布から交互にサンプリングします。例えば2変量であれば、$p(x_t|y_{t-1})$ → $p(y_t|x_t)$ → $p(x_{t+1}|y_t)$ → ... というように次々とサンプルを生成していくことになります。
ただしこれは条件付き確率分布から簡単にサンプリングしてこれる場合に限ります。もともと、求めたい同時パラメータ分布 $p(x, y)$ からのサンプリングが困難であるために、代わりに $p(x|y)$, $p(y|x)$ から交互にサンプリングしてこようという発想だからです。

MCMCは高次元でもサンプリングの効率を高く保ち、"次元の呪い"を克服することができます。

MCMCに関する条件や性質として

  • 詳細釣り合い条件
  • エルゴード性
  • 既約性

があるようなのですが、ここらへんはまだ追いきれていないので勉強しないとですね(T_T)

JIT(Just in Time)コンパイラとは

JITコンパイラとは、プログラムの実行時に、あらかじめ用意された(実行環境に依存しない汎用的な)中間コードを、プログラムの実行時点でプロセッサが実行可能な機械語(ネイティブコード)にコンパイルすることである。
-- IT用語辞典バイナリ --

Pythonはスクリプト言語であるため、1行ごとにコンパイルをしながらコードを実行します。そのため、forループのようなコードを書くとその度にコンパイルを実施しなければならず、事前コンパイルでループ処理を最適化するC++などに比べて大変大きな処理時間が掛かってしまいます。そのためPythonではNumpyなどのライブラリが用意されており、行列演算を駆使しながら高速化を図っています。しかしNumbaを使えば、forループを使ったコードや関数を事前にコンパイルし最適化してくれるので、CやC++のように高速に動作させることができます!

Numbaのインストールはこちら

実装

ここからpythonによる実装に入ります。
コードはここを参考に書かせてもらいました。

必要なもの

使用したpythonのバージョンやライブラリは以下です。

  • python 3.5
  • numpy 1.12.1
  • numba 0.31.0

コード

今回は確率分布 $p(x,y) = exp(-\frac{x^{2} -xy + y^{2}}{2})$ からギブスサンプリングすることを想定します。

import numpy as np
import time

def gibbs(N, iter_num, burn_in):
    sample = []
    x = 0.0
    y = 0.0
    for i in range(N):
        for j in range(iter_num):
            x = np.random.normal(1/2 * y, 1)
            y = np.random.normal(1/2 * x, 1)
        sample.append((x, y))
    return np.array(sample[int(N * burn_in):])


start = time.time()
plots = gibbs(3000, 10000, 0.3)
end = time.time()
print('elapsed_time: {0}[s]'.format(end - start))
elapsed_time: 193.93689107894897[s]

ここでは194秒かかっています。サンプルされた点をプロットすると下のようになります。

次にライブラリNumbaを通してJITコンパイラを適用してみようと思います。
このライブラリはとても便利で、numbaからjitモジュールをインポートした後、最適化コンパイラしたい関数に@jitとデコレータを付けるだけで済みます!

import numpy as np
import time
from numba import jit

@jit
def gibbs_with_jit(N, iter_num, burn_in):
    sample = []
    x = 0.0
    y = 0.0
    for i in range(N):
        for j in range(iter_num):
            x = np.random.normal(1/2 * y, 1)
            y = np.random.normal(1/2 * x, 1)
        sample.append((x1, x2))
    return np.array(sample[int(N * burn_in):])


start = time.time()
plots_jit = gibbs_with_jit(3000, 10000, 0.3)
end = time.time()
print('elapsed_time: {0}[s]'.format(end - start))
elapsed_time: 2.623775005340576[s]

かかった時間は3秒ほど。
コードをほとんど変えることなく、実行時間が大幅に減少したのがわかるかと思います。

まとめ

今までpythonやNode.jsなどのスクリプト言語しか触ったことがなかったのですが、今回JITコンパイラを勉強したこと通してスクリプト言語とコンパイル言語の仕組みの違いなどを改めて学ぶことができました。Numbaにはまだいろいろと機能があるみたいなので、もう少し触ってみたいと思います。

17
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
16