LoginSignup
0
1

【Python】サンプリングアルゴリズム - 棄却サンプリング

Last updated at Posted at 2023-01-19

サンプリングアルゴリズムについていくつか紹介していく. 第二回目は 棄却サンプリング について.

確率密度関数$p(x)$に基づく乱数を生成したいが, python 上で実装されていない状況を考える. さらに次のような前提条件を設定する.
前提条件:

  • 確率密度関数は $p(x) = \frac{1}{Z_{p}}f(x)$ であり, $f(x)$すぐに計算できるが規格化定数$Z$がわかっていない.
  • 確率密度関数$q(x)$からのサンプリングが可能である.

棄却サンプリングはこのような状況において $p(x)$からのサンプリングを行い規格化定数を求める方法である.
用語整理:

  • $f(x)$: 目的関数
  • $p(x)$: 目的分布
  • $q(x)$: 提案分布
  • $Z$: 規格化定数, 分配関数

棄却サンプリングにおけるアルゴリズムは次のとおり:
手順:

  1. 任意の $x$ に対して $f(x)\leq Mq(x)$ を満たす$M$を求める
  2. 提案分布 $x \sim q(z)$ に従う乱数を生成
  3. $[0, M q(x)]$ の区間の一様分布から $u$ を生成する
  4. $u \leq f(x) $ を満たす $x$ だけを受理する
  5. 2~4を多数回行う

受理した $x$がちょうど$p(x) = \frac{1}{Z_{p}}f(x)$からのサンプリングとなる.

例1: とりあえずPython による実装

ベータ分布を例に棄却サンプリングを行ってみる. 今回は

  • 提案分布:$[0,1]$区間の一様分布$U(0,1)$
  • 目的分布:$ベータ分布$

必要なライブラリのインポート

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import beta, uniform, norm
from scipy import optimize

パラメータの設定と目的関数設定:

# ベータ分布のパラメータ
a, b = 1.5, 2
#提案分布からのサンプリング数
N=50000
# シードの固定
np.random.seed(0)
# ベータ分布の取りうる範囲
x = np.linspace(beta.ppf(0.001,a,b),beta.ppf(0.999,a,b),1000)
# 目的関数の設定
# ベータ分布の確率密度関数の定義
p = beta(a,b).pdf

(手順:1)任意の$x$に対して$f(x)\leq Mq(x)$を満たす$M$を求める
今回の場合は提案分布が一様分布であるため, ベータ分布の頂点での値を求めれば良い.

#  ベータ分布の確率密度関数が最大となる xの値を求める.
res = optimize.fmin(lambda x :-p(x),0.3)
# ベータ分布の最大値を格納
y_max = p(res)
print(res,y_max,)

提案分布からのサンプリング(手順:2)
ベータ分布の取りうる値の範囲が$[0,1]$区間であることも考慮して, 今回は$[0,1]$区間の一様分布$q(x) = U(0,1)$を利用してみる

# 提案分布からのサンプリング (x軸方向のサンプリング)
# 今回は(0~1の区間の一様分布)
x_mcs = uniform.rvs(size = N)

$[0, M q(x)]$区間の一様分布からのサンプリング(手順:3)

# y軸方向のサンプリング(0~y_maxの区間でサンプリング)
r = uniform.rvs(size = N) * y_max

受容されるものだけを取り出す(手順:4)

accept = x_mcs[r<= p(x_mcs)]

結果を可視化

# アクセプトされたものをヒストグラム化
plt.hist(accept,density=True,bins=30,rwidth=0.8,label='rejection sampling')
# 正解の確率密度関数をプロット
plt.plot(x,beta.pdf(x,a,b),label='Target')
plt.legend()
plt.grid()
plt.show()

ダウンロード (1).png
ベータ分布に従ったヒストグラムができている事がわかる. なぜ以上のアルゴリズムでうまくいくのかを考える.

棄却サンプリング概説

まず最も良い提案分布というものを考えてみると, 理想的には目的分布そのものである. しかしライブラリに用意されていないためそれができない. 直感的には目的分布に似ている提案分布を代替として利用することが思いつく.問題は2つの分布の差分をどのようにして埋めるかである.

下のグラフは提案分布が一様分布である状況である.提案分布と目的分布の違いとなる部分はグレーで塗りつぶしている. この違いの部分を手順4の条件で埋めている. その差は手順2で生成された $x$ における比率で確認できる.
サンプリング.png
グラフを見れば一目瞭然だが, $x$という値が目的分布からサンプリングされる件数と提案分布からサンプリングされる件数比が$ f(x):M q(x) $ で有る. よって$f(x)/M q(x)$ の確率で $x$ を受理すれば提案分布と目的分布の差分を埋める事ができる. このサンプリング件数に比率の違いを埋めるために手順4では, $[0, M q(x)]$の一様分布から生成された $u$が $f(x)$よりも小さい場合のみ受理するということを行う.

実際に提案分布が目的分布に一致している $(q(x) = f(x) = p(x), M = 1)$である場合,手順4において全てのサンプルが受理される. 提案分布と目的分布に違いが存在していないためである.

次に手順2について考える. これは提案分布が一様分布だとわかりにくいが, 例えば提案分布を正規分布に変更した場合と比較するとわかりやすい. 先程の例のように提案分布が一様分布だと $x$が $[0.9,1.0]$ 近傍の値を取るサンプルは全体の1割低度存在しているが, グレーの部分が大きいため手順4によりそのほとんどが棄却され, 無駄なサンプルになってしまう.

しかし, 正規分布の場合はその件数がぐっと下がる. これは2つの分布の違いが小さくなっているためである. 正規分布における $[0.9,1.0]$ 近傍のサンプリング数は最初から少ないため無駄が減る. 逆にベータ分布の峰の部分の近くは正規分布から得られるサンプル数も多くなっているためこの部分でも無駄が減る. 2つの分布が親しいほどサンプリングに無駄が減る.
サンプリング (2).png
ただし, 正規分布の取りうる値の範囲は$[-\infty,\infty]$であり $[0,1]$区間以外のサンプルは生成される件数は非常に少ないが, 全て棄却されてしまうため無駄になってしまうことには注意してほしい. 正規分布によるサンプリングは例3で行う.

以上からうまいサンプリングを行うためにはできる限り目的分布に似ている提案分布からのサンプリングが良いと考えられる.

例2: 規格化されていないベータ分布

例1では規格化定数が $Z_{p} = 1 $ であるとわかっている状況を考えた. ここでは規格化定数も含めて求まることを確認したい.
ベータ分布の確率密度関数は

p(x) = \frac{x^{\alpha - 1}(1-x)^{\beta -1}}{B(\alpha,\beta)}

である. ここで$B(\alpha,\beta)$はベータ関数であり, $\int_{0}^{1}dx f(x) = 1$ を満たすための規格化定数である.

  • 提案分布:$[0,1]$区間の一様分布$U(0,1)$
  • 目的関数:$f(x) = x^{\alpha - 1}(1-x)^{\beta -1}$

とする.

まず目的関数の設定

from functools import partial
# 関数の部分適用を行い規格化されていないベータ分布を考えてみる.
def test_beta(x,alpha,beta):
    return (x ** (alpha - 1)) * ((1-x) ** (beta - 1))
# 規格化されていないベータ関数を保存
f = partial(test_beta, alpha=a, beta=b)

(手順:1)$f(x) \leq Mq(x)$を満たす$M$を求める

#  ベータ分布の確率密度関数が最大となる xの値を求める.
res = optimize.fmin(lambda x :-f(x),0.3)
# ベータ分布の最大値を格納
y_max = f(res)
print(res,y_max,)

(手順:2)提案分布からのサンプリング(手順:3)$[0, M q(x)]$区間の一様分布からのサンプリング

#x軸方向のサンプリング
x_mcs = uniform.rvs(size = N)
#y軸方向のサンプリング(0~y_maxの区間でサンプリング)
r = uniform.rvs(size = N) * y_max

(手順:4)受容されるものだけを取り出す

# 条件を満たすものだけを取り出す
accept = x_mcs[r<= f(x_mcs)]

# アクセプトされたものをヒストグラム化
plt.hist(accept,density=True,bins=30,rwidth=0.8,label='rejection sampling')
# 正解の確率密度関数をプロット
plt.plot(x,beta.pdf(x,a,b),label='Target')
plt.legend()
plt.grid()
plt.show()

ダウンロード (2).png
ヒストグラムはベータ分布に従っている事がわかる.

更に規格化定数は

p(\text{受理}) = \int\! dx \, \, \frac{f(x)}{k q(x)} q(x)
 = \frac{1}{k} \int \! dx \, \,  f(x) = \frac{Z}{k}

であるため規格化定数$Z = k p(\text{受理})$で与えられる. 実際に確認すると一致していることがわかる.

import scipy
# 受理率
print("受理率:",len(accept)/N ) 
# 受理率: 0.69066
# ベータ関数の変数交換
print("規格化定数の正解値:",scipy.special.beta(a,b), "  サンプリングの結果:",(len(accept)/N ) * y_max[0])
# 規格化定数の正解値: 0.2666666666666666   サンプリングの結果: 0.2658351579076435

例3: 提案分布を正規分布とする

次に先程の規格化されていないベータ分布を正規分布でサンプリングしてみよう.
正規分布の平均値はベータ分布の平均と分散が$\mu = \frac{a}{a+b}, \sigma = 0.38$ を用いる

np.random.seed(0)
# 規格化されていないベータ関数を保存
p = partial(test_beta, alpha=a, beta=b)
mu = a / (a + b)
sigma = 0.38

(手順:1)$f(x)\leq Mq(x)$を満たす$M$を求める

res = optimize.fmin(lambda x : - p(x) / norm.pdf(x, mu, sigma) , 0)
# ベータ分布の最大値を格納
y_max = f(res) / norm.pdf(res, mu, sigma)
print(res,y_max,)

# 拡大された提案分布が目的分布をカバーしているかを確認
X = np.arange(-0.5,1.5, 0.01)
x = np.linspace(beta.ppf(0.00,a,b),beta.ppf(1,a,b),1000)
plt.plot(x, p(x), label = 'Target')
plt.plot(X, norm.pdf(X, mu, sigma) * y_max , label = 'Proposal')
# 指定範囲の区間を塗りつぶす
r1 = np.arange(-0.5,0.01, 0.01)
r2 = np.arange(0,1.01, 0.01)
r3 = np.arange(1.0,1.5, 0.01)
plt.fill_between(r1, 0, norm.pdf(r1, mu, sigma) * y_max, facecolor='gray', alpha=0.5)
plt.fill_between(r2,  p(r2), norm.pdf(r2, mu, sigma) * y_max, facecolor='gray', alpha=0.5)
plt.fill_between(r3, 0, norm.pdf(r3, mu, sigma) * y_max, facecolor='gray', alpha=0.5)
plt.legend()
plt.grid()
plt.show()

ダウンロード.png
しっかりとカバーできている.

提案分布からのサンプリング(手順:2)$[0, M q(x)]$区間の一様分布からのサンプリング(手順:3)

#x軸方向のサンプリング
x_mcs = norm.rvs(loc=mu, scale=sigma, size=N)
#y軸方向のサンプリング(0~y_maxの区間でサンプリング)
r = [uniform.rvs() * num for num in norm.pdf(x_mcs, loc=mu, scale=sigma) * y_max]

受容されるものだけを取り出す(手順:4)

# 条件を満たすものだけを取り出す
accept = x_mcs[r<= p(x_mcs)]

# アクセプトされたものをヒストグラム化
plt.hist(accept,density=True,bins=30,rwidth=0.8,label='rejection sampling')
# 正解の確率密度関数をプロット
plt.plot(x,beta.pdf(x,a,b),label='Target')
plt.legend()
plt.grid()
plt.show()

ダウンロード (1).png
ヒストグラムはベータ分布に従っている. しかし受理率は先程よりも下がっている.

import scipy
# 受理率
print("受理率:",len(accept)/N ) 
# 受理率: 0.6167
# ベータ関数の変数交換
print("規格化定数の正解値:",scipy.special.beta(a,b), "  サンプリングの結果:",(len(accept)/N ) * y_max[0])
# 規格化定数の正解値: 0.2666666666666666   サンプリングの結果: 0.2661521086929165

これは先程の注意が理由である. つまり, ベータ分布の取りうる値は$[0,1]$の値しか取らないが, 提案分布として選んだ正規分布は理論上$[-\infty,\infty]$の値を取りうる. よって正規分布から得られる$[0,1]$以外の値は全てサンプリングにおいて無駄になる.

例4 受理率が低い提案分布からのサンプリング

最後に提案分布が目的関数とかけ離れている場合どのような結果が得られるのかを確認する.
今回の提案分布は峰が2つある明らかにベータ分布とは異なる分布を使ってみる.

  • 提案分布: $x \sim \frac{1}{\sqrt{2\pi}} \left(e^{-2(x-1)^2} + e^{-2(x+1)^2}\right)$
  • 目的関数: $f(x) = e^{-2 x^{2}}$

提案分布と目的関数を設定し$f(x)\leq Mq(x)$を満たす$M$を求め, 拡大された提案分布が目的関数を覆っているかを確認

np.random.seed(0)
mu = 1
sigma = 0.5
# 目的分布
def p(x):
    return np.exp(- 2 * x**2)
# 提案分布
def prop(x):
    return norm.pdf(x,mu, sigma) /2 +norm.pdf(x, -mu, sigma) /2

def g(x):
    return  p(x) / prop(x)

# 手順1
res = optimize.fmin(lambda x : - g(x) , 0)
y_max = g(res)
print(res,y_max,)

# 可視化
X = np.arange(-2,2, 0.01)
plt.plot(X, p(X), label = 'Target')
plt.plot(X, prop(X) * y_max, label = 'Proposal')

plt.fill_between(X, p(X), prop(X) * y_max, facecolor='gray', alpha=0.5)
plt.legend()
plt.grid()
plt.show()

プロットを行うときちんとカバーできていることがわかる.しかし提案分布が目的関数と異なりすぎているためグレーの部分が非常に大きくなっている.
ダウンロード (2).png

手順2~4を行う
2つの正規分布の混合分布なのでこれらを1:1の割合で生成し一つのリストにしておけば良い.

#x軸方向のサンプリング
x_mcs = np.concatenate([norm.rvs(loc = mu, scale = sigma, size = int(N /2)), 
                norm.rvs(loc = -mu, scale = sigma, size = int(N /2))])
#y軸方向のサンプリング(0~y_maxの区間でサンプリング)
r = [uniform.rvs() * num for num in prop(x_mcs) * y_max]
# 条件を満たすものだけを取り出す
accept = x_mcs[r<= p(x_mcs)]

# アクセプトされたものをヒストグラム化
plt.hist(accept,density=True,bins=30,rwidth=0.8,label='rejection sampling')
# 正解の確率密度関数をプロット

plt.plot(X , norm.pdf(X ,0, 1/2),label='Target')
plt.legend()
plt.grid()
plt.show()

# 受理率
print("受理率:",len(accept)/N ) 
# 受理率: 0.13268

サンプリングはうまく言っている用に見えるが受理率は13.2% とかなり低いことがわかる.
ダウンロード (3).png

参考

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1