65
67

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

kerasでScore-CAM実装.Grad-CAMとの比較

Last updated at Posted at 2019-11-04

概要

  • 2019年10月3日にarxivに投稿された論文の手法「Score-CAM」をkerasで実装しました.
  • Grad-CAM, Grad-CAM++と比較しました.
  • 独自に高速化した Faster Score-CAM を実装しました.
  • DAGMデータセットで学習したモデルに適用しました.
  • コードはgithubにあります.
compare.png

class6.jpg

Score-CAMとは

  • CNNの判断根拠の可視化手法のひとつ
    • 判断根拠の理解についてはこちらに素晴らしいまとめがあります.
  • 先行研究として,Grad-CAM, Grad-CAM++などがあります
  • ノイズが減って安定性が向上し,勾配計算に依存しなくなりました

環境

  • Python 3.6.8
  • Keras 2.2.4
  • tensorflow-gpu 1.14.0

Score-CAMの手続き

  1. 活性化マップを取得して,最初の入力と同じサイズまでバイリニア補完で拡大する
  2. 各チャンネルを区間 [0,1] に正規化(normalize)する
  3. 入力画像に対し,各チャンネルごとに積をとってマスクされた画像をチャンネル数ぶん用意する.
  4. 各マスク済み画像をCNNに通して,ソフトマックス演算後の配列を取得する
  5. 各チャンネルの重要度を,ソフトマックス後の,対象クラスのスコアで定義する
  6. 各チャンネルに先ほどの重要度をかけて足し合わせてReLU演算を施したものを,最終的なスコアマップとする.(ReLUを通すのは,ネガティブな要素に興味がないから)

kerasで実装

import cv2
import numpy as np
from keras.models import Model

def ScoreCam(model, img_array, layer_name, max_N=-1):

    cls = np.argmax(model.predict(img_array))
    act_map_array = Model(inputs=model.input, outputs=model.get_layer(layer_name).output).predict(img_array)
    
    # extract effective maps
    if max_N != -1:
        act_map_std_list = [np.std(act_map_array[0,:,:,k]) for k in range(act_map_array.shape[3])]
        unsorted_max_indices = np.argpartition(-np.array(act_map_std_list), max_N)[:max_N]
        max_N_indices = unsorted_max_indices[np.argsort(-np.array(act_map_std_list)[unsorted_max_indices])]
        act_map_array = act_map_array[:,:,:,max_N_indices]

    input_shape = model.layers[0].output_shape[1:]  # get input shape
    # 1. upsampled to original input size
    act_map_resized_list = [cv2.resize(act_map_array[0,:,:,k], input_shape[:2], interpolation=cv2.INTER_LINEAR) for k in range(act_map_array.shape[3])]
    # 2. normalize the raw activation value in each activation map into [0, 1]
    act_map_normalized_list = []
    for act_map_resized in act_map_resized_list:
        if np.max(act_map_resized) - np.min(act_map_resized) != 0:
            act_map_normalized = act_map_resized / (np.max(act_map_resized) - np.min(act_map_resized))
        else:
            act_map_normalized = act_map_resized
        act_map_normalized_list.append(act_map_normalized)
    # 3. project highlighted area in the activation map to original input space by multiplying the normalized activation map
    masked_input_list = []
    for act_map_normalized in act_map_normalized_list:
        masked_input = np.copy(img_array)
        for k in range(3):
            masked_input[0,:,:,k] *= act_map_normalized
        masked_input_list.append(masked_input)
    masked_input_array = np.concatenate(masked_input_list, axis=0)
    # 4. feed masked inputs into CNN model and softmax
    pred_from_masked_input_array = softmax(model.predict(masked_input_array))
    # 5. define weight as the score of target class
    weights = pred_from_masked_input_array[:,cls]
    # 6. get final class discriminative localization map as linear weighted combination of all activation maps
    cam = np.dot(act_map_array[0,:,:,:], weights)
    cam = np.maximum(0, cam)  # Passing through ReLU
    cam /= np.max(cam)  # scale 0 to 1.0
    
    return cam

def softmax(x):
    f = np.exp(x)/np.sum(np.exp(x), axis = 1, keepdims = True)
    return f

VGG16でScore-CAM

ImageNetで学習済みのVGG16に適用してみます.

画像読み込み

from keras.preprocessing.image import load_img
import matplotlib.pyplot as plt

orig_img = np.array(load_img('./image/hummingbird.jpg'),dtype=np.uint8)
plt.imshow(orig_img)
plt.show()

hummingbird.png

Score-CAMによって得られるマップ

from keras.applications.vgg16 import VGG16
from gradcamutils import read_and_preprocess_img
import matplotlib.pyplot as plt

model = VGG16(include_top=True, weights='imagenet')
layer_name = 'block5_conv3'
img_array = read_and_preprocess_img('./image/hummingbird.jpg', size=(224,224))

score_cam = ScoreCam(model,img_array,layer_name)

plt.imshow(score_cam)
plt.show()

heatmap.png

引数解説

  • model: kerasのモデルインスタンス
  • img_array: 注視領域を判定したい画像の前処理後のデータ.model.predict(img_array)のように,すぐにpredictを実行できる形であること
  • layer_name: 最終convolution層の直後のactivation層の名前.activation層がconvolution層に含まれている場合はconvolution層の名前でよい.層の名前はmodel.summary()で確認できる.
  • max_N: 私が勝手に実装した高速化のための設定値.-1ならオリジナルのScore-CAM.自然数を指定すると,CNNの推論回数をその数まで削減する.推奨値は10くらい.大きい値は処理時間が伸びるだけでヒートマップにはあまり影響がありませんが,小さくしすぎるとヒートマップが変になってきます.

その他注意

  • RGB, BGRなどの3チャンネル画像を入力にもつモデルを想定しています.
  • 層の多いモデルは,最終convolution層での配列の座標と入力画像での縦横の座標が乖離していることがあり,いい感じのヒートマップが出ないことがあります.ResNetやXceptionやMobileNetなどの有名なモデルを使う場合,層の深さに注意してください

Grad-CAM, Grad-CAM++と比較

上で得られたヒートマップを,元画像と重ねて表示してみます.

Grad-CAM, Grad-CAM++についてはgradcam++ for kerasのコードを使用させていただきました.

実行コードはgithubにあります.

sample_outputs.png

  • emphasizedと表示されている画像は,ヒートマップに閾値処理を入れて,よりクッキリと注視部位を表したものです.
  • Score-CAMは注視部位を満遍なく拾えているように見えます.
  • Guided Backpropagationは一応掲載していますが,こちらで指摘されているように,ニューラルネットの情報を反映していない疑いがあります.
  • 注視している輪郭等を抽出するならば,画像としての勾配を表示したほうがまだマシということで,画像としてのgradを計算して重ねたものが最下段の画像です.

他の画像についても結果だけ表示します.

dog.png

spoonbill.png

border_collie.png

処理速度比較

Google colaboratoryで処理速度を測ります.GPU使用.

print("Grad-CAM")
%timeit grad_cam = GradCam(model, img_array, layer_name)
print("Grad-CAM++")
%timeit grad_cam_plus_plus = GradCamPlusPlus(model, img_array, layer_name)
print("Score-Cam")
%timeit score_cam = ScoreCam(model, img_array, layer_name)
print("Faster-Score-Cam N=10")
%timeit faster_score_cam = ScoreCam(model, img_array, layer_name, max_N=10)
print("Faster-Score-Cam N=3")
%timeit faster_score_cam = ScoreCam(model, img_array, layer_name, max_N=3)
print("Guided-BP}")
%timeit saliency = GuidedBackPropagation(guided_model, img_array, layer_name)
Grad-CAM
1 loop, best of 3: 196 ms per loop
Grad-CAM++
1 loop, best of 3: 221 ms per loop
Score-Cam
1 loop, best of 3: 5.24 s per loop
Faster-Score-Cam N=10
1 loop, best of 3: 307 ms per loop
Faster-Score-Cam N=3
The slowest run took 4.45 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 238 ms per loop
Guided-BP}
1 loop, best of 3: 415 ms per loop

このように,Score-CAMは非常に処理が重いことがわかります.Grad-CAMの25倍以上も時間がかかっています

処理速度改善(Faster-Score-CAM)

最終convolution層の出力(VGG16だと512チャンネル)を使って実験したところ,最終的なヒートマップの生成には数個のチャンネルが支配的であると考えて,各チャンネルのうち,潜在変数マップの分散が大きいものを優先的にマスク画像として使う という処理を加えたものがFaster-Score-CAMになります.(名前は勝手につけました.max_N=-1とすればオリジナルのScore-CAMになります)

効果は処理速度比較で載せたとおりで,10倍以上の高速化が可能です.
それでもGrad-CAM++のほうが高速です.

自前のモデルを使用する場合

実用性を確認するため,オープンデータセットで学習させた自前のモデルにScore-CAMを適用します.

データセットとしてDAGMデータセット,モデルとしてResNet(80層程度の浅いもの)を使います.

DAGMデータセット準備

DAGMデータセットをダウンロードして解凍した場所のパスdagm_pathは適宜書き換えてください.

from keras.utils import to_categorical
import numpy as np  
import glob
from sklearn.model_selection import train_test_split  
from gradcamutils import read_and_preprocess_img

num_classes = 2                               
img_size = (224,224)
dagm_path = "./DAGM"

def get_dagm_data(names):
    x = []
    y = []
    for i, name in enumerate(names):
        for path in glob.glob(f"{dagm_path}/{name}/*.png"):    
            img_array = read_and_preprocess_img(path, size=img_size)
            x.append(img_array)  
            y.append(i) 

    x = np.concatenate(x, axis=0)   
    y = np.array(y)  

    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=111)

    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)

    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')
    return x_train, x_test, y_train, y_test

x_train, x_test, y_train, y_test = get_dagm_data(["Class1","Class1_def"])

ResNet準備

横着してkerasのapplicationsに入っているResNetから切り取ったものを使用します.(あまり良い書き方ではありませんが動くので良しとします)

from keras.applications.resnet50 import ResNet50
from keras.models import Model
from keras.optimizers import Adam
from keras.layers import Dense, Input, Activation, GlobalAveragePooling2D
from keras.callbacks import EarlyStopping, ModelCheckpoint

def build_ResNet():
    model = ResNet50(include_top=True, input_tensor=Input(shape=(img_size[0],img_size[1],3)))

    x = model.layers[-98].output
    x = Activation('relu', name="act_last")(x)
    x = GlobalAveragePooling2D()(x)
    x = Dense(2, name="dense_out")(x)
    outputs = Activation('softmax')(x)

    model = Model(model.input, outputs)
    # model.summary()

    model.compile(loss='binary_crossentropy',
                  optimizer=Adam(amsgrad=True),
                  metrics=['accuracy'])
    return model

model = build_ResNet()

es_cb = EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='auto')
chkpt = './resnet_weight_DAGM.h5'
cp_cb = ModelCheckpoint(filepath = chkpt, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True, mode='auto')

epochs = 15
batch_size = 32

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    validation_data=(x_test, y_test),
                    callbacks=[es_cb,cp_cb],
                    class_weight={0: 1., 1: 6.},
                    shuffle=True)

# 重みをロード
model.load_weights('./resnet_weight_DAGM.h5')

傷あり画像に対してGrad-CAM, ++, Score-CAM適用

import matplotlib.pyplot as plt
import cv2
import numpy as np
from gradcamutils import GradCam, GradCamPlusPlus, ScoreCam, GuidedBackPropagation, superimpose, read_and_preprocess_img

def build_ResNet_and_load():
    model = build_ResNet()
    model.load_weights('./resnet_weight_DAGM.h5')
    return model

img_path = f'{dagm_path}/Class1_def/12.png'
orig_img = np.array(load_img(img_path),dtype=np.uint8)
img_array = read_and_preprocess_img(img_path, size=(224,224))

layer_name = "act_last"

grad_cam=GradCam(model,img_array,layer_name)
grad_cam_superimposed = superimpose(img_path, grad_cam)
grad_cam_emphasized = superimpose(img_path, grad_cam, emphasize=True)

grad_cam_plus_plus=GradCamPlusPlus(model,img_array,layer_name)
grad_cam_plus_plus_superimposed = superimpose(img_path, grad_cam_plus_plus)
grad_cam_plus_plus_emphasized = superimpose(img_path, grad_cam_plus_plus, emphasize=True)

score_cam=ScoreCam(model,img_array,layer_name)
score_cam_superimposed = superimpose(img_path, score_cam)
score_cam_emphasized = superimpose(img_path, score_cam, emphasize=True)

faster_score_cam=ScoreCam(model,img_array,layer_name, max_N=10)
faster_score_cam_superimposed = superimpose(img_path, faster_score_cam)
faster_score_cam_emphasized = superimpose(img_path, faster_score_cam, emphasize=True)

guided_model = build_guided_model(build_ResNet_and_load)
saliency = GuidedBackPropagation(guided_model, img_array, layer_name)
saliency_resized = cv2.resize(saliency, (orig_img.shape[1], orig_img.shape[0]))

grad_cam_resized = cv2.resize(grad_cam, (orig_img.shape[1], orig_img.shape[0]))
guided_grad_cam = saliency_resized * grad_cam_resized[..., np.newaxis]

grad_cam_plus_plus_resized = cv2.resize(grad_cam_plus_plus, (orig_img.shape[1], orig_img.shape[0]))
guided_grad_cam_plus_plus = saliency_resized * grad_cam_plus_plus_resized[..., np.newaxis]

score_cam_resized = cv2.resize(score_cam, (orig_img.shape[1], orig_img.shape[0]))
guided_score_cam = saliency_resized * score_cam_resized[..., np.newaxis]

faster_score_cam_resized = cv2.resize(faster_score_cam, (orig_img.shape[1], orig_img.shape[0]))
guided_faster_score_cam = saliency_resized * faster_score_cam_resized[..., np.newaxis]

img_gray = cv2.imread(img_path, 0)
dx = cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3)
dy = cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)
grad = np.sqrt(dx ** 2 + dy ** 2)  # 画像の勾配を取得
grad = cv2.dilate(grad,kernel=np.ones((5,5)), iterations=1)  # 太らせる処理
grad -= np.min(grad)
grad /= np.max(grad)  # scale 0. to 1.

grad_times_grad_cam = grad * grad_cam_resized
grad_times_grad_cam_plus_plus = grad * grad_cam_plus_plus_resized
grad_times_score_cam = grad * score_cam_resized
grad_times_faster_score_cam = grad * faster_score_cam_resized

fig, ax = plt.subplots(nrows=4,ncols=5, figsize=(18, 16))
ax[0,0].imshow(orig_img)
ax[0,0].set_title("input image")
ax[0,1].imshow(grad_cam_superimposed)
ax[0,1].set_title("Grad-CAM")
ax[0,2].imshow(grad_cam_plus_plus_superimposed)
ax[0,2].set_title("Grad-CAM++")
ax[0,3].imshow(score_cam_superimposed)
ax[0,3].set_title("Score-CAM")
ax[0,4].imshow(faster_score_cam_superimposed)
ax[0,4].set_title("Faster-Score-CAM")
ax[1,0].imshow(orig_img)
ax[1,0].set_title("input image")
ax[1,1].imshow(grad_cam_emphasized)
ax[1,1].set_title("Grad-CAM emphasized")
ax[1,2].imshow(grad_cam_plus_plus_emphasized)
ax[1,2].set_title("Grad-CAM++ emphasized")
ax[1,3].imshow(score_cam_emphasized)
ax[1,3].set_title("Score-CAM emphasized")
ax[1,4].imshow(faster_score_cam_emphasized)
ax[1,4].set_title("Faster-Score-CAM emphasized")
ax[2,0].imshow(saliency_resized)
ax[2,0].set_title("Guided-BP")
ax[2,1].imshow(guided_grad_cam)
ax[2,1].set_title("Guided-Grad-CAM")
ax[2,2].imshow(guided_grad_cam_plus_plus)
ax[2,2].set_title("Guided-Grad-CAM++")
ax[2,3].imshow(guided_score_cam)
ax[2,3].set_title("Guided-Score-CAM")
ax[2,4].imshow(guided_faster_score_cam)
ax[2,4].set_title("Guided-Faster-Score-CAM")
ax[3,0].imshow(grad, 'gray')
ax[3,0].set_title("grad")
ax[3,1].imshow(grad_times_grad_cam, 'gray')
ax[3,1].set_title("grad * Grad-CAM")
ax[3,2].imshow(grad_times_grad_cam_plus_plus, 'gray')
ax[3,2].set_title("grad * Grad-CAM++")
ax[3,3].imshow(grad_times_score_cam, 'gray')
ax[3,3].set_title("grad * Score-CAM")
ax[3,4].imshow(grad_times_faster_score_cam, 'gray')
ax[3,4].set_title("grad * Faster-Score-CAM")
for i in range(4):
    for j in range(5):
        ax[i,j].axis('off')
plt.show()

class1.png

  • どの手法もうまく傷の位置を検出できているようです.
  • 傷だけを強調するような表示は難しいようです.(最終conv層の情報を取り出しているので仕方ない)

Class2からClass6まで

各クラス5枚ずつ,閾値処理を施した結果だけ掲載します.

Class2

Class2_result_0.png

Class3

Class3_result_0.png

Class4

Class4_result_0.png

Class5

Class5_result_0.png

Class6

Class6_result_0.png

概ね正しく傷の位置を表せているようです.

所感

  • 異常検知において,異常部位の大まかな可視化に使えそう.
  • Grad-CAM, ++, Score-CAMのいずれも,使えるモデルが制限されるのが非常に不便.
    • 最終conv層の座標と入力画像の座標がうまく対応する必要があり,層数の多いモデルが使いづらい.

参照

65
67
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
65
67

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?