40
27

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

自作深層生成モデルライブラリの紹介

Last updated at Posted at 2017-01-18

今回は,深層生成モデルの自作ライブラリ「Tars」を紹介します.以前から僕が一人で勝手に実装して勝手に使っていたもので,これまで非公開としていましたが,今回この記事を書くにあたって公開することにしました.
https://github.com/masa-su/Tars
もし需要がありましたら,今後も何回かに分けて紹介していきたいと思います.

※ 本記事は,深層生成モデルについてある程度理解している人を前提に書いています.VAE等を全く知らない人は論文や他の記事を読んでください.

確率モデルとニューラルネットワーク

近年深層生成モデル(VAE[1],GAN[2],自己回帰モデル[3]など)が流行っていますが,こういったモデル(特にVAE)では,ネットワークを確率分布として,あるいはユニットを確率変数として扱っています.

こういった深層生成モデルの実装はたくさんあるのですが,問題点として,モデルを記述する統一的なフレームワークがないということがあります.GANのような確率モデルとして意識する必要のないモデルについては,あまり問題にならないのですが,VAEはどの確率モデルを使用するかを明確に定義する必要があります.つまりVAEに代表される深層生成モデルを実装するためには

1.どのような形のネットワークを使うか
2.どのような確率モデルを使うか

を設定する必要があります.さらに,近年は様々な深層生成モデルの手法(学習アルゴリズムも含む)が提案されています.例えば,多層化した確率分布[4]や,複雑な形のグラフィカルモデル等によって定義された確率モデルを学習するVAEが提案されています.さらにVAEをGAN[5]やPixelCNN[6]と組み合わせたりといった様々なモデルが提案されています.よって実装する際に決めるべきこととして,

3.どのような深層生成モデルの手法を使うか

が加わります.これまでは,ネットワークを明示的に確率分布として扱っていなかったため,1~3を書いた実装があっても汎用性がありませんでした.例えば,2に該当する部分を別のものに書き換えようと思ったら,最悪コード全部を書き換えるのと同じくらいの労力が必要になることもありました.そのため,過去の実装を使い回すのが難しいという状況がありました.これを解決するためには,1~3を完全に独立に扱う枠組みが必要です.

深層生成モデルフレームワークの提案

このような背景を踏まえて,本記事では深層生成モデルをより簡単に実装できるふフレームワーク「Tars」をつくりました.実装にはTheanoとLasagneを利用しています.

本フレームワークの大きな特徴は,1~3を完全に独立に記述できるということです.簡単にいうと,確率分布クラスというものが定義されており,それでニューラルネットワークを覆うことで,外部からは中のニューラルネットワークを意識する必要がなくなるということです.また,確率分布クラスにはガウス分布やベルヌーイ分布など様々な実装が用意されていますが,すべて同じ仕様で設計されているので,確率分布の形を意識せずに同じようにサンプリングや尤度計算ができます.具体的な実装例は下に書いてあります.

似たような(というか丸かぶり)思想のフレームワークとして,Edward(https://github.com/blei-lab/edward )があります.
こちらは天下のBlei研が作っただけに非常に出来がいいライブラリで,Tensorflowなどで実装したネットワークを変分推論など様々な手法で学習できます.
正直こちらでいいのではないかという話もあるのですが,本フレームワークではより深層生成モデルに特化したものになっています.具体的には本フレームワークには以下の特徴があります.

  • 様々な確率分布の実装(ガウス分布,ガンマ分布,ベータ/ディリクレ分布,ベルヌーイ/カテゴリ分布など)
    • 全てreparameterizaiton trick[1]によってサンプリングできる
    • ベルヌーイ/カテゴリ分布はGumbel-softmax[7][8],ガンマ分布,ベータ/ディリクレ分布はrejection sampling[9]を利用
  • 様々な深層生成モデル手法の実装
    • Autoencoder
    • VAE
      • Conditional VAE [10][11]
      • Importance weighted autoencoder [12]
      • Joint multimodal VAE [13]
    • VAE-GAN,conditional VAE-GAN [5]
    • GAN,Conditional GAN [14]
    • VAE-RNN
      • Variational RNN [15]
      • DRAW [16],Convolutional DRAW [17]
  • (上と関連して)様々な下界の実装
    • ELBO
    • Importance sampling lower bound [12]
    • Variational Renyi bound [18]

上記のうちいくつかの実装は現在非公開としていますが,今後随時アップデートしていく予定です.

実装例

本フレームワークでVAEを実装したい場合,例えば以下のように非常に簡単に書くことができます(インポート等は省略)

train.py
x = InputLayer((None,n_x))
q_0 = DenseLayer(x,num_units=512,nonlinearity=activation)
q_1 = DenseLayer(q_0,num_units=512,nonlinearity=activation)
q_mean = DenseLayer(q_1,num_units=n_z,nonlinearity=linear)
q_var = DenseLayer(q_1,num_units=n_z,nonlinearity=softplus)
q = Gaussian(q_mean,q_var,given=[x]) # q(z|x)

z = InputLayer((None,n_z))
p_0 = DenseLayer(z,num_units=512,nonlinearity=activation)
p_1 = DenseLayer(p_0,num_units=512,nonlinearity=activation)
p_mean = DenseLayer(p_1,num_units=n_x,nonlinearity=sigmoid)
p = Bernoulli(p_mean, given=[z]) # p(x|z)

model = VAE(q, p, n_batch=n_batch, optimizer=adam) # コンパイル

lower_bound_train = model.train([train_x]) # 訓練

$q$が近似分布(エンコーダ),$p$が生成分布(デコーダ)です.VAEではそれぞれの分布をニューラルネットワークで書きますが,この部分は自由に書き換えることができるので,様々な構造のネットワークで確率分布を設計できます.Tarsでは,このようにして設計したネットワークを確率分布クラスに渡します.上の実装では,GaussianやBernoulliといった確率分布クラスに渡していますが,それぞれガウス分布とベルヌーイ分布に設定したことになります.

この確率分布クラスを変更すれば,同じネットワーク構造でも別の確率分布を表現することができます.例えば,エンコーダをガンマ分布に変更したい場合は

gamma_q.py
x = InputLayer((None,n_x))
q_0 = DenseLayer(x,num_units=512,nonlinearity=activation)
q_1 = DenseLayer(q_0,num_units=512,nonlinearity=activation)
q_alpha = DenseLayer(q_1,num_units=n_z,nonlinearity= softplus)
q_beta = DenseLayer(q_1,num_units=n_z,nonlinearity=softplus)
q = Gamma(q_alpha,q_beta,given=[x]) #q(z|x)

とするだけでガンマ分布にすることができます.あとは,$q$と$p$をVAEクラスに渡すだけでVAEの訓練やテストができます.他にもGANクラスなどもあるので,GANを学習する場合は,それに確率分布やネットワークを渡せば学習できます.

このようにして定義した$q$と$p$は,VAEクラスの中では確率分布として扱われます.つまり,ニューラルネットワークで実装されていることを意識せずに,次のようにサンプリングや尤度を計算することができます(TheanoシンボルとNumpy形式の両方で計算できます).

sampling_likelihood_example.py

# サンプリング
samples = q.sample_given_x(x) # Theano
samples = q.np_sample_given_x(x) # bumpy

# 対数尤度計算
log_likelihood = q.log_likelihood_given_x([x, samples]) # Theano
log_likelihood = q.log_likelihood_given_x([x, samples]) # numpy

これらを使うと,VAEの実装も比較的簡単にできてしまいます(VAEクラスの下界の計算部分のみ抜粋,簡単のため一部書き換え).

vae.py
kl_divergence = analytical_kl(self.q, self.prior
                              given=[x, None],
                              deterministic=deterministic)        
[_, z] = self.q.sample_given_x(x, repeat=l,
                          deterministic=deterministic)
log_likelihood =\
	self.p.log_likelihood_given_x([z, x],
                                  deterministic=deterministic)	

loss = -T.mean(log_likelihood - kl_divergence)

q_params = self.q.get_params()
p_params = self.p.get_params()
params = q_params + p_params

この実装は,確率分布やネットワークに依存しないことに注意してください.そのため,確率分布やネットワークの形が変更されても同じ実装を使い回すことができます.

その他詳しい実装方法は,Tarsにあるexampleや今後更新する(かもしれない)記事を参照してください.

今後の方針

今回は,自作深層生成モデルライブラリ「Tars」を紹介しました.特に公開することを考えずに実装していたので,ライブラリとしては不完全なままです.今後の方針としては

  • ライブラリとしての整備
  • モデルの追加
  • Tensorflowへの移行

等を考えています.何か問題や間違いがありましたら,コメント等で指摘していただけると幸いです.

参考文献

[1] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
[2] Goodfellow, Ian, et al. "Generative adversarial nets." Advances in Neural Information Processing Systems. 2014.
[3] van den Oord, Aaron, Nal Kalchbrenner, and Koray Kavukcuoglu. "Pixel Recurrent Neural Networks." arXiv preprint arXiv:1601.06759 (2016)
[4] Sønderby, Casper Kaae, et al. "Ladder variational autoencoders." Advances in Neural Information Processing Systems. 2016.
[5] Larsen, Anders Boesen Lindbo, Søren Kaae Sønderby, and Ole Winther. "Autoencoding beyond pixels using a learned similarity metric." arXiv preprint arXiv:1512.09300 (2015).
[6] Gulrajani, Ishaan, et al. "PixelVAE: A Latent Variable Model for Natural Images." arXiv preprint arXiv:1611.05013 (2016).
[7] Jang, Eric, Shixiang Gu, and Ben Poole. "Categorical Reparameterization with Gumbel-Softmax." arXiv preprint arXiv:1611.01144 (2016).
[8] Maddison, Chris J., Andriy Mnih, and Yee Whye Teh. "The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables." arXiv preprint arXiv:1611.00712 (2016).
[9] Naesseth, Christian A., et al. "Rejection Sampling Variational Inference." arXiv preprint arXiv:1610.05683 (2016).
[10] Kingma, Diederik P., et al. “Semi-supervised learning with deep generative models.” Advances in Neural Information Processing Systems. 2014.
[11] Sohn, Kihyuk, Honglak Lee, and Xinchen Yan. “Learning Structured Output Representation using Deep Conditional Generative Models.” Advances in Neural Information Processing Systems. 2015.
[12] Burda, Yuri, Roger Grosse, and Ruslan Salakhutdinov. "Importance weighted autoencoders." arXiv preprint arXiv:1509.00519 (2015).
[13] Suzuki, Masahiro, Kotaro Nakayama, and Yutaka Matsuo. "Joint Multimodal Learning with Deep Generative Models." arXiv preprint arXiv:1611.01891 (2016).
[14] Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014).
[15] Chung, Junyoung, et al. "A recurrent latent variable model for sequential data." Advances in neural information processing systems. 2015.
[16] Gregor, Karol, et al. "DRAW: A recurrent neural network for image generation." arXiv preprint arXiv:1502.04623 (2015).
[17] Mansimov, Elman, et al. "Generating images from captions with attention." arXiv preprint arXiv:1511.02793 (2015).
[18] Li, Yingzhen, and Richard E. Turner. "Rényi divergence variational inference." Advances in Neural Information Processing Systems. 2016.

40
27
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
40
27

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?