はじめに
TensorFlowやKerasでGANを訓練する例自体はいくつもあるのですが、そのままTPUで訓練しようとするとうまく動かなかったりアホみたいに遅かったりで実用に耐えません。そこで、TensorFlowの低レベルAPIを用いてTPUに対応したGANを実装したいと思います。
基本的にはTPUでCustom Loopを動かすためのTensorFlow公式チュートリアルを踏まえた話になりますが、自分が確認しただけでも3個のチュートリアルがあり[1][2][3]、それぞれ書き方が微妙に異なって動いたり動かなかったりします。なので、ここで一度まとめて整理しようってわけです。
今回はGANに限定して話しますが、紹介する書き方はGAN以外のどんなネットワークにも適用できると思います。
環境
- Google Colabratory
- TensorFlow 1.14.0
TPU対応の基本的な書き方
KerasでTPUに対応した書き方についてはすでにQiitaに記事が投稿されています。
TensorFlow1.14以降のTPUの取り扱い方について
strategy.scope()
の後で今まで通りモデル定義や学習を行えばよい形になっています。KerasでModel.compile()
してModel.fit()
で学習、とかならこれだけでOKです。簡単ですね。
ただし、ちょっと外れた書き方をしたり、複雑なことをしようとすると途端にわけわからなくなります。
TensorFlow+TPUでGANを実装する時の問題点
strategy.scope()
を使えば通常のCNNやらはTPUで学習できるのですが、GANの場合は具体的にどんな問題点があるのか、ここで一度まとめておきます。興味ない人は飛ばしてください。
問題点1. Keras+TPUではtrain_on_batchが使えない
GANではミニバッチ毎にDiscriminatorとGeneratorの2つのネットワークを交互に訓練していきます。
そのため、TensorFlowの高レベルAPIであるKerasではModel.train_on_batch()
を使うのが定石ですが、2019年7月現在の最新版であるTensorFlow1.14.0ではKerasのModel.train_on_batch()
はTPUに対応していません。
1.13.1ではModel.train_on_batch()
が使えていましたが、1.14.0でTPU対応のモデルの書き方が変わり、暫定的に未対応となったのだと思います。
じゃあ通常のModel.fit()
をミニバッチ毎に使えばいいじゃん!って思いつきますが、動きはするもののめちゃくちゃ遅くて使い物にはなりません。
今後、Model.train_on_batch()
が再びTPU対応するとは思いますが、いつ対応するかはわからないので、今のところは別の書き方が必要になります。
問題点2. TensorFlowのTF-GANは使いにくい
これはどちらかと言うと個人的な理由になるかもしれません。
TensorFlowにはTF-GANというGANを手軽に試せるAPIが用意されています。TensorFlowの中レベルAPIであるEstimatorを使ったGANを作れるのですが、
- コスト関数を自分で書けない
-
tf.contrib
扱いなので、今後削除される可能性が小さくない - そもそも分かりにくい
といった感じであまり積極的に利用したくはないかなと思っています。特にGANの仕組みをコードを書きながら理解したい、自分で色々弄りたい、って人には向きません。
TPUに対応したGANの書き方
ここからが本題です。コード全体は少し長いので折り畳んでいます。githubには結果も載せているので合わせてどうぞ。
コード全体はこちら
import sys, os
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.models import Sequential
class GAN(object):
def __init__(self):
self.z_dim = 100 # 潜在変数の次元
self.image_shape = (28, 28, 1) # 画像のサイズ
self.noise_shape = (self.z_dim,) # ノイズのサイズ
self.epochs = 100 # 学習回数
self.batch_size = 512 # バッチサイズ
# データセットのロード
self.X_train = self.load_dataset()
self.num_batches = self.X_train.shape[0] // self.batch_size # ミニバッチの数
print('number of batches:', self.num_batches)
# TPU対応のおまじない1
tf.keras.backend.clear_session()
tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
self.strategy = tf.contrib.distribute.TPUStrategy(self.tpu_cluster_resolver)
# ここからTPU対応のモデルやらを書いていく
with self.strategy.scope():
# Discriminatorの定義
self.discriminator = self.build_discriminator()
self.optimizer_disc = tf.train.AdamOptimizer(2.0e-4, 0.5) # Discriminator用のOptimizer
self.var_disc = self.discriminator.trainable_variables # Discriminatorの重み
# Generatorの定義
self.generator = self.build_generator()
self.optimizer_gen = tf.train.AdamOptimizer(2.0e-4, 0.5) # Generator用のOptimizer
self.var_gen = self.generator.trainable_variables # Generatorの重み
# データセットの入力用placeholder
self.images_placeholder = tf.placeholder(tf.float32, [None, *self.image_shape])
self.noise_placeholder = tf.placeholder(tf.float32, [None, *self.noise_shape])
self.labels_placeholder = tf.placeholder(tf.float32, [None, 1])
# Dataset APIで入力パイプラインを定義
dataset = tf.data.Dataset.from_tensor_slices(
(self.images_placeholder,
self.noise_placeholder,
self.labels_placeholder
))
dataset = dataset.repeat()
dataset = dataset.batch(self.batch_size, drop_remainder=True) # TPUではdrop_remainder=Trueが必須
# DatasetをTPU用のDatasetに変換
dist_dataset = self.strategy.experimental_distribute_dataset(dataset)
# iteratorを定義
input_iterator = dist_dataset.make_initializable_iterator()
self.iterator_init = input_iterator.initialize()
# 学習等のopsを定義
inputs = input_iterator.get_next() # ネットワークの入力
self.train_disc_ops = self.train_step_disc(inputs) # Discriminatorの学習
self.train_gen_ops = self.train_step_gen(inputs) # Generatorの学習
self.output_gen_ops = self.output_images_gen(inputs) # Generatorの出力
# TPU対応のおまじない2
tf.contrib.distribute.initialize_tpu_system(self.tpu_cluster_resolver)
config = tf.ConfigProto()
config.allow_soft_placement = True
cluster_spec = self.tpu_cluster_resolver.cluster_spec()
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
# Sessionの定義
self.sess = tf.Session(
target=self.tpu_cluster_resolver.master(),
config=config
)
# 変数の初期化
self.sess.run(tf.global_variables_initializer())
def load_dataset(self):
# mnistデータの読み込み
(X_train, _), (_, _) = mnist.load_data()
# 値を-1 to 1に規格化
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
return X_train
def build_discriminator(self):
# discriminatorモデル
# kerasのSequentialを使っているが、Functional APIでもtensorflowの低レベルAPIでもたぶん大丈夫
layers_disc = []
layers_disc.append(
Conv2D(16, (5, 5), strides=(2, 2), padding='same', input_shape=self.image_shape))
layers_disc.append(LeakyReLU(alpha=0.2))
layers_disc.append(
Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
layers_disc.append(LeakyReLU(alpha=0.2))
layers_disc.append(Flatten())
layers_disc.append(Dense(1))
discriminator = Sequential(layers_disc)
return discriminator
def build_generator(self):
# Generatorモデル
# kerasのSequentialを使っているが、Functional APIでもtensorflowの低レベルAPIでもたぶん大丈夫
layers_gen = []
layers_gen.append(Dense(7 * 7 * 256, use_bias=False, input_shape=self.noise_shape))
layers_gen.append(BatchNormalization(momentum=0.8))
layers_gen.append(LeakyReLU(alpha=0.2))
layers_gen.append(Reshape((7, 7, 256)))
layers_gen.append(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
layers_gen.append(BatchNormalization(momentum=0.8))
layers_gen.append(LeakyReLU(alpha=0.2))
layers_gen.append(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
layers_gen.append(BatchNormalization(momentum=0.8))
layers_gen.append(LeakyReLU(alpha=0.2))
layers_gen.append(
Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='tanh'))
generator = Sequential(layers_gen)
return generator
def train_step_disc(self, dist_inputs):
# Discriminatorに対して
# コストを計算して逆伝播法で重みを更新する
def step_fn(inputs):
features, _, labels = inputs # 入力データ
logits = self.discriminator(features) # Discriminatorの出力
# コスト関数と重み更新
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
loss = tf.reduce_sum(cross_entropy) / self.batch_size # reduce_meanは使わない方がいい
train_op_disc = self.optimizer_disc.minimize(loss, var_list=self.var_disc) # discriminatorの重みのみ更新する
# 精度
logits_bool = tf.cast(tf.greater_equal(logits, 0), tf.float32)
acc = tf.reduce_sum(1.0 - tf.abs(labels - logits_bool)) / self.batch_size
# 必ずtf.control_dependenciesを使うこと
with tf.control_dependencies([train_op_disc]):
return tf.identity(loss), tf.identity(acc)
# TPUコア毎にstep_fnを実行して結果を出力
per_replica_losses, per_replica_accs = self.strategy.experimental_run_v2(step_fn, args=(dist_inputs,))
# TPUコア毎のコストと精度をまとめる
# tf.distribute.ReduceOp.SUMはtf.reduce_sum
# tf.distribute.ReduceOp.MEANはtf.reduce_meanに対応
# MEANは正しい結果になっているかちょっと自信ないので、SUMにしている
losses = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
accs = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_accs, axis=None)
return losses, accs
def output_images_gen(self, dist_inputs):
# Generatorの出力画像を得る
def step_fn(inputs):
_, noises, _ = inputs # 入力データ
return self.generator(noises, training=False) # GeneratorにBatchNormalizationを入れている場合はtraining=Falseを指定
# TPUコア毎にstep_fnを実行して結果を出力
gen_output = self.strategy.experimental_run_v2(step_fn, args=(dist_inputs,))
# TPUコア毎の結果を連結
gen_output = tf.concat(gen_output.values, axis=0)
return gen_output
def train_step_gen(self, dist_inputs):
# Generatorに対して
# コストを計算して逆伝播法で重みを更新する
def step_fn(inputs):
_, noises, labels = inputs # 入力データ
features = self.generator(noises, training=True) # GeneratorにBatchNormalizationを入れている場合はtraining=Trueを指定
logits = self.discriminator(features)
# コスト関数と重み更新
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
loss = tf.reduce_sum(cross_entropy) / self.batch_size
train_op_gen = self.optimizer_gen.minimize(loss, var_list=self.var_gen) # Generatorの重みのみ更新
# 精度
logits_bool = tf.cast(tf.greater_equal(logits, 0), tf.float32)
acc = tf.reduce_sum(1.0 - tf.abs(labels - logits_bool)) / self.batch_size
# BatchNormalizationの平均と分散の更新
# GeneratorにBatchNormalizationを入れている場合は必須
update_ops = self.generator.get_updates_for(None) + self.generator.get_updates_for(noises)
# 必ずtf.control_dependenciesを使うこと
# BatchNormalizationを使っている場合はupdate_opsも一緒に入れる
with tf.control_dependencies([train_op_gen, *update_ops]):
return tf.identity(loss), tf.identity(acc)
# TPUコア毎にstep_fnを実行して結果を出力
per_replica_losses, per_replica_accs = self.strategy.experimental_run_v2(step_fn, args=(dist_inputs,))
# TPUコア毎のコストと精度をまとめる
# tf.distribute.ReduceOp.SUMはtf.reduce_sum
# tf.distribute.ReduceOp.MEANはtf.reduce_meanに対応
# MEANは正しい結果になっているかちょっと自信ないので、SUMにしている
losses = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
accs = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_accs, axis=None)
return losses, accs
def fit(self):
# TPU上でDiscriminatorとGeneratorを更新する
with self.strategy.scope():
start_fit = time.time()
noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32) # Generatorの入力
image_real = self.X_train[:self.batch_size] # Discriminatorの入力
label_real = np.ones((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル
# 入力パイプラインを初期化
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_real,
self.noise_placeholder: noise,
self.labels_placeholder: label_real
})
# 学習前のGeneratorの出力を確認
image_fake = self.sess.run(self.output_gen_ops)
self.show_images(image_fake, epoch=0)
# 学習開始
for epoch in range(self.epochs):
# 各エポックのコストと精度
d_loss_epoch = 0
d_acc_epoch = 0
g_loss_epoch = 0
g_acc_epoch = 0
start_epoch = time.time()
# 各エポックの学習前に学習データをシャッフル
np.random.shuffle(self.X_train)
# ミニバッチ学習
for iter in range(self.num_batches):
noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32) # Generatorの入力
image_real = self.X_train[iter * self.batch_size:(iter + 1) * self.batch_size] # Discriminatorの入力(本物)
label_real = np.ones((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル(本物)
label_fake = np.zeros((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル(偽物)
#---------------------
# Discriminatorの学習
#---------------------
# iteratorを初期化
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_real, # Discriminatorの入力(本物)
self.noise_placeholder: noise, # Genratorの入力
self.labels_placeholder: label_real # Discriminatorの出力ラベル(本物)
})
# 偽物画像を生成
image_fake = self.sess.run(self.output_gen_ops)
# 本物画像でDiscriminatorを学習
d_loss_real, d_acc_real = self.sess.run(self.train_disc_ops)
# Discriminatorに偽物画像を与えるため
# iteratorを初期化
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_fake, # Discriminatorの入力(偽物)
self.noise_placeholder: noise, # Genratorの入力(使わないのでなんでもいい)
self.labels_placeholder: label_fake # Discriminatorの出力ラベル(偽物)
})
# 偽物画像でDiscriminatorを学習
d_loss_fake, d_acc_fake = self.sess.run(self.train_disc_ops)
# 本物画像の結果と偽物画像の結果を平均
d_loss = 0.5 * (d_loss_real + d_loss_fake)
d_acc = 0.5 * (d_acc_real + d_acc_fake)
#---------------------
# Generatorの学習
#---------------------
# iteratorを初期化
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_real, # Discriminatorの入力(使わないのでなんでもいい)
self.noise_placeholder: noise, # Genratorの入力
self.labels_placeholder: label_real # Discriminatorの出力ラベル(本物)
})
# 本物ラベルでGeneratorを学習
g_loss, g_acc = self.sess.run(self.train_gen_ops)
# エポック毎の結果
d_loss_epoch += d_loss
d_acc_epoch += d_acc
g_loss_epoch += g_loss
g_acc_epoch += g_acc
# 進捗の表示
sys.stdout.write(
'\repoch:{:d} iter:{:d} [D loss: {:f}, acc: {:.2f}%] [G loss: {:f}, acc: {:.2f}%] '.format(
epoch + 1, iter + 1, d_loss, 100 * d_acc, g_loss, 100 * g_acc))
sys.stdout.flush()
# ミニバッチ毎の結果を平均
d_loss_epoch /= self.num_batches
d_acc_epoch /= self.num_batches
g_loss_epoch /= self.num_batches
g_acc_epoch /= self.num_batches
epoch_time = time.time() - start_epoch
# エポックの結果を表示
sys.stdout.write(
'\repoch:{:d} iter:{:d} [D loss: {:f}, acc: {:.2f}%] [G loss: {:f}, acc: {:.2f}%] time: {:f}\n'.format(
epoch + 1, iter + 1, d_loss_epoch, 100 * d_acc_epoch, g_loss_epoch, 100 * g_acc_epoch, epoch_time))
sys.stdout.flush()
# Generatorの出力を確認
if (epoch + 1) % 10 == 0:
noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32)
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_real, # Discriminatorの入力(使わないのでなんでもいい)
self.noise_placeholder: noise, # Genratorの入力
self.labels_placeholder: label_real # Discriminatorの出力ラベル(使わないのでなんでもいい)
})
image_fake = self.sess.run(self.output_gen_ops)
self.show_images(image_fake, epoch=epoch + 1)
def show_images(self, images, epoch):
# 出力画像を確認
fig = plt.figure(figsize=(4, 4))
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow(images[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
fig.suptitle('epoch: {:}'.format(epoch))
fig.savefig('mnist_epoch_{:}.png'.format(epoch))
plt.show()
if __name__ == '__main__':
G = GAN()
G.fit()
TPUStrategyを定義
モデル定義の前にTPUStrategyを定義していきましょう。これはGANに関わらずTPU対応のためには必須です。
# TPU対応のおまじない1
tf.keras.backend.clear_session()
tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
self.strategy = tf.contrib.distribute.TPUStrategy(self.tpu_cluster_resolver)
# ここからTPU対応のモデルやらを書いていく
with self.strategy.scope():
# モデル定義やらモデル実行やら
with strategy.scope():
の後にモデル定義やらモデル実行のコードを書いていきます。
モデル定義
with strategy.scope()
の後にモデル定義を書いていきましょう。
with self.strategy.scope():
# Discriminatorの定義
self.discriminator = self.build_discriminator()
self.optimizer_disc = tf.train.AdamOptimizer(2.0e-4, 0.5) # Discriminator用のOptimizer
self.var_disc = discriminator.trainable_variables # Discriminatorの重み
# Generatorの定義
self.generator = self.build_generator()
self.optimizer_gen = tf.train.AdamOptimizer(2.0e-4, 0.5) # Generator用のOptimizer
self.var_gen = generator.trainable_variables # Generatorの重み
あとでDiscriminatorとGeneratorを個別に更新するために、それぞれの重みを集めておきます。
この部分はTPU未対応でも変わらないと思います。
DiscriminatorとGeneratorの構造はTensorFlow公式のDCGANチュートリアルをほぼそのまま使っています。
KerasのSequentialで定義しておきますが、KerasのFunctional APIでもTensorFlowの低レベルAPIでも大丈夫なはずです。
モデルの構造はこちら
def build_discriminator(self):
# discriminatorモデル
# kerasのSequentialを使っているが、Functional APIでもtensorflowの低レベルAPIでもたぶん大丈夫
layers_disc = []
layers_disc.append(
Conv2D(16, (5, 5), strides=(2, 2), padding='same', input_shape=self.image_shape))
layers_disc.append(LeakyReLU(alpha=0.2))
layers_disc.append(
Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
layers_disc.append(LeakyReLU(alpha=0.2))
layers_disc.append(Flatten())
layers_disc.append(Dense(1))
discriminator = Sequential(layers_disc)
return discriminator
def build_generator(self):
# Generatorモデル
# kerasのSequentialを使っているが、Functional APIでもtensorflowの低レベルAPIでもたぶん大丈夫
layers_gen = []
layers_gen.append(Dense(7 * 7 * 256, use_bias=False, input_shape=self.noise_shape))
layers_gen.append(BatchNormalization(momentum=0.8))
layers_gen.append(LeakyReLU(alpha=0.2))
layers_gen.append(Reshape((7, 7, 256)))
layers_gen.append(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
layers_gen.append(BatchNormalization(momentum=0.8))
layers_gen.append(LeakyReLU(alpha=0.2))
layers_gen.append(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
layers_gen.append(BatchNormalization(momentum=0.8))
layers_gen.append(LeakyReLU(alpha=0.2))
layers_gen.append(
Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='tanh'))
generator = Sequential(layers_gen)
return generator
データセットの入力パイプラインを定義
今回はDataset APIを使ってデータを入力していきます。ここはモデル定義と前後しても大丈夫です。
with self.strategy.scope():
# データセットの入力用placeholder
self.images_placeholder = tf.placeholder(tf.float32, [None, *self.image_shape])
self.noise_placeholder = tf.placeholder(tf.float32, [None, *self.noise_shape])
self.labels_placeholder = tf.placeholder(tf.float32, [None, 1])
# Dataset APIで入力パイプラインを定義
dataset = tf.data.Dataset.from_tensor_slices(
(self.images_placeholder,
self.noise_placeholder,
self.labels_placeholder
))
dataset = dataset.repeat()
dataset = dataset.batch(self.batch_size, drop_remainder=True) # TPUではdrop_remainder=Trueが必須
# DatasetをTPU用のDatasetに変換
dist_dataset = self.strategy.experimental_distribute_dataset(dataset)
# iteratorを定義
input_iterator = dist_dataset.make_initializable_iterator()
self.iterator_init = input_iterator.initialize()
Discriminatorの入力を本物画像と偽物画像で入れ替えるために、placeholderからDatasetを作って後から入力を変更できるようにしておきます。
ここで重要なのは通常のDatasetを定義した後に、strategy.experimental_distribute_dataset()
を使ってTPU用のDatasetに変換するところです。Iterator
はこのTPU対応のDatasetから作っていきます。
コスト関数と重み更新を定義
コスト関数と重み更新を定義していきましょう。
with self.strategy.scope():
# 学習等のopsを定義
inputs = input_iterator.get_next() # ネットワークの入力
self.train_disc_ops = self.train_step_disc(inputs) # Discriminatorの学習
self.train_gen_ops = self.train_step_gen(inputs) # Generatorの学習
self.output_gen_ops = self.output_images_gen(inputs) # Generatorの出力
それぞれ関数化していますが、中身はだいたい同じなのでtrain_step_gan()
を例に見ていきます。
def train_step_gen(self, dist_inputs):
# Generatorに対して
# コストを計算して逆伝播法で重みを更新する
def step_fn(inputs):
_, noises, labels = inputs # 入力データ
features = self.generator(noises, training=True) # GeneratorにBatchNormalizationを入れている場合はtraining=Trueを指定
logits = self.discriminator(features)
# コスト関数と重み更新
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
loss = tf.reduce_sum(cross_entropy) / self.batch_size
train_op_gen = self.optimizer_gen.minimize(loss, var_list=self.var_gen) # Generatorの重みのみ更新
# 精度
logits_bool = tf.cast(tf.greater_equal(logits, 0), tf.float32)
acc = tf.reduce_sum(1.0 - tf.abs(labels - logits_bool)) / self.batch_size
# BatchNormalizationの平均と分散の更新
# GeneratorにBatchNormalizationを入れている場合は必須
update_ops = self.generator.get_updates_for(None) + self.generator.get_updates_for(noises)
# 必ずtf.control_dependenciesを使うこと
# BatchNormalizationを使っている場合はupdate_opsも一緒に入れる
with tf.control_dependencies([train_op_gen, *update_ops]):
return tf.identity(loss), tf.identity(acc)
# TPUコア毎にstep_fnを実行して結果を出力
per_replica_losses, per_replica_accs = self.strategy.experimental_run_v2(step_fn, args=(dist_inputs,))
# TPUコア毎のコストと精度をまとめる
# tf.distribute.ReduceOp.SUMはtf.reduce_sum
# tf.distribute.ReduceOp.MEANはtf.reduce_meanに対応
# MEANは正しい結果になっているかちょっと自信ないので、SUMにしている
losses = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
accs = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_accs, axis=None)
return losses, accs
重要なのは、コスト関数と重み更新はstep_fn()
と関数内で定義して、それをstrategy.experimental_run_V2()
に渡して実行することです。こうすることでTPUコア毎の結果を集めて、まとめて重み更新できるようです。
返り値はPerReplica
オブジェクトなので、PerReplica.values
で値を取り出すか、strategy.reduce()
で合計か平均を算出します。
もうひとつ重要なことは、keras.layers.BatchNormalization
を使っている場合は、学習時に平均と分散の更新を行うために、optimizer.minimize()
とは別にupdate_ops
を定義して、同時に実行する必要があることです。
また、BatchNormalization
が入っているModel
をcall
するときにtraining=True
を指定します。
これらを除いてしまうとBatchNormalization
がうまく学習できません。一方で、推論時にはtraining=False
とするのみでupdate_ops
は必要ありません。
tf.Session
を定義して初期化
モデル定義は終わりました。後はtf.Session
を定義して、重みを初期化しましょう。
with self.strategy.scope():
# TPU対応のおまじない2
tf.contrib.distribute.initialize_tpu_system(self.tpu_cluster_resolver)
config = tf.ConfigProto()
config.allow_soft_placement = True
cluster_spec = self.tpu_cluster_resolver.cluster_spec()
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
# Sessionの定義
self.sess = tf.Session(
target=self.tpu_cluster_resolver.master(),
config=config
)
# 変数の初期化
self.sess.run(tf.global_variables_initializer())
ここで重要なことは、cluster_spec
の部分とtf.Session()
の中でtarget=tpu_cluster_resolver.master()
とすることです。
これがないとエラー吐いて実行できません。
学習を実行
準備は整いました。学習を行うコードを書いていきましょう。長いのと、大したポイントはないので折り畳んでいます。
学習用コードはこちら
def fit(self):
# TPU上でDiscriminatorとGeneratorを更新する
with self.strategy.scope():
start_fit = time.time()
noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32) # Generatorの入力
image_real = self.X_train[:self.batch_size] # Discriminatorの入力
label_real = np.ones((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル
# 入力パイプラインを初期化
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_real,
self.noise_placeholder: noise,
self.labels_placeholder: label_real
})
# 学習前のGeneratorの出力を確認
image_fake = self.sess.run(self.output_gen_ops)
self.show_images(image_fake)
# 学習開始
for epoch in range(self.epochs):
# 各エポックのコストと精度
d_loss_epoch = 0
d_acc_epoch = 0
g_loss_epoch = 0
g_acc_epoch = 0
start_epoch = time.time()
# 各エポックの学習前に学習データをシャッフル
np.random.shuffle(self.X_train)
# ミニバッチ学習
for iter in range(self.num_batches):
noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32) # Generatorの入力
image_real = self.X_train[iter * self.batch_size:(iter + 1) * self.batch_size] # Discriminatorの入力(本物)
label_real = np.ones((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル(本物)
label_fake = np.zeros((self.batch_size, 1), np.float32) # Discriminatorの出力ラベル(偽物)
#---------------------
# Discriminatorの学習
#---------------------
# iteratorを初期化
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_real, # Discriminatorの入力(本物)
self.noise_placeholder: noise, # Genratorの入力
self.labels_placeholder: label_real # Discriminatorの出力ラベル(本物)
})
# 偽物画像を生成
image_fake = self.sess.run(self.output_gen_ops)
# 本物画像でDiscriminatorを学習
d_loss_real, d_acc_real = self.sess.run(self.train_disc_ops)
# Discriminatorに偽物画像を与えるため
# iteratorを初期化
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_fake, # Discriminatorの入力(偽物)
self.noise_placeholder: noise, # Genratorの入力(使わないのでなんでもいい)
self.labels_placeholder: label_fake # Discriminatorの出力ラベル(偽物)
})
# 偽物画像でDiscriminatorを学習
d_loss_fake, d_acc_fake = self.sess.run(self.train_disc_ops)
# 本物画像の結果と偽物画像の結果を平均
d_loss = 0.5 * (d_loss_real + d_loss_fake)
d_acc = 0.5 * (d_acc_real + d_acc_fake)
#---------------------
# Generatorの学習
#---------------------
# iteratorを初期化
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_real, # Discriminatorの入力(使わないのでなんでもいい)
self.noise_placeholder: noise, # Genratorの入力
self.labels_placeholder: label_real # Discriminatorの出力ラベル(本物)
})
# 本物ラベルでGeneratorを学習
g_loss, g_acc = self.sess.run(self.train_gen_ops)
# エポック毎の結果
d_loss_epoch += d_loss
d_acc_epoch += d_acc
g_loss_epoch += g_loss
g_acc_epoch += g_acc
# 進捗の表示
sys.stdout.write(
'\repoch:{:d} iter:{:d} [D loss: {:f}, acc: {:.2f}%] [G loss: {:f}, acc: {:.2f}%] '.format(
epoch + 1, iter + 1, d_loss, 100 * d_acc, g_loss, 100 * g_acc))
sys.stdout.flush()
# ミニバッチ毎の結果を平均
d_loss_epoch /= self.num_batches
d_acc_epoch /= self.num_batches
g_loss_epoch /= self.num_batches
g_acc_epoch /= self.num_batches
epoch_time = time.time() - start_epoch
# エポックの結果を表示
sys.stdout.write(
'\repoch:{:d} iter:{:d} [D loss: {:f}, acc: {:.2f}%] [G loss: {:f}, acc: {:.2f}%] time: {:f}\n'.format(
epoch + 1, iter + 1, d_loss_epoch, 100 * d_acc_epoch, g_loss_epoch, 100 * g_acc_epoch, epoch_time))
sys.stdout.flush()
# Generatorの出力を確認
if (epoch + 1) % 10 == 0:
noise = np.random.normal(0, 1, (self.batch_size, self.z_dim)).astype(np.float32)
self.sess.run(
self.iterator_init,
feed_dict={
self.images_placeholder: image_real, # Discriminatorの入力(使わないのでなんでもいい)
self.noise_placeholder: noise, # Genratorの入力
self.labels_placeholder: label_real # Discriminatorの出力ラベル(使わないのでなんでもいい)
})
image_fake = self.sess.run(self.output_gen_ops)
self.show_images(image_fake)
with strategy.scope()
の後に実行していく以外は通常の低レベルAPIの書き方と変わりません。
TensorFlowの低レベルAPIに慣れてない人にとっては長ったらしく感じますが、sess.run(iterator_init)
とsess.run(train_ops)
の2つを合わせたものがModel.train_on_batch()
に相当します。
また、GeneratorとDiscriminatorの学習を切り替える際にIterator
を初期化し、入力データと正解ラベルを指定します。これは入力データと正解ラベルをGeneratorの学習(本物)、Generatorの学習(偽物)、Discriminatorの学習で変更する必要があるためです。
GANのように入力データと正解ラベルを切り替える必要がない場合は、各エポックの学習前にIterator
を初期化しておくだけでOKです。Dataset APIを使った入力は色々なやり方があるので、調べてみるといいかもしれません。
実際に学習して結果を確認
Google Colabで実行して結果を確認してみます。
epoch:1 iter:117 [D loss: 0.699009, acc: 49.88%] [G loss: 0.513983, acc: 99.87%] time: 15.332689
epoch:2 iter:117 [D loss: 0.719187, acc: 45.43%] [G loss: 0.630630, acc: 99.88%] time: 13.916150
epoch:3 iter:117 [D loss: 0.711329, acc: 43.12%] [G loss: 0.658984, acc: 99.49%] time: 14.093864
epoch:4 iter:117 [D loss: 0.707588, acc: 38.28%] [G loss: 0.673343, acc: 95.34%] time: 13.517264
epoch:5 iter:117 [D loss: 0.705363, acc: 35.09%] [G loss: 0.679608, acc: 89.22%] time: 13.370727
epoch:6 iter:117 [D loss: 0.704420, acc: 31.46%] [G loss: 0.683517, acc: 82.51%] time: 12.869707
epoch:7 iter:117 [D loss: 0.703543, acc: 28.16%] [G loss: 0.686726, acc: 75.82%] time: 12.502573
epoch:8 iter:117 [D loss: 0.702794, acc: 28.11%] [G loss: 0.686894, acc: 78.77%] time: 12.119865
epoch:9 iter:117 [D loss: 0.702125, acc: 26.32%] [G loss: 0.688055, acc: 76.48%] time: 12.938859
epoch:10 iter:117 [D loss: 0.701599, acc: 24.18%] [G loss: 0.689229, acc: 73.12%] time: 13.130902
epoch:11 iter:117 [D loss: 0.701109, acc: 21.90%] [G loss: 0.690466, acc: 67.27%] time: 12.030408
epoch:12 iter:117 [D loss: 0.700516, acc: 21.28%] [G loss: 0.690808, acc: 65.58%] time: 12.857735
epoch:13 iter:117 [D loss: 0.700026, acc: 20.26%] [G loss: 0.691294, acc: 63.43%] time: 12.776606
epoch:14 iter:117 [D loss: 0.699605, acc: 19.71%] [G loss: 0.691599, acc: 62.43%] time: 12.512784
epoch:15 iter:117 [D loss: 0.699337, acc: 18.98%] [G loss: 0.692067, acc: 59.06%] time: 12.459395
epoch:16 iter:117 [D loss: 0.699040, acc: 18.98%] [G loss: 0.691960, acc: 60.51%] time: 12.932993
epoch:17 iter:117 [D loss: 0.698658, acc: 19.96%] [G loss: 0.691851, acc: 61.94%] time: 12.733138
epoch:18 iter:117 [D loss: 0.698368, acc: 19.76%] [G loss: 0.691887, acc: 62.32%] time: 12.300640
epoch:19 iter:117 [D loss: 0.698112, acc: 19.10%] [G loss: 0.692205, acc: 59.66%] time: 13.295983
epoch:20 iter:117 [D loss: 0.697834, acc: 19.20%] [G loss: 0.692334, acc: 58.76%] time: 12.488116
うまく実行できていそうです。
100エポック学習後のGeneratorの出力を見てみましょう。
学習エポックが少ないので十分学習できているとは言い難いですが、とりあえずうまく学習が進んでいそうです。
GPUと比較
GPUで実行した場合の結果も見ておきましょう。
epoch:1 iter:117 [D loss: 0.697059, acc: 49.96%] [G loss: 0.486593, acc: 99.92%] time: 10.189113
epoch:2 iter:117 [D loss: 0.721437, acc: 49.37%] [G loss: 0.616008, acc: 100.00%] time: 9.618804
epoch:3 iter:117 [D loss: 0.712062, acc: 49.01%] [G loss: 0.643090, acc: 100.00%] time: 9.653123
epoch:4 iter:117 [D loss: 0.708121, acc: 47.38%] [G loss: 0.662188, acc: 99.95%] time: 9.692954
epoch:5 iter:117 [D loss: 0.705938, acc: 46.76%] [G loss: 0.668273, acc: 99.72%] time: 9.735072
epoch:6 iter:117 [D loss: 0.704378, acc: 43.46%] [G loss: 0.677067, acc: 98.02%] time: 9.694915
epoch:7 iter:117 [D loss: 0.703626, acc: 41.36%] [G loss: 0.679792, acc: 96.66%] time: 9.680238
epoch:8 iter:117 [D loss: 0.702828, acc: 37.18%] [G loss: 0.683268, acc: 92.13%] time: 9.715913
epoch:9 iter:117 [D loss: 0.702190, acc: 33.08%] [G loss: 0.685683, acc: 87.84%] time: 9.769355
epoch:10 iter:117 [D loss: 0.701563, acc: 31.81%] [G loss: 0.686774, acc: 87.01%] time: 9.807986
epoch:11 iter:117 [D loss: 0.701012, acc: 30.47%] [G loss: 0.687479, acc: 86.72%] time: 9.739087
epoch:12 iter:117 [D loss: 0.700570, acc: 29.01%] [G loss: 0.688042, acc: 85.67%] time: 9.820838
epoch:13 iter:117 [D loss: 0.700140, acc: 28.27%] [G loss: 0.688463, acc: 85.74%] time: 9.707165
epoch:14 iter:117 [D loss: 0.699699, acc: 27.34%] [G loss: 0.688918, acc: 83.37%] time: 9.731362
epoch:15 iter:117 [D loss: 0.699300, acc: 26.80%] [G loss: 0.689223, acc: 84.06%] time: 9.768835
epoch:16 iter:117 [D loss: 0.698917, acc: 26.54%] [G loss: 0.689530, acc: 83.29%] time: 9.927150
epoch:17 iter:117 [D loss: 0.698575, acc: 25.80%] [G loss: 0.689942, acc: 81.63%] time: 9.874165
epoch:18 iter:117 [D loss: 0.698268, acc: 25.14%] [G loss: 0.690354, acc: 78.29%] time: 9.926621
epoch:19 iter:117 [D loss: 0.698021, acc: 24.95%] [G loss: 0.690568, acc: 78.01%] time: 9.896298
epoch:20 iter:117 [D loss: 0.697783, acc: 24.95%] [G loss: 0.690692, acc: 78.20%] time: 9.907461
あれ…? GPUの方が若干速くない…?
今回は画像サイズが28x28x1と小さく、モデルサイズもGeneratorが4層、Discriminatorが3層と比較的小さかったため、TPUの恩恵をあまり得られなかったのかもしれません。
もう少し本格的なモデルで検証してみる必要がありそうです。
こちらも100エポック学習後のGeneratorの出力を見てみます。
大丈夫そうですね。TPUとの差もなさそうです。
まとめ
今回はTPUでGANを学習させる時の実装例を見ていきました。この書き方を応用すれば、GANに限らず任意のネットワークをTPUで学習させることができそうです。
学習速度については、モデルサイズが小さかったこともあり、改めて検証してみる必要があります。
Google Colabを使えば、いくつかの制限はあるもののTPUを無料で手軽に試せます。ぜひこの機会に既存のモデルやらをTPU対応して試してみてください。
今回のコードや結果はgithubに載せているので、参考にしたり、コピペで試したり、ご自由にどうぞ。
参考
TPU関連
Custom training with TPUs | TensorFlow Core | TensorFlow
tf.distribute.Strategy with Training Loops | TensorFlow Core | TensorFlow
Distributed Training in TensorFlow | TensorFlow Core | TensorFlow
TensorFlow1.14以降のTPUの取り扱い方について
GAN関連
Deep Convolutional Generative Adversarial Network | TensorFlow Core | TensorFlow
今さら聞けないGAN(1) 基本構造の理解