#はじめに
画像内の目立つ部分のことを"Salient"な部分と言う.
特に画像が与えられた時に目立つ部分の分布を表したmapのことをSaliency Mapと言う.
下の例は左の画像とそれに対応するSaliency Mapを並べて表したもので、人の部分が目立っていることが読み取れる.
深層学習を用いてこのようなSaliency Mapを生成する手法がいくつか提案されている.
今回はその中でもGANを用いて生成を行うSalGANというモデルをKerasを使って実装し性能を評価してみることにした.
こちらは元論文と今回の実装のURL.
元論文 - SalGAN: Visual Saliency Prediction with Generative Adversarial Networks
GitHub - KerasでSalGANの実装
#データについて
今回は元論文に合わせてSALICON datasetを用いた.
- train_data 10,000枚
- validation_data 5,000枚
- test_data 5,000枚
からなるデータセット.
人間に画像を見せm見ている箇所をクリックしてもらうというタスクをクラウドソーシングやってもらって集めたらしい.
他にもいくつかデータセットがあるのですが、これが一番大きいため採用したと筆者は書いている.
#SalGANについて
SalGANは2017年に提案されたGANのアーキテクチャでSaliency Mapを生成するモデル.
図. SalGANの構造(元論文より引用)
全体としてGeneratorの前半とDiscriminatorの後半に分かれる.
GeneratorはEncoder-Decoderモデルとなっていて、VGG16の構造で畳み込む前半部と、UpSamplingをしてSaliency Mapを生成する後半部にさらに分かれる.
前半3グループはVGG16の重みを固定してそのまま利用していた.
Discriminatorは元画像(3-channel) + Saliency Map(1-channel)を組み合わせて4-channelの入力とし、それが本物か偽物かを見分けるモデルになっている.
##Loss関数
Loss関数が少し特徴的だったので紹介.
GeneratorのLoss関数は次のように計算される.
通常のAdversarial Lossの他に、Binary Cross Entropyで計算されたSaliency Mapの生成に対する誤差項も加わる.
これら二つを重みαを使って足し合わせたものをGeneratorのLossに使っていた.
(αは論文中では0.005)
また生成誤差の方のBinary Cross Entropy Lossは1/4にダウンサンプリングしてから計算したほうが精度が高まったことが元論文で報告されていて、本実装でもそのようにLoss関数を実装した.
##他モデルとの性能比較
他モデルとの性能比較について論文から引用したのが次の表です.
State-of-the-artとまでは行かないけど、そこそこの精度は出ているっぽい.
今回は比較的シンプルな構造だったのが採用した理由.
#実装
重要と思われる部分の実装をいくつか示す.
##Generator
Encoder-Decoderモデルを実装している.
Encoder部分ではKerasに組み込まれているVGG16の重みをLoadして利用している.
先ほど述べたように最初の3グループの重みは固定する.
class ModelBuilder():
'''Construct model for salgan and BCE
'''
@staticmethod
def build_encoder(img_width,img_height,l2_norm):
input_tensor = Input(shape=(img_width, img_height, 3))
# vgg16 = VGG16(include_top=False, weights='model/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', input_tensor=input_tensor)
vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
model_encoder = Sequential()
model_encoder.add(InputLayer(input_shape=(img_height,img_width, 3)))
for i,layer in enumerate(vgg16.layers[:-1]):
if i <= 10:
layer.trainable = False
else:
layer.kernel_regularizer=regularizers.l2(l2_norm)
model_encoder.add(layer)
return model_encoder
@staticmethod
def build_decoder(img_width,img_height,l2_norm):
model_decoder = Sequential()
model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(UpSampling2D((2,2)))
model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(512,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(UpSampling2D((2,2)))
model_decoder.add(Conv2D(256,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(256,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(256,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(UpSampling2D((2,2)))
model_decoder.add(Conv2D(128,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(128,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(UpSampling2D((2,2)))
model_decoder.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_decoder.add(Conv2D(1,1,activation='sigmoid'))
return model_decoder
def generator(self,img_width,img_height,l2_norm=0,load_model_path=None):
model_encoder = self.build_encoder(img_width,img_height,l2_norm)
model_decoder = self.build_decoder(img_width,img_height,l2_norm)
model_generator = Model(input=model_encoder.input, output=model_decoder(model_encoder.output))
if load_model_path != None:
print('Loading model weights from {}'.format(load_model_path))
model_generator.load_weights(load_model_path)
model_generator.summary()
return model_generator
##Discriminator
論文と同様に実装している.
論文中では明記されていませんでしたが、全層でl2の正則化を行っている.
class ModelBuilder():
'''Construct model for salgan and BCE
'''
@staticmethod
def discriminator(img_width,img_height,l2_norm):
model_discriminator = Sequential()
model_discriminator.add(Conv2D(3,1,input_shape=(img_height,img_width,4),activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(Conv2D(32,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(MaxPooling2D((2,2)))
model_discriminator.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(MaxPooling2D((2,2)))
model_discriminator.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(Conv2D(64,3,activation='relu',padding='same',kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(MaxPooling2D((2,2)))
model_discriminator.add(Flatten())
model_discriminator.add(Dense(100,kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(Activation('tanh'))
model_discriminator.add(Dense(2,kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(Activation('tanh'))
model_discriminator.add(Dense(1,kernel_regularizer=regularizers.l2(l2_norm)))
model_discriminator.add(Activation('sigmoid'))
model_discriminator.summary()
return model_discriminator
##学習
学習部分の実装ではKeras implementation of GANを参考にさせていただいた.
α(論文中では0.005)を用いてGeneratorとDiscriminatorのLossを組み合わせて学習させていることが特徴.
ColabのGPUを使うと論文と同じ120epochを6時間程度で学習できた.
def train_salgan(args):
#-- parse parameters --#
model_name = args.model_name
data_path = args.data_path
l2_norm = args.l2_norm
batch_size = args.batch_size
num_epoch = args.num_epoch
learning_rate = args.learning_rate
img_width, img_height = args.image_size
loss_alpha = args.loss_alpha
model_save_ratio = args.model_save_ratio
load_model_path = args.load_model_path
#-- parse parameters --#
X_train, Y_train = load_data(model_name,data_path)
model_builder = model.ModelBuilder()
model_generator = model_builder.generator(img_width=img_width,img_height=img_height,l2_norm=l2_norm,load_model_path=load_model_path)
model_discriminator = model_builder.discriminator(img_width=img_width,img_height=img_height,l2_norm=l2_norm)
output_true_batch, output_false_batch = np.ones((batch_size, 1)), np.zeros((batch_size, 1))
model_combine = model_builder.build_combine(model_generator,model_discriminator,img_width=img_width,img_height=img_height)
model_discriminator.trainable = True
model_discriminator.compile(optimizer=optimizers.Adagrad(lr=learning_rate), loss="binary_crossentropy")
model_discriminator.trainable = False
loss = [model.LossFunction().binary_crossentropy_forth, "binary_crossentropy"]
loss_weights = [loss_alpha, 1]
model_combine.compile(optimizer=optimizers.Adagrad(lr=learning_rate), loss=loss, loss_weights=loss_weights)
model_discriminator.trainable = True
for epoch in range(1,num_epoch+1):
print('epoch: {}/{}'.format(epoch, num_epoch))
print('batches: {}'.format(int(X_train.shape[0] / batch_size)))
permutated_indexes = np.random.permutation(X_train.shape[0])
d_losses = []
g_losses = []
c_losses = []
for index in range(int(X_train.shape[0] / batch_size)):
batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
image_batch = X_train[batch_indexes]
salmap_batch = Y_train[batch_indexes]
generated_salmap = model_generator.predict(x=image_batch, batch_size=batch_size)
d_loss_real = model_discriminator.train_on_batch(np.concatenate([image_batch,salmap_batch], 3), output_true_batch)
d_loss_fake = model_discriminator.train_on_batch(np.concatenate([image_batch,generated_salmap], 3), output_false_batch)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
d_losses.append(d_loss)
model_discriminator.trainable = False
c_loss = model_combine.train_on_batch(image_batch, [salmap_batch,output_true_batch])
c_losses.append(c_loss[0])
g_losses.append(c_loss[1])
model_discriminator.trainable = True
if epoch % model_save_ratio == 0:
save_all_weights(epoch, model_generator, model_discriminator)
print("discriminator_loss", np.mean(d_losses), "combine_loss", np.mean(c_losses), "generator_loss", np.mean(g_losses))
##1/4にダウンスケールしたBCEの実装
学習の部分でも登場した1/4にダウンスケールしたBinary Cross EntropyのLoss関数の実装.
AveragePoolingで1/4にスケールしてBCEを計算している.
class LossFunction():
'''Original BCE loss mentioned in paper
1/4 downscaling using AveragePooling is conducted
'''
@staticmethod
def binary_crossentropy_forth(y_true, y_pred):
y_true_forth = AveragePooling2D(pool_size=(4, 4), padding='valid')(y_true)
y_pred_forth = AveragePooling2D(pool_size=(4, 4), padding='valid')(y_pred)
return K.mean(K.binary_crossentropy(y_true_forth, y_pred_forth), axis=-1)
#精度評価
##Saliency Mapの推定結果
Saliency Mapの推定結果の例をいくつか示す.
(左から元画像・予測されたMap・正解ラベル)
全体としてそこそこの精度で生成できていそう.
が、正解ラベルと比べると滑らかで詳細な構造は捉えられていない箇所もあることがわかる.
##指標を使ったモデル評価
論文内で使われている指標はいくつかあるが、その中でもSALICONデータセットの評価に使われている、
- AUC_Borji
- AUC_Shuffled
- NSS
- CC
の4指標で今回は評価を行った.
評価指標の実装はSaliency_metricsを参考にさせていただいた.
Python2での実装だったためPython3版に修正して用いた.
矢印にあるように全ての指標で値が高い方が精度が高いことを表している.
CC以外著者実装より精度が高いという結果になっていて、しかも値がかなり離れた指標もあるが理由は不明.
著者が指標の実装を載せていないため、指標の実装で違いが出ている可能性が大.
(いずれにせよ、そこそこの精度は出ていると信じることに)
Model | AUC_Borji↑ | AUC_Shuffled↑ | NSS↑ | CC↑ |
---|---|---|---|---|
著者実装 | 0.884 | 0.772 | 2.459 | 0.781 |
本実装 | 0.941 | 0.880 | 3.070 | 0.576 |
#まとめと考察
実装しながら思ったことをつらつらと書いていく.
- Lossの設計の重要性
- 論文通りに1/4のダウンスケールングをしたら明らかに精度が向上した(理由はよくわかっていない)
- GANの学習でLossを重み付けして組み合わせているのが勉強になった
- GANの学習の難しさ
- 試すたびに結果が大きく変わったり、Lossから学習の様子を知るのが難しかったり...
- Discriminatorの入力を画像と組み合わせた4-channelで入力してるのは参考になった
- (Saliency Mapか否かではなく、その画像から生成されたMapか否かを分類するため)
以上になったが、オープンデータと論文から自分でこれだけのものを実装できることを知って結構感動した.
せっかく実装したので、今度はSalGANを使った分析を記事にして出していこうと思う.
#参考文献