LoginSignup
0
0

More than 1 year has passed since last update.

SSIMオートエンコーダ(全結合層)による異常検知

Last updated at Posted at 2021-11-05

背景

全結合層でオートエンコーダし、画像の異常検知を行う。
loss関数にSSIMを使用すれば構造的類似性の観点からより精度の良いオートエンコーダ結果が得られることから、MNISTデータを用いて検証する。
また、異常度の出し方を画像の輝度差分ではなくSSIMで出すことでどういう値になるかを確認する。
本ページはコードの記録のみとする。検証結果は別紙参照のこと。

やったこと

MNIST数字データで1のみを学習データに使用して全結合層オートエンコーダモデルを構築。1と9の混じったテストデータで9を異常と検知するか確認する。
異常度を算出し、ヒートマップで異常位置を可視化する。

#必要ライブラリ
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import tensorflow.keras

#MNIST読み込み
(x_train, y_train), (x_test, y_test) = mnist.load_data()

#学習データを1のみにする
x1=[]
for i in range(len(x_train)):
    if y_train[i] == 1:
        x1.append(x_train[i])
x_train = np.array(x1)

#テストデータを1と9にする
x2, y = [], []
for i in range(len(x_test)):
    if y_test[i] == 1 or y_test[i] == 9:
        x2.append(x_test[i])
        y.append(y_test[i])
x_test = np.array(x2)
y = np.array(y)

#形状確認
print("x_train.shape;", x_train.shape)#学習データ1
print("x_test.shape;", x_test.shape)#テストデータ1 or 9
print("y.shape:", y.shape)#テストデータのラベル

#スプリットし検証用データ生成
from sklearn.model_selection import train_test_split
train_data, val_data = train_test_split(x_train, test_size=0.2, random_state=0)
print("train_data.shape;", train_data.shape)#学習データ1
print("val_data.shape;", val_data.shape)#検証データ1
print("x_test.shape;", x_test.shape)#テストデータ1 or 9

#SSIMは4次元で受け付けるので4次元へ変更
train_data = np.expand_dims(train_data, 3)
val_data = np.expand_dims(val_data, 3)
x_test = np.expand_dims(x_test, 3)
print("train_data.shape;", train_data.shape)#学習データ1
print("val_data.shape;", val_data.shape)#検証データ1
print("x_test.shape;", x_test.shape)#テストデータ1 or 9

#float32へ変更
train_data = train_data.astype("float32")
val_data = val_data.astype("float32")
x_test = x_test.astype("float32")

#SSIM関数定義
def ssim_loss(y_true, y_pred):
    return 1-tf.reduce_mean(tf.image.ssim(y_true, y_pred, 
                                          max_val = 1.0,filter_size=11,
                                        filter_sigma=1.5, k1=0.01, k2=0.03 ))


#学習モデル
from tensorflow.keras.layers import Input, Dense, Flatten, Reshape
from tensorflow.keras.models import Model

INPUT_SHAPE = (train_data[0].shape[0], train_data[0].shape[1], 1)
FLATTENED_SHAPE = INPUT_SHAPE[0]*INPUT_SHAPE[1]

input_img = Input(shape=(INPUT_SHAPE))
encoded = Flatten()(input_img)
encoded = Dense(256, activation="relu")(encoded)
encoded = Dense(64, activation="relu")(encoded)
decoded = Dense(256, activation="relu")(encoded)
decoded = Dense(FLATTENED_SHAPE, activation="relu")(decoded)
decoded = Reshape(INPUT_SHAPE)(decoded)
autoencoder = Model(inputs = input_img, outputs=decoded)

#コンパイル
autoencoder.compile(optimizer="adam", loss=ssim_loss)
autoencoder.summary()


#学習
fit_record = autoencoder.fit(train_data, train_data,
                            epochs=50,
                            batch_size=64,
                            shuffle=True,
                            validation_data=(val_data, val_data))


#可視化グラフ
def plot_loss_accuracy_graph(fit_record):
    plt.plot(fit_record.history["loss"], "-D", color="blue", label="train_loss", linewidth=2)
    plt.plot(fit_record.history["val_loss"], "-D", color="black", label="val_loss", linewidth=2)
    plt.title("LOSS")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend(loc="upper right")
    plt.show()
#グラフ化
plot_loss_accuracy_graph(fit_record)

#予測
pred = autoencoder.predict(x_test)

#可視化(異常度は画像の輝度差分)
scores = []

n=10
plt.figure(figsize=(18,9))
for i in range(n):
    #テスト画像を表示
    ax = plt.subplot(3, n, i+1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    #出力画像を表示
    ax = plt.subplot(3, n, i+1+n)
    plt.imshow(pred[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    #入出力の差分を計算
    diff_img = x_test[i].reshape(28, 28) - pred[i].reshape(28, 28)

    #入出力の差分数値(異常度)を計算
    diff = np.sum(np.abs(x_test[i].reshape(28, 28) - pred[i].reshape(28, 28)))
    scores.append(diff)

    #差分画像と差分数値(異常度)を表示
    ax = plt.subplot(3, n, i+1+n*2)
    plt.imshow(diff_img, cmap="jet")

    #plt.gray()
    ax.get_xaxis().set_visible(True)
    ax.get_yaxis().set_visible(True)
    ax.set_xlabel("score= "+str(diff))

plt.savefig("result.png")
plt.show()
plt.close()


#可視化(異常度はssimから算出)_ 画像にするため2次元にreshape必要
n = 6  
plt.figure(figsize=(20, 14), dpi=100)
plt.subplots_adjust( wspace=0.1, hspace=0.07)
plt_a=1
for i in range(n):
    # 検証データ vs 検証データ
    ax = plt.subplot(3, n, plt_a)
    plt.imshow(x_test[i].reshape(28,28))
    ax.get_xaxis().set_visible(True)
    ax.get_yaxis().set_visible(False)
    value_a = ssim_loss(x_test[i], x_test[i]) #同一のものの類似度
    ax.set_title("Original Image")
    label = 'SSIM Loss value: {:.3f}'
    ax.set_xlabel(label.format(value_a) )

    # 再構築データ(予測)  vs 検証データ
    ax = plt.subplot(3, n, plt_a + n )
    plt.imshow(pred[i].reshape(28,28))
    ax.get_xaxis().set_visible(True)
    ax.get_yaxis().set_visible(False)    
    value_a = ssim_loss(pred[i], x_test[i])
    ax.set_title("Reconstructed Image")
    label = 'SSIM Loss value: {:.2}'
    ax.set_xlabel(label.format(value_a) )

    plt_a+=1
plt.show()

学習結果はlossとval lossで同じ傾向であり、良好。
異常個所が明確に色付けされたヒートマップが得られ、異常度も1と9でしっかり層別できた。
今後は製造物に対してssimを検討する予定。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0