5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

tf-keras-visを使ってGradCAM、GradCAM++、ScoreCAM、Faster-ScoreCAM

Posted at

はじめに

ディープラーニングでの予測結果における特徴部位可視化でよく使われるGradCAMなどを出力してくれるツールtf-keras-visですが、すごく簡単に出力できるうえGradCAMだけでなくGradCAM++、ScoreCAM、Faster-ScoreCAM、Vanilla Saliency、SmoothGradといろんな種類の可視化を行うことができます
そんなtf-keras-visですが、公式のサンプルだとloss関数のクラスインデックスを書き換えて利用する方式なので、クラスインデックス・予測画像・認識モデルの3つの引数で取得するサンプルを作りました

tf-keras-vis
https://github.com/keisen/tf-keras-vis

検証環境

この記事の内容は、以下の環境で検証しました。
Python 3.6.9
TensorFlow 2.4.0-rc0
tf-keras-vis 0.5.3

環境準備

すでに必要なモジュールが入っているならとばしてください

pip install --upgrade tf-keras-vis matplotlib

モジュールの読み込み

import os
import glob
import numpy as np 
import matplotlib.pyplot as plt
import tensorflow as tf

モデル用モジュール

from tensorflow.keras.applications.vgg16 import VGG16 as Model
from tensorflow.keras.applications.vgg16 import preprocess_input

モデル読み込み

model = VGG16(include_top=False, weights='imagenet')
# model = tf.keras.models.load_model('mymodel.h5')
model.summary()

tf-keras-vis用モジュール

from tf_keras_vis.saliency import Saliency
from tf_keras_vis.gradcam import Gradcam
from tf_keras_vis.gradcam import GradcamPlusPlus
from tf_keras_vis.scorecam import ScoreCAM
from tf_keras_vis.utils import normalize
from matplotlib import cm

特徴可視化マップ取得関数

SmoothGrad

def GetSmoothGrad(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  saliency = Saliency(model,model_modifier=model_modifier,clone=False)
  cam = saliency(loss, img, smooth_samples=20, smooth_noise=0.20)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

GradCAM

def GetGradCAM(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  gradcam = Gradcam(model,model_modifier=model_modifier,clone=False)
  cam = gradcam(loss, img, penultimate_layer=-1)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

GradCAM++

def GetGradCAMPlusPlus(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  gradcam = GradcamPlusPlus(model,model_modifier=model_modifier,clone=False)
  cam = gradcam(loss, img, penultimate_layer=-1)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

ScoreCAM

def GetScoreCAM(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  scorecam = ScoreCAM(model,model_modifier=model_modifier,clone=False)
  cam = scorecam(loss, img, penultimate_layer=-1)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

Faster ScoreCAM

def GetFasterScoreCAM(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  scorecam = ScoreCAM(model,model_modifier=model_modifier,clone=False)
  cam = scorecam(loss, img, penultimate_layer=-1, max_N=10)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

取得テスト

from tensorflow.keras.preprocessing.image import load_img

IMAGE_PATH = 'Image.JPG'
CAT_CLASS_INDEX = 0

# Load image
img = load_img(IMAGE_PATH, target_size=(224, 224))
# Preparing input data
X = preprocess_input(np.array(img))

# Get SmoothGrad
heatmap = GetSmoothGrad(CAT_CLASS_INDEX, X, model)
plt.figure(figsize=(20,20))
plt.subplot(1, 3, 1)
plt.title('SmoothGrad')
plt.imshow(heatmap)

# Get GradCAM++
heatmap = GetGradCAMPlusPlus(CAT_CLASS_INDEX, X, model)
plt.subplot(1, 3, 2)
plt.title('GradCAMPlusPlus')
plt.imshow(heatmap)

# Get FasterScoreCAM
heatmap = GetFasterScoreCAM(CAT_CLASS_INDEX, X, model)
plt.subplot(1, 3, 3)
plt.title('FasterScoreCAM')
plt.imshow(heatmap)
plt.show()

image.png

注意点

CAT_CLASS_INDEXは分類のクラスidではなくindexであることにご注意ください
また各関数に投げる画像はサンプルにあるように
img = load_img(IMAGE_PATH, target_size=(224, 224))
X = preprocess_input(np.array(img))
学習時と同じ大きさと状態の画像を投げてください

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?