50
43

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

pytorch-gradcamで簡単にGrad-CAMを行う

Posted at

pytorch-gradcamで簡単にGrad-CAMを実行できる

Grad-CAMと呼ばれるCNNの可視化技術があり、画像分類の際にどの特徴量を根拠にして分類しているのかを可視化することができます。これによって分類規則の根拠を考察したり、場合によってはそこから得られた知見などを元にしてマーケティングなどに役立てたりします。

下は、VGG16を使ってある画像に対して注目している特徴量を可視化した結果になります。

スクリーンショット 2020-04-28 17.31.13.png

こちらの実装手順ですが、以下のような形で実現しています。

  • 畳み込みの出力層の手前でGlobal Average Pooling
  • あるクラスにおける最終層の各チャンネルの重みを決定
  • 重みに応じてそれぞれのチャンネルを掛けて足し合わせる
  • それらをRelu関数に通す

(自前で実装するときに参考にした資料はこちらです: PyTorchでGrad-CAMによるCNNの可視化.)

当初はpytorchを使って自力でコーディングしていたのですが、並列GPUで学習済みだとtorch.nn.DataParallelがmodelをラッピングしてしまってモデルの階層構造が変化していたり、ファインチューニングで選んだネットワークによってモジュールの定義が異なっていたり...僕のコーディング能力もあると思いますが、異なるネットワークを使うたびに面倒臭い汗

ストレスを感じることがなく、気軽にpytorchでGrad-CAMを実行できるライブラリとか誰か作っていないかなと思って先週に漁っていたら...ありました。

pytorch-gradcam

GradCAMとGradCAM++の結果を可視化することができ、かつalexnet, vgg, resnet, densenet, squeezenetに対応しています。非常に有難い!

しかもインストール方法は簡単で、
pip install pytorch-gradcam
を行うだけです!

ソースコードですが、以下のように実行して可視化できます(ソースコードでは、既にdensenet161を用いてあるデータセットを学習済のモデルをロードしていて、5クラス分類の学習済モデルになります)。

# Basic Modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch Modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import torchvision.transforms as transforms
from torch.utils.data.dataset import Subset
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import make_grid, save_image

# Grad-CAM
from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp


device = torch.device("cuda:0" if torch.cuda.is_available()  else "cpu")
model = models.densenet161(pretrained=True)
model.fc = nn.Linear(2048,5)
model = torch.nn.DataParallel(model).to(device)
model.eval()
model.load_state_dict(torch.load('trained_model.pt'))

# Grad-CAM
target_layer = model.module.features
gradcam = GradCAM(model, target_layer)
gradcam_pp = GradCAMpp(model, target_layer)

images = []
# あるラベルの検証用データセットを呼び出してる想定
for path in glob.glob("{}/label1/*".format(config['dataset'])):
    img = Image.open(path)
    torch_img = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])(img).to(device)
    normed_torch_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    
    images.extend([torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])
grid_image = make_grid(images, nrow=5)

# 結果の表示
transforms.ToPILImage()(grid_image)

気をつけるところはtarget_layer = model.module.featuresの部分で対象のレイヤーを指定する必要があるのですが、Githubのutils.pyを参考にして、各ネットワークモデルに対応したtarget_layerの名前を調べられます--->utils.py。下記はutils.pyに書いてある一部をそのまま抜粋しましたが、各対応ネットワーク毎にtarget_layerの名前が記載されています。

| @register_layer_finder('densenet') |
|:--|
| def find_densenet_layer(arch, target_layer_name): |
|     """Find densenet layer to calculate GradCAM and GradCAM++ |
|     Args: |
|         arch: default torchvision densenet models |
|         target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. |
|             target_layer_name = 'features' |
|             target_layer_name = 'features_transition1' |
|             target_layer_name = 'features_transition1_norm' |
|             target_layer_name = 'features_denseblock2_denselayer12' |
|             target_layer_name = 'features_denseblock2_denselayer12_norm1' |
|             target_layer_name = 'features_denseblock2_denselayer12_norm1' |
|             target_layer_name = 'classifier' |
|     Return: |
|         target_layer: found layer. this layer will be hooked to get forward/backward pass information. |
|     """ |

また、target_layer = model.module.featuresでmoduleをfeatureの前に挟んでいるのはDataParallelを使って並列GPUでの学習済みモデルを使用しているためです。詳しく知りたい方はこちらに並列GPUを行う上での躓きポイントがまとめてありますので参考にしてください【PyTorch】DataParallelを使った並列GPU化の躓きどころ。別に並列GPUの形で学習モデルを作成していなければmoduleはいらないです。

まとめ

今回は、pytorchで簡単にGrad-CAMを行うためのモジュールを紹介しました。

50
43
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
50
43

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?