21
21

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Variational Autoencoderを使った画像の異常検知 後編 (塩尻MLもくもく会#7)

Last updated at Posted at 2018-08-18

前編では、ヒートマップを使って異常個所を可視化しました。

後編では、従来手法と提案手法を数値的に比較してみます。
使うツールは、ROC曲線です。

#ROC曲線
ROC曲線は、ラベル間のデータ数に差があるときに使われます。
ROC曲線について、詳しく知りたい方はこちら↓を参考にしてください。
http://www.randpy.tokyo/entry/roc_auc

今回は、ラベル間のデータ数に大きな差はありませんが、論文にならって
ROC曲線を描画します。最終的に、AUCを算出して、これが大きい方が優れた
異常検知器といえます。

#コードの解説
コードの一部を解説します。

##異常スコアの算出
論文によると、異常判定の基準は以下のとおりです。

このパッチ全てに関して異常度を算出し、少なくとも1枚が閾値を
超えてる場合、テストデータは異常であると判断した.

パッチというのは、前編でお伝えした小窓に相当します。
前編では、下記のヒートマップを作成するにあたり、元の画像(サイズ28×28)に
対し、小窓(サイズ8×8)を上下左右に走らせました。

new_.png

このとき、1枚の画像から小窓のスコアは21×21=441個出てきます。
論文に従うと、441個の小窓のスコアの中から、その最大値(スコアが高いほど異常度が
高いです。)をその画像の異常スコアとします。

そして、閾値を決め、閾値より異常スコアが高いものを「異常画像」、
低いものを「正常画像」と判定します。

まとめると、以下のとおりです。

1枚の画像 → 小窓を走らせる → 441個のスコア → スコアの最大値 = 異常スコア → 閾値と比較

pythonのコードは以下のとおりです。関数には、テストデータを渡して、
小窓を走らせながら、画像1枚1枚に異常スコアを付けています。
nameの名前で、従来手法と提案手法を切り替えて評価しています。

#最大異常値の計算
def result_score(model, x, name, height=8, width=8, move=2):
    score = []
    
    for k in range(len(x)):
        max_score = -1000000000
        if k%100 == 0:
          print(k)
        
        for i in range(int((x.shape[1]-height)/move)):
            for j in range(int((x.shape[2]-width)/move)):
                x_sub = x[k, i*move:i*move+height, j*move:j*move+width, 0]
                x_sub = x_sub.reshape(1, height, width, 1)
            
                #従来手法
                if name == "old_":
                    #スコア
                    temp_score = model.evaluate(x_sub, batch_size=1, verbose=0)
                    if temp_score > max_score:
                        max_score = temp_score
                
                #提案手法
                else:
                    #スコア
                    mu, sigma = model.predict(x_sub, batch_size=1, verbose=0)
                    loss = 0
                    for o in range(height):
                        for l in range(width):
                            loss += 0.5 * (x_sub[0,o,l,0] - mu[0,o,l,0])**2 / sigma[0,o,l,0]
                    if loss > max_score:
                       max_score = loss
        
        score.append(max_score)
        
    return(score)

##ROC曲線の描画
ROC曲線の描画は、以下の記事を参考にしました。
https://qiita.com/9pid/items/53946c3ec4b2489e7cb2

scikit-learnを使えば一発です。

#MNISTを使った結果
※11/21修正 コードにバグがあった関係で、図と文章を全面的に修正しました。

Colaboratoryの場合、計算時間は学習で10分、テストデータの評価で2時間ほど
かかります。

ROC曲線は以下のとおりです。
ROC curve.png

グラフの中のareaがAUCに相当します。そして、AUCが高いほうが優秀な異常検知器です。

予想に反して、従来手法の方が良い結果となりました。
提案手法の精度が落ちた理由は、後で考察します。

#Fashion-MNISTを使った結果
※11/21修正 コードにバグがあった関係で、図と文章を全面的に修正しました。

計算時間はMNISTと同じくらいでした。

ROC curve.png

こちらはMNISTと逆で、提案手法の方が良い結果となりました。
この理由も後で考察します。

#考察 11/21追記

MNISTとFashion-MNISTで結果が変わった理由を考察してみると、
MNISTでは以下の図のような学習結果になっていると思われます。

無題.png

そして、テストの際に1と9を入力すると、直線部分の画像は右の山に
位置してきます。一方、異常判別の材料となる画像(1の折れ曲がり部や
9の曲線部分の画像)は左の山に位置してきます。そして、1の折れ曲
がり部は左の山の「中央」付近に、9の曲線部分は左の山の「端」に
位置することになり、異常判別ができます。

ここで重要なのは、右の山は無視して良いということです。つまり、
異常か正常かを判断する材料は左の山だけで十分です。

ところが、提案手法では右の山を下げて、左の山と同じ土俵で勝負させます。
よって、提案手法では、本来無視してよい右の山(直線部分の画像)も判別し
異常と認識してしまっている可能性が高く、精度を下げる一因になっていると
思われます。

一方、Fashion-MNISTはブーツの かかと は複雑なため左の山に、
ブーツの口は単純なため右の山に属していると思われます。従って、
どちらの山も判断材料になるため、左右の山を同じ土俵で勝負させる
提案手法の方が精度が出たと思われます。

まとめますと、「複雑な」画像にしか異常が出ない場合は、従来手法
方が優れていると思われます。工業製品のように「単純な」ところにも
「複雑な」ところにも異常が出る場合は、提案手法の方が優れていると
思われます。

ちなみに、Pythonでコードを組む場合、計算速度については、従来手法の方が
速く(model.evaluate()を使っているため)現場向きかと思われます。
一方、提案手法はfor文を駆使しているため、速度が遅くなります。

#まとめ
※11/21修正 コードにバグがあった関係で、文章を修正しました。

前編と後編に分けて、VAEによる画像の異常検知を実装してきました。
見てお分かりのとおり、提案手法は視覚的にも能力的にも優れていること
が分かりました。

能力の面では、考察に示したとおり、異常検知の対象となる画像の性質を
考えて使い分けした方が良さそうです。

余裕があったら、本手法を使って画像以外の異常検知もやってみたいと思います。

2018/1/8追記
本手法より精度が良い論文の記事を書きました。
https://qiita.com/shinmura0/items/cfb51f66b2d172f2403b

#コード全文
9/15 追記:コードにバグがありました。お詫びを申し上げ、修正させていただきます。
12/18 コード修正 (@gungiven さんご指摘ありがとうございます。)

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from keras.layers import Lambda, Input, Dense, Reshape
from keras.models import Model
from keras.datasets import mnist
from keras.datasets import fashion_mnist
from keras.losses import mse
from keras.utils import plot_model
from keras import backend as K
from keras.layers import BatchNormalization, Activation, Flatten
from keras.layers.convolutional import Conv2DTranspose, Conv2D

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import os
from sklearn import metrics

#最大異常値の計算
def result_score(model, x, name, height=8, width=8, move=2):
    score = []
    
    for k in range(len(x)):
        max_score = -1000000000
        if k%100 == 0:
          print(k)
        
        for i in range(int((x.shape[1]-height)/move)+1):
            for j in range(int((x.shape[2]-width)/move)+1):
                x_sub = x[k, i*move:i*move+height, j*move:j*move+width, 0]
                x_sub = x_sub.reshape(1, height, width, 1)
            
                #従来手法
                if name == "old_":
                    #スコア
                    temp_score = model.evaluate(x_sub, batch_size=1, verbose=0)
                    if temp_score > max_score:
                        max_score = temp_score
                
                #提案手法
                else:
                    #スコア
                    mu, sigma = model.predict(x_sub, batch_size=1, verbose=0)
                    loss = 0
                    for o in range(height):
                        for l in range(width):
                            loss += 0.5 * (x_sub[0,o,l,0] - mu[0,o,l,0])**2 / sigma[0,o,l,0]
                    if loss > max_score:
                       max_score = loss
        
        score.append(max_score)
        
    return(score)
            
#8×8のサイズに切り出す
def cut_img(x, number, height=8, width=8):
    print("cutting images ...")
    x_out = []
    x_shape = x.shape
    
    for i in range(number):
        shape_0 = np.random.randint(0,x_shape[0])
        shape_1 = np.random.randint(0,x_shape[1]-height)
        shape_2 = np.random.randint(0,x_shape[2]-width)
        temp = x[shape_0, shape_1:shape_1+height, shape_2:shape_2+width, 0]
        x_out.append(temp.reshape((height, width, x_shape[3])))
    
    print("Complete.")
    x_out = np.array(x_out)
    
    return x_out

# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean + sqrt(var)*eps
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
    
# dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

#1と9のデータ抽出
x_train_1 = []
x_test_1 = []
x_test_9 = []

x_train_shape = x_train.shape

for i in range(len(x_train)):
  if y_train[i] == 1:#スニーカーは7
    temp = x_train[i,:,:,:]
    x_train_1.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
x_train_1 = np.array(x_train_1)
x_train_1 = cut_img(x_train_1, 100000)
print("train data:",len(x_train_1))

for i in range(len(x_test)):
  if y_test[i] == 1:#スニーカーは7
    temp = x_test[i,:,:,:]
    x_test_1.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
  if y_test[i] == 9:
    temp = x_test[i,:,:,:]
    x_test_9.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
x_test_1 = np.array(x_test_1)
x_test_9 = np.array(x_test_9)

# network parameters
input_shape=(8, 8, 1)
batch_size = 128
latent_dim = 2
epochs = 10
Nc = 16

# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
x = Conv2D(Nc, kernel_size=2, strides=2)(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(2*Nc, kernel_size=2, strides=2)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Flatten()(x)

z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
#encoder.summary()

# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(2*2)(latent_inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Reshape((2,2,1))(x)
x = Conv2DTranspose(2*Nc, kernel_size=2, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(Nc, kernel_size=2, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

x1 = Conv2DTranspose(1, kernel_size=4, padding='same')(x)
x1 = BatchNormalization()(x1)
out1 = Activation('sigmoid')(x1)#out.shape=(n,28,28,1)

x2 = Conv2DTranspose(1, kernel_size=4, padding='same')(x)
x2 = BatchNormalization()(x2)
out2 = Activation('sigmoid')(x2)#out.shape=(n,28,28,1)

decoder = Model(latent_inputs, [out1, out2], name='decoder')
#decoder.summary()

# build VAE model
outputs_mu, outputs_sigma_2 = decoder(encoder(inputs)[2])
vae = Model(inputs, [outputs_mu, outputs_sigma_2], name='vae_mlp')

# VAE loss
m_vae_loss = (K.flatten(inputs) - K.flatten(outputs_mu))**2 / K.flatten(outputs_sigma_2)
m_vae_loss = 0.5 * K.sum(m_vae_loss)
    
a_vae_loss = K.log(2 * 3.14 * K.flatten(outputs_sigma_2))
a_vae_loss = 0.5 * K.sum(a_vae_loss)
        
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
    
vae_loss = K.mean(kl_loss + m_vae_loss + a_vae_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')

# train the autoencoder
vae.fit(x_train_1,
        epochs=epochs,
        batch_size=batch_size)
        #validation_data=(x_test, None))
vae.save_weights('vae_mlp_mnist.h5')

#正常/異常のテストデータ
test_normal = x_test_1
test_anomaly = x_test_9
    
#従来手法の評価
print("normal test data:",len(test_normal))
old_score_normal = result_score(vae, test_normal, "old_")
print("anomaly test data:",len(test_anomaly))
old_score_anomaly = result_score(vae, test_anomaly, "old_")

#提案手法の評価
print("normal test data:",len(test_normal))
new_score_normal = result_score(vae, test_normal, "new_")
print("anomaly test data:",len(test_anomaly))
new_score_anomaly = result_score(vae, test_anomaly, "new_")

#新旧手法のスコア可視化
path = 'images/'
if not os.path.exists(path):
      os.mkdir(path)
    
plt.figure()
plt.plot(old_score_normal,label="normal")
plt.plot(old_score_anomaly,label="anomaly",c="red")
plt.title("Old method")
plt.xlabel("Test No")
plt.ylabel("Score")
plt.legend()
plt.savefig(path + "Old method.png")
plt.show

plt.figure()
plt.plot(new_score_normal,label="normal")
plt.plot(new_score_anomaly,label="anomaly",c="red")
plt.title("New method")
plt.xlabel("Test No")
plt.ylabel("Score")
plt.legend()
plt.savefig(path + "New method.png")
plt.show()

#ROC曲線の描画
y_true = np.zeros(len(test_normal)+len(test_anomaly))
y_true[len(test_normal):] = 1#0:正常、1:異常
old_score = np.array(old_score_normal)
old_score = np.hstack((old_score, np.array(old_score_anomaly)))
new_score = np.array(new_score_normal)
new_score = np.hstack((new_score, np.array(new_score_anomaly)))

# FPR, TPR(, しきい値) を算出
fpr_old, tpr_old, _ = metrics.roc_curve(y_true, old_score)
fpr_new, tpr_new, _ = metrics.roc_curve(y_true, new_score)

# AUC
auc_old = metrics.auc(fpr_old, tpr_old)
auc_new = metrics.auc(fpr_new, tpr_new)
    
# ROC曲線をプロット
plt.figure
plt.plot(fpr_old, tpr_old, label='Old method (area = %.2f)'%auc_old)
plt.plot(fpr_new, tpr_new, label='New method (area = %.2f)'%auc_new, c="r")
plt.legend()
plt.title('ROC curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.grid(True)
plt.savefig(path + "ROC curve.png")
plt.show()

21
21
5

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
21

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?