#1. はじめに
学習サンプル数が多い時にミニバッチ学習でパラメータをMAP推定する手法としてPreconditioned Stochastic Gradient Langevin Dynamics(pSGLD)があります。
Tensorflow Probability版pSGLDの使い方で参考になるサンプルコードがネット上に見つからなかったので、混合正規分布によるクラスタリングへの適用を通してpSGLDの使い方を調べてみました。
この記事の対象としている人
- ベイズ推定をなんとなく知っている。
- Tensorflow ProbabilityでHMCを使って隠れ変数の推定をしたことがある。
追記
- 2019/01/12 サンプルのクラスタ割り当てを事後分布を使う実装に変更しました(Colaboratory)。
- 2018/12/30 理解不足、間違っている箇所あったらご指摘ください。
#2. 最適化するパラメータ(隠れ変数)の定義
HMCといった通常のサンプリング手法を使う場合と比べて、pSGLDでは最適化するパラメータ(隠れ変数)の定義方法が違います。通常の方法としてHMCで解く場合を公開されているColaboratoryの混合正規分布によるクラスタリングのデモで見てみます。
initial_state = [
tf.fill([components],
value=np.array(1. / components, dtype),
name='mix_probs'),
tf.constant(np.array([[-2, -2],
[0, 0],
[2, 2]], dtype),
name='loc'),
tf.eye(dims, batch_shape=[components], dtype=dtype, name='chol_precision'),
]
HMCを使う場合は上記のようにtrainableでない変数として定義しています。
pSGLDを使う場合はSGDと同じ要領でパラメータを最適化するのため、trainableな変数として定義する必要があります。以下例です。
mix_probs = tf.nn.softmax(tf.get_variable(
'mix_probs',
initializer=
np.ones([components], dtype)
*(1. / components)))
loc = tf.get_variable(
'loc',
initializer=
np.array([[-1, -1],
[0.2, 0.5],
[1, 1]], dtype))
precision = tf.nn.softplus(tf.get_variable(
'precision',
initializer=
np.ones([components, dims], dtype=dtype)))
# pack them as one variable for convinience
training_vals = [mix_probs, loc, precision]
tf.nn.softmax()とtf.nn.softplus()を付けている理由はmix_probsの合計が1になること、precisionが正の実数であることを保証するためです。
#3. pSGLD optimizerの定義
実装上はoptimizerより先に事後分布を定義する必要があります。しかし、optimizerの引数の意味を理解しないと正しく書けない部分が事後分布の中にはあるので、先にoptimizerについて書きます。
コードを見てみます。
optimizer_kernel = tfp.optimizer.StochasticGradientLangevinDynamics(
learning_rate=learning_rate,
preconditioner_decay_rate=0.99,
burnin=500,
data_size=num_samples)
train_op = optimizer_kernel.minimize(-unnormalized_posterior_log_prob())
最適化には事後分布の負の値を誤差関数としてoptimizerに与えます。unnormalized_posterior_log_probはlog変換された事後分布(正確には同時分布?)です。
StochasticGradientLangevinDynamicsの引数を見てみます。data_sizeにnum_samples(データ数)を与えているのがポイントです。この引数の意味を理解するためにpSGLDのAPIを参照すると、
data_size: Scalar int-like Tensor. The effective number of points in the data set. Assumes that the loss is taken as the mean over a minibatch. Otherwise if the sum was taken, divide this number by the batch size. If a prior is included in the loss function, it should be normalized by data_size. Default value: 1.
"The effective number of points in the data set."の一文からdata_sizeにはデータ数を与えることが分かります。後に続く文章は「誤差関数をミニバッチ数で割ってない場合はdata_sizeをミニバッチ数で割ること。誤差関数にpriorも含まれている場合はdata_sizeで割ること。」と言っています。どうらや事後分布(誤差関数)を定義する際に気を付けることがありそうです。
上記の意味を理解するために、論文を確認します。パラメータの更新式を見ると、
\begin{aligned} \Delta \boldsymbol { \theta } _ { t } & \sim \frac { \epsilon _ { t } } { 2 } \left[ G \left( \boldsymbol { \theta } _ { t } \right) \left( \nabla _ { \boldsymbol { \theta } } \log p \left( \boldsymbol { \theta } _ { t } \right) \right. \right. \\ & + \frac { N } { n } \sum _ { i = 1 } ^ { n } \nabla _ { \boldsymbol { \theta } } \log p \left( \boldsymbol { d } _ { t _ { i } } | \boldsymbol { \theta } _ { t } \right) ) + \Gamma \left( \boldsymbol { \theta } _ { t } \right) ] + G ^ { \frac { 1 } { 2 } } \left( \boldsymbol { \theta } _ { t } \right) \mathcal { N } \left( 0 , \epsilon _ { t } \mathbf { I } \right) \end{aligned}
関係のある変数だけ説明すると、$N$はデータ数、$n$はミニバッチ数、$\boldsymbol { d } _ { t _ { i } }$はデータ点、$\boldsymbol { \theta } _ { t } $は最適化するパラメータをそれぞれ表しています。上記式から、尤度の微分値$\sum _ { i = 1 } ^ { n } \nabla _ { \boldsymbol { \theta } } \log p \left( \boldsymbol { d } _ { t _ { i } } | \boldsymbol { \theta } _ { t } \right) )$はデータ数とミニバッチ数の比を掛ける、事前分布の微分値 $\nabla _ { \boldsymbol { \theta } } \log p \left( \boldsymbol { \theta } _ { t } \right)$はそのまま使うことが分かりました。
次にpSGLDの実装中で上記式の$ {\frac { \epsilon _ { t } } { 2 } \left[ G \left( \boldsymbol { \theta } _ { t } \right) \left( \nabla _ { \boldsymbol { \theta } } \log p \left( \boldsymbol { \theta } _ { t } \right) \right. \right. } \ { + \frac { N } { n } \sum _ { i = 1 } ^ { n } \nabla _ { \boldsymbol { \theta } } \log p \left( \boldsymbol { d } _ { t _ { i } } | \boldsymbol { \theta } _ { t } \right) ) + \Gamma \left( \boldsymbol { \theta } _ { t } \right) ] } $を実装している箇所を見てみます。
mean = 0.5 * (preconditioner * grad *
tf.cast(self._data_size, grad.dtype)
- preconditioner_grads[0])
gradは尤度と事前分布の合計に対する微分値(≒事後分布の微分値)を表しています。先に見た論文の更新式と実装より事後分布を定義する時には、
- 尤度の微分値にデータ数は掛けられているため、後はバッチ数で割る必要がある。
- そのままだとdata_sizeは事前分布にも掛けられてしまっているため、事前分布を定義する時には予めdata_sizeで割っておく必要がある。
上記二つの処理が必要であることが分かりました。
#3. 事後分布の定義
すぐ上で述べたように、尤度はミニバッチ数で割る、事前分布はデータ数で割る必要があります。実装は以下になります。
# Define joint log probabilities of likelihood and priors,
# used as unnormalized posterior and "loss" to maximize this time.
def unnormalized_posterior_log_prob():
# Define mixture model as likelihood
rv_observations = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=mix_probs),
components_distribution=tfd.MultivariateNormalDiag(
loc=loc,
scale_diag=precision))
# Log probabilities of priors and likelihood.
# Notice that each prior is divided by num_samples and
# likelihood is divided by batch_size for pSGLD optimization.
log_prob_parts = [
rv_mix_probs.log_prob(mix_probs)[..., tf.newaxis]/num_samples,
rv_loc.log_prob(loc)/num_samples,
rv_precision.log_prob(precision)/num_samples,
rv_observations.log_prob(observations_tensor)/batch_size
]
sum_log_prob = tf.reduce_sum(tf.concat(log_prob_parts, axis=-1), axis=-1)
return sum_log_prob
上記実装は尤度と事前分布をミニバッチ数とデータ数でそれぞれ割っている以外はColaboratoryの混合正規分布によるクラスタリングのデモと基本同じです。
#4. パラメータ推定結果の取得
HMCでは事後分布からのパラメータのサンプリング結果が指定したサンプリング数分、一度に返ってきます。対して、pSGLDでは一回のイテレーションが一回のサンプリングに相当するようです(pSGLDのAPI中のサンプルコードを参照)。下記の実装ではパラメータ推定結果を取得する方法として、イテレーション毎のパラメータを配列に保存して最後に平均を取る代わりに、逐次平均を計算しています。
mean_mix_probs_ = mean_loc_ = mean_precision_ = 0
for it in range(training_steps):
[
mix_probs_,
loc_,
precision_,
_
] = sess.run([
*training_vals,
train_op
], feed_dict={observations_tensor: sess.run(next_batch)})
iterative_mean_ratio=1./(it+1)
mean_mix_probs_ = iterative_mean_ratio*mix_probs_+(1-iterative_mean_ratio)*mean_mix_probs_
mean_loc_ = iterative_mean_ratio*loc_+(1-iterative_mean_ratio)*mean_loc_
mean_precision_ = iterative_mean_ratio*precision_+(1-iterative_mean_ratio)*mean_precision_
5. pSGLDと混合正規分布でクラスタリング
下記グラフのような共分散行列が対角行列で値が1、平均がそれぞれ[-4, 4], [0, 0], [4, 4]の2次元正規分布から5000点をサンプリングし、pSGLDと混合正規分布でクラスタリングします。
下記グラフはクラスタリング結果です。正しくクラスタリングできているように見えます。
真の平均(True loc)、推定された平均(Estimated loc)、推定された共分散行列(Estimated precision)を見てみます。
True loc:
[[-4. -4.]
[ 0. 0.]
[ 4. 4.]]
Estimated loc:
[[-3.92828434e+00 -3.95637671e+00]
[-1.22592043e-02 2.30253367e-03]
[ 3.94211339e+00 3.93120071e+00]]
Estimated mix probability:
[0.34901182 0.32463811 0.32635007]
Estimated precision:
[[1.03245309 1.08452091]
[1.03612861 1.05535136]
[1.09015031 1.0572301 ]]
正しく平均、共分散行列が推定できいそうです。
上記のグラフ出力を除いた実装全体は以下になります。
# -*- coding: utf-8 -*-
from progressbar import ProgressBar, Percentage, Bar
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
sess = tf.InteractiveSession(config=config)
dtype = np.float64
# Dimension of 2-D coordinate, which is two
dims = 2
# Number of latent state
components = 3
# Number of training samples
num_samples = 5000
# Mini-batch size used for a single update in SGD
batch_size = 100
# Number of training epochs
epoch = 50
# Automatically set number of training steps
training_steps = epoch * num_samples // batch_size
# True loc values which we will be going to estimte later on
true_loc = np.array([[-4, -4],
[0, 0],
[4, 4]], dtype)
random = np.random.RandomState(seed=42)
# Generate training samples from true loc
true_hidden_component = random.randint(0, components, num_samples)
observations = (true_loc[true_hidden_component] +
random.randn(num_samples, dims).astype(dtype))
# Variables to train. We want variable "loc" to be something close to "true_loc" defined above.
mix_probs = tf.nn.softmax(tf.get_variable(
'mix_probs',
initializer=
np.ones([components], dtype)
*(1. / components)))
loc = tf.get_variable(
'loc',
initializer=
np.array([[-1, -1],
[0.2, 0.5],
[1, 1]], dtype))
precision = tf.nn.softplus(tf.get_variable(
'precision',
initializer=
np.ones([components, dims], dtype=dtype)))
training_vals = [mix_probs, loc, precision]
# prior probabilities
rv_mix_probs = tfd.Dirichlet(
concentration=np.ones(components, dtype) / components,
name='rv_mix_probs')
rv_loc = tfd.Independent(
tfd.Normal(
loc=np.stack([
-np.ones(dims, dtype),
np.zeros(dims, dtype),
np.ones(dims, dtype),
]),
scale=tf.ones([components, dims], dtype)),
reinterpreted_batch_ndims=1,
name='rv_loc')
rv_precision = tfd.Independent(
tfp.distributions.InverseGamma(
concentration=np.ones([components, dims]),
rate=np.ones([components, dims])
),
reinterpreted_batch_ndims=1,
name='rv_precision'
)
# Placeholder for mini-batch
observations_tensor = tf.placeholder(tf.float64, shape=[batch_size, dims])
# Define joint log probabilities of likelihood and priors,
# used as unnormalized posterior and "loss" to maximize this time.
def unnormalized_posterior_log_prob():
# Define mixture model as likelihood
rv_observations = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=mix_probs),
components_distribution=tfd.MultivariateNormalDiag(
loc=loc,
scale_diag=precision))
# Log probabilities of priors.
# Notice that each prior should be divided by num_samples and
# likelihood is divided by batch_size for pSGLD optimization.
log_prob_parts = [
rv_mix_probs.log_prob(mix_probs)[..., tf.newaxis]/num_samples,
rv_loc.log_prob(loc)/num_samples,
rv_precision.log_prob(precision)/num_samples,
rv_observations.log_prob(observations_tensor)/batch_size
]
sum_log_prob = tf.reduce_sum(tf.concat(log_prob_parts, axis=-1), axis=-1)
return sum_log_prob
# make mini-batch generator
dx = tf.data.Dataset.from_tensor_slices(observations).shuffle(500).repeat().batch(batch_size)
iterator = dx.make_one_shot_iterator()
next_batch = iterator.get_next()
# Define learning rate scheduling
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 1e-4
end_learning_rate = 1e-8
decay_steps = 1e4
learning_rate = tf.train.polynomial_decay(starter_learning_rate,
global_step, decay_steps,
end_learning_rate, power=1.)
# Set up the optimizer. Don't forget to set data_size=num_samples, or the result will be much worse.
optimizer_kernel = tfp.optimizer.StochasticGradientLangevinDynamics(
learning_rate=learning_rate,
preconditioner_decay_rate=0.99,
burnin=500,
data_size=num_samples)
train_op = optimizer_kernel.minimize(-unnormalized_posterior_log_prob())
init = tf.global_variables_initializer()
sess.run(init)
# To know how soon optimization ends..
p = ProgressBar(widgets=[Percentage(), Bar()], max_value=training_steps).start()
# Prepare variables for iterative mean calculation.
# For standard sampling methods, we store each sampling result in an array and
# take mean over it at last to get an estimation value.
# However, we don't want to make an array to store samples this time to save memory,
# because the number of sample, which is number of iteration, is quite large.
# Instead, we iteratively calculate mean for each iteration to reduce memory usage.
mean_mix_probs_ = mean_loc_ = mean_precision_ = 0
for it in range(training_steps):
p.update(it+1)
[
mix_probs_,
loc_,
precision_,
_
] = sess.run([
*training_vals,
train_op
], feed_dict={observations_tensor: sess.run(next_batch)})
iterative_mean_ratio=1./(it+1)
mean_mix_probs_ = iterative_mean_ratio*mix_probs_+(1-iterative_mean_ratio)*mean_mix_probs_
mean_loc_ = iterative_mean_ratio*loc_+(1-iterative_mean_ratio)*mean_loc_
mean_precision_ = iterative_mean_ratio*precision_+(1-iterative_mean_ratio)*mean_precision_
# Compare estimated loc and true loc
print('\nTrue loc:\n', true_loc)
print('\nEstimated loc:\n', mean_loc_)
# Show other estimated latent variables
print('\nEstimated mix probability:\n', mean_mix_probs_)
print('\nEstimated precision:\n', mean_precision_)
6. まとめ
pSGLDを使ったベイズ推定の実装方法を混合正規分布によるクラスタリングを例に見て行きました。Tensorflow ProbabilityはAPI以外のドキュメントやサンプルコードが少なく、敷居が高い&勿体無い印象です。EdwardではpSGLDのサンプルコードがすぐ見つかりました。