LoginSignup
22
16

More than 3 years have passed since last update.

ガンマ分布のパラメータ指定方法まとめ

Last updated at Posted at 2019-12-12

ガンマ分布

ガンマ分布は正の連続値をとる値のモデリングに使われる。
パラメータは2つあるが、ライブラリによって微妙に定義が違ったり、平均値、分散がどの程度になるか分かりづらく、いつも調べている気がするので、まとめておく。

確率密度関数

\begin{align}
f(x) &= \frac{1}{\Gamma(k) \theta^k} x^{k-1} e^{- \frac{x}{\theta}} \\
&= \frac{\lambda^k}{\Gamma(k)} x^{k-1} e^{- \lambda x}
\end{align}

$k$: shape parameter
$\theta$: scale parameter
いずれも正の値をもつパラメータである。
ただし、ライブラリによっては、$\lambda = \frac{1}{\theta}$を使って表す場合もある。

いくつかのパラメータでプロットするとこんな感じ。
gamma_dists.png

統計量

平均 $\mu= k \theta = \frac{k}{\lambda}$
分散 $v = k \theta^2 = \frac{k}{\lambda^2}$

逆に、平均、分散からパラメータを決めたい場合は、こちらを用いる。
$\theta = \frac{v}{\mu} \ (\lambda = \frac{\mu}{v})$
$k = \frac{\mu^2}{v}$

各ライブラリでのパラメータ指定方法

まとめるとこんな感じ。この表が書きたかった!

ライブラリ shape parameter scale parameter
numpy.random.gamma $k$ $\theta$
scipy.stats.gamma $a$ $\theta$
PyMC3(pm.Gamma) $\alpha$ $1 / \beta$
TensorFlow Probability (tfp.Gamma)) concentration 1/rate
Stan (gamma) $\alpha$ $1 / \beta$
R (rgamma) shape scale, 1/rate

numpy, scipyはscale parameter $\theta$ を採用しているが、PyMC3, Stan TFPなどいわゆる確率的プログラミング言語では $\theta$ の逆数での指定となっている。

$\theta$の逆数$\lambda$はrate parameterと呼ばれ、パラメータを$\alpha, \beta$と呼んでいるライブラリでは、$\lambda$による定義を採用しているようだ。

なお、PyMC3では平均(mu), 標準偏差(sigma)でガンマ分布を指定することも可能。
また、Rではshape, rateのいずれでも指定できるようだ。

実装の確認

上記のパラメータ一覧が正しいか確認するため、各ライブラリで $Gamma(2, 2)$ から1万個の乱数を取得、確率密度関数を推定して比較してみた。

import numpy as np
import scipy as sp
import pymc3 as pm
import tensorflow_probability as tfp
import pystan
import matplotlib.pyplot as plt
import seaborn as sns

shape = 2
scale = 2
rate = 1 / scale
n_sample = 10000

xx = np.linspace(0, 20)
# ground truth
gamma_pdf = sp.stats.gamma(a=shape, scale=scale).pdf

s_np = np.random.gamma(shape=shape, scale=scale, size=n_sample)
s_sp = sp.stats.gamma(a=shape, scale=scale).rvs(size=n_sample)
s_tfp = tfp.distributions.Gamma(concentration=shape, rate=rate).sample(n_sample).numpy()
s_pm = pm.Gamma.dist(alpha=shape, beta=rate).random(size=n_sample)

fig, ax = plt.subplots()
ax.plot(xx, gamma_pdf(xx), label='truth', lw=2, c='k')
sns.kdeplot(s_np, ax=ax, label='numpy', alpha=0.8)
sns.kdeplot(s_sp, ax=ax, label='scipy', alpha=0.8)
sns.kdeplot(s_tfp, ax=ax, label='TFP', alpha=0.8)
sns.kdeplot(s_pm, ax=ax, label='PyMC3', alpha=0.8)

結果は下図の通りで、どのライブラリでも正しく実装できていることが確認できた。

gamma_pdf.png

Stanだけは確率分布から直接乱数を得る方法が分からなかったので、代わりに上記で得た乱数からガンマ分布のパラメータを推定してみた。

stan_code = '''

data {
  int N;
  vector<lower=0>[N] Y;
}

parameters {
  real<lower=0> shape;
  real<lower=0> rate;
}

model {
  Y ~ gamma(shape, rate);
}

'''

data = dict(N=n_sample, Y=s_np)
stan_model = pystan.StanModel(model_code=stan_code)
fit = stan_model.sampling(data=data)
print(fit)

shape, rateパラメータの推定値はそれぞれ、1.98, 0.49となり、こちらも期待通りの結果となっている。

Inference for Stan model: anon_model_6a5d60bed963727c801dc434b96a49a1.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

        mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
shape   1.98  8.3e-4   0.03   1.93   1.97   1.98    2.0   2.04   1016    1.0
rate    0.49  2.3e-4 7.4e-3   0.47   0.48   0.49   0.49    0.5   1020    1.0
lp__  -2.3e4    0.03   1.01 -2.3e4 -2.3e4 -2.3e4 -2.3e4 -2.3e4   1192    1.0

平均、標準偏差による分布の指定

平均、標準偏差から、該当するガンマ分布のshape, scaleを計算する関数を用意しておくと便利。

def calc_gamma_param(mu, sigma):
    return (mu / sigma)**2, sigma**2 / mu

mu, sigma = 4, 2
shape, scale = calc_gamma_param(mu, sigma)

def plot_gamma(xx, shape, scale):
    plt.plot(xx, sp.stats.gamma(a=shape, scale=scale).pdf(xx), label=f'shape={shape}, scale={scale}')

xx = np.linspace(0, 10)
plot_gamma(xx, shape, scale)
plt.legend()

平均4、標準偏差2のガンマ分布。右に裾が長い分布なので、最頻値(最も確率が高い値)は平均値より小さくなることに注意。

gamma_mu_sigma.png

おまけ:他の確率分布との関係

$k = 1$の時、ガンマ分布はパラメータ$\theta$の指数分布と一致する。
$k = \frac{n}{2}(n=1,2,\dots),\ \theta=2$ の時、ガンマ分布は自由度$n$のカイ二乗分布と一致する。

gamma_exp_chi2.png

コードはこちら。

from scipy import stats

xx = np.linspace(0, 10)
fig, ax = plt.subplots(1, 2, figsize=(10, 3))

shape, scale = 1, 3
ax[0].plot(xx, stats.gamma(a=shape, scale=scale).pdf(xx), label=f'Gamma({shape}, {scale})')
ax[0].plot(xx, stats.expon(scale=scale).pdf(xx), label=f'Exp({scale})')
ax[0].legend()

shape, scale = 3/2, 2 
ax[1].plot(xx, stats.gamma(a=shape, scale=scale).pdf(xx), label=f'Gamma({shape}, {scale})')
ax[1].plot(xx, stats.chi2(df=2*shape).pdf(xx), label=f'Chi2({int(2*shape)})')
ax[1].legend()

plt.savefig('gamma_exp_chi2.png')
22
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
16