Help us understand the problem. What is going on with this article?

tensorflow+TPUでGANを学習させる

More than 1 year has passed since last update.

はじめに

TensorFlowやKerasでGANを訓練する例自体はいくつもあるのですが、そのままTPUで訓練しようとするとうまく動かなかったりアホみたいに遅かったりで実用に耐えません。そこで、TensorFlowの低レベルAPIを用いてTPUに対応したGANを実装したいと思います。

基本的にはTPUでCustom Loopを動かすためのTensorFlow公式チュートリアルを踏まえた話になりますが、自分が確認しただけでも3個のチュートリアルがあり[1][2][3]、それぞれ書き方が微妙に異なって動いたり動かなかったりします。なので、ここで一度まとめて整理しようってわけです。

今回はGANに限定して話しますが、紹介する書き方はGAN以外のどんなネットワークにも適用できると思います。

環境

  • Google Colabratory
  • TensorFlow 1.14.0

TPU対応の基本的な書き方

KerasでTPUに対応した書き方についてはすでにQiitaに記事が投稿されています。

TensorFlow1.14以降のTPUの取り扱い方について

strategy.scope()の後で今まで通りモデル定義や学習を行えばよい形になっています。KerasでModel.compile()してModel.fit()で学習、とかならこれだけでOKです。簡単ですね。
ただし、ちょっと外れた書き方をしたり、複雑なことをしようとすると途端にわけわからなくなります。

TensorFlow+TPUでGANを実装する時の問題点

strategy.scope()を使えば通常のCNNやらはTPUで学習できるのですが、GANの場合は具体的にどんな問題点があるのか、ここで一度まとめておきます。興味ない人は飛ばしてください。

問題点1. Keras+TPUではtrain_on_batchが使えない

GANではミニバッチ毎にDiscriminatorとGeneratorの2つのネットワークを交互に訓練していきます。
そのため、TensorFlowの高レベルAPIであるKerasではModel.train_on_batch()を使うのが定石ですが、2019年7月現在の最新版であるTensorFlow1.14.0ではKerasのModel.train_on_batch()はTPUに対応していません。
1.13.1ではModel.train_on_batch()が使えていましたが、1.14.0でTPU対応のモデルの書き方が変わり、暫定的に未対応となったのだと思います。
じゃあ通常のModel.fit()をミニバッチ毎に使えばいいじゃん!って思いつきますが、動きはするもののめちゃくちゃ遅くて使い物にはなりません。
今後、Model.train_on_batch()が再びTPU対応するとは思いますが、いつ対応するかはわからないので、今のところは別の書き方が必要になります。

問題点2. TensorFlowのTF-GANは使いにくい

これはどちらかと言うと個人的な理由になるかもしれません。

TensorFlowにはTF-GANというGANを手軽に試せるAPIが用意されています。TensorFlowの中レベルAPIであるEstimatorを使ったGANを作れるのですが、

  • コスト関数を自分で書けない
  • tf.contrib扱いなので、今後削除される可能性が小さくない
  • そもそも分かりにくい

といった感じであまり積極的に利用したくはないかなと思っています。特にGANの仕組みをコードを書きながら理解したい、自分で色々弄りたい、って人には向きません。

TPUに対応したGANの書き方

ここからが本題です。コード全体は少し長いので折り畳んでいます。githubには結果も載せているので合わせてどうぞ。

コード全体はこちら
import sys, os
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.models import Sequential

class GAN(object):
    def __init__(self):

        self.z_dim = 100 # 潜在変数の次元

        self.image_shape = (28, 28, 1) # 画像のサイズ
        self.noise_shape = (self.z_dim,) # ノイズのサイズ

        self.epochs = 100 # 学習回数
        self.batch_size = 512 # バッチサイズ

        # データセットのロード
        self.X_train = self.load_dataset()
        self.num_batches = self.X_train.shape[0] // self.batch_size # ミニバッチの数
        print('number of batches:', self.num_batches)

        # TPU対応のおまじない1
        tf.keras.backend.clear_session()
        tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
        self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
        self.strategy = tf.contrib.distribute.TPUStrategy(self.tpu_cluster_resolver)

        # ここからTPU対応のモデルやらを書いていく
        with self.strategy.scope():

            # Discriminatorの定義
            self.discriminator = self.build_discriminator()
            self.optimizer_disc = tf.train.AdamOptimizer(2.0e-4, 0.5) # Discriminator用のOptimizer
            self.var_disc = self.discriminator.trainable_variables # Discriminatorの重み

            # Generatorの定義
            self.generator = self.build_generator()
            self.optimizer_gen = tf.train.AdamOptimizer(2.0e-4, 0.5) # Generator用のOptimizer
            self.var_gen = self.generator.trainable_variables # Generatorの重み

            # データセットの入力用placeholder
            self.images_placeholder = tf.placeholder(tf.float32, [None, *self.image_shape])
            self.noise_placeholder = tf.placeholder(tf.float32, [None, *self.noise_shape])
            self.labels_placeholder = tf.placeholder(tf.float32, [None, 1])

            # Dataset APIで入力パイプラインを定義
            dataset = tf.data.Dataset.from_tensor_slices(
                (self.images_placeholder,
                 self.noise_placeholder,
                 self.labels_placeholder
                ))
            dataset = dataset.repeat()
            dataset = dataset.batch(self.batch_size, drop_remainder=True) # TPUではdrop_remainder=Trueが必須

            # DatasetをTPU用のDatasetに変換
            dist_dataset = self.strategy.experimental_distribute_dataset(dataset)

            # iteratorを定義
            input_iterator = dist_dataset.make_initializable_iterator()
            self.iterator_init = input_iterator.initialize()

            # 学習等のopsを定義
            inputs = input_iterator.get_next() # ネットワークの入力
            self.train_disc_ops = self.train_step_disc(inputs) # Discriminatorの学習
            self.train_gen_ops = self.train_step_gen(inputs) # Generatorの学習
            self.output_gen_ops = self.output_images_gen(inputs) # Generatorの出力

            # TPU対応のおまじない2
            tf.contrib.distribute.initialize_tpu_system(self.tpu_cluster_resolver)
            config = tf.ConfigProto()
            config.allow_soft_placement = True
            cluster_spec = self.tpu_cluster_resolver.cluster_spec()
            if cluster_spec:
                config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

            # Sessionの定義
            self.sess = tf.Session(
                target=self.tpu_cluster_resolver.master(),
                config=config
            )

            # 変数の初期化
            self.sess.run(tf.global_variables_initializer())

    def load_dataset(self):

        # mnistデータの読み込み
        (X_train, _), (_, _) = mnist.load_data()

        # 値を-1 to 1に規格化
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        return X_train

    def build_discriminator(self):
        # discriminatorモデル
        # kerasのSequentialを使っているが、Functional APIでもtensorflowの低レベルAPIでもたぶん大丈夫

        layers_disc = []
        layers_disc.append(
            Conv2D(16, (5, 5), strides=(2, 2), padding='same', input_shape=self.image_shape))
        layers_disc.append(LeakyReLU(alpha=0.2))

        layers_disc.append(
            Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
        layers_disc.append(LeakyReLU(alpha=0.2))

        layers_disc.append(Flatten())
        layers_disc.append(Dense(1))

        discriminator = Sequential(layers_disc)

        return discriminator

    def build_generator(self):
        # Generatorモデル
        # kerasのSequentialを使っているが、Functional APIでもtensorflowの低レベルAPIでもたぶん大丈夫

        layers_gen = []
        layers_gen.append(Dense(7 * 7 * 256, use_bias=False, input_shape=self.noise_shape))
        layers_gen.append(BatchNormalization(momentum=0.8))
        layers_gen.append(LeakyReLU(alpha=0.2))

        layers_gen.append(Reshape((7, 7, 256)))

        layers_gen.append(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
        layers_gen.append(BatchNormalization(momentum=0.8))
        layers_gen.append(LeakyReLU(alpha=0.2))

        layers_gen.append(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
        layers_gen.append(BatchNormalization(momentum=0.8))
        layers_gen.append(LeakyReLU(alpha=0.2))

        layers_gen.append(
            Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='tanh'))

        generator = Sequential(layers_gen)

        return generator

    def train_step_disc(self, dist_inputs):
        # Discriminatorに対して
        # コストを計算して逆伝播法で重みを更新する

        def step_fn(inputs):
            features, _, labels = inputs # 入力データ
            logits = self.discriminator(features) # Discriminatorの出力

            # コスト関数と重み更新
            cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
            loss = tf.reduce_sum(cross_entropy) / self.batch_size # reduce_meanは使わない方がいい
            train_op_disc = self.optimizer_disc.minimize(loss, var_list=self.var_disc) # discriminatorの重みのみ更新する

            # 精度
            logits_bool = tf.cast(tf.greater_equal(logits, 0), tf.float32)
            acc = tf.reduce_sum(1.0 - tf.abs(labels - logits_bool)) / self.batch_size

            # 必ずtf.control_dependenciesを使うこと
            with tf.control_dependencies([train_op_disc]):
                return tf.identity(loss), tf.identity(acc)

        # TPUコア毎にstep_fnを実行して結果を出力
        per_replica_losses, per_replica_accs = self.strategy.experimental_run_v2(step_fn, args=(dist_inputs,))

        # TPUコア毎のコストと精度をまとめる
        # tf.distribute.ReduceOp.SUMはtf.reduce_sum
        # tf.distribute.ReduceOp.MEANはtf.reduce_meanに対応
        # MEANは正しい結果になっているかちょっと自信ないので、SUMにしている
        losses = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        accs = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_accs, axis=None)

        return losses, accs

    def output_images_gen(self, dist_inputs):
        # Generatorの出力画像を得る

        def step_fn(inputs):
            _, noises, _ = inputs # 入力データ
            return self.generator(noises, training=False) # GeneratorにBatchNormalizationを入れている場合はtraining=Falseを指定

        # TPUコア毎にstep_fnを実行して結果を出力
        gen_output = self.strategy.experimental_run_v2(step_fn, args=(dist_inputs,))

        # TPUコア毎の結果を連結
        gen_output = tf.concat(gen_output.values, axis=0)

        return gen_output

    def train_step_gen(self, dist_inputs):
        # Generatorに対して
        # コストを計算して逆伝播法で重みを更新する

        def step_fn(inputs):
            _, noises, labels = inputs # 入力データ
            features = self.generator(noises, training=True) # GeneratorにBatchNormalizationを入れている場合はtraining=Trueを指定
            logits = self.discriminator(features)

            # コスト関数と重み更新
            cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
            loss = tf.reduce_sum(cross_entropy) / self.batch_size
            train_op_gen = self.optimizer_gen.minimize(loss, var_list=self.var_gen) # Generatorの重みのみ更新

            # 精度
            logits_bool = tf.cast(tf.greater_equal(logits, 0), tf.float32)
            acc = tf.reduce_sum(1.0 - tf.abs(labels - logits_bool)) / self.batch_size

            # BatchNormalizationの平均と分散の更新
            # GeneratorにBatchNormalizationを入れている場合は必須
            update_ops = self.generator.get_updates_for(None) + self.generator.get_updates_for(noises)

            # 必ずtf.control_dependenciesを使うこと
            # BatchNormalizationを使っている場合はupdate_opsも一緒に入れる
            with tf.control_dependencies([train_op_gen, *update_ops]):
                return tf.identity(loss), tf.identity(acc)

        # TPUコア毎にstep_fnを実行して結果を出力
        per_replica_losses, per_replica_accs = self.strategy.experimental_run_v2(step_fn, args=(dist_inputs,))

        # TPUコア毎のコストと精度をまとめる
        # tf.distribute.ReduceOp.SUMはtf.reduce_sum
        # tf.distribute.ReduceOp.MEANはtf.reduce_meanに対応
        # MEANは正しい結果になっているかちょっと自信ないので、SUMにしている
        losses = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        accs = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_accs, axis=None)

        return losses, accs

    def fit(self):
        # TPU上でDiscriminatorとGeneratorを更新する

        with self.strategy.scope():
            start_fit = time.time()

            noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32) # Generatorの入力
            image_real = self.X_train[:self.batch_size] # Discriminatorの入力
            label_real = np.ones((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル

            # 入力パイプラインを初期化
            self.sess.run(
                self.iterator_init,
                feed_dict={
                    self.images_placeholder: image_real,
                    self.noise_placeholder: noise,
                    self.labels_placeholder: label_real
                })

            # 学習前のGeneratorの出力を確認
            image_fake = self.sess.run(self.output_gen_ops)
            self.show_images(image_fake, epoch=0)

            # 学習開始
            for epoch in range(self.epochs):

                # 各エポックのコストと精度
                d_loss_epoch = 0
                d_acc_epoch = 0
                g_loss_epoch = 0
                g_acc_epoch = 0

                start_epoch = time.time()

                # 各エポックの学習前に学習データをシャッフル
                np.random.shuffle(self.X_train)

                # ミニバッチ学習
                for iter in range(self.num_batches):

                    noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32) # Generatorの入力
                    image_real = self.X_train[iter * self.batch_size:(iter + 1) * self.batch_size] # Discriminatorの入力(本物)
                    label_real = np.ones((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル(本物) 
                    label_fake = np.zeros((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル(偽物)

                    #---------------------
                    # Discriminatorの学習
                    #---------------------

                    # iteratorを初期化
                    self.sess.run(
                        self.iterator_init,
                        feed_dict={
                            self.images_placeholder: image_real, # Discriminatorの入力(本物)
                            self.noise_placeholder: noise, # Genratorの入力
                            self.labels_placeholder: label_real # Discriminatorの出力ラベル(本物)
                        })

                    # 偽物画像を生成
                    image_fake = self.sess.run(self.output_gen_ops)

                    # 本物画像でDiscriminatorを学習
                    d_loss_real, d_acc_real = self.sess.run(self.train_disc_ops)

                    # Discriminatorに偽物画像を与えるため
                    # iteratorを初期化
                    self.sess.run(
                        self.iterator_init,
                        feed_dict={
                            self.images_placeholder: image_fake, # Discriminatorの入力(偽物)
                            self.noise_placeholder: noise, # Genratorの入力(使わないのでなんでもいい)
                            self.labels_placeholder: label_fake # Discriminatorの出力ラベル(偽物)
                        })

                    # 偽物画像でDiscriminatorを学習
                    d_loss_fake, d_acc_fake = self.sess.run(self.train_disc_ops)

                    # 本物画像の結果と偽物画像の結果を平均
                    d_loss = 0.5 * (d_loss_real + d_loss_fake)
                    d_acc = 0.5 * (d_acc_real + d_acc_fake)

                    #---------------------
                    # Generatorの学習
                    #---------------------

                    # iteratorを初期化
                    self.sess.run(
                        self.iterator_init,
                        feed_dict={
                            self.images_placeholder: image_real, # Discriminatorの入力(使わないのでなんでもいい)
                            self.noise_placeholder: noise, # Genratorの入力
                            self.labels_placeholder: label_real # Discriminatorの出力ラベル(本物)
                        })

                    # 本物ラベルでGeneratorを学習
                    g_loss, g_acc = self.sess.run(self.train_gen_ops)

                    # エポック毎の結果
                    d_loss_epoch += d_loss
                    d_acc_epoch += d_acc
                    g_loss_epoch += g_loss
                    g_acc_epoch += g_acc

                    # 進捗の表示
                    sys.stdout.write(
                        '\repoch:{:d}  iter:{:d}   [D loss: {:f}, acc: {:.2f}%] [G loss: {:f}, acc: {:.2f}%]   '.format(
                            epoch + 1, iter + 1, d_loss, 100 * d_acc, g_loss, 100 * g_acc))
                    sys.stdout.flush()

                # ミニバッチ毎の結果を平均
                d_loss_epoch /= self.num_batches
                d_acc_epoch /= self.num_batches
                g_loss_epoch /= self.num_batches
                g_acc_epoch /= self.num_batches

                epoch_time = time.time() - start_epoch

                # エポックの結果を表示
                sys.stdout.write(
                    '\repoch:{:d}  iter:{:d}   [D loss: {:f}, acc: {:.2f}%] [G loss: {:f}, acc: {:.2f}%]   time: {:f}\n'.format(
                        epoch + 1, iter + 1, d_loss_epoch, 100 * d_acc_epoch, g_loss_epoch, 100 * g_acc_epoch, epoch_time))
                sys.stdout.flush()

                # Generatorの出力を確認
                if (epoch + 1) % 10 == 0:
                    noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32)
                    self.sess.run(
                        self.iterator_init,
                        feed_dict={
                            self.images_placeholder: image_real, # Discriminatorの入力(使わないのでなんでもいい)
                            self.noise_placeholder: noise, # Genratorの入力
                            self.labels_placeholder: label_real # Discriminatorの出力ラベル(使わないのでなんでもいい)
                        })
                    image_fake = self.sess.run(self.output_gen_ops)
                    self.show_images(image_fake, epoch=epoch + 1)

    def show_images(self, images, epoch):
        # 出力画像を確認

        fig = plt.figure(figsize=(4, 4))
        for i in range(16):
          plt.subplot(4, 4, i + 1)
          plt.imshow(images[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
          plt.axis('off')

        fig.suptitle('epoch: {:}'.format(epoch))
        fig.savefig('mnist_epoch_{:}.png'.format(epoch))
        plt.show()

if __name__ == '__main__':
    G = GAN()
    G.fit()

TPUStrategyを定義

モデル定義の前にTPUStrategyを定義していきましょう。これはGANに関わらずTPU対応のためには必須です。

# TPU対応のおまじない1
tf.keras.backend.clear_session()
tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
self.strategy = tf.contrib.distribute.TPUStrategy(self.tpu_cluster_resolver)

# ここからTPU対応のモデルやらを書いていく
with self.strategy.scope():
    # モデル定義やらモデル実行やら

with strategy.scope():の後にモデル定義やらモデル実行のコードを書いていきます。

モデル定義

with strategy.scope()の後にモデル定義を書いていきましょう。

with self.strategy.scope():

    # Discriminatorの定義
    self.discriminator = self.build_discriminator()
    self.optimizer_disc = tf.train.AdamOptimizer(2.0e-4, 0.5) # Discriminator用のOptimizer
    self.var_disc = discriminator.trainable_variables # Discriminatorの重み

    # Generatorの定義
    self.generator = self.build_generator()
    self.optimizer_gen = tf.train.AdamOptimizer(2.0e-4, 0.5) # Generator用のOptimizer
    self.var_gen = generator.trainable_variables # Generatorの重み

あとでDiscriminatorとGeneratorを個別に更新するために、それぞれの重みを集めておきます。
この部分はTPU未対応でも変わらないと思います。

DiscriminatorとGeneratorの構造はTensorFlow公式のDCGANチュートリアルをほぼそのまま使っています。
KerasのSequentialで定義しておきますが、KerasのFunctional APIでもTensorFlowの低レベルAPIでも大丈夫なはずです。

モデルの構造はこちら
def build_discriminator(self):
    # discriminatorモデル
    # kerasのSequentialを使っているが、Functional APIでもtensorflowの低レベルAPIでもたぶん大丈夫

    layers_disc = []
    layers_disc.append(
        Conv2D(16, (5, 5), strides=(2, 2), padding='same', input_shape=self.image_shape))
    layers_disc.append(LeakyReLU(alpha=0.2))

    layers_disc.append(
        Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
    layers_disc.append(LeakyReLU(alpha=0.2))

    layers_disc.append(Flatten())
    layers_disc.append(Dense(1))

    discriminator = Sequential(layers_disc)

    return discriminator

def build_generator(self):
    # Generatorモデル
    # kerasのSequentialを使っているが、Functional APIでもtensorflowの低レベルAPIでもたぶん大丈夫

    layers_gen = []
    layers_gen.append(Dense(7 * 7 * 256, use_bias=False, input_shape=self.noise_shape))
    layers_gen.append(BatchNormalization(momentum=0.8))
    layers_gen.append(LeakyReLU(alpha=0.2))

    layers_gen.append(Reshape((7, 7, 256)))

    layers_gen.append(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    layers_gen.append(BatchNormalization(momentum=0.8))
    layers_gen.append(LeakyReLU(alpha=0.2))

    layers_gen.append(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    layers_gen.append(BatchNormalization(momentum=0.8))
    layers_gen.append(LeakyReLU(alpha=0.2))

    layers_gen.append(
        Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='tanh'))

    generator = Sequential(layers_gen)

    return generator

データセットの入力パイプラインを定義

今回はDataset APIを使ってデータを入力していきます。ここはモデル定義と前後しても大丈夫です。

with self.strategy.scope():
     # データセットの入力用placeholder
    self.images_placeholder = tf.placeholder(tf.float32, [None, *self.image_shape])
    self.noise_placeholder = tf.placeholder(tf.float32, [None, *self.noise_shape])
    self.labels_placeholder = tf.placeholder(tf.float32, [None, 1])

    # Dataset APIで入力パイプラインを定義
    dataset = tf.data.Dataset.from_tensor_slices(
        (self.images_placeholder,
         self.noise_placeholder,
         self.labels_placeholder
        ))
    dataset = dataset.repeat()
    dataset = dataset.batch(self.batch_size, drop_remainder=True) # TPUではdrop_remainder=Trueが必須

    # DatasetをTPU用のDatasetに変換
    dist_dataset = self.strategy.experimental_distribute_dataset(dataset)

    # iteratorを定義
    input_iterator = dist_dataset.make_initializable_iterator()
    self.iterator_init = input_iterator.initialize()

Discriminatorの入力を本物画像と偽物画像で入れ替えるために、placeholderからDatasetを作って後から入力を変更できるようにしておきます。

ここで重要なのは通常のDatasetを定義した後に、strategy.experimental_distribute_dataset()を使ってTPU用のDatasetに変換するところです。IteratorはこのTPU対応のDatasetから作っていきます。

コスト関数と重み更新を定義

コスト関数と重み更新を定義していきましょう。

with self.strategy.scope():
    # 学習等のopsを定義
    inputs = input_iterator.get_next() # ネットワークの入力
    self.train_disc_ops = self.train_step_disc(inputs) # Discriminatorの学習
    self.train_gen_ops = self.train_step_gen(inputs) # Generatorの学習
    self.output_gen_ops = self.output_images_gen(inputs) # Generatorの出力

それぞれ関数化していますが、中身はだいたい同じなのでtrain_step_gan()を例に見ていきます。

def train_step_gen(self, dist_inputs):
    # Generatorに対して
    # コストを計算して逆伝播法で重みを更新する

    def step_fn(inputs):
        _, noises, labels = inputs # 入力データ
        features = self.generator(noises, training=True) # GeneratorにBatchNormalizationを入れている場合はtraining=Trueを指定
        logits = self.discriminator(features)

        # コスト関数と重み更新
        cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
        loss = tf.reduce_sum(cross_entropy) / self.batch_size
        train_op_gen = self.optimizer_gen.minimize(loss, var_list=self.var_gen) # Generatorの重みのみ更新

        # 精度
        logits_bool = tf.cast(tf.greater_equal(logits, 0), tf.float32)
        acc = tf.reduce_sum(1.0 - tf.abs(labels - logits_bool)) / self.batch_size

        # BatchNormalizationの平均と分散の更新
        # GeneratorにBatchNormalizationを入れている場合は必須
        update_ops = self.generator.get_updates_for(None) + self.generator.get_updates_for(noises)

        # 必ずtf.control_dependenciesを使うこと
        # BatchNormalizationを使っている場合はupdate_opsも一緒に入れる
        with tf.control_dependencies([train_op_gen, *update_ops]):
            return tf.identity(loss), tf.identity(acc)

    # TPUコア毎にstep_fnを実行して結果を出力
    per_replica_losses, per_replica_accs = self.strategy.experimental_run_v2(step_fn, args=(dist_inputs,))

    # TPUコア毎のコストと精度をまとめる
    # tf.distribute.ReduceOp.SUMはtf.reduce_sum
    # tf.distribute.ReduceOp.MEANはtf.reduce_meanに対応
    # MEANは正しい結果になっているかちょっと自信ないので、SUMにしている
    losses = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    accs = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_accs, axis=None)

    return losses, accs

重要なのは、コスト関数と重み更新はstep_fn()と関数内で定義して、それをstrategy.experimental_run_V2()に渡して実行することです。こうすることでTPUコア毎の結果を集めて、まとめて重み更新できるようです。
返り値はPerReplicaオブジェクトなので、PerReplica.valuesで値を取り出すか、strategy.reduce()で合計か平均を算出します。

もうひとつ重要なことは、keras.layers.BatchNormalizationを使っている場合は、学習時に平均と分散の更新を行うために、optimizer.minimize()とは別にupdate_opsを定義して、同時に実行する必要があることです。
また、BatchNormalizationが入っているModelcallするときにtraining=Trueを指定します。
これらを除いてしまうとBatchNormalizationがうまく学習できません。一方で、推論時にはtraining=Falseとするのみでupdate_opsは必要ありません。

tf.Sessionを定義して初期化

モデル定義は終わりました。後はtf.Sessionを定義して、重みを初期化しましょう。

with self.strategy.scope():
    # TPU対応のおまじない2
    tf.contrib.distribute.initialize_tpu_system(self.tpu_cluster_resolver)
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    cluster_spec = self.tpu_cluster_resolver.cluster_spec()
    if cluster_spec:
        config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

    # Sessionの定義
    self.sess = tf.Session(
        target=self.tpu_cluster_resolver.master(),
        config=config
    )

    # 変数の初期化
    self.sess.run(tf.global_variables_initializer())

ここで重要なことは、cluster_specの部分とtf.Session()の中でtarget=tpu_cluster_resolver.master()とすることです。
これがないとエラー吐いて実行できません。

学習を実行

準備は整いました。学習を行うコードを書いていきましょう。長いのと、大したポイントはないので折り畳んでいます。

学習用コードはこちら
def fit(self):
    # TPU上でDiscriminatorとGeneratorを更新する

    with self.strategy.scope():
        start_fit = time.time()

        noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32) # Generatorの入力
        image_real = self.X_train[:self.batch_size] # Discriminatorの入力
        label_real = np.ones((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル

        # 入力パイプラインを初期化
        self.sess.run(
            self.iterator_init,
            feed_dict={
                self.images_placeholder: image_real,
                self.noise_placeholder: noise,
                self.labels_placeholder: label_real
            })

        # 学習前のGeneratorの出力を確認
        image_fake = self.sess.run(self.output_gen_ops)
        self.show_images(image_fake)

        # 学習開始
        for epoch in range(self.epochs):

            # 各エポックのコストと精度
            d_loss_epoch = 0
            d_acc_epoch = 0
            g_loss_epoch = 0
            g_acc_epoch = 0

            start_epoch = time.time()

            # 各エポックの学習前に学習データをシャッフル
            np.random.shuffle(self.X_train)

            # ミニバッチ学習
            for iter in range(self.num_batches):

                noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32) # Generatorの入力
                image_real = self.X_train[iter * self.batch_size:(iter + 1) * self.batch_size] # Discriminatorの入力(本物)
                label_real = np.ones((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル(本物) 
                label_fake = np.zeros((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル(偽物)

                #---------------------
                # Discriminatorの学習
                #---------------------

                # iteratorを初期化
                self.sess.run(
                    self.iterator_init,
                    feed_dict={
                        self.images_placeholder: image_real, # Discriminatorの入力(本物)
                        self.noise_placeholder: noise, # Genratorの入力
                        self.labels_placeholder: label_real # Discriminatorの出力ラベル(本物)
                    })

                # 偽物画像を生成
                image_fake = self.sess.run(self.output_gen_ops)

                # 本物画像でDiscriminatorを学習
                d_loss_real, d_acc_real = self.sess.run(self.train_disc_ops)

                # Discriminatorに偽物画像を与えるため
                # iteratorを初期化
                self.sess.run(
                    self.iterator_init,
                    feed_dict={
                        self.images_placeholder: image_fake, # Discriminatorの入力(偽物)
                        self.noise_placeholder: noise, # Genratorの入力(使わないのでなんでもいい)
                        self.labels_placeholder: label_fake # Discriminatorの出力ラベル(偽物)
                    })

                # 偽物画像でDiscriminatorを学習
                d_loss_fake, d_acc_fake = self.sess.run(self.train_disc_ops)

                # 本物画像の結果と偽物画像の結果を平均
                d_loss = 0.5 * (d_loss_real + d_loss_fake)
                d_acc = 0.5 * (d_acc_real + d_acc_fake)

                #---------------------
                # Generatorの学習
                #---------------------

                # iteratorを初期化
                self.sess.run(
                    self.iterator_init,
                    feed_dict={
                        self.images_placeholder: image_real, # Discriminatorの入力(使わないのでなんでもいい)
                        self.noise_placeholder: noise, # Genratorの入力
                        self.labels_placeholder: label_real # Discriminatorの出力ラベル(本物)
                    })

                # 本物ラベルでGeneratorを学習
                g_loss, g_acc = self.sess.run(self.train_gen_ops)

                # エポック毎の結果
                d_loss_epoch += d_loss
                d_acc_epoch += d_acc
                g_loss_epoch += g_loss
                g_acc_epoch += g_acc

                # 進捗の表示
                sys.stdout.write(
                    '\repoch:{:d}  iter:{:d}   [D loss: {:f}, acc: {:.2f}%] [G loss: {:f}, acc: {:.2f}%]   '.format(
                        epoch + 1, iter + 1, d_loss, 100 * d_acc, g_loss, 100 * g_acc))
                sys.stdout.flush()

            # ミニバッチ毎の結果を平均
            d_loss_epoch /= self.num_batches
            d_acc_epoch /= self.num_batches
            g_loss_epoch /= self.num_batches
            g_acc_epoch /= self.num_batches

            epoch_time = time.time() - start_epoch

            # エポックの結果を表示
            sys.stdout.write(
                '\repoch:{:d}  iter:{:d}   [D loss: {:f}, acc: {:.2f}%] [G loss: {:f}, acc: {:.2f}%]   time: {:f}\n'.format(
                    epoch + 1, iter + 1, d_loss_epoch, 100 * d_acc_epoch, g_loss_epoch, 100 * g_acc_epoch, epoch_time))
            sys.stdout.flush()

            # Generatorの出力を確認
            if (epoch + 1) % 10 == 0:
                noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32)
                self.sess.run(
                    self.iterator_init,
                    feed_dict={
                        self.images_placeholder: image_real, # Discriminatorの入力(使わないのでなんでもいい)
                        self.noise_placeholder: noise, # Genratorの入力
                        self.labels_placeholder: label_real # Discriminatorの出力ラベル(使わないのでなんでもいい)
                    })
                image_fake = self.sess.run(self.output_gen_ops)
                self.show_images(image_fake)

with strategy.scope()の後に実行していく以外は通常の低レベルAPIの書き方と変わりません。
TensorFlowの低レベルAPIに慣れてない人にとっては長ったらしく感じますが、sess.run(iterator_init)sess.run(train_ops)の2つを合わせたものがModel.train_on_batch()に相当します。

また、GeneratorとDiscriminatorの学習を切り替える際にIteratorを初期化し、入力データと正解ラベルを指定します。これは入力データと正解ラベルをGeneratorの学習(本物)、Generatorの学習(偽物)、Discriminatorの学習で変更する必要があるためです。
GANのように入力データと正解ラベルを切り替える必要がない場合は、各エポックの学習前にIteratorを初期化しておくだけでOKです。Dataset APIを使った入力は色々なやり方があるので、調べてみるといいかもしれません。

実際に学習して結果を確認

Google Colabで実行して結果を確認してみます。

epoch:1  iter:117   [D loss: 0.699009, acc: 49.88%] [G loss: 0.513983, acc: 99.87%]   time: 15.332689
epoch:2  iter:117   [D loss: 0.719187, acc: 45.43%] [G loss: 0.630630, acc: 99.88%]   time: 13.916150
epoch:3  iter:117   [D loss: 0.711329, acc: 43.12%] [G loss: 0.658984, acc: 99.49%]   time: 14.093864
epoch:4  iter:117   [D loss: 0.707588, acc: 38.28%] [G loss: 0.673343, acc: 95.34%]   time: 13.517264
epoch:5  iter:117   [D loss: 0.705363, acc: 35.09%] [G loss: 0.679608, acc: 89.22%]   time: 13.370727
epoch:6  iter:117   [D loss: 0.704420, acc: 31.46%] [G loss: 0.683517, acc: 82.51%]   time: 12.869707
epoch:7  iter:117   [D loss: 0.703543, acc: 28.16%] [G loss: 0.686726, acc: 75.82%]   time: 12.502573
epoch:8  iter:117   [D loss: 0.702794, acc: 28.11%] [G loss: 0.686894, acc: 78.77%]   time: 12.119865
epoch:9  iter:117   [D loss: 0.702125, acc: 26.32%] [G loss: 0.688055, acc: 76.48%]   time: 12.938859
epoch:10  iter:117   [D loss: 0.701599, acc: 24.18%] [G loss: 0.689229, acc: 73.12%]   time: 13.130902
epoch:11  iter:117   [D loss: 0.701109, acc: 21.90%] [G loss: 0.690466, acc: 67.27%]   time: 12.030408
epoch:12  iter:117   [D loss: 0.700516, acc: 21.28%] [G loss: 0.690808, acc: 65.58%]   time: 12.857735
epoch:13  iter:117   [D loss: 0.700026, acc: 20.26%] [G loss: 0.691294, acc: 63.43%]   time: 12.776606
epoch:14  iter:117   [D loss: 0.699605, acc: 19.71%] [G loss: 0.691599, acc: 62.43%]   time: 12.512784
epoch:15  iter:117   [D loss: 0.699337, acc: 18.98%] [G loss: 0.692067, acc: 59.06%]   time: 12.459395
epoch:16  iter:117   [D loss: 0.699040, acc: 18.98%] [G loss: 0.691960, acc: 60.51%]   time: 12.932993
epoch:17  iter:117   [D loss: 0.698658, acc: 19.96%] [G loss: 0.691851, acc: 61.94%]   time: 12.733138
epoch:18  iter:117   [D loss: 0.698368, acc: 19.76%] [G loss: 0.691887, acc: 62.32%]   time: 12.300640
epoch:19  iter:117   [D loss: 0.698112, acc: 19.10%] [G loss: 0.692205, acc: 59.66%]   time: 13.295983
epoch:20  iter:117   [D loss: 0.697834, acc: 19.20%] [G loss: 0.692334, acc: 58.76%]   time: 12.488116

うまく実行できていそうです。
100エポック学習後のGeneratorの出力を見てみましょう。
mnist_epoch_100.png
学習エポックが少ないので十分学習できているとは言い難いですが、とりあえずうまく学習が進んでいそうです。

GPUと比較

GPUで実行した場合の結果も見ておきましょう。

epoch:1  iter:117   [D loss: 0.697059, acc: 49.96%] [G loss: 0.486593, acc: 99.92%]   time: 10.189113
epoch:2  iter:117   [D loss: 0.721437, acc: 49.37%] [G loss: 0.616008, acc: 100.00%]   time: 9.618804
epoch:3  iter:117   [D loss: 0.712062, acc: 49.01%] [G loss: 0.643090, acc: 100.00%]   time: 9.653123
epoch:4  iter:117   [D loss: 0.708121, acc: 47.38%] [G loss: 0.662188, acc: 99.95%]   time: 9.692954
epoch:5  iter:117   [D loss: 0.705938, acc: 46.76%] [G loss: 0.668273, acc: 99.72%]   time: 9.735072
epoch:6  iter:117   [D loss: 0.704378, acc: 43.46%] [G loss: 0.677067, acc: 98.02%]   time: 9.694915
epoch:7  iter:117   [D loss: 0.703626, acc: 41.36%] [G loss: 0.679792, acc: 96.66%]   time: 9.680238
epoch:8  iter:117   [D loss: 0.702828, acc: 37.18%] [G loss: 0.683268, acc: 92.13%]   time: 9.715913
epoch:9  iter:117   [D loss: 0.702190, acc: 33.08%] [G loss: 0.685683, acc: 87.84%]   time: 9.769355
epoch:10  iter:117   [D loss: 0.701563, acc: 31.81%] [G loss: 0.686774, acc: 87.01%]   time: 9.807986
epoch:11  iter:117   [D loss: 0.701012, acc: 30.47%] [G loss: 0.687479, acc: 86.72%]   time: 9.739087
epoch:12  iter:117   [D loss: 0.700570, acc: 29.01%] [G loss: 0.688042, acc: 85.67%]   time: 9.820838
epoch:13  iter:117   [D loss: 0.700140, acc: 28.27%] [G loss: 0.688463, acc: 85.74%]   time: 9.707165
epoch:14  iter:117   [D loss: 0.699699, acc: 27.34%] [G loss: 0.688918, acc: 83.37%]   time: 9.731362
epoch:15  iter:117   [D loss: 0.699300, acc: 26.80%] [G loss: 0.689223, acc: 84.06%]   time: 9.768835
epoch:16  iter:117   [D loss: 0.698917, acc: 26.54%] [G loss: 0.689530, acc: 83.29%]   time: 9.927150
epoch:17  iter:117   [D loss: 0.698575, acc: 25.80%] [G loss: 0.689942, acc: 81.63%]   time: 9.874165
epoch:18  iter:117   [D loss: 0.698268, acc: 25.14%] [G loss: 0.690354, acc: 78.29%]   time: 9.926621
epoch:19  iter:117   [D loss: 0.698021, acc: 24.95%] [G loss: 0.690568, acc: 78.01%]   time: 9.896298
epoch:20  iter:117   [D loss: 0.697783, acc: 24.95%] [G loss: 0.690692, acc: 78.20%]   time: 9.907461

あれ…? GPUの方が若干速くない…?

今回は画像サイズが28x28x1と小さく、モデルサイズもGeneratorが4層、Discriminatorが3層と比較的小さかったため、TPUの恩恵をあまり得られなかったのかもしれません。
もう少し本格的なモデルで検証してみる必要がありそうです。

こちらも100エポック学習後のGeneratorの出力を見てみます。
mnist_epoch_100.png
大丈夫そうですね。TPUとの差もなさそうです。

まとめ

今回はTPUでGANを学習させる時の実装例を見ていきました。この書き方を応用すれば、GANに限らず任意のネットワークをTPUで学習させることができそうです。
学習速度については、モデルサイズが小さかったこともあり、改めて検証してみる必要があります。

Google Colabを使えば、いくつかの制限はあるもののTPUを無料で手軽に試せます。ぜひこの機会に既存のモデルやらをTPU対応して試してみてください。

今回のコードや結果はgithubに載せているので、参考にしたり、コピペで試したり、ご自由にどうぞ。

参考

TPU関連
Custom training with TPUs | TensorFlow Core | TensorFlow
tf.distribute.Strategy with Training Loops | TensorFlow Core | TensorFlow
Distributed Training in TensorFlow | TensorFlow Core | TensorFlow
TensorFlow1.14以降のTPUの取り扱い方について

GAN関連
Deep Convolutional Generative Adversarial Network | TensorFlow Core | TensorFlow
今さら聞けないGAN(1) 基本構造の理解

mgmk2
音源分離×DeepLearningの研究してましたが画像に浮気中… 信号処理/機械学習/python/TensorFlow/Matlab/C++
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away