確率分布を$x \in[a,b]$に制限してサンプリングを行う方法のメモ。
例えば tf.random.truncated_normalみたいなやつを使いたい時、他のフレームワークには実装が無かったりして困る。
元にする分布の確率密度関数を$f$その累積分布関数を $F$ とするとこれを $[a,b]$ で切断した分布の密度関数$g$は
$$
g(x) = \frac{f(x)}{F(b)-F(a)} \quad (a \leq x \leq b)
$$
となる。したがってこれの累積分布関数 $a \leq x \leq b$ で
$$
G(x) = \frac{F(x)-F(a)}{F(b)-F(a)}
$$
となる。よってこれの逆関数は $0 \leq x \leq 1$ で
$$
G^{-1}(x) = F^{-1}\left((F(b)-F(a))x+F(a)\right)
$$
となる。よって 逆関数法 を用いて $[0,1]$ での一様乱数を $U$ とすると
$$
F^{-1}\left((F(b)-F(a))U+F(a)\right)
$$
と変換すれば求める分布が得られる。
Truncated Normal Distribution
標準正規分布の場合で計算してみる。$\mathcal{N}(0,1)$の累積分布関数は
$$
F(x) = \frac{1}{2}\left[1 + \mathrm{erf}\left(\frac{x}{\sqrt{2}}\right)\right]
$$
と書くことが出来るので、この逆関数を求めると
$$
F^{-1}(x) = \sqrt{2}\mathrm{erf}^{-1}(2x-1)
$$
となる。PyTorchでの実装例は以下。
from math import sqrt
import torch
def truncated_normal(shape, a, b):
U = torch.distributions.uniform.Uniform(0, 1)
u = U.sample(shape)
Fa = 0.5 * (1 + torch.erf(a/sqrt(2)))
Fb = 0.5 * (1 + torch.erf(b/sqrt(2)))
return sqrt(2)*torch.erfinv(2 *((Fb - Fa) * u + Fa) - 1)
実験
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
N = 10000
a = -1.
b = 2.
samples = truncated_normal((N,), torch.Tensor([a]), torch.Tensor([b]))
plt.hist(samples, bins=100, density=True)
x = np.linspace(-3, 3)
plt.plot(x, norm.pdf(x)/(norm.cdf(b)-norm.cdf(a)))
plt.savefig('truncated_norm.png')
良さそう。