Help us understand the problem. What is going on with this article?

ガンマ分布のパラメータ指定方法まとめ

ガンマ分布

ガンマ分布は正の連続値をとる値のモデリングに使われる。
パラメータは2つあるが、ライブラリによって微妙に定義が違ったり、平均値、分散がどの程度になるか分かりづらく、いつも調べている気がするので、まとめておく。

確率密度関数

\begin{align}
f(x) &= \frac{1}{\Gamma(k) \theta^k} x^{k-1} e^{- \frac{x}{\theta}} \\
&= \frac{\lambda^k}{\Gamma(k)} x^{k-1} e^{- \lambda x}
\end{align}

$k$: shape parameter
$\theta$: scale parameter
いずれも正の値をもつパラメータである。
ただし、ライブラリによっては、$\lambda = \frac{1}{\theta}$を使って表す場合もある。

いくつかのパラメータでプロットするとこんな感じ。
gamma_dists.png

統計量

平均 $\mu= k \theta = \frac{k}{\lambda}$
分散 $v = k \theta^2 = \frac{k}{\lambda^2}$

逆に、平均、分散からパラメータを決めたい場合は、こちらを用いる。
$\theta = \frac{v}{\mu} \ (\lambda = \frac{\mu}{v})$
$k = \frac{\mu^2}{v}$

各ライブラリでのパラメータ指定方法

まとめるとこんな感じ。この表が書きたかった!

ライブラリ shape parameter scale parameter
numpy.random.gamma $k$ $\theta$
scipy.stats.gamma $a$ $\theta$
PyMC3(pm.Gamma) $\alpha$ $1 / \beta$
TensorFlow Probability (tfp.Gamma)) concentration 1/rate
Stan (gamma) $\alpha$ $1 / \beta$
R (rgamma) shape scale, 1/rate

numpy, scipyはscale parameter $\theta$ を採用しているが、PyMC3, Stan TFPなどいわゆる確率的プログラミング言語では $\theta$ の逆数での指定となっている。

$\theta$の逆数$\lambda$はrate parameterと呼ばれ、パラメータを$\alpha, \beta$と呼んでいるライブラリでは、$\lambda$による定義を採用しているようだ。

なお、PyMC3では平均(mu), 標準偏差(sigma)でガンマ分布を指定することも可能。
また、Rではshape, rateのいずれでも指定できるようだ。

実装の確認

上記のパラメータ一覧が正しいか確認するため、各ライブラリで $Gamma(2, 2)$ から1万個の乱数を取得、確率密度関数を推定して比較してみた。

import numpy as np
import scipy as sp
import pymc3 as pm
import tensorflow_probability as tfp
import pystan
import matplotlib.pyplot as plt
import seaborn as sns

shape = 2
scale = 2
rate = 1 / scale
n_sample = 10000

xx = np.linspace(0, 20)
# ground truth
gamma_pdf = sp.stats.gamma(a=shape, scale=scale).pdf

s_np = np.random.gamma(shape=shape, scale=scale, size=n_sample)
s_sp = sp.stats.gamma(a=shape, scale=scale).rvs(size=n_sample)
s_tfp = tfp.distributions.Gamma(concentration=shape, rate=rate).sample(n_sample).numpy()
s_pm = pm.Gamma.dist(alpha=shape, beta=rate).random(size=n_sample)

fig, ax = plt.subplots()
ax.plot(xx, gamma_pdf(xx), label='truth', lw=2, c='k')
sns.kdeplot(s_np, ax=ax, label='numpy', alpha=0.8)
sns.kdeplot(s_sp, ax=ax, label='scipy', alpha=0.8)
sns.kdeplot(s_tfp, ax=ax, label='TFP', alpha=0.8)
sns.kdeplot(s_pm, ax=ax, label='PyMC3', alpha=0.8)

結果は下図の通りで、どのライブラリでも正しく実装できていることが確認できた。

gamma_pdf.png

Stanだけは確率分布から直接乱数を得る方法が分からなかったので、代わりに上記で得た乱数からガンマ分布のパラメータを推定してみた。

stan_code = '''

data {
  int N;
  vector<lower=0>[N] Y;
}

parameters {
  real<lower=0> shape;
  real<lower=0> rate;
}

model {
  Y ~ gamma(shape, rate);
}

'''

data = dict(N=n_sample, Y=s_np)
stan_model = pystan.StanModel(model_code=stan_code)
fit = stan_model.sampling(data=data)
print(fit)

shape, rateパラメータの推定値はそれぞれ、1.98, 0.49となり、こちらも期待通りの結果となっている。

Inference for Stan model: anon_model_6a5d60bed963727c801dc434b96a49a1.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

        mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
shape   1.98  8.3e-4   0.03   1.93   1.97   1.98    2.0   2.04   1016    1.0
rate    0.49  2.3e-4 7.4e-3   0.47   0.48   0.49   0.49    0.5   1020    1.0
lp__  -2.3e4    0.03   1.01 -2.3e4 -2.3e4 -2.3e4 -2.3e4 -2.3e4   1192    1.0

平均、標準偏差による分布の指定

平均、標準偏差から、該当するガンマ分布のshape, scaleを計算する関数を用意しておくと便利。

def calc_gamma_param(mu, sigma):
    return (mu / sigma)**2, sigma**2 / mu

mu, sigma = 4, 2
shape, scale = calc_gamma_param(mu, sigma)

def plot_gamma(xx, shape, scale):
    plt.plot(xx, sp.stats.gamma(a=shape, scale=scale).pdf(xx), label=f'shape={shape}, scale={scale}')

xx = np.linspace(0, 10)
plot_gamma(xx, shape, scale)
plt.legend()

平均4、標準偏差2のガンマ分布。右に裾が長い分布なので、最頻値(最も確率が高い値)は平均値より小さくなることに注意。

gamma_mu_sigma.png

おまけ:他の確率分布との関係

$k = 1$の時、ガンマ分布はパラメータ$\theta$の指数分布と一致する。
$k = \frac{n}{2}(n=1,2,\dots),\ \theta=2$ の時、ガンマ分布は自由度$n$のカイ二乗分布と一致する。

gamma_exp_chi2.png

コードはこちら。

from scipy import stats

xx = np.linspace(0, 10)
fig, ax = plt.subplots(1, 2, figsize=(10, 3))

shape, scale = 1, 3
ax[0].plot(xx, stats.gamma(a=shape, scale=scale).pdf(xx), label=f'Gamma({shape}, {scale})')
ax[0].plot(xx, stats.expon(scale=scale).pdf(xx), label=f'Exp({scale})')
ax[0].legend()

shape, scale = 3/2, 2 
ax[1].plot(xx, stats.gamma(a=shape, scale=scale).pdf(xx), label=f'Gamma({shape}, {scale})')
ax[1].plot(xx, stats.chi2(df=2*shape).pdf(xx), label=f'Chi2({int(2*shape)})')
ax[1].legend()

plt.savefig('gamma_exp_chi2.png')
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした