2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Saliency MapをGANで生成するSalGANを実装してみた

Last updated at Posted at 2020-06-02

#はじめに
画像内の目立つ部分のことを"Salient"な部分と言う.
特に画像が与えられた時に目立つ部分の分布を表したmapのことをSaliency Mapと言う.

下の例は左の画像とそれに対応するSaliency Mapを並べて表したもので、人の部分が目立っていることが読み取れる.
sample.png
深層学習を用いてこのようなSaliency Mapを生成する手法がいくつか提案されている.
今回はその中でもGANを用いて生成を行うSalGANというモデルをKerasを使って実装し性能を評価してみることにした.

こちらは元論文と今回の実装のURL.
元論文 - SalGAN: Visual Saliency Prediction with Generative Adversarial Networks
GitHub - KerasでSalGANの実装

#データについて
今回は元論文に合わせてSALICON datasetを用いた.

  • train_data 10,000枚
  • validation_data 5,000枚
  • test_data 5,000枚
    からなるデータセット.

人間に画像を見せm見ている箇所をクリックしてもらうというタスクをクラウドソーシングやってもらって集めたらしい.
他にもいくつかデータセットがあるのですが、これが一番大きいため採用したと筆者は書いている.

#SalGANについて
SalGANは2017年に提案されたGANのアーキテクチャでSaliency Mapを生成するモデル.
スクリーンショット 2020-06-01 20.26.26.png
図. SalGANの構造(元論文より引用)

全体としてGeneratorの前半とDiscriminatorの後半に分かれる.

GeneratorはEncoder-Decoderモデルとなっていて、VGG16の構造で畳み込む前半部と、UpSamplingをしてSaliency Mapを生成する後半部にさらに分かれる.
前半3グループはVGG16の重みを固定してそのまま利用していた.

Discriminatorは元画像(3-channel) + Saliency Map(1-channel)を組み合わせて4-channelの入力とし、それが本物か偽物かを見分けるモデルになっている.

##Loss関数
Loss関数が少し特徴的だったので紹介.
GeneratorのLoss関数は次のように計算される.
スクリーンショット 2020-06-01 20.46.06.png
スクリーンショット 2020-06-01 20.46.03.png
通常のAdversarial Lossの他に、Binary Cross Entropyで計算されたSaliency Mapの生成に対する誤差項も加わる.

これら二つを重みαを使って足し合わせたものをGeneratorのLossに使っていた.
(αは論文中では0.005)

また生成誤差の方のBinary Cross Entropy Lossは1/4にダウンサンプリングしてから計算したほうが精度が高まったことが元論文で報告されていて、本実装でもそのようにLoss関数を実装した.
##他モデルとの性能比較
他モデルとの性能比較について論文から引用したのが次の表です.
スクリーンショット 2020-06-01 20.55.51.png
State-of-the-artとまでは行かないけど、そこそこの精度は出ているっぽい.
今回は比較的シンプルな構造だったのが採用した理由.
#実装
重要と思われる部分の実装をいくつか示す.
##Generator
Encoder-Decoderモデルを実装している.
Encoder部分ではKerasに組み込まれているVGG16の重みをLoadして利用している.
先ほど述べたように最初の3グループの重みは固定する.

model.py
class ModelBuilder():
    '''Construct model for salgan and BCE
    '''

    @staticmethod
    def build_encoder(img_width,img_height,l2_norm):
        input_tensor = Input(shape=(img_width, img_height, 3))
        # vgg16 = VGG16(include_top=False, weights='model/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', input_tensor=input_tensor)
        vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)

        model_encoder = Sequential()
        model_encoder.add(InputLayer(input_shape=(img_height,img_width, 3)))

        for i,layer in enumerate(vgg16.layers[:-1]):
            if i <= 10:
                layer.trainable = False
            else:
                layer.kernel_regularizer=regularizers.l2(l2_norm)
            model_encoder.add(layer)
        return model_encoder

    @staticmethod
    def build_decoder(img_width,img_height,l2_norm):
        model_decoder = Sequential()

        model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(UpSampling2D((2,2)))

        model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(UpSampling2D((2,2)))

        model_decoder.add(Conv2D(256,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(256,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(256,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(UpSampling2D((2,2)))

        model_decoder.add(Conv2D(128,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(128,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(UpSampling2D((2,2)))

        model_decoder.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_decoder.add(Conv2D(1,1,activation='sigmoid'))

        return model_decoder

    def generator(self,img_width,img_height,l2_norm=0,load_model_path=None):
        model_encoder = self.build_encoder(img_width,img_height,l2_norm)
        model_decoder = self.build_decoder(img_width,img_height,l2_norm)

        model_generator = Model(input=model_encoder.input, output=model_decoder(model_encoder.output))

        if load_model_path != None:
            print('Loading model weights from {}'.format(load_model_path))
            model_generator.load_weights(load_model_path)
        
        model_generator.summary()

        return model_generator

##Discriminator
論文と同様に実装している.
論文中では明記されていませんでしたが、全層でl2の正則化を行っている.

model.py
class ModelBuilder():
    '''Construct model for salgan and BCE
    '''
    @staticmethod
    def discriminator(img_width,img_height,l2_norm):
        model_discriminator = Sequential()

        model_discriminator.add(Conv2D(3,1,input_shape=(img_height,img_width,4),activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(Conv2D(32,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(MaxPooling2D((2,2)))

        model_discriminator.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(MaxPooling2D((2,2)))

        model_discriminator.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(MaxPooling2D((2,2)))

        model_discriminator.add(Flatten())
        model_discriminator.add(Dense(100,kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(Activation('tanh'))
        model_discriminator.add(Dense(2,kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(Activation('tanh'))
        model_discriminator.add(Dense(1,kernel_regularizer=regularizers.l2(l2_norm)))
        model_discriminator.add(Activation('sigmoid'))

        model_discriminator.summary()

        return model_discriminator

##学習
学習部分の実装ではKeras implementation of GANを参考にさせていただいた.
α(論文中では0.005)を用いてGeneratorDiscriminatorのLossを組み合わせて学習させていることが特徴.

ColabのGPUを使うと論文と同じ120epochを6時間程度で学習できた.

train.py
def train_salgan(args):
    #-- parse parameters --#
    model_name = args.model_name
    data_path = args.data_path
    l2_norm = args.l2_norm
    batch_size = args.batch_size
    num_epoch = args.num_epoch
    learning_rate = args.learning_rate
    img_width, img_height = args.image_size
    loss_alpha = args.loss_alpha
    model_save_ratio = args.model_save_ratio
    load_model_path = args.load_model_path
    #-- parse parameters --#

    X_train, Y_train = load_data(model_name,data_path)
    
    model_builder = model.ModelBuilder()

    model_generator = model_builder.generator(img_width=img_width,img_height=img_height,l2_norm=l2_norm,load_model_path=load_model_path)
    model_discriminator = model_builder.discriminator(img_width=img_width,img_height=img_height,l2_norm=l2_norm)

    output_true_batch, output_false_batch = np.ones((batch_size, 1)), np.zeros((batch_size, 1))

    model_combine = model_builder.build_combine(model_generator,model_discriminator,img_width=img_width,img_height=img_height)

    model_discriminator.trainable = True
    model_discriminator.compile(optimizer=optimizers.Adagrad(lr=learning_rate), loss="binary_crossentropy")
    model_discriminator.trainable = False
    loss = [model.LossFunction().binary_crossentropy_forth, "binary_crossentropy"]
    loss_weights = [loss_alpha, 1]
    model_combine.compile(optimizer=optimizers.Adagrad(lr=learning_rate), loss=loss, loss_weights=loss_weights)
    model_discriminator.trainable = True

    for epoch in range(1,num_epoch+1):
        print('epoch: {}/{}'.format(epoch, num_epoch))
        print('batches: {}'.format(int(X_train.shape[0] / batch_size)))
        
        permutated_indexes = np.random.permutation(X_train.shape[0])

        d_losses = []
        g_losses = []
        c_losses = []

        for index in range(int(X_train.shape[0] / batch_size)):
            batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
            image_batch = X_train[batch_indexes]
            salmap_batch = Y_train[batch_indexes]

            generated_salmap = model_generator.predict(x=image_batch, batch_size=batch_size)

            d_loss_real = model_discriminator.train_on_batch(np.concatenate([image_batch,salmap_batch], 3), output_true_batch)
            d_loss_fake = model_discriminator.train_on_batch(np.concatenate([image_batch,generated_salmap], 3), output_false_batch)
            d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
            d_losses.append(d_loss)

            model_discriminator.trainable = False

            c_loss = model_combine.train_on_batch(image_batch, [salmap_batch,output_true_batch])
            c_losses.append(c_loss[0])
            g_losses.append(c_loss[1])
            
            model_discriminator.trainable = True

        if epoch % model_save_ratio == 0:
            save_all_weights(epoch, model_generator, model_discriminator)

        print("discriminator_loss", np.mean(d_losses), "combine_loss", np.mean(c_losses), "generator_loss", np.mean(g_losses))

##1/4にダウンスケールしたBCEの実装
学習の部分でも登場した1/4にダウンスケールしたBinary Cross EntropyのLoss関数の実装.
AveragePoolingで1/4にスケールしてBCEを計算している.

model.py
class LossFunction():
    '''Original BCE loss mentioned in paper
    1/4 downscaling using AveragePooling is conducted
    '''

    @staticmethod
    def binary_crossentropy_forth(y_true, y_pred):
        y_true_forth = AveragePooling2D(pool_size=(4, 4), padding='valid')(y_true)
        y_pred_forth = AveragePooling2D(pool_size=(4, 4), padding='valid')(y_pred)
        return K.mean(K.binary_crossentropy(y_true_forth, y_pred_forth), axis=-1)

#精度評価
##Saliency Mapの推定結果
Saliency Mapの推定結果の例をいくつか示す.
(左から元画像・予測されたMap・正解ラベル)
sample1.png
sample2.png
sample3.png
sample4.png

全体としてそこそこの精度で生成できていそう.
が、正解ラベルと比べると滑らかで詳細な構造は捉えられていない箇所もあることがわかる.
##指標を使ったモデル評価
論文内で使われている指標はいくつかあるが、その中でもSALICONデータセットの評価に使われている、

  • AUC_Borji
  • AUC_Shuffled
  • NSS
  • CC
    の4指標で今回は評価を行った.

評価指標の実装はSaliency_metricsを参考にさせていただいた.
Python2での実装だったためPython3版に修正して用いた.

矢印にあるように全ての指標で値が高い方が精度が高いことを表している.
CC以外著者実装より精度が高いという結果になっていて、しかも値がかなり離れた指標もあるが理由は不明.

著者が指標の実装を載せていないため、指標の実装で違いが出ている可能性が大.
(いずれにせよ、そこそこの精度は出ていると信じることに)

Model AUC_Borji↑ AUC_Shuffled↑ NSS↑ CC↑
著者実装 0.884 0.772 2.459 0.781
本実装 0.941 0.880 3.070 0.576

#まとめと考察
実装しながら思ったことをつらつらと書いていく.

  • Lossの設計の重要性
    • 論文通りに1/4のダウンスケールングをしたら明らかに精度が向上した(理由はよくわかっていない)
    • GANの学習でLossを重み付けして組み合わせているのが勉強になった
  • GANの学習の難しさ
    • 試すたびに結果が大きく変わったり、Lossから学習の様子を知るのが難しかったり...
    • Discriminatorの入力を画像と組み合わせた4-channelで入力してるのは参考になった
      • (Saliency Mapか否かではなく、その画像から生成されたMapか否かを分類するため)

以上になったが、オープンデータと論文から自分でこれだけのものを実装できることを知って結構感動した.
せっかく実装したので、今度はSalGANを使った分析を記事にして出していこうと思う.

#参考文献

  1. 元論文 - SalGAN: Visual Saliency Prediction with Generative Adversarial Networks
  2. GitHub - SalGAN: Visual Saliency Prediction with Generative Adversarial Networks
  3. GitHub - Keras implementation of GAN
  4. GitHub - Saliency_metrics
  5. SALICON
  6. MIT Saliency Benchmark
  7. Saliency Mapを使って画像を良い感じに切り抜くAIを作った
2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?