これはTensorFlow Advent Calendar 2017の22日目の記事です。
12/12にGoogleからTFGANがリリースされた。TFGANはTensorFlowでGenerative Adversarial Networks (GAN)を手軽に使えるライブラリ。さっそく触ってみたので、お手軽に試す手順を紹介したい。おそらく30〜60分ほどでこんなふうにFashion MNISTの画像が徐々に生成されていく様子が確認できるはず。
GANって何?
GANについては、アイドル顔画像生成やいらすとや画像生成などの事例で目にしたことのある人も多いはず。いわゆる生成モデルに分類される技法で、既存のデータを投入して学習すると、そのデータの特徴を捉えた新しいデータの生成を行える。2014年にIan Goodfellow他が考案したモデルだ。
最近ではNVIDIAからGANによる高精細のセレブ画像生成事例が紹介され、もはやリアル画像か生成画像か見分けが付かないクオリティに達している。
GANで生成された非実在セレブの皆さん(From: [Progressive Growing of GANs for Improved Quality, Stability, and Variation](https://youtu.be/XOxxPcy5Gr4?t=38s))単純な画像生成にかぎらず、動画のウマをシマウマに変えちゃったり(CycleGANによるスタイル変換)、
From: [Turning a horse video into a zebra video (by CycleGAN)](https://www.youtube.com/watch?v=9reHvktowLY)線画やラベル画像、文章からそれっぽい画像を生成したり(pix2pix/StackGAN)、
From: Image-to-Image Translation with Conditional Adversarial Nets
その他、低解像画像から高解像画像を生成する超解像(super resolution)、データ圧縮等々、幅広い応用が可能だ。
応用の対象は画像に限定されるわけではなく、今後は画像以外の様々な用途でGANが応用されていくはず。例えば、開発現場でのテストデータの自動生成やシミュレーションなどがすぐに思いつく。音楽のコード進行を与えるとそれっぽい旋律を生成する事例(MidiNet)もある。
From: GANで音楽生成
GANの仕組み
GANの仕組みについては、日本語による良質な解説記事がいくつか公開されているので、詳しくはそちらを見ていただきたい。
- はじめてのGAN
- Generative Adversarial Networks (GAN) の学習方法進展・画像生成・教師なし画像変換
- Generative Adversarial Networks(GAN)を勉強して、kerasで手書き文字生成する
- タカハシ春の GAN 祭り!〜 一日一GAN(๑•̀ㅂ•́)و✧ 〜
- NIPS 2016 Tutorial: Generative Adversarial Networks:Ian Goodfellowによるチュートリアル
ここでは、サンプルコードの理解に必要となるおおざっぱな仕組みだけかいつまんで説明しよう。
GeneratorとDiscriminator
GANの基本的なアイディアはそれほど複雑ではない。GANでは、GeneratorとDiscriminatorと呼ばれる2つのニューラルネットワークを組み合わせて学習を行う。
From: A Short Introduction to Generative Adversarial Networks
それぞれ、以下の目標に向けて学習が行われる。
- Discriminatorの目標:本物データと生成データを見分けること
- Generatorの目標:Discriminatorが本物データと見間違うような生成データを作ること
この両者を同時に学習させていくと、Discriminatorはデータのより細かな特徴を見つけて本物データと生成データの違いを判別可能になる。一方、Generatorは、Discriminatorをだませるよう、より本物に近い特徴を備えたデータを生成できるようになる。
TFGANのGANEstimatorでお手軽W-GAN
TFGANでは、このGANを簡単に構成できるAPIとして、GANEstimator
を提供しており、サンプルコードでは以下のようにさくっとGANが簡潔に記述されている。
gan_estimator = tfgan.estimator.GANEstimator(
generator_fn=generator_fn, # Generatorの指定
discriminator_fn=discriminator_fn, # Distriminatorの指定
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
generator_optimizer=tf.train.AdamOptimizer(0.001, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(0.0001, 0.5),
add_summaries=tfgan.estimator.SummaryType.IMAGES)
ここで、generator_fn
とdiscriminator_fn
にはそれぞれGeneratorとDiscriminatorのCNN構成を返す関数を渡す。generator_loss_fn
とdiscriminator_loss_fn
(GeneratorとDiscriminatorの学習に用いる損失関数)としては、Wasserstein metric(後述)による損失関数を指定している。またオプティマイザーとしてはAdamを利用する。
これまでGANを試されたことのある方ならお分かりの通り、従来のGANは学習が安定しないという問題があった。確かに、ネット上にはGANの学習がうまくいかない〜ってお悩みのブログが散見される。GANの学習を成功させるべく、オプティマイザーにはAdamを使うべし、活性化関数にはLeakyReLUを使うべし等々、GANをうまく学習するためのテクニックが集約されてきた。
そうしたGAN研究の最新成果のひとつがWasserstein GAN (W-GAN)だ。正直、W-GANの解説記事を読んでも私には理解し難いのだが、とりあえずWasserstein metricsを損失関数として用いると、その低次元多様体間距離やらリプシッツ連続性やらのよーわからんがいい感じの数学的特性によってGANの学習が安定するという。
このようにTFGANでは、W-GANの詳細を知らなくても簡単にそのメリットを享受できる。つまり、これまでのGAN研究のノウハウに基づいたAPIであるため、つまづきにくく初心者にやさしいGANライブラリと言える。サンプルコードをいろいろ試してる分には安定した学習が可能になっている。
Cloud DatalabでTFGANを動かす
では実際にTFGANを動かしてみよう。今回は、Google Cloudに統合されたJupyter Notebook環境であるCloud Datalab上でTFGANに付属のサンプルコードを試してみる。
クラウド上のJupyterと言えばGoogle Researchが公開するColaboratory (Colab)もお手軽だが、Colabは無償サービスのためGPUが使えない。重いGANの学習にGPUを使いたくなったときのことを考え、今回は商用サービスであるDatalabを選択した。Datalabであれば、Tesla K80やP100を使った学習も可能だ。ちなみに、DatalabでのGPU利用方法についてはこちらを参考のこと。
まずは、DatalabのQuickstartを参考に、Datalabインスタンスを作成する。ただしGANの計算は重いので、デフォルトのn1-standard-1
ではサンプルコードの学習でも長時間を要してしまうだろう。お財布に余裕があれば、この記述を参考に、4個や8個のCPUを割り当てておきたい。インスタンスのコストは4個のvCPUで1時間$0.19だ(2017年12月現在)。
TensorFlowを1.4にアップデート
DatalabのデフォルトではTensorFlow 1.2がインストールされるので、まずはTensorFlowを1.4にアップデートしよう。Datalabをブラウザで開いたらnotebooks
フォルダを開き、Notebook
をクリックして新規Notebookを作成する。そこで、以下のコマンドを入力して実行する(GPU利用時はtensorflow
の代わりにtensorflow-GPU
を指定)。
!pip install -U tensorflow
つづいて、画面右上にあるReset Session
をクリックし、Datalabセッションをリセットする。その後、以下のコードを実行してTensorFlowのバージョンを確認する。
import tensorflow as tf
print(tf.__version__)
この時、以下のように1.4以降のバージョン番号が表示されればOKだ。
modelsリポジトリをclone
つづいて、TFGANが含まれるmodelsリポジトリ全体をDatalabにcloneする。
!git clone https://github.com/tensorflow/models.git
するとリポジトリがcloneされ、notebooks
フォルダにはmodels
フォルダが作成される。
Fashion MNISTのURLを指定
TFGANのサンプルコードは、notebooks/models/research/gan
フォルダにあるtutorial.ipynb
だ。これをクリックして開く。
このままNotebookをポチポチしていけばすぐにMNIST画像の生成サンプルが動くのだが、今回はFashion MNISTを代わりに使ってみよう。
Fashion MNISTはMNISTとフォーマット互換の学習データで、手書き数字の代わりにファッションアイテムが並んでいる。ファイルフォーマットやファイル名すべてMNISTと同じにしてあり、世の中にたくさんあるMNISTベースのサンプルコードで動作できるように工夫されている。
このFashion MNISTでTFGANを試すには、models/research/slim/datasets
フォルダにあるdownload_and_convert_mnist.py
をクリックして開き、_DATA_URL
を以下のURLで差し替える。
_DATA_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
これでOK。ファイルを保存したら、tutorial.ipynb
に戻って上からポチポチ実行していこう。Fashion MNISTのデータファイルをダウンロードする部分で時間がかかり、DatalabのUIがフリーズしてしまうが、じっとこらえて放置しておけば2〜3分でダウンロードと変換が完了する。
TFGANの学習
ダウンロードと変換が完了したあとは、さらにポチポチしていくことでTFGANによる学習を開始できる。こんな感じで、学習の過程が確認できる。
サンプルコードにある800ステップの学習が完了するまでには、4個のvCPUを持つDatalabインスタンスでおよそ10分がかかる。サンプルコードに記されたコメントによると、GPUインスタンスにすることで、この学習が5倍の速さで完了するとのこと。GANEstimatorでW-GANによる2000ステップの学習を行った場合は、以下のようなクオリティの画像が生成される。
お手軽さの割には、なかなか悪くない! 学習も安定している様子。
というわけで、私のように「GANってなんだか大変そうだな〜」と敬遠しがちだった皆さま、TFGANでGANデビューしてみてはいかが。
Disclaimer この記事は個人的なものです。ここで述べられていることは私の個人的な意見に基づくものであり、私の雇用者には関係はありません。