LoginSignup
36
25

More than 5 years have passed since last update.

Conditional GANをchainerで実装した

Last updated at Posted at 2017-07-11

はじめに

Conditional GANをchainerで実装しました。
Github - https://github.com/lyakaap/GAN/tree/master/ConditionalGAN

実装は以下の記事を参考にしました。
http://qiita.com/ysasaki6023/items/55b8545c296ce32ac135

解説

「Conditional GAN」はGANの一種で、従来のGANが生成される画像をコントロール出来なかったのに対して、ラベルを指定することで生成される画像を任意のクラスのものに出来るという素晴らしいモデルです。
また、指定する条件はラベル以外でも可能で、色々な応用が考えられます。
主なGANとの違いは、Generator及びDiscriminatorの入力に対応するラベルデータも渡してあげるという点です。

実装

実装する際に考える事として、「どのようにネットワークにラベルを渡してあげるか」が重要になってきます。
具体的にクラス数10のMNISTを例に挙げて説明します。

Generator

  • 入力ノイズ z:100次元の乱数のベクトル
  • ラベル l:10次元のone-hotのベクトル

をchainerのconcat()を使ってくっつけます。

具体的にそれぞれのshapeを書き出してみると、

  • z => (ミニバッチのサイズ, 100, 1)
  • l => (ミニバッチのサイズ, 10, 1)
  • input => (ミニバッチのサイズ, 100+10, 1)

となっています。上記3つのテンソルに共通して三次元目に付いてる「1」はチャンネル数です(RGB画像を扱うときは当然3になる)。

inputは最終的にGeneratorの入力となるテンソルを表していて、input = F.concat((z, l), axis=1) のように結合しています。

Discriminator

DCGANベースだと、ネットワークがConvolution層になっているので、そのままではone-hotのラベルを渡すことが出来ません。
なので、下に示すようにしてDiscriminatorの入力を定義しました。

スクリーンショット 2017-07-11 20.01.38.png

上の図の例では入力画像が「2」であるので、対応する2チャンネル目が塗りつぶされています。このようにone-hotラベルを10チャンネルの画像のように扱うことで、Convolutionにも対応できるようになっています。
最終的に、ラベルと入力画像を結合したものは、11チャンネルの28x28画像(MNISTの画像サイズ)としてDiscriminatorの入力層に渡されることになります。
具体的な入力のshapeは
(ミニバッチのサイズ, 10+1チャンネル, 28, 28)
のようになります。
結合にはGeneratorと同じくconcat()を使っています。

結果

生成する画像のラベルを上から行ごとに指定して、並べたものを可視化しました。

学習の様子です

ConditionalGAN.gif

ちゃんと指定されたラベルに対応するような学習をしていることが分かります。

また、おまけとして入力ノイズ空間の原点を入力とした際の、300epoch後の生成画像を載せておきます。(同じ数字の出力は当然ながら同一のものです)

image180000.png

生成画像を見ると、どの数字の画像も、ノイズ空間の原点が綺麗で中性的な画像を生成するように対応付けされていることが分かるので楽しい。

任意の数字を引数として入力すると対応する画像を表示してくれるスクリプトも作ってみました。同じ数字でも色々な筆跡の手書き文字を出力されます。

$ python digit_generator.py --digits 20170710

スクリーンショット 2017-07-12 17.28.22.png

スクリーンショット 2017-07-12 17.28.02.png

ひらがなのデータセットとかでやると面白そう。

その他

tensorflowにあるようなone-hotラベル化する関数がchainerに無いのが意外でした。(見落としてるだけかもしれない)
scikit-learn使えってことなんですかね…

因みに自分の実装はこんな感じです。

def to_onehot(label, class_num):
    return numpy.eye(class_num)[label]

ラベルとクラス数を渡すと単位行列を基にone-hotなラベルを出力してくれます。

36
25
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
25