Help us understand the problem. What is going on with this article?

TensorFlow2.0 + 無料のColab TPUでDCGANを実装した

TensorFlow2.0とGoogle Colaboratoryの無料TPUを使って、DCGANを実装しました。

訓練経過の様子

何をやったか

  • Google ColabのTPU+TF2.0でCelebA(約20万枚)をDCGANで生成
  • TF1.X系のTPUでは、同時に実行可能なグラフは1個の制約があったため、GANの訓練が容易ではなかった(こちらの記事にある通り、不可能であったわけではない。しかし、低レベルAPIが必須で決して容易ではなかった)。TF2.X系のTPUでは、もっと容易にGANを実装できた
  • DCGANの論文通りのモデル(パラメーター数:G=12.7M, D=11.0M)で。64x64の画像20万枚を、1エポックを40秒程度で訓練可能。100エポック回しても1時間程度。
  • 大きなバッチサイズ(BS=1024)で訓練できた。BigGANの論文にもあるように、バッチサイズを大きくすることは生成画像の質・安定性の向上ともに重要。
  • 1時間程度で以下のような顔画像が生成できた

Colab Notebook

こちらから遊べます
https://colab.research.google.com/drive/1rSm01ZiAFgxsWOcQ48cdtgLiVnwwZNfx

TF2.0のTPUの使い方について

基本的な使い方は、MNISTサンプルの記事を書いたのでこちらを参照ください。

TensorFlow2.0 with KerasでいろいろなMNIST(TPU対応)

訓練ループの中身

ポイントをかいつまんで解説していきます。

GANの訓練ループをどう書いているかについて。1バッチ単位での訓練は次のようにします。

@distrtibuted(Reduction.SUM, Reduction.SUM, Reduction.CONCAT)
def train_on_batch(real_img1, real_img2):
    # concat x1, x2
    real_img = tf.concat([real_img1, real_img2], axis=0)
    # generate fake images
    with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
        z = tf.random.normal(shape=(real_img.shape[0], 1, 1, 100))
        fake_img = model_G(z)
    # train discriminator
    with d_tape:
        fake_out = model_D(fake_img)
        real_out = model_D(real_img)
        d_loss = (loss_func(tf.zeros(shape=(z.shape[0], 1)), fake_out)
                    + loss_func(tf.ones(shape=(z.shape[0], 1)), real_out)) / 2.0
        d_loss = tf.reduce_sum(d_loss) * (1.0 / batch_size)
    gradients = d_tape.gradient(d_loss, model_D.trainable_weights)
    param_D.apply_gradients(zip(gradients, model_D.trainable_weights))
    # train generator
    with g_tape:
        fake_out = model_D(fake_img)
        g_loss = loss_func(tf.ones(shape=(z.shape[0], 1)), fake_out)
        g_loss = tf.reduce_sum(g_loss) * (1.0 / batch_size)
    gradients = g_tape.gradient(g_loss, model_G.trainable_weights)
    param_G.apply_gradients(zip(gradients, model_G.trainable_weights))
    return d_loss, g_loss, fake_img 

distributedデコレーター

これはTensorFlowの組み込みではなく、自分で実装したデコレーターです。

TPUの訓練では(複数GPUと同様)、複数のTPUデバイスにデータをMapしたあと、個々の計算結果(損失値や生成画像)をReduceする必要があります。実装する際、特に意識しないといけないのがReduceで、Reduce用の関数はTFの組み込みでもいくつか用意されています。

しかし、組み込み関数だけでは単体のtrain_on_batch関数を定義したあと、distributed対応の関数を別に書かないといけず、デザイン的に少し野暮ったいのです。そこで、分散訓練の対応+Reduceをいい感じにやってくれるデコレータを自分で実装しました。詳しくはこちら。

TensorFlow2.0でDistributed Trainingをいい感じにやるためのデコレーターを作った

デコレーターの引数は各返り値に対するReduceの方法です。d_lossg_lossはSUM、fake_imgはaxis=0でCONCATしています(Concatは組み込みのreduce関数で非対応なので、デコレーターを実装する際は少し工夫がいる)。

2つのGradient Tape

GとDという2つのモデルを訓練(偏微分を計算)しないといけないので、Gradient Tapeは2個用意します。これはtape.gradient()で偏微分を計算してしまうと、それまでの計算グラフがリセットされてしまうからです。これはPyTorchでも同じです。

TF2.0では、Gradient Tape以下の計算は自動微分可能です。最初のfake_imgの生成では2個のTapeがあるので、どちらのTapeでもGは微分可能ということになります(もう少し賢い書き方あるかもしれません)。

そして、DのTapeとGのTapeで敵対的学習するようなロスを定義し、学習ステップを進めています。

ここでは生成画像fake_imgをDとGの間で使いまわししています。DとGの順番に注意しましょう。Dを訓練した後はGの係数は変わりませんが、Gを訓練した後は当然Gの係数が変わります。したがって、Gの訓練→Dの訓練での生成画像の使いまわしはできません。D→Gなら使いまわしできます

ちなみにGradient Tapeを2個使って、2階微分の計算なんかもできたりします。WGAN-GPをやりたいときに便利ではないでしょうか。

データの読み込みについて

GANの訓練よりも実はここが一番のポイントだったりします。いくつかトラップがあります。

TPU+tf.dataではローカルファイルから読み込めない(らしい)

例えば、tf.data.Dataset.list_filesを使って、「ファイルパス→tf.data内で画像を読み込んでテンソルを返す」という処理は、CPUやGPUでもできてもTPUでは現状はできないようです。

これはTPUのトラブルシューティングにも書いてあります(同じのエラーが出ます)。

ローカル ファイルシステムを使用できない
エラー メッセージ
InvalidArgumentError: Unimplemented: File system scheme '[local]' not implemented

詳細
すべての入力ファイルとモデル ディレクトリは Cloud Storage バケットパス(gs://bucket-name/...)を使用する必要があり、このバケットは TPU サーバーからアクセス可能である必要があります。すべてのデータ処理とモデル チェックポインティングは、ローカルマシンではなく TPU サーバー上で実行されることに注意してください

これを読む限り、Cloud Storageでないと無理みたいですね。一応、TensorFlowのソースを読んでいくと、StreamingFilesDatasetというのもあり、コメントを読んでいる限りなんかローカルファイルでも使えそうな気がします。残念ながらドキュメントがなく、どう使うのかはよくわかりませんでした。

そこで今回は、一度全ての画像を解像度が64x64のNumpy配列に格納し、それをfrom_tensor_slicesでtf.dataに読み込ませることで解決を図りました。

TensorFlowのオブジェクトの2GBの壁

Numpy配列化すれば全て解決というわけではありません。CelebAをNumpy配列化すると、uint8型でも202599枚 × 64px × 64px × 3ch = 2.31GBと大容量になってしまいます。

TensorFlowのオブジェクトには1個あたり2GBの制約があります。from_tensor_slicesでは、内部的に一度Numpy配列を定数のテンソルに置き換えている(と思われる)ので、2GBをオーバーするとエラーが出ます。公式ドキュメントにも以下のような記載があります。

Note that if tensors contains a NumPy array, and eager execution is not enabled, the values will be embedded in the graph as one or more tf.constant operations. For large datasets (> 1 GB), this can waste memory and run into byte limits of graph serialization.

そこで、今回はデータを半分に分割して、あたかも2つの画像テンソルが流れてくるようなデータセットとみなすようにしました。全体のテンソルは2.31GBなので、半分に分割すれば2GBの制約はクリアできます。前半をX1、後半をX2とします。擬似コードですが、

for X1, X2 in dataset:
    # TPUデバイス内で
    X = tf.concat([X1, X2], axis=0)

とすれば、データを半分に分割しても、個々のTPUデバイス側で結合することができます(だいぶ頭おかしい解決方法)。これでちゃんと訓練できました。動いてしまえば正義ですね。

ちなみにNumpy配列化したときに、そこまで容量が大きくなければ(2GBを超えなければ)このような心配をする必要はありません。

訓練経過

1エポック

25エポック

そこそこ形にはなってきました。
epoch_0024.png

50エポック

相当綺麗です。バッチサイズが1024と大きいおかげで、勾配の信頼性が高く、DCGANでも安定しやすいのでしょう。ここまで30分程度です。
epoch_0050.png

100エポック

1時間ちょいでこのようになりました。ただのDCGANでここまでいけるのはすごい。

まとめ

この記事では、TensorFlow2.0+Colab TPUを使って、CelebA20万枚をDCGANで1時間程度で訓練する方法を紹介しました。

今までは、GANをColab上で訓練することは厳しかったが実情でした。なぜなら、TPUでは訓練コードを書くことが容易ではなかったですし、GPUでは実行時間12時間の制約にかかりやすく、小さい解像度でしか訓練できなかったからです。GANでは自前のGPUが推奨で、遊ぶためのハードルが高いのが現実問題としてありました。

しかし、TensorFlow2.0を使うと、TPUでもGANが容易な形で実装可能(データの読み込みなど多少面倒なところはありますが)となったため、GANのハードルがぐっと下がったといえるでしょう。

さらに、TPUの計算はfloat32でも非常に高速で、同じ内容をGPUで行う場合、1ポック40秒程度で回すのは相当なGPU数がいると思われます(まずバッチサイズ1024で訓練するのが相当大変)。それだけの計算資源を無料で得られるわけですから、やはりTPUはすごい。

記事の冒頭にNotebookを公開しているので、ぜひ遊んでみてください。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした