4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【機械学習】grad-CAMをpythonで実装し、予測の根拠を視覚化する。

Last updated at Posted at 2023-09-27

更新情報
2023/12/19 文章を補足。

はじめに

今回過学習を起こしている問題のあるモデルで可視化を行っています。(モデルの問題の有無を確認できる。)
過学習を起こしていないモデルで試したバージョンは、こちらの記事で試しています。

CAMとgrad-CAMを軽くおさらいした後にpythonで実装していきます。実装メインなので理論的な詳細には触れません。
最終的に医用画像を正しく認識したのか(していないのか)を視覚化します。
全体のコード例(jupyter notebook)は以下のgitリポジトリから見れます。
https://github.com/senbe0/gradCam_notebook.git

1. CAM、 grad_CAMおさらい

例えば、入力画像をクラス分類するようなモデルがある。しかしこのモデルは外側からでは、何を根拠に、入力画像をクラス分類しているのか分からず、ブラックボックスである。

CAM

そこで、aiがどこを見て判断しているのかを可視化する手法が考案された。それが CAM (Class Activation Map)である。

しかし純粋なCAMは、特徴マップを直接ソフトマックス層の前に置く必要があるため、予測の直前に畳み込みマップに対するグローバル平均プーリングを行う、特定の種類のCNNアーキテクチャにしか適用できない。
(畳み込み特徴マップ→グローバル平均プーリング→ソフトマックス層というアーキテクチャにしか対応できない。代表的な例はGoogleNet。↓↓↓)
GoogleNet.png
*googlenet最終層辺り抜粋(https://arxiv.org/abs/1409.4842 参照日2023/9/27)。

grad-CAM

この問題に対処する為に、CAMを改良したgrad-CAMがRamprasaath R. Selvarajuらによって考案された。畳み込み層の特徴マップと、ターゲットとなるクラスに対する勾配(偏微分)を使用し、(ニューラルネットワークの)アーキテクチャに依存しない形で、任意の畳み込み層における視覚化を可能とする。他にもアーキテクチャに依存しない手法は考案されたが、grad-CAMはその中でも計算量が少なく、精度が高い。

2. 計算式

対象となるクラスの最終出力層の値を、(通常一番最後の[3])畳み込み層の特徴マップで微分し(勾配を求め)、グローバルアベレージpoolingする。

grad_CAM.png
こうして得られた特徴マップに対する重みと、畳み込み層の特徴マップを、線型結合する。
最後にRelu関数で負の値を0にクリップし、得られたのが、Grad-CAM(Gradient-weighted Class Activation Mapping)。
relu.png

3. pythonで実装する。

kaggleに公開されているnotebookから、それなりに精度が高いこちら([4]https://www.kaggle.com/code/jonaspalucibarbosa/chest-x-ray-pneumonia-cnn-transfer-learning/notebook )のアーキテクチャを拝借する。今回使用するデータセットはこちら([5]https://data.mendeley.com/datasets/rscbjbr9sj/3 )の約6000枚の小児患者の胸部X線画像を拝借する。

今回は、grad-CAMの実装がメインなので、機械学習モデルの実装部分は[4]のnotebookを参照していただきたい。なお、jupyter notebookを使用しているという前提で実装していく。

参照コード通りに実装すれば、以下のような構成のモデルが得られる。

> model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0     

 conv2d (Conv2D)             (None, 222, 222, 16)      448  

 batch_normalization (BatchN  (None, 222, 222, 16)     64        
 ormalization)

 activation (Activation)     (None, 222, 222, 16)      0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 111, 111, 16)     0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 111, 111, 16)      0         
                                                                 
 conv2d_1 (Conv2D)           (None, 109, 109, 32)      4640      
                                                                 
 batch_normalization_1 (Batc  (None, 109, 109, 32)     128       
 hNormalization)                                                 
                                                                 
 activation_1 (Activation)   (None, 109, 109, 32)      0         
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 54, 54, 32)       0         
 2D)                                                             
                                                                 
 dropout_1 (Dropout)         (None, 54, 54, 32)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 52, 52, 64)        18496     
                                                                 
 conv2d_3 (Conv2D)           (None, 50, 50, 64)        36928     
                                                                 
 batch_normalization_2 (Batc  (None, 50, 50, 64)       256       
 hNormalization)                                                 
                                                                 
 activation_2 (Activation)   (None, 50, 50, 64)        0         
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 25, 25, 64)       0         
 2D)                                                             
                                                                 
 dropout_2 (Dropout)         (None, 25, 25, 64)        0         
                                                                 
 flatten (Flatten)           (None, 40000)             0         
                                                                 
 dense (Dense)               (None, 64)                2560064   
                                                                 
 dropout_3 (Dropout)         (None, 64)                0         
                                                                 
 dense_1 (Dense)             (None, 1)                 65        
                                                                 
=================================================================
Total params: 2,621,089
Trainable params: 2,620,865
Non-trainable params: 224

実装

ここから追加で、grad-CAMを実装していく。

1-1 モデルの準備

2章で示した計算過程に従って実装する為、まずこのモデルの一番深い最後の畳み込み層[3](つまり"conv2d_3")の特徴マップを抽出していく。

# 作成した学習済みのモデルをロードする。
> model = load_model("Chest_XRay_model.h5")

model.get_layerは、指定した層のみではなく、指定した層以前のレイヤーのメタデータも含む。そのため、Modelクラスのoutputオプションに、model.get_layerで取り出したレイヤーのオブジェクトを指定してあげれば、計算に必要な(中間層の)情報は渡される。

# 今回、一番深い畳み込み層"conv2d_3"を取得する。
> target_layer = model.get_layer("conv2d_3")

# 特徴マップ抽出用の新たなモデルを定義する。outputsオプションで何を出力するかを設定している。
> intermediate_model = Model(inputs=[model.inputs], outputs=[target_layer.output, model.output])

1-2 インプット画像の準備

拝借したコード[4]の通りに、データを前処理すると、以下のように1枚の画像を指定できる。

> idx = 1
> print(ds_test[idx][0].shape)
(1, 224, 224, 3)

モデルに入力する画像を一枚、取り出してみる。

# インデックス番号100を指定している。
> idx = 100
> input_img = ds_test[idx][0]
> plt.imshow(input_img[0])

chest.png

1-3 grad-CAMの計算

grad-CAMの値を計算する為には、勾配を計算する必要がある。tensorflowでは、勾配を計算するために、自動微分の仕組み(GradientTape)を提供している。tf.GradientTape()のコンテキスト内で実行された操作は記録され、その後、tape.gradient()メソッドを使用してこれらの操作に関連する損失に対する各変数の勾配を取得できる。

conv_outputには畳み込み層の特徴マップが得られる。
predictionsが最終層の出力である。
この2つの変数を使い、勾配を計算できるという事である。

import tensorflow as tf
import numpy as np

# Grad-CAMの計算
input_img_tensor = tf.convert_to_tensor(input_img.reshape(1, 224, 224, 3))  # NumPy配列をテンソルに変換
with tf.GradientTape() as tape:
    tape.watch(input_img_tensor)  # ここでテンソルをwatch
    conv_output, predictions = intermediate_model(input_img_tensor)
    class_idx = np.argmax(predictions[0])
    loss = predictions[:, class_idx]

先ほど、勾配を計算するための特徴マップと、最終層の出力が得られたので実際に計算してみる。

ここでインデックス0を指定しているが、単に今回入力した1枚目の画像に対する勾配を、指定しているだけである。
なお、コード(1)と(2)は以下の式の計算に該当する。
grad_CAM.png

grads = tape.gradient(loss, conv_output)[0] # コード(1)

得られたマップをグローバルアベレージpoolingする。

weights = np.mean(grads, axis=(0, 1)) # コード(2)

最後に得られたマップ(weights)と、畳み込み層の特徴マップ(conv_output)を線型結合し、Reluで負の値をクリップしたらgrad-CAMが得られる。
relu.png

cam = np.dot(conv_output[0], weights) # 線形結合(Linear combination)
cam = np.maximum(cam, 0)  # ReLU

1-4 得られたgrad-CAMをプロットしてみる

実際に得られたマップを視覚化してみる。プロットするために正規化をする(範囲を-1から1に収める)。

画像からも分かるとおり、背景が活性化してしまっている。つまりこのモデルは、病気かどうかを判別するのに背景を根拠に判別してしまっている。

ちなみにこの学習済みモデルは数値上の精度は、9割を超えており一見優秀なモデルにみえる。しかし実際には使い物にならないモデルである事が分かった。

import matplotlib.pyplot as plt


cam = cam / cam.max()  # 正規化

plt.subplot(121)
plt.imshow(cam, cmap="jet")

plt.subplot(122)
plt.imshow(input_img[0], cmap='gray')

plt.show()

cam_normal.png

1-5 画像の合成(おまけ)

最後に画像を合成しておこう。得られたgrad-CAMと入力画像サイズが合わないので、grad-CAMをアップサンプリングする必要がある。

from scipy.ndimage import zoom

# 50x50の特徴マップを224x224にアップサンプリング
zoom_factor = 224 / 50  # 224: ターゲットのサイズ, 50: 元のサイズ
cam_resized = zoom(cam, zoom_factor)

# アップサンプリングした画像と入力画像の位置が合わない場合は、
# 以下のコードで微調整できる。必要に応じてアンコメント。

# grad_CAMの位置を調整する。
# shift_x, shift_y = 4, 0  # ずらすピクセル数
# cam_resized = np.roll(cam_resized, shift_x, axis=0)
# cam_resized = np.roll(cam_resized, shift_y, axis=1)
# プロット
plt.subplot(131)
plt.title("Original Image")
plt.imshow(input_img[0], cmap="gray")

plt.subplot(132)
plt.title("grad_CAM")
plt.imshow(cam_resized, cmap="jet")

plt.subplot(133)
plt.title("Combined")
plt.imshow(input_img[0], cmap="gray")
plt.imshow(cam_resized, cmap='jet', alpha=0.5) 

plt.show()

nor_cam_con.png
まだ見にくいので、活性化(反応)しなかったピクセルの色を無色透明にする。そのために既存のjetカラーマップを少しカスタマイズする。

import matplotlib.colors as mcolors

# カラーマップ
cmap = plt.get_cmap('jet')
cmaplist = [cmap(i) for i in range(cmap.N)]
# 0の値を無色透明にする
cmaplist[0] = (1, 1, 1, 0.0)  # (R, G, B, Alpha)
cmap_custom = mcolors.LinearSegmentedColormap.from_list('custom_cmap', cmaplist, cmap.N)
cmap_custom

左端(0の値)だけ無色透明にした。
colormap.png
作成したカラーマップを使用し、改めて元の入力画像とgrad-CAMの画像を合成する。

import matplotlib.colors as mcolors

# 元の画像
plt.imshow(input_img[0], cmap='gray')
# CAMを透過してオーバーレイ表示
plt.imshow(cam_resized, cmap=cmap_custom, alpha=0.7, vmin=0)

plt.show()

combined_big.png

まとめ

今回grad-CAMを使用して、医用画像を判別する機械学習モデルが何を根拠に判別しているかの視覚化を行った。
結果、作成したモデルは背景にフォーカスしてしまっていた事が判明した。バイアス等や学習率等を正しく学習するように、微調整する必要がある。
正しく学習させれば、医用画像のどの部分に疾患があるのかなどの判別に役立つかもしれない。正しく学習させたバージョンはまた後日。。。

参照文献、注釈等

[1]. Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra, https://arxiv.org/abs/1610.02391, Oct 2019.
[2]. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich, Going Deeper with Convolutions, https://arxiv.org/abs/1409.4842, Nov 2021.
[3]. 一般的に深い層ほど高度な視覚的特徴を捉える能力がある。Matthew D. Zeiler, Rob Fergus, Visualizing and Understanding Convolutional Networks, https://arxiv.org/abs/1311.2901, Nov 2013.
[4]. Jonas Paluci Barbosa, https://www.kaggle.com/code/jonaspalucibarbosa/chest-x-ray-pneumonia-cnn-transfer-learning/notebook, Licence: Apache 2.0, 2021.
[5]. Daniel Kermany, Kang Zhang, Michael Goldbaum, https://data.mendeley.com/datasets/rscbjbr9sj/3, Licence: CC BY 4.0, June 2018.

4
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?