35
28

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

kerasとtensorflowでGrad-CAMを実装してみた

Last updated at Posted at 2020-08-16

はじめに

今回は自前のCNNモデルにGradCAMを実装してみました。
GoogleColaboratoryを使っていますが、ローカルでも、jupyter notebookでも普通に使えると思います。

CNNのモデルはkerasを用いて実装、学習済みのAlexNetになります。

画像は実際の研究で使用しているモノであるため公開は出来ないので、参考にする方は環境に合わせて読み替えてください。

Grad-CAMとは?

そもそもGrad-CAMってなに?って方は
深層学習は画像のどこを見ている!? CNNで「お好み焼き」と「ピザ」の違いを検証
の記事で結構詳しく見てみてください。

対象読者

  • Grad-CAMってのがあるのは知っているけど実際にどう実装するの?ってなっている人
  • tensorflowのバージョンに苦しんでいる人
  • 他の記事を見たけど↓のエラーが出て困っている人 (よーわからん...)
> RuntimeError: tf.gradients is not supported when eager execution is enabled. Use tf.GradientTape instead.

バージョン

実装に移る前にバージョンを確認しておきます
セルに以下のコードを入れてバージョンの確認をしておきます

import tensorflow as tf
import keras
print('tensorflow version: ', tf.__version__)
print('keras version: ', keras.__version__)

実行結果

tensorflow version:  2.3.0
keras version:  2.4.3

実装

まずは必要なことを済ませましょう
↓のコマンドでグーグルドライブをマウントします。

from google.colab import drive
drive.mount('/content/drive')

画像やモデルをを読み込むために現在のディレクトリのパスを定義しておきます。
My Drive以下はみなさんの環境に合わせて変更してください。

current_directory_path = '/content/drive/My Drive/Research/AlexNet/'

必要なモジュールのインポート

import numpy as np
import cv2

# 画像用
from keras.preprocessing.image import array_to_img, img_to_array, load_img
# モデル読み込み用
from keras.models import load_model
# Grad−CAM計算用
from tensorflow.keras import models
import tensorflow as tf

定数の定義
ここも皆さんの環境に合わせて変更してください。

IMAGE_SIZE  = (32, 32)

Grad-CAMを計算するメソッド

def grad_cam(input_model, x, layer_name):
    """
    Args: 
        input_model(object): モデルオブジェクト
        x(ndarray): 画像
        layer_name(string): 畳み込み層の名前
    Returns:
        output_image(ndarray): 元の画像に色付けした画像
    """

    # 画像の前処理
    # 読み込む画像が1枚なため、次元を増やしておかないとmode.predictが出来ない
    X = np.expand_dims(x, axis=0)
    preprocessed_input = X.astype('float32') / 255.0    

    grad_model = models.Model([input_model.inputs], [input_model.get_layer(layer_name).output, input_model.output])
    
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(preprocessed_input)
        class_idx = np.argmax(predictions[0])
        loss = predictions[:, class_idx]

    # 勾配を計算
    output = conv_outputs[0]
    grads = tape.gradient(loss, conv_outputs)[0]

    gate_f = tf.cast(output > 0, 'float32')
    gate_r = tf.cast(grads > 0, 'float32')

    guided_grads = gate_f * gate_r * grads

    # 重みを平均化して、レイヤーの出力に乗じる
    weights = np.mean(guided_grads, axis=(0, 1))
    cam = np.dot(output, weights)

    # 画像を元画像と同じ大きさにスケーリング
    cam = cv2.resize(cam, IMAGE_SIZE, cv2.INTER_LINEAR)
    # ReLUの代わり
    cam  = np.maximum(cam, 0)
    # ヒートマップを計算
    heatmap = cam / cam.max()

    # モノクロ画像に疑似的に色をつける
    jet_cam = cv2.applyColorMap(np.uint8(255.0*heatmap), cv2.COLORMAP_JET)
    # RGBに変換
    rgb_cam = cv2.cvtColor(jet_cam, cv2.COLOR_BGR2RGB)
    # もとの画像に合成
    output_image = (np.float32(rgb_cam) + x / 2)  
    
    return output_image

Grad-CAMを計算

まずは、モデルと画像を読み込みます。
それぞれのパスはみなさんの環境に合わせてください。

model_path = current_directory_path + '/model.hdf5'
image_path = current_directory_path + '/vis_images/1/2014_04_1_3.png'

model = load_model(model_path)
x = img_to_array(load_img(image_path, target_size=IMAGE_SIZE))

読み込んだ画像が合っているか確認します。

array_to_img(x)

Grad-CAMを計算します

target_layer = 'conv_filter5'
cam = grad_cam(model, x, target_layer)

計算した画像を確認します。

array_to_img(cam)

おわりに

今回はGoogleColaboratoryでGrad-CAMを実装してみました。
tensorflowのバージョンで苦しんだ人の参考になっていれば嬉しいです。

ここ間違ってるよ!ってのがあったらぜひ教えてほしいです!

参考

深層学習は画像のどこを見ている!? CNNで「お好み焼き」と「ピザ」の違いを検証
Grad CAM implementation with Tensorflow 2

35
28
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
35
28

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?