1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorchのregister_forward_hookを少し触ってみた

Last updated at Posted at 2023-06-12

とあるディープラーニングの勾配情報の可視化コードを見ていたら、register_forward_hookなるものがあり、少し触ってみたので備忘録。

以下、morphousさんの記事でほとんど勉強させて頂きました。
大変ありがとうございます。

register_forward_hookとは

  • PyTorch 公式

  • PyTorchのregister_forward_hooknn.Moduleに対して、勾配情報へのアクセスや演算ができる

  • grad CAMはこの情報が必要らしい

  • register_forward_hookは名前の通りforwardのときに機能する

  • backwardのときに機能するものもあるにはあるのだが、ちょっと問題があるらしく、nn.Moduleregister_backward_hookは使わずに、register_forward_hookだけで勾配の情報を取ることが推奨されている

(morphousさん記事からかなり抜粋)

MNISTで試す

勾配情報の可視化の部分以外はほとんどmorphousさんの記事から抜粋させて頂いております。
合ってるかわからないコメントを大量に追加しています。

CNN中間層の取得

  • MNISTデータセットに対するシンプルなCNNでの分類問題で実装

MNISTデータ読み込み

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib
import matplotlib.pyplot as plt
from scipy.ndimage import zoom

np.random.seed(1234)
torch.manual_seed(1234)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#root = os.path.join('data', 'mnist')
transform = transforms.Compose([transforms.Resize(28),
                                transforms.ToTensor(),
                                ])

mnist_train = \
    torchvision.datasets.MNIST(root='./',
                                      download=True,
                                      train=True,
                                      transform=transform)

mnist_test = \
    torchvision.datasets.MNIST(root='./',
                                      download=True,
                                      train=False,
                                      transform=transform)
    
train_dataloader = DataLoader(mnist_train,
                              batch_size=100,
                              shuffle=True)

test_dataloader = DataLoader(mnist_test,
                              batch_size=1,
                              shuffle=False)

モデル定義

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.nelements = 7*7*32 # 全結合層前の要素数
        self.device = device
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.l1 = nn.Linear(self.nelements, 10)

    def forward(self, x):
        h = self.conv1(x)
        h = torch.relu(h)
        h = F.max_pool2d(h, 2)
        h = self.conv2(h)
        h = torch.relu(h)
        h = F.max_pool2d(h, 2)
        h = h.view(-1, self.nelements)
        y = self.l1(h)

        return y
# モデルの設定
model = Net().to(device)
from torchsummary import summary
summary(model, input_size=[[1, 28, 28]]) # input_size=(チャンネル数, 高さ, 幅)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 28, 28]             320
            Conv2d-2           [-1, 32, 14, 14]           9,248
            Linear-3                   [-1, 10]          15,690
================================================================
Total params: 25,258
Trainable params: 25,258
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.24
Params size (MB): 0.10
Estimated Total Size (MB): 0.34
----------------------------------------------------------------

学習

# モデルの設定
model = Net().to(device)
# 損失関数の設定
criterion = nn.CrossEntropyLoss()
# 最適化関数の設定
optimizer = optimizers.Adam(model.parameters())
# エポック数
epochs = 10

for epoch in range(epochs):
    train_loss = 0.

    for (x, label) in train_dataloader:
        x = x.to(device)
        label = label.to(device)
        model.train()

        preds = model(x)
        loss = criterion(preds, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_dataloader)

    print(f"Epoch: {epoch+1}, Loss: {train_loss:.3f}")
Epoch: 1, Loss: 0.270
Epoch: 2, Loss: 0.076
Epoch: 3, Loss: 0.057
Epoch: 4, Loss: 0.046
Epoch: 5, Loss: 0.039
Epoch: 6, Loss: 0.033
Epoch: 7, Loss: 0.028
Epoch: 8, Loss: 0.025
Epoch: 9, Loss: 0.022
Epoch: 10, Loss: 0.019
CPU times: user 1min 52s, sys: 520 ms, total: 1min 53s
Wall time: 1min 54s

評価

from tqdm import tqdm

model.eval()
with torch.no_grad():
    total = 0.0
    correct = 0.0

    for (x, label) in tqdm(test_dataloader):
        x = x.to(device)
        label = label.to(device)
        outputs = model(x)
        _, predicted = torch.max(outputs, dim=1)
        correct += int((predicted==label).sum())
        total += label.shape[0]

    print("Validation Acc", correct/total)

精度はエポック10で98-99%の正解率

100%|██████████| 10000/10000 [00:08<00:00, 1171.97it/s]

Validation Acc 0.9898
CPU times: user 8.4 s, sys: 34.7 ms, total: 8.43 s
Wall time: 8.54 s
# モデル保存
torch.save(model.state_dict(), "model_MNIST_CNN.pth")

中間層へのアクセス

  • 畳み込み2層目へのアクセスをする
# モデルの指定されたレイヤーの出力と勾配を保存するクラス
class SaveOutput:
    def __init__(self, model, target_layer):  # 引数:モデル, 対象のレイヤー
        self.model = model
        self.layer_output = []
        self.layer_grad = []
        
        # 特徴マップを取るためのregister_forward_hookを設定
        self.feature_handle = target_layer.register_forward_hook(self.feature)
        # 勾配を取るためのregister_forward_hookを設定
        self.grad_handle = target_layer.register_forward_hook(self.gradient)

    # self.feature_handleの定義時に呼び出されるメソッド
    ## モデルの指定されたレイヤーの出力(特徴マップ)を保存する
    def feature(self, model, input, output):
         activation = output
         self.layer_output.append(activation.to("cpu").detach())

    # self.grad_handleの定義時に呼び出されるメソッド
    ## モデルの指定されたレイヤーの勾配を保存する
    ## 勾配が存在しない場合や勾配が必要ない場合は処理をスキップ
    def gradient(self, model, input, output):
        # 勾配が無いとき
        if not hasattr(output, "requires_grad") or not output.requires_grad:
            return # ここでメソッド終了

        # 勾配を取得
        def _hook(grad): 
            # gradが定義されていないが、勾配が計算されると各テンソルのgrad属性に保存されるっぽい(詳細未確認)
            self.layer_grad.append(grad.to("cpu").detach())

        # PyTorchのregister_hookメソッド(https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html)
        output.register_hook(_hook) 

    # メモリの解放を行うメソッド、フックを解除してメモリを解放する
    def release(self):
        self.feature_handle.remove()
        self.grad_handle.remove()
# モデルとアクセスしたい層の名称を渡す
save=SaveOutput(model, model.conv2)
# モデルにデータを流し込んで、特徴マップを取得する
img, label=next(iter(test_dataloader))
print("img:", img.shape)
print("label:", label.shape)

# 特徴マップ
output = model(img.to(device))
print("output:", output.shape)
img: torch.Size([1, 1, 28, 28])
label: torch.Size([1])
output: torch.Size([1, 10])

特徴マップを可視化

# 特徴マップをチャネル毎に可視化
intermediate = save.layer_output[0].squeeze(0).numpy()

fig=plt.figure(figsize=(40,20))
for i,im in enumerate(intermediate):
    ax1=fig.add_subplot(4,8,i+1)
    ax1.imshow(im,'gray')

output_36_0.png

勾配情報を可視化(1枚)

# 予測クラスのインデックスを取得
## 分類結果をargmaxで求めて、対応する全結合層の出力要素に対して、backwardをかける
## (backwardはスカラーにしか使えない。分類問題なので、最大の値を返したものを該当クラスとみなして、その勾配を取りに行く)
idx=torch.argmax(output)

# 勾配を計算
## backward()メソッドの呼び出しにより勾配の計算が開始
## SaveOutputクラスのgradientメソッドが呼び出され、output.register_hook(_hook)により、_hook関数は、gradを受け取ってself.layer_gradに追加
## gradはbackward()を呼び出した時点で自動的に計算されるっぽい
### (backward()メソッドを呼び出すと、計算グラフを遡りながら各テンソルのgrad_fnに対応する勾配計算メソッドが実行され、各テンソルのgrad属性に保存される)
output[0,idx].backward()

# 2回やるとエラーになる?(詳細未確認)

これも可視化

# 勾配情報をチャネル毎に可視化
grad = save.layer_grad[0].squeeze(0).numpy()

fig=plt.figure(figsize=(40,20))
for i,im in enumerate(grad):
    ax1=fig.add_subplot(4,8,i+1)
    ax1.imshow(im,'gray')

output_40_0.png

  • 勾配情報をチャネル方向に加算して、元画像と重ねて表示してみる
# 勾配情報をチャネル方向に加算
grad_re = np.sum(grad, axis=0)
grad_re.shape
(14, 14)
# 拡大後の画像サイズ
new_size = (28, 28)

# バイリニア補完による画像の拡大(元画像と重ねる為、同じサイズにする)
grad_re = zoom(grad_re, new_size / np.array(grad_re.shape), order=1)
grad_re.shape
(28, 28)
# 勾配情報のみ表示
fig=plt.figure(figsize=(4,4))
ax1=fig.add_subplot(1,1,1)
ax1.imshow(grad_re)
<matplotlib.image.AxesImage at 0x7f30d60d7a00>

元画像は「7」なので、それっぽくはなっているかも(黄色が勾配情報 大、青が小)
output_44_1.png

# 元画像
fig=plt.figure(figsize=(4,4))
ax1=fig.add_subplot(1,1,1)
ax1.imshow(img.squeeze().numpy(), "gray")
<matplotlib.image.AxesImage at 0x7f30d62b82b0>

output_45_1.png

# 勾配情報と元画像を重ねて表示
cam = grad_re + img.squeeze().numpy()
cam = cam / np.max(cam)

fig=plt.figure(figsize=(4,4))
ax1=fig.add_subplot(1,1,1)
ax1.imshow(cam)
<matplotlib.image.AxesImage at 0x7f304f151060>

こんな感じの使い方でいいのかな
output_46_1.png

勾配情報を可視化(複数枚)

# 32枚可視化する
num_imgs = 32

tmp_dataloader = DataLoader(mnist_test,
                              batch_size=num_imgs,
                              shuffle=True)
# batch_size分データ取り出し
imgs, labels = next(iter(tmp_dataloader))

fig=plt.figure(figsize=(40,20))
for i, (img, label) in enumerate(zip(imgs, labels)):
    # 特徴量
    output = model(img.to(device))
    
    # 勾配を計算
    idx = torch.argmax(output)
    output[0,idx].backward()
    grad = save.layer_grad[0].squeeze(0).numpy()
    
    # 勾配をチャンネル方向に加算
    grad_re = np.sum(grad, axis=0)

    # 拡大後の画像サイズ
    new_size = (28, 28)
    # バイリニア補完による画像の拡大
    grad_re = zoom(grad_re, new_size / np.array(grad_re.shape), order=1)

    # 元画像と重ねる
    cam = grad_re + img.squeeze().numpy()
    cam = cam / np.max(cam)

    # 表示
    ax1=fig.add_subplot(4,8,i+1)
    ax1.imshow(cam)

使い方合ってるかわからないけど、数字の部分の勾配情報が大きくなっているので良いのかな
output_49_0.png

以上

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?