3
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

VAEでMNISTデータをAugmentationして分類機を学習する

Last updated at Posted at 2020-05-31

この記事について

VAE(Variational Autoencoder)を勉強しているとData augmentationに利用したらどうなる?と疑問に思ったが,調べてもあまりでてこなかったので実際にやってみた.

・・・という記事です.勉強がてらやったものとして気軽に御覧ください.ご指摘などあればぜひお願いします!

コード(ipynbファイル)はここを参照.

環境

Google Colab.時々インスタンス切られつつなんとか動作.

はじめに - VAE(変分オートエンコーダー)とオートエンコーダー

##オートエンコーダー
inputとoutputが同じ.
中間層で次元削減し特徴量抽出 -> 教師なし学習として用いられる
AutoEncoder.jpg

VAE

inputとoutputが同じなのは共通.
異なるのは潜在変数を直接学習するのではなく,潜在変数を事後分布 $q(z_i|x_i)$として学習(具体的にはガウス分布の平均と分散)する.
潜在変数空間からサンプリングされたベクトルを用いてoutputを「生成」する.要は生成モデル.
目的関数としては入力xに対し,$\log p_{\theta}(x)$の変分下限を最大化するアプローチになります.理論的な背景としては原著論文(1)をご参照ください.日本語の解説記事も多くありますが,(2)もわかりやすかったです.
(参考)
(1) Kingma DP, et.al. Semi-supervised learning with deep generative models. Adv Neural Inf Process Syst 2014;4:3581–3589.
(2)【超初心者向け】VAEの分かりやすい説明とPyTorchの実装

VAE.jpg

方法

####分類器
VGG16 likeなものにBatchNormalization層を追加
コードは下記(冗長ですみません

CNNmodel.py
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import keras
from keras import layers

def base_layer(output_dim):
  '''CNN block for function(baseConvModel)'''
  initializer = keras.initializers.he_normal()
  def f(input_tensor):
    x = layers.Conv2D(output_dim, (3, 3), padding='same', kernel_initializer=initializer)(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(output_dim, (3, 3), padding='same', kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    out = layers.MaxPooling2D((2,2),strides=(2,2),padding='same')(x)
    return out
  return f


def baseConvModel(lr=0.0002,clipvalue=1.0,beta_1=0.9,beta_2=0.99):
  '''
  light VGG-like CNN model for MNIST.
  (28,28,1) - > (14, 14, 16) -> (7, 7, 64) -> (3136, ) -> (512, ) -> (10, )
  #######CAUTION######
  Return -> COMPILED model
  ####################
  Args:
    lr: learning rate of adam-optimizer
    beta_1, beta_2, clip_value : for adam_optimizer
  Returns : COMPILED model
  '''
  initializer = keras.initializers.he_normal()
  inputs = layers.Input(shape=(28,28,1))
  '''small VGG-like network'''
  x = base_layer(16)(inputs)
  x = base_layer(64)(x)

  '''FL'''
  x = layers.Flatten()(x)# 7x7x64 = 3136
  x = layers.core.Dropout(0.3)(x)
  x = layers.Dense(512,activation='relu',kernel_initializer=initializer)(x)
  x = layers.core.Dropout(0.5)(x)
  predict = layers.Dense(10,activation='softmax',kernel_initializer=keras.initializers.glorot_normal())(x)

  model = keras.models.Model(inputs=inputs,outputs=predict)
  '''Compilie'''
  adam = keras.optimizers.Adam(lr=lr,beta_1=beta_1,beta_2=beta_2,decay=1e-8,clipvalue=clipvalue)
  model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
  #model.summary()
  return model

####データ拡張
VAEを用いて,
・ことなるオリジナル画像数で(100,200,400,800)
・異なる学習具合(lossで表現)
・ことなる増幅数(x10, x50, x250)
で画像を増幅.

コードは下記参照.

↓VAE

vae.py
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import keras
from keras import layers
from keras import backend as K
import numpy as np

class CustomVariationalLayer(keras.layers.Layer):
  '''Custom loss functiond(dummy layer)
  Define loss function: image_loss + KL divergense'''
  def vae_loss(self, x, z_decoded,z_mean,z_log_var):
    input_dim = np.prod(K.int_shape(x)[1:])
    x = K.flatten(x)
    z_decoded = K.flatten(z_decoded)
    reconst_loss = input_dim * keras.metrics.binary_crossentropy(x, z_decoded)
    kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return K.mean(reconst_loss + kl_loss)

  def call(self, inputs):
    x, z_decoded, z_mean, z_log_var = inputs
    loss = self.vae_loss(*inputs)
    self.add_loss(loss, inputs=inputs)
    return x

def sampling(arg):
  ''' Samping z from encoded values for decoder(For VAE training) '''
  z_mean, z_log_var = arg

  ep = K.random_normal(shape=(K.int_shape(z_mean)[1], ),# K.int_shape(z_mean)[0],
                       mean = 0,stddev = 1)
  return z_mean + ep * K.exp(0.5 * z_log_var)


def base_layers_enc(output_dim, strides=(1,1)):
  initializer = keras.initializers.he_normal()
  def f(input_tensor):
    x = layers.Conv2D(output_dim, (3, 3), padding='same',strides=strides, kernel_initializer=initializer)(input_tensor)
    x = layers.BatchNormalization()(x)
    out = layers.LeakyReLU(0.2)(x)
    return out
  return f

def set_vae(img_shape=(28,28,1),latent_dim=20):
  '''
  Define VAE, VAE-encoder, VAE-decoder
  (28, 28, 1) -> (28, 28, 32) -> (14, 14, 64) -> (14, 14, 64)[decoder] - >(12544, ) -> (32, ) -> (latent_dim)

  '''
  #for Encoder###########################################
  input_img_vae = keras.Input(shape=img_shape)
  initializer = keras.initializers.he_normal()
  out_dims = [[32, (1,1)], [64, (2,2)], [64, (1,1)]]
  x = input_img_vae

  for _d, _st in out_dims:
    x = base_layers_enc(_d, _st)(x)

  x = layers.Conv2D(64,(3,3),padding='same',kernel_initializer=initializer)(x)

  cache_shape = K.int_shape(x)# for decoder_dim

  x = layers.BatchNormalization()(x)
  x = layers.LeakyReLU(0.2)(x)
  x = layers.Flatten()(x)
  #x = layers.Dropout(0.3)(x)
  x = layers.Dense(32,kernel_initializer=initializer)(x)
  x = layers.BatchNormalization()(x)
  x = layers.LeakyReLU(0.2)(x)

  z_mean = layers.Dense(latent_dim)(x)
  z_log_var = layers.Dense(latent_dim)(x)

  VaeEnc = keras.models.Model(input_img_vae, [z_mean, z_log_var])
  #VaeEnc.summary()

  # samping
  z = layers.Lambda(sampling, output_shape = (latent_dim, ))([z_mean, z_log_var])

  # for decoder################################################
  '''
  (laten_dim) -> (12544, ) -> (14, 14, 64) -> (28, 28, 32) -> (28, 28, 1)
  '''
  decoder_input_vae = keras.Input(shape=(latent_dim,))
  fl_dim = np.prod(cache_shape[1:])
  x = layers.Dense(fl_dim,kernel_initializer=initializer)(decoder_input_vae)
  x = layers.BatchNormalization()(x)
  x = layers.LeakyReLU(0.2)(x)
  x = layers.Reshape(cache_shape[1:])(x)
  x = layers.Conv2DTranspose(32,(3,3),strides=(2,2),padding='same',kernel_initializer=initializer)(x)
  x = layers.BatchNormalization()(x)
  x = layers.LeakyReLU(0.2)(x)
  x = layers.Conv2D(1,(3,3),padding='same',kernel_initializer=keras.initializers.glorot_normal())(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation('sigmoid')(x)

  VaeDec = keras.models.Model(decoder_input_vae,x)
  #VaeDec.summary()

  # define encoder-to-decoder
  z_decoded = VaeDec(z)

  # define loss-layer
  y = CustomVariationalLayer()([input_img_vae, z_decoded,z_mean,z_log_var])
  Vae = keras.models.Model(input_img_vae,y)

  Vae.compile(optimizer='rmsprop', loss=None)
  #Vae.summary()
  return VaeEnc, VaeDec, Vae


def dec_sampling(arg,n):
  '''Sampling z from encoded values(For decoder test)'''
  z_mean, z_log_var = arg
  ep = np.random.normal(size=(n,z_mean.shape[-1]),loc=0,scale=1)
  return z_mean + ep * np.exp(0.5 * z_log_var)
cnn = baseConvModel()
VaeEnc,VaeDec,Vae = set_vae()

でモデル(Compile後)を定義しています.

ここからはノートブックを参照ください.

1.VAE学習->一定の間隔で止める, データ増幅・保存
を繰り返したのち
2.CNNで学習,プロットを保存
を繰り返しています.(Colabでやるのは刻まないといけないので大変だった)

結果

まず,100個のデータのみを用いた時の学習結果を下に示します.
validationに対しては80%弱の予測となっており,トレーニングデータが不足していると考えられます.

オリジナルデータのみ(100枚用いた場合)

VAE_MNIST__ori100.png

下に示すのがVAEで増幅した画像で学習したときの結果です.データは
プラトーに達した段階でのvalidation-accuracyをプロットに用いました.

n_train: オリジナルの画像数
aug: 増幅画像数
破線:オリジナル画像のみを用いた時のAccuracy

VAE増幅->学習結果

VAE_MNIST_results.png

どうやら,特にオリジナル画像が少ない際は分類器にいい影響を与えそうです.
VAE-Lossが小さくなりすぎないほうが元画像から若干離れた画像を生成できていることのなるので学習に良い影響を与えるかもしれないと考えていましたが,VAE-Lossは低ければ低いほうが(≒オリジナルにより似た画像を生成できている方が)よいようです.

これが他のAugmentation手段と同時に用いることでさらにいい結果を期待できるか,他のデータセットでも同様の結果が得られるか,などはやっていないので,また時間ができたら試しにやってみようかななどと思っています.

結論

VAEはMNISTデータセットの増幅には有効.

3
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?