0
1

More than 1 year has passed since last update.

tf-keras-visで画像Classificationの判断根拠を可視化する(Grad-CAM, Grad-CAM++, Score-CAM, Layer-CAM)

Posted at

要約

  • 画像Classificationモデルの判断根拠を可視化する
  • CAM系の手法をふんわり理解する
  • tf-keras-visを使ってみる

環境

  • Python 3.6.8
  • tensorflow 2.3.0
  • tf-keras-vis 0.8.0

手法

より詳しく正確な説明がこちらでされています。
ここではふんわりとまとめました。

Grad-CAM

  • 「判断根拠となる領域ほど重みが大きく更新される」という仮定に基づく。
    gradcam.png
  • 可視化対象クラスにおける最終層の勾配(①)に対してGlobalAveragePoolingを行う(②)
  • (②)を各チャネルの重要度として、特徴マップに掛け合わせる(③)
  • (③)のsummationに対してReLUを掛けたもの(④)を得る →Grad-CAM
  • 文献リンク

Grad-CAM++

  • Grad-CAMの発展形
  • Grad-CAMとの差分
    • 特徴マップ(上図③)に対してReLUを掛けるようにした。
    • 各チャネルの重要度(上図②)のsummationが対象クラスのスコアになるようにした。
  • 文献リンク

Score-CAM

  • 勾配情報を用いない可視化方法。Grad-CAMやGrad-CAM++における勾配消失の懸念を払拭した。
    scorecam.png
  • Phase 1では特徴マップを元画像と同じサイズにUpsamplingする。
  • Phase 2ではPhase 1の特徴マップをマスクとして元画像に掛け合わせ、そのoverlay画像を元にモデルに推論させる →これをPhase 1の特徴マップの重要度とする
    • 対象クラスに対する推論スコアが大きいなら、そのマスク自体が推論に大きな影響を与えていると考えられる
  • Phase 1の特徴マップとPhase 2の重要度を掛け合わせる →Score-CAM
  • 文献リンク

Layer-CAM

  • 従来のX-CAMでは深い層の特徴マップを対象としており、元画像サイズにUpsamplingことで粗くなってしまうという課題があった。
    layercam.png
  • 浅い層も含めてX-CAMを適用し、得られたSaliency mapをsummationする。
  • 文献リンク

コード

今回はimagenetで提供されているピザの画像を使います。おいしそう。
imagenet_pizza.png

import numpy as np
import pprint
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore
from tf_keras_vis.gradcam import Gradcam
from tf_keras_vis.gradcam_plus_plus import GradcamPlusPlus
from tf_keras_vis.scorecam import Scorecam
from tf_keras_vis.layercam import Layercam

MobileNetV2のimagenet学習済モデルを使用します。

model = tf.keras.applications.mobilenet_v2.MobileNetV2(weights='imagenet', include_top=True)

# 画像の読み込みと前処理
img = load_img('n07873807_pizza.JPEG', target_size=(224, 224))
img = np.asarray(img)
X = preprocess_input(img)
X = X[tf.newaxis, ...]
plt.axis('off')
plt.imshow(img)

# 推論
predict = model.predict(X)
result = tf.keras.applications.mobilenet_v2.decode_predictions(predict, top=5)
pprint.pprint(result)

おいしそうなpizzaのclass indexを調べておきます。

class_indexes = tf.keras.applications.mobilenet_v2.decode_predictions(np.expand_dims(np.arange(1000), 0), top=1000)
for class_info in class_indexes[0]:
    if class_info[1] == 'pizza':
        print(class_info[2])

963

ここでtf-keras-visで提供されているクラスを使用します。
今回はGrad-CAM, Grad-CAM++, Score-CAM, Layer-CAMで可視化します。

# モデル最終層のactivationをlinearに置き換えるクラス
replace2linear = ReplaceToLinear()
# Scoreを取得するクラス
class_index = 963
score = CategoricalScore([class_index])

# Generate instance
gradcam = Gradcam(model, model_modifier=replace2linear, clone=True)
gradcam_pp = GradcamPlusPlus(model, model_modifier=replace2linear, clone=True)
scorecam = Scorecam(model)
layercam = Layercam(model, model_modifier=replace2linear, clone=True)

# Generate heatmap with X-CAM
target_layer_index = -2
heatmap_gradcam = gradcam(score, X, penultimate_layer=target_layer_index)
heatmap_gradcam_pp = gradcam_pp(score, X, penultimate_layer=target_layer_index)
heatmap_scorecam = scorecam(score, X, penultimate_layer=target_layer_index)
heatmap_layercam = layercam(score, X, penultimate_layer=target_layer_index)

元画像にheatmapをoverlayして結果を表示します。

cams = {
    'GradCAM': heatmap_gradcam,
    'GradCAM++': heatmap_gradcam_pp,
    'ScoreCAM': heatmap_scorecam,
    'LayerCAM': heatmap_layercam
}

fig, axes = plt.subplots(1, len(cams), figsize=(20, 10))
for i, (cam_name, cam) in enumerate(cams.items()):
    heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
    axes[i].imshow(img)
    axes[i].imshow(heatmap, cmap='jet', alpha=0.5) 
    axes[i].set_title(cam_name)
    axes[i].axis('off')
plt.tight_layout()
plt.show()

cams_results.png
Layer-CAMは全体を注視している印象です。

この画像でも試してみました。対象クラスはmenuです。
imagenet_menu.png
menuの判断として、料理のカテゴリが記載されている画像上部を注視していることが分かります。
cams_results_menu.png

References

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1