4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【最新版】2021/02/17時点でGoogle colaboratoryのTPU分散学習でGANの訓練を成功させる方法

Last updated at Posted at 2020-07-31

はじめに

Google colaboratoryでTPUを使えば計算時間が短縮できるので、GPUを自前で用意できない人には強力なツールになります。しかし、TPUを扱っている記事は少なく、フレームワーク自体の変化が激しいので、調べた記事の内容に沿ってコーディングしても動かせなかったりします。一々Tensorflowのリリースノートを見てコーディングをし直すことは面倒な上、TPUの使用は控えめに言って意味不明なので、今回はとりあえず2020/09/10時点でTPUの訓練を行える方法を記します。

前準備

ColabのTPUへ接続する

import tensorflow as tf
import os
print(tf.__version__)

tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)

TPUを使う際のおまじないのようなコードです。ここでエラーが出たら、タイプミスか、自力ではどうしようもない何かが原因となっている可能性が高いです。
現在(Tensorflow2.4.1)では、TPUstrategyexperimentalが取れています。

分散学習に便利なデコレータ

以下に示すコードは、「TensorFlow2.0でDistributed Trainingをいい感じにやるためのデコレーターを作った」から引用させて頂いたコードをtensorflow 2.4.1に対応させたものです。詳細はリンク先のサイトを見てください。このコードで改変した点はstrategy.experimental_run_v2strategy.runに変更したというだけです。

from enum import Enum

class Reduction(Enum):
    NONE = 0
    SUM = 1
    MEAN = 2
    CONCAT = 3

def distributed(*reduction_flags):
    def _decorator(fun):
        def per_replica_reduction(z, flag):
            if flag == Reduction.NONE:
                return z
            elif flag == Reduction.SUM:
                return strategy.reduce(tf.distribute.ReduceOp.SUM, z, axis=None)
            elif flag == Reduction.MEAN:
                return strategy.reduce(tf.distribute.ReduceOp.MEAN, z, axis=None)
            elif flag == Reduction.CONCAT:
                z_list = strategy.experimental_local_results(z)
                return tf.concat(z_list, axis=0)
            else:
                raise NotImplementedError()
        
        @tf.function
        def _decorated_fun(*args, **kwargs):
            fun_result = strategy.run(fun, args=args, kwargs=kwargs)
            if len(reduction_flags) == 0:
                assert fun_result is None
                return
            elif len(reduction_flags) == 1:
                assert type(fun_result) is not tuple and fun_result is not None
                return per_replica_reduction(fun_result, *reduction_flags)
            else:
                assert type(fun_result) is tuple
                return tuple((per_replica_reduction(fr, rf) for fr, rf in zip(fun_result, reduction_flags)))
        return _decorated_fun
    return _decorator

いろいろなモジュールのインポート

from PIL import Image
import glob
import pickle
import matplotlib.pyplot as plt
import time

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

データローダの設定

今回はcifar10を用いるので、cifar10のdataloderを作ります。celebAをやってみたい人は、「TensorFlow2.0 + 無料のColab TPUでDCGANを実装した」を参考にしてください。
これは余談ですが、cifar10はデータ数が少なく、機械の画像や動物の画像があったりと画像ごとの違いが大きいので、頑張らないとまともな画像が生成されません。**この記事ではGoogle colaboratoryのTPU分散学習をすることに重点を置いているので、生成画像はお粗末なものになっています。**そこらへんはご容赦ください。
コードは以下のようになります。BATCH_SIZE=1024と非常に大きい値となっていますが、バッチサイズを大きくし、生成画像の品質が上げられるのもTPUの強みです。

def load_cifar10(batch_size):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    images = tf.concat([x_train, x_test], axis=0)
    labels = tf.concat([y_train, y_test], axis=0)
    labels = tf.keras.utils.to_categorical(labels)
    def preprocess(img):
        x = tf.cast(img, tf.float32) / 127.5 - 1.0
        return x
    
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.map(
        lambda img, label: (preprocess(img), tf.cast(label, tf.float32))
    ).shuffle(4096).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

BATCH_SIZE = 1024
dataset = load_cifar10(BATCH_SIZE)

画像を表示する関数の作成

生成画像を確認するための画像を表示する関数を以下に示します。関数make_gridは「TensorFlow2.0 + 無料のColab TPUでDCGANを実装した」を引用しました。
plot_imagesに各引数(imagesはtensorflowのEagerTensorを入れる)を代入すれば使えます。

def make_grid(imgs, nrow, padding=0):
    assert imgs.ndim == 4 and nrow > 0
    batch, height, width, ch = imgs.shape
    n = nrow * (batch // nrow + np.sign(batch % nrow))
    ncol = n // nrow
    pad = np.zeros((n-batch, height, width, ch), imgs.dtype)
    x = np.concatenate([imgs, pad], axis=0)
    if padding > 0:
        x = np.pad(x, ((0, 0), (0, padding), (0, padding), (0, 0)), 
                   'constant', constant_values=(0, 0))
        height += padding
        width += padding
    x = x.reshape(ncol, nrow, height, width, ch)
    x = x.transpose([0, 2, 1, 3, 4])
    x = x.reshape(height*ncol, width*nrow, ch)
    if padding > 0:
        x = x[:(height*ncol - padding), :(width*nrow - padding), :]
    return x

def plot_images(images, nrow=10, padding=0, img_name='sample_img.png', plotting=True):
    imgs = images.numpy()
    grid = make_grid(imgs, nrow, padding)
    grid = ((grid+1)*127.5).astype(np.uint8)
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(grid)
    plt.savefig(img_name, bbox_inches='tight', pad_inches=0.0)
    if plotting: plt.show()
    plt.close('all')

モデルの構築

今回はGoogle colaboratoryのTPU分散学習をすることに重点を置いているので、**モデルは適当に組みました。**せっかくなので拙作「Colab TPUでもtensorflow.kerasでBilinear法のアップサンプリングを行う方法」のコードを用いてアップサンプリングを行いました。どれくらいの効果が合ったのかは知りません。

各関数の定義

BATCH_SIZE = 1024
Z_DIM = 512 # 潜在変数の次元

def upsampling2d_bilinear(inputs, scale=2):
    w, h = inputs.shape[1], inputs.shape[2]
    w *= scale; h *= scale
    return tf.compat.v1.image.resize_bilinear(inputs, (w, h), align_corners=True)

def ch(res):
    c = 1024*4 // res
    if c > 1024: c = 1024
    if c < 128: c = 128
    return c

Discriminatorの定義

def d_block(x, res):
    a = Conv2D(ch(res//2), kernel_size=1, use_bias=False)(x)
    a = AveragePooling2D()(a)

    x = ReLU()(x)
    x = Conv2D(ch(res), kernel_size=4, strides=2, padding='same', use_bias=False)(x)
    x = ReLU()(x)
    x = Conv2D(ch(res//2), kernel_size=4, strides=1, padding='same', use_bias=False)(x)

    x = Add()([x, a])
    return x

def create_D():
    inputs = Input((32,32,3))
    x = inputs
    for i, res in enumerate([32, 16, 8]):
        x = d_block(x, res)
    x = ReLU()(x)
    out = Conv2D(1, kernel_size=1, use_bias=False)(x)
    return Model(inputs, out)

なんとなくResidual blockを使っています。outputのサイズは(batch_size, 4, 4, 1)です。これは、Patch GANのようなものです。このあと定義するLossもそれに対応したものにしています。

Generatorの定義

def g_block(x, res):
    a = Lambda(upsampling2d_bilinear, arguments={'scale': 2})(x)
    a = Conv2D(ch(res), kernel_size=1, use_bias=False)(a)

    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2DTranspose(ch(res//2), kernel_size=4, strides=2, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(ch(res), kernel_size=4, strides=1, padding='same', use_bias=False)(x)

    x = Add()([x, a])
    return x

def create_G():
    inputs = Input((Z_DIM,))
    x = Reshape((1, 1, Z_DIM))(inputs)
    for i, res in enumerate([8, 16, 32]):
        if i==0:
            x = Conv2DTranspose(ch(res//2), kernel_size=4, strides=1, padding='valid', use_bias=False)(x)
        else:
            x = g_block(x, res)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2DTranspose(ch(32), kernel_size=4, strides=2, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(3, kernel_size=4, strides=1, padding='same', use_bias=False)(x)
    out = Activation('tanh')(x)
    return Model(inputs, out)

学習

各定数の設定

K.clear_session()

BATCH_SIZE = 1024
Z_DIM = 512
STEPS_PER_EPOCH = 60000 // BATCH_SIZE + 1

out_dir = 'out'
os.makedirs(out_dir, exist_ok=True)

ネットワーク等の定義

with strategy.scope():
    netD = create_D()
    netG = create_G()
    optD = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
    optG = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
    dataset = load_cifar10(BATCH_SIZE)
    dataset = strategy.experimental_distribute_dataset(dataset)

strategy.experimental_distribute_datasetによってdatasetEagerTensorを吐き出すものからPerReplicaを吐き出すものに変わります。PerReplicaEagerTensorのように処理できなるので注意です。

Loss関数の定義

with strategy.scope():
    class Losses:
        @staticmethod
        def hinge_loss(logits, loss_type):
            assert loss_type in ['gen', 'dis_real', 'dis_fake']
            if loss_type == 'gen':
                return -tf.reduce_mean(logits, axis=(1,2,3))
            elif loss_type == 'dis_real':
                minval = tf.minimum(logits-1, tf.zeros(logits.shape, dtype=logits.dtype))
                return -tf.reduce_mean(minval, axis=(1,2,3))
            else:
                minval = tf.minimum(-logits-1, tf.zeros(logits.shape, dtype=logits.dtype))
                return -tf.reduce_mean(minval, axis=(1,2,3))
        @staticmethod
        def cross_entropy_loss(logits, loss_type):
            p = tf.math.sigmoid(logits)
            assert loss_type in ['gen', 'dis_real', 'dis_fake']
            if loss_type == 'gen':
                loss = -tf.math.log(p)
                return tf.reduce_mean(loss, axis=(1,2,3))
            elif loss_type == 'dis_real':
                loss = -tf.math.log(p)
                return tf.reduce_mean(loss, axis=(1,2,3))
            else:
                loss = -tf.math.log(1.0-p)
                return tf.reduce_mean(loss, axis=(1,2,3))
        @staticmethod
        def generator_loss(logits_fake):
            return Losses.hinge_loss(logits_fake, 'gen')
        @staticmethod
        def discriminator_loss(logits_real, logits_fake):
            return Losses.hinge_loss(logits_real, 'dis_real') + Losses.hinge_loss(logits_fake, 'dis_fake')

このLossはPatch GANに対応しているので、logitsは4次元テンソルである必要があります。
また、このLoss関数はhinge_losscross_entropy_lossに対応しています。変えたいときはgenerator_lossdiscriminator_lossを変更してください。

学習する関数の定義

with strategy.scope():
    @distributed(Reduction.SUM, Reduction.SUM, Reduction.CONCAT)
    def train_on_batch(real):
        b_size = real.shape[0]
        z = tf.random.normal((b_size, Z_DIM))
        with tf.GradientTape() as tape_D, tf.GradientTape() as tape_G:
            fake = netG(z, training=True)

        d_loss = 0.0
        for c in range(1):
            with tape_D:
                real_out = netD(real, training=True)
                fake_out = netD(fake, training=True)
                d_loss = tf.reduce_sum(Losses.discriminator_loss(real_out, fake_out)) / BATCH_SIZE
            grad_D = tape_D.gradient(d_loss, netD.trainable_weights)
            optD.apply_gradients(zip(grad_D, netD.trainable_weights))
        
        with tape_G:
            fake_out = netD(fake, training=False)
            g_loss = tf.reduce_sum(Losses.generator_loss(fake_out)) / BATCH_SIZE
        grad_G = tape_G.gradient(g_loss, netG.trainable_weights)
        optG.apply_gradients(zip(grad_G, netG.trainable_weights))

        return d_loss, g_loss, fake

Discriminatorの訓練の部分をfor文で囲んでいるのはSpectral NormalizationやGradient Penaltyを適用する際にDiscriminatorの訓練回数の比率を増やす必要があるから、それに適応しやすくするためです。ここで、訓練ループの前にd_loss = 0.0を入れておかないとエラーが出ます。
次に、@distributedについて**Reductionの種類は、分散学習で各TPUから値が戻ってくるときに、返り値にどうして欲しいかによって決まります。**例えば、d_lossg_lossの場合、既に関数の中でBATCH_SIZEで割っているので、それらを足して欲しくなります。fakeの場合、TPUから帰ってくるのは(batch_size/TPUの数, 32, 32, 3)の画像データなので、ときはbatch_sizeのところにconcatenateして欲しいです。また、何らかの評価指標の場合、TPUからの返り値を平均したいのでReduction.MEANを使います。

訓練ループ

with strategy.scope():
    losses = []
    total_step = 0
    t0 = time.time()
    EPOCH = 50
    for epoch in range(EPOCH):
        t1 = time.time()
        print(f'Epoch: {str(epoch).zfill(3)}')
        for step, (real, label) in enumerate(dataset):
            # 最後のstepでNonetypeをTensorに変換しようとしやがるのでcontinueさせる。
            if step==(STEPS_PER_EPOCH-1): continue
            d_loss, g_loss, fake = train_on_batch(real)
            if step%(STEPS_PER_EPOCH//2)==0:
                print(f'\tStep {str(total_step).zfill(7)}, d_loss: {d_loss.numpy():f}, g_loss: {g_loss.numpy():f}')
            losses.append([d_loss.numpy(), g_loss.numpy()])
            total_step += 1
        plot_images(fake[:100], img_name=f'./{out_dir}/epoch_{str(epoch).zfill(3)}.png', plotting=(epoch%2==0))
        print(f'{time.time()-t1:f}s is elapsed for epoch {str(epoch).zfill(3)}.')
    with open(f'losses.pkl', 'wb') as fp: pickle.dump(losses, fp)
    print(f'Train time: {(time.time()-t0)/60:f}min.')

if step==(STEPS_PER_EPOCH-1): continueについて、**これがないとtensorflowおなじみのNoneTypeをTensorに変換しようとしたことにエラーが出ます。**おそらくEpochの最後でバッチサイズに足りない分をNonetypeで送っているからだと思われます。ガバガバですね。

結果

所要時間は9.427246minでした。早いですね。

1Epoch

epoch_000.png
1Epoch目で真っ黒とか真緑とかそういう画像が生成されているときは、ほぼ訓練が失敗したと考えて良いでしょう。

10Epoch

epoch_009.png

35Epoch

epoch_035.png

50Epoch

epoch_049.png

注意点

結果を見ると35Epochのほうが50Epochより画像が綺麗だと思うかも知れません。学習を進めると更に生成画像が変な画像になってきます。このことについて、以前に勾配のノルムを平均して可視化したところ、学習を進めると勾配が爆発的に大きくなるようなピークが発生しており、それによって画像の品質が劣化していると考えられます。この問題を解決するには勾配の大きさを制限するGradient Penaltyとかが非常に有効だと思われます。

追記

2020/07/31までの情報

ColabのTPUへ接続する

import tensorflow as tf
import os
print(tf.__version__)

tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
strategy = tf.distribute.experimental.TPUStrategy(tpu_cluster_resolver)

これは、他のサイトでもよく見られるColabのTPUに接続するためのコードですが、注意点があります。まず、tensorflowのバージョンは2.2.0を用いる必要があります。2020/07/31時点で、tensorflowの最新バージョンは2.3.0ですが、以下のようなエラーが出てTPUに接続できません。これは、tensorflow 2.3.0ではTPUStrategyの前のexperimentalが外れることも考慮した上での結果です。

長いので折りたたんでいます。
INFO:tensorflow:Initializing the TPU system: grpc://10.112.235.34:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.112.235.34:8470
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-3-95d3adc1bbd6> in <module>()
      5 tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
      6 tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
----> 7 tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
      8 strategy = tf.distribute.experimental.TPUStrategy(tpu_cluster_resolver)

3 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu_strategy_util.py in initialize_tpu_system(cluster_resolver)
    109     context.context()._clear_caches()  # pylint: disable=protected-access
    110 
--> 111     serialized_topology = output.numpy()
    112 
    113     # TODO(b/134094971): Remove this when lazy tensor copy in multi-device

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in numpy(self)
   1061     """
   1062     # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
-> 1063     maybe_arr = self._numpy()  # pylint: disable=protected-access
   1064     return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
   1065 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _numpy(self)
   1029       return self._numpy_internal()
   1030     except core._NotOkStatusException as e:  # pylint: disable=protected-access
-> 1031       six.raise_from(core._status_to_exception(e.code, e.message), None)  # pylint: disable=protected-access
   1032 
   1033   @property

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: NodeDef expected inputs 'string' do not match 0 inputs specified; Op<name=_Send; signature=tensor:T -> ; attr=T:type; attr=tensor_name:string; attr=send_device:string; attr=send_device_incarnation:int; attr=recv_device:string; attr=client_terminated:bool,default=false; is_stateful=true>; NodeDef: {{node _Send}}
このエラーからもTPUの意味分からなさが分かりますね。
4
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?