Edited at

TFGANでFashion MNISTのGANをさくっと試す

More than 1 year has passed since last update.

これはTensorFlow Advent Calendar 2017の22日目の記事です。

12/12にGoogleからTFGANがリリースされた。TFGANはTensorFlowでGenerative Adversarial Networks (GAN)を手軽に使えるライブラリ。さっそく触ってみたので、お手軽に試す手順を紹介したい。おそらく30〜60分ほどでこんなふうにFashion MNISTの画像が徐々に生成されていく様子が確認できるはず。

GAN_with_fashion_MNIST.gif

TFGANで生成したFashion MNIST画像


GANって何?

GANについては、アイドル顔画像生成いらすとや画像生成などの事例で目にしたことのある人も多いはず。いわゆる生成モデルに分類される技法で、既存のデータを投入して学習すると、そのデータの特徴を捉えた新しいデータの生成を行える。2014年にIan Goodfellow他が考案したモデルだ。

最近ではNVIDIAからGANによる高精細のセレブ画像生成事例が紹介され、もはやリアル画像か生成画像か見分けが付かないクオリティに達している。

Screen Shot 2017-12-22 at 10.54.28 AM.png

GANで生成された非実在セレブの皆さん(From: Progressive Growing of GANs for Improved Quality, Stability, and Variation

単純な画像生成にかぎらず、動画のウマをシマウマに変えちゃったり(CycleGANによるスタイル変換)、

Screen Shot 2017-12-22 at 11.01.17 AM.png

From: Turning a horse video into a zebra video (by CycleGAN)

線画やラベル画像、文章からそれっぽい画像を生成したり(pix2pix/StackGAN)、



From: Image-to-Image Translation with Conditional Adversarial Nets

その他、低解像画像から高解像画像を生成する超解像(super resolution)、データ圧縮等々、幅広い応用が可能だ。

応用の対象は画像に限定されるわけではなく、今後は画像以外の様々な用途でGANが応用されていくはず。例えば、開発現場でのテストデータの自動生成やシミュレーションなどがすぐに思いつく。音楽のコード進行を与えるとそれっぽい旋律を生成する事例(MidiNet)もある。

From: GANで音楽生成


GANの仕組み

GANの仕組みについては、日本語による良質な解説記事がいくつか公開されているので、詳しくはそちらを見ていただきたい。

ここでは、サンプルコードの理解に必要となるおおざっぱな仕組みだけかいつまんで説明しよう。


GeneratorとDiscriminator

GANの基本的なアイディアはそれほど複雑ではない。GANでは、GeneratorDiscriminatorと呼ばれる2つのニューラルネットワークを組み合わせて学習を行う。

From: A Short Introduction to Generative Adversarial Networks

それぞれ、以下の目標に向けて学習が行われる。


  • Discriminatorの目標:本物データと生成データを見分けること

  • Generatorの目標:Discriminatorが本物データと見間違うような生成データを作ること

この両者を同時に学習させていくと、Discriminatorはデータのより細かな特徴を見つけて本物データと生成データの違いを判別可能になる。一方、Generatorは、Discriminatorをだませるよう、より本物に近い特徴を備えたデータを生成できるようになる。


TFGANのGANEstimatorでお手軽W-GAN

TFGANでは、このGANを簡単に構成できるAPIとして、GANEstimatorを提供しており、サンプルコードでは以下のようにさくっとGANが簡潔に記述されている。

gan_estimator = tfgan.estimator.GANEstimator(

generator_fn=generator_fn, # Generatorの指定
discriminator_fn=discriminator_fn, # Distriminatorの指定
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
generator_optimizer=tf.train.AdamOptimizer(0.001, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(0.0001, 0.5),
add_summaries=tfgan.estimator.SummaryType.IMAGES)

ここで、generator_fndiscriminator_fnにはそれぞれGeneratorとDiscriminatorのCNN構成を返す関数を渡す。generator_loss_fndiscriminator_loss_fn(GeneratorとDiscriminatorの学習に用いる損失関数)としては、Wasserstein metric(後述)による損失関数を指定している。またオプティマイザーとしてはAdamを利用する。

これまでGANを試されたことのある方ならお分かりの通り、従来のGANは学習が安定しないという問題があった。確かに、ネット上にはGANの学習がうまくいかない〜ってお悩みのブログが散見される。GANの学習を成功させるべく、オプティマイザーにはAdamを使うべし、活性化関数にはLeakyReLUを使うべし等々、GANをうまく学習するためのテクニックが集約されてきた。

そうしたGAN研究の最新成果のひとつがWasserstein GAN (W-GAN)だ。正直、W-GANの解説記事を読んでも私には理解し難いのだが、とりあえずWasserstein metricsを損失関数として用いると、その低次元多様体間距離やらリプシッツ連続性やらのよーわからんがいい感じの数学的特性によってGANの学習が安定するという。

このようにTFGANでは、W-GANの詳細を知らなくても簡単にそのメリットを享受できる。つまり、これまでのGAN研究のノウハウに基づいたAPIであるため、つまづきにくく初心者にやさしいGANライブラリと言える。サンプルコードをいろいろ試してる分には安定した学習が可能になっている。


Cloud DatalabでTFGANを動かす

では実際にTFGANを動かしてみよう。今回は、Google Cloudに統合されたJupyter Notebook環境であるCloud Datalab上でTFGANに付属のサンプルコードを試してみる。

クラウド上のJupyterと言えばGoogle Researchが公開するColaboratory (Colab)もお手軽だが、Colabは無償サービスのためGPUが使えない。重いGANの学習にGPUを使いたくなったときのことを考え、今回は商用サービスであるDatalabを選択した。Datalabであれば、Tesla K80やP100を使った学習も可能だ。ちなみに、DatalabでのGPU利用方法についてはこちらを参考のこと。

Screen Shot 2017-12-22 at 6.42.45 PM.png

まずは、DatalabのQuickstartを参考に、Datalabインスタンスを作成する。ただしGANの計算は重いので、デフォルトのn1-standard-1ではサンプルコードの学習でも長時間を要してしまうだろう。お財布に余裕があれば、この記述を参考に、4個や8個のCPUを割り当てておきたい。インスタンスのコストは4個のvCPUで1時間$0.19だ(2017年12月現在)。


TensorFlowを1.4にアップデート

DatalabのデフォルトではTensorFlow 1.2がインストールされるので、まずはTensorFlowを1.4にアップデートしよう。Datalabをブラウザで開いたらnotebooksフォルダを開き、Notebookをクリックして新規Notebookを作成する。そこで、以下のコマンドを入力して実行する(GPU利用時はtensorflowの代わりにtensorflow-GPUを指定)。

!pip install -U tensorflow

つづいて、画面右上にあるReset Sessionをクリックし、Datalabセッションをリセットする。その後、以下のコードを実行してTensorFlowのバージョンを確認する。

import tensorflow as tf

print(tf.__version__)

この時、以下のように1.4以降のバージョン番号が表示されればOKだ。

Screen Shot 2017-12-22 at 6.41.08 PM.png


modelsリポジトリをclone

つづいて、TFGANが含まれるmodelsリポジトリ全体をDatalabにcloneする。

!git clone https://github.com/tensorflow/models.git

するとリポジトリがcloneされ、notebooksフォルダにはmodelsフォルダが作成される。


Fashion MNISTのURLを指定

TFGANのサンプルコードは、notebooks/models/research/ganフォルダにあるtutorial.ipynbだ。これをクリックして開く。

このままNotebookをポチポチしていけばすぐにMNIST画像の生成サンプルが動くのだが、今回はFashion MNISTを代わりに使ってみよう。

Fashion MNISTはMNISTとフォーマット互換の学習データで、手書き数字の代わりにファッションアイテムが並んでいる。ファイルフォーマットやファイル名すべてMNISTと同じにしてあり、世の中にたくさんあるMNISTベースのサンプルコードで動作できるように工夫されている。

image.png

このFashion MNISTでTFGANを試すには、models/research/slim/datasetsフォルダにあるdownload_and_convert_mnist.pyをクリックして開き、_DATA_URLを以下のURLで差し替える。

_DATA_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'

これでOK。ファイルを保存したら、tutorial.ipynbに戻って上からポチポチ実行していこう。Fashion MNISTのデータファイルをダウンロードする部分で時間がかかり、DatalabのUIがフリーズしてしまうが、じっとこらえて放置しておけば2〜3分でダウンロードと変換が完了する。


TFGANの学習

ダウンロードと変換が完了したあとは、さらにポチポチしていくことでTFGANによる学習を開始できる。こんな感じで、学習の過程が確認できる。

Screen Shot 2017-12-22 at 7.09.05 PM.png

サンプルコードにある800ステップの学習が完了するまでには、4個のvCPUを持つDatalabインスタンスでおよそ10分がかかる。サンプルコードに記されたコメントによると、GPUインスタンスにすることで、この学習が5倍の速さで完了するとのこと。GANEstimatorでW-GANによる2000ステップの学習を行った場合は、以下のようなクオリティの画像が生成される。

gan_final.png

お手軽さの割には、なかなか悪くない! 学習も安定している様子。

というわけで、私のように「GANってなんだか大変そうだな〜」と敬遠しがちだった皆さま、TFGANでGANデビューしてみてはいかが。


Disclaimer この記事は個人的なものです。ここで述べられていることは私の個人的な意見に基づくものであり、私の雇用者には関係はありません。