68
70

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchを使ってCNNの判断根拠を可視化するGrad-CAMを実装してみた

Last updated at Posted at 2021-01-31

はじめに

AIの説明性とか判断根拠に関するトピックに興味があって、画像系のデータであれば、Grad-CAMとか自然言語ならAttentionとか、いろんな手法が研究されてると思います。
そこで今回はAIの判断根拠に関するトピックで幅広く使われている(と思われる)、CNNが画像分類する際、画像のどこを見てそう判断したのかを可視化する手法であるGrad-CAMを実際にPyTorchを使って実装して試してみようと思います。

今回の記事はあくまで、まずはGrad-CAMをPyTorchで実装してみる、ということが目的であり、細かな理論面などには触れておりませんが、実装面についてはできるだけ細かくコメントを残しながら説明しているつもりです。少しでも同様のことをしようとして詰まっている方への助けとなれば幸いです。
また当方画像系に不慣れなこともあり、もしかしたらおかしな記載をしている可能性がありますので、お気づきの際はご指摘していただけると嬉しいです。

Grad-CAMに関する参考文献

いずれも理論や実装面でとても参考になるものばかりです。

  1. Grad-CAMの論文
  2. 【CNN+Grad-CAM】仕組みの解説と画像の予測根拠可視化
  3. 深層学習は画像のどこを見ている!? CNNで「お好み焼き」と「ピザ」の違いを検証
  4. 【PyTorch】GradCAMを用いたCNN(VGG16)の可視化
  5. PyTorchでGrad-CAMによるCNNの可視化.
  6. CNNを使った分類問題の判断根拠(画像編)

実装

今回は学習済みモデルのVGG19を使い、Google clabを使って実装していきます。(特にGPUは使わないので、ローカルでも全然問題ないと思われますが。)

準備

colabにGoogle Driveをマウントしていろいろインポートします。判断根拠を可視化したい画像は事前にGoogle Driveに格納しておきます。

# Google Driveをcolabにマウント
from google.colab import drive
drive.mount('/content/drive')

# 各種ライブラリインポート
%matplotlib inline
import urllib
import pickle
import cv2
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision import models

import matplotlib.pyplot as plt
import seaborn as sns
# colabをダークモードにしていると、グラフ表示したときに目盛が見えなくなってしまうことに対する対処
sns.set_style("white") 

# 判断根拠を可視化したい画像をここに置いとく
drive_dir = "drive/My Drive/Colab Notebooks/grad-cam/"

# ImageNetのラベル情報をダウンロード
labels = pickle.load(urllib.request.urlopen('https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl') )
# こんな感じで辞書形式で1000個のラベル情報を取得できます。VGG19は画像からこれらの1000個を分類するモデルです。
#{0: 'tench, Tinca tinca',
# 1: 'goldfish, Carassius auratus',
# 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
# ・・・
# 997: 'bolete',
# 998: 'ear, spike, capitulum',
# 999: 'toilet tissue, toilet paper, bathroom tissue'}

# VGG19の学習済みモデルを読み込む
model = models.vgg19(pretrained=True)

実際に今回使う画像データをcolab上で可視化してみましょう。私は以下のように2匹の猫の画像を用意しました。

# 検証する画像を読み込んで表示する
test_image1 = Image.open(drive_dir + "neko.png")
test_image2 = Image.open(drive_dir + "neko2.JPG")

plt.imshow(test_image1)
plt.show()

plt.imshow(test_image2)
plt.show()

image.pngimage.png

これらの画像をVGG19に流せるようにテンソル型に変換しましょう。VGG19は(224x224)の画像サイズである必要があります。

追記
本記事の結果はコメントでご指摘のある通り、当方のミスで以下の画像の前処理の段階でRGBの正規化を行っていません。
その結果として上図の明らかに猫な画像が一部猫として認識されていませんが、正規化を正しく行うことで両画像は猫と判断されます。
実際にお手元で試される方は正規化の処理をお忘れなく。

# 検証画像をVGG19のネットワークに通せるように変換する処理

# 画像の縦横を224x224にしてtensor型に変換
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], ←この処理を忘れないように。
    #                     std=[0.229, 0.224, 0.225])
])
test_image1_tensor = transform(test_image1)
test_image2_tensor = transform(test_image2)
print(test_image1_tensor.size()) # 色情報(チャネル)x高さx幅
# torch.Size([3, 224, 224])

# ネットワークに通すには(batch_size, channel, height, width)じゃないといけないので、
# unsqueezeでbatch_size=1を挿入してます。
test_image1_tensor = test_image1_tensor.unsqueeze(0)
test_image2_tensor = test_image2_tensor.unsqueeze(0)
print(test_image2_tensor.size())
# torch.Size([1, 3, 224, 224])

実際にこれらの画像をVGG19で予測させてみましょう。1枚目の私のQiitaのアイコンにもしてる猫はなぜかアンゴララビットと予測されてしまいました...
2枚目の猫も正確には茶虎猫(VGG19の正解でいうと281 tabby, tabby catだと思いますが...)なんですが、まぁいいでしょう。


# 検証画像をVGG19で予測させてみる
# モデルを検証モードに変更してから画像をネットワークに流す
# eval()は内部のDropoutやBatchNormalizationのオン/オフを制御するためのもの
model = model.eval()
with torch.no_grad():
    predict1_index = model(test_image1_tensor).max(1)[1]
    predict2_index = model(test_image2_tensor).max(1)[1]
    print(int(predict2_index))
    print('検証画像1の予測結果', labels[int(predict1_index)])
    print('検証画像2の予測結果', labels[int(predict2_index)])
# 検証画像1の予測結果 Angora, Angora rabbit
# 検証画像2の予測結果 Siamese cat, Siamese

それでは次からは、2つの画像がなぜこのように判断されたのかをGrad-CAMで可視化していきます!

その前に少しだけ理論的な話

Grad-CAMの理論的な話(数式)を全く知らないと、以降の実装が全くもって意味不明になってしまうので、簡単に根幹となる数式について触れておきます。以下の2つの式は本家の論文から引用しました。

image.png

image.png

基本的にGrad-CAMを実装する際は上式の $L^c_{Grad-CAM}$ を計算し、画像にマッピングすることになります。超簡単に上式について説明すると、$y^c$ は予測結果で、$A^k_{ij}$ は特徴マップ(畳み込み層の出力結果)を表しています。(1)式で予測結果に対する特徴マップの勾配を計算して、その特徴マップの勾配に対して、Global Average Poolingを計算し各特徴マップの重要度 $\alpha^c_k$ を算出します。Global Average Poolingって何?って方は以下の記事がとても簡潔にわかりやすくまとめられているので一読されることをお勧めします。

(2)式で(1)で求めた特徴マップの重要度 $\alpha^c_k$ を使って特徴マップ $A^k$ の要素を重み付けします。そして、$ReLU$ を施して完成って流れです。

ちなみになんで最後に$ReLU$ ($ReLU(x) = \max \{ x, 0 \} $) をとるの?って感じですが、論文内では以下のように言及されています。

We apply a ReLU to the linear combination of maps because we are only interested in the features that have a positive influence on the class of interest, $i.e.$ pixels whose intensity should be increased in order to increase $y^c$.
Negative pixels are likely to belong to other categories in the image. As expected, without this ReLU, localization maps sometimes highlight more than just the desired class and perform worse at localization.

つまり興味あるのは予測結果の正の方向に強く影響を与えている画素であり、負の方向に影響を与える画素は他の予測結果に影響を与えている可能性が高い、と。実際に$ReLU$がないと、可視化したい画素以上に強調されてしまうケースがあるようですね。

Grad-CAMの仕組み的は話は論文の他にも参考文献の2.、3.がとてもわかりやすかったです!

では、今度こそGrad-CAMの実装に入ります。実装方法は私かちょこちょこ調べた限り2パターンの方法があるようです。

実装方法その1

あらかじめVGG19のネットワークからGrad-CAMの計算に必要なレイヤー(つまり特徴マップを計算するレイヤー)を切り出して計算する方法です。
参考文献4.であげている以下の記事がそのように実装されており、なるほどーって感じで大変勉強になります。

実装方法その1で紹介する内容はこちらの記事のほぼ写経です。

まずは必要なレイヤーをVGG19モデルから切り出す。VGG19の構造を知らない場合、まずはVGG19の中身を表示して確認しましょう

# VGG19のネットワークを確認する。出力は省略
print(model)

# Grad-CAMを計算するときは特徴マップが必要なので、modelから特徴マップを計算するネットワークのところを取り出す
# 予測結果の出力も必要なので、残りのレイヤーも切り出しておく
features = model.features.eval()
avgpool = model.avgpool.eval()
classifier = model.classifier.eval()

次に、バラしたネットワークを使って予測結果を取得し、その予測結果に対して誤差逆伝播で勾配を計算します。
特にclone()とかdetach()とかが混乱するところかなと思いますが、こちらについては以下のteratailの回答に詳しく違いが記載されています。(とはいえまだちゃんと理解しきれていないですが...)

# 特徴マップを抽出する
feature = features(test_image1_tensor)
print('特徴マップサイズ', feature.size())
# 特徴マップサイズ torch.Size([1, 512, 7, 7])

# 上のteratailの回答を参照
feature = feature.clone().detach().requires_grad_(True)

# 特徴マップをVGG19の残りの全結合層に通して、予測結果を得る
pooled = avgpool(feature)
y_pred = classifier(pooled.view(-1,512*7*7))
pred_index = torch.argmax(y_pred)

# 予測結果に対して誤差逆伝播
y_pred[0][pred_index].backward()

続いて、論文の数式通りの計算を実施して、 $L^c_{Grad-CAM}$ を求めます。

# 特徴マップの勾配(feature.grad)のGlobal Average Poolingを計算する
# つまり512枚の各特徴マップの要素の平均値を算出する(512次元のベクトルになる)

# まずは7x7のそれぞれの特徴マップを1本のベクトルに変換
feature_vec = feature.grad.view(512, 7*7) # feature_vec.size() = (512, 49)

# 512本のそれぞれのベクトルの要素の平均を取る
# 論文のαが計算される
alpha = torch.mean(feature_vec, axis=1) # alpha.size() = (512)

# batch_sizeの次元を削除
# (1x512x7x7) -> (512x7x7)
feature = feature.squeeze(0)

# 論文のLを計算
L = F.relu(torch.sum(feature*alpha.view(-1,1,1),0))
L = L.detach().numpy()
# L.size() = (7x7)

あとは検証画像に上で求めた L をマッピングして完成。マッピングの仕方も参考記事をまるパクリしています。

# 0-1で正規化
L_min = np.min(L)
L_max = np.max(L - L_min)
L = (L - L_min)/L_max

# 元画像と同じサイズにリサイズする
L = cv2.resize(L, (224, 224))

# heat map に変換
def toHeatmap(x):
    x = (x*255).reshape(-1)
    cm = plt.get_cmap('jet')
    x = np.array([cm(int(np.round(xi)))[:3] for xi in x])
    return x.reshape(224,224,3)

img2 = toHeatmap(L)
img1 = test_image1_tensor.squeeze(0).permute(1,2,0)

alpha = 0.5
grad_cam_image = img1*alpha + img2*(1-alpha)

plt.imshow(grad_cam_image)

image.pngimage.png

1枚目の画像はどうやら口元を見てアンゴラウサギと判断したようです。2枚目の画像は、これは耳と髭?もしくは模様?を見てシャム猫と判断したようです。

何はともあれ、いい感じに可視化できているように感じます。1枚目は口元をガッツリ見てしまっている、逆にいうと、2枚目にもあるように耳が猫の特徴として判断されたかもしれないのに、背景が黒っぽくて耳を捉えきれていなかったのかもしれません。

実装方法その2

VGG19の順伝播時と逆伝播時にGrad-CAMの計算に必要なVGG19の中間層の出力を取得できるようにregister_forward_hookregister_backward_hookを活用する方法です。

すでにいろんな方がPyTorchを使ってGrad-CAMの実装をされていますが、大半の実装方法はこちらの方法で、私も初見でGrad-CAMをなんとか実装しようと思うと、こちらの方法で実装してたと思います。実装方法その1のように最初からレイヤーを分けるなんて賢いやり方は思い浮かばなかったかもしれません。

register_forward_hookregister_backward_hookを使った実装方法は参考文献6. であげている以下の記事が大変わかりやすくて、取り扱っている分析内容自体も興味深いです。

register_forward_hookregister_backward_hookがいまだちゃんと理解しきれているか怪しいですが、これはモデルのforward関数、backward関数が実行された時に一緒に実行したい内容をあらかじめネットワークに登録できる機能でして、以下の記事に使い方の例が記載されています。併せてリファンスもご参照ください。

VGG19の特徴マップを計算するレイヤー(model.features)に対して、register_forward_hookを使って、順伝播時の特徴マップを取得できるようにして、register_backward_hookを使って、誤差逆伝播時に特徴マップの勾配を取得できるようにします。
この辺は特に理解に苦しんだ箇所だったので、多めにコメントを残しておきます。
畳み込み層に2つのhook関数を登録しておき、検証画像をVGG19に流して、予測結果に対して誤差逆伝播を実行するところまでは以下の通りです。

# moduleのforward関数が呼ばれるときに一緒に実行したい内容を記載する
# module: 登録先のネットワーク
# inputs: moduleのforward関数のインプットとなったデータ
# outputs: moduleのforward関数のアウトプット
def forward_hook(module, inputs, outputs):
    global feature
    # 複数のinputsに対応できるようにinputsはtupleでラップされるようです
    # 今回はinputsは1つだけなので、[0]を指定する
    # hook関数の登録先であるmoduleは今回は特徴マップを計算する層(model.features)を想定しているので
    # model.featuresの計算結果であるoutputsをそのまま取得できるようにすればよし
    feature = outputs[0] # feature.size() = (1x512x7x7)

# moduleのbackward関数が呼ばれるときに一緒に実行したい内容を記載する
# module: 登録先のネットワーク
# grad_inputs: moduleのbackward関数のインプットとなったデータ
# grad_outputs: moduleのfbackward関数のアウトプット
def backward_hook(module, grad_inputs, grad_outputs):
    global feature_grad
    # hook関数の登録先であるmoduleはforward_hookと同様にmodel.featuresを想定しており、
    # 特徴マップの勾配がほしいので、model.featuresの勾配結果に相当するgrad_outputsをそのまま取得すればよし
    feature_grad = grad_outputs[0] # feature_grad.size() = (1x512x7x7)

# 畳み込み層にhook関数を登録する
# これによりVGG19にデータが流れる時、順伝搬の畳み込みの処理のところで上記のforward_hookが呼び出され(つまりグローバル変数featureに特徴マップが格納され)
# 逆伝搬時にグローバル変数feature_gradに特徴マップの勾配が格納されることになる
model.features.register_forward_hook(forward_hook)
model.features.register_backward_hook(backward_hook)

# VGG19のモデルにそのまま検証する画像を流して予測結果を得る
# このタイミングで上記のforward_hookが裏で実行されてます
y_pred = model(test_image1_tensor)
pred_index = torch.argmax(y_pred)

# 予測結果に対して誤差逆伝播
# このタイミングで上記のbackward_hookが裏で実行されてます
y_pred[0][pred_index].backward()

以降の計算は実装方法その1と同様なので、説明は省略(ソースコードは閉じておきます。)

$L^c_{Grad-CAM}$を計算して元画像に重ね合わせるところまでの実装

# 以降の計算は実装方法その1と同様でOK

# 特徴マップの勾配(feature.grad)のGlobal Average Poolingを計算する
# つまり512枚の各特徴マップの要素の平均値を算出する(512次元のベクトルになる)

# まずは7x7のそれぞれの特徴マップを1本のベクトルに変換
feature_vec = feature_grad.view(512, 7*7) # feature_vec.size() = (512, 49)

# 512本のそれぞれのベクトルの要素の平均を取る
# 論文のαが計算される
alpha = torch.mean(feature_vec, axis=1) # alpha.size() = (512)

# batch_sizeの次元を削除
# (1x512x7x7) -> (512x7x7)
feature = feature.squeeze(0)

L = F.relu(torch.sum(feature*alpha.view(-1,1,1),0))#.cpu().detach().numpy()
L = L.detach().numpy()
# L.size() = (7x7)

# 0-1で正規化
L_min = np.min(L)
L_max = np.max(L - L_min)
L = (L - L_min)/L_max

# 元画像と同じサイズにリサイズする
L = cv2.resize(L, (224, 224))

img2 = toHeatmap(L)

img1 = test_image1_tensor.squeeze(0).permute(1,2,0)

alpha = 0.5
grad_cam_image = img1*alpha + img2*(1-alpha)

plt.imshow(grad_cam_image)

Grad-CAMの可視化結果は実装方法その1と同様ですので、省略します。

Grad-CAMならでは(?)のちょっとした分析の紹介

参考文献6. でも最後のほうに触れられていますが、Grad-CAMの使い方をちょっと工夫すると、予測結果でないクラスについてもどこを見ていたかを可視化することができます。
Grad-CAMの予測結果に対する勾配を計算するところで、予測結果ではない、他のindexに対して微分すればいいだけですね。
これを使うと、本来得たかった結果になぜならなかったのかを深堀することができます。

検証画像1は猫なのにアンゴララビットと誤判定していたので、検証画像1を猫を判定するならどこを見ていたかをGrad-CAMで可視化してみましょう。
ちなみに、検証画像1はおそらく雑種だと思うのですが、VGG19の猫のカテゴリに雑種がないっぽい?ので、まだ近いと思われる281 tabby, tabby cat(虎猫)を本来の正解とみなしてみます。

ソースコードで変更する箇所は以下のようにbackwardを実行する手前で無理矢理虎猫のインデックスである281を指定するだけでOK。

model.features.register_forward_hook(forward_hook)
model.features.register_backward_hook(backward_hook)

# VGG19のモデルにそのまま検証する画像を流して予測結果を得る
# このタイミングで上記のforward_hookが裏で実行されてます
y_pred = model(test_image1_tensor)
# pred_index = torch.argmax(y_pred)
pred_index = torch.tensor(281)
# print(y_pred)

# 予測結果に対して誤差逆伝播
# このタイミングで上記のbackward_hookが裏で実行されてます
y_pred[0][pred_index].backward()

可視化結果はこちら
image.png

アンゴララビットよりも口の右側をよりみているようですが、大体みてるところ一緒ですかね。。。
検証画像2で判定されたシャム猫と判断した箇所もついでに可視化すると、

image.png

目をみてる感じですかね。アンゴララビットも他の猫の判定確率もそんなに大差ないかなと思い、分類確率を算出してみると、以下の通りなんで、数あるカテゴリの中でもまぁまぁ自信持ってアンゴララビットって言ってるのかな?

y_pred = model(test_image1_tensor)
print("アンゴララビットである確率", F.softmax(y_pred)[0][332].item())
print("虎猫である確率", F.softmax(y_pred)[0][281].item())
print("シャム猫である確率", F.softmax(y_pred)[0][284].item())
# アンゴララビットである確率 0.5156328082084656
# 虎猫である確率 0.0016278631519526243
# シャム猫である確率 0.007759877946227789

おわりに

Grad-CAMのほかにも発展系であるGrad-CAM++とかScore-CAMとかあるみたいですが、まずは判断根拠の入門としてGrad-CAMを実装してみました。
世の中的にもAIの説明性や解釈などはホットトピックですし、技術的にはどこまでのことができるか、についてはしっかりと追っていきたいと思います。

おわり

68
70
6

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
68
70

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?