とあるディープラーニングの勾配情報の可視化コードを見ていたら、register_forward_hookなるものがあり、少し触ってみたので備忘録。
以下、morphousさんの記事でほとんど勉強させて頂きました。
大変ありがとうございます。
register_forward_hookとは
-
PyTorchの
register_forward_hook
でnn.Module
に対して、勾配情報へのアクセスや演算ができる -
grad CAMはこの情報が必要らしい
-
register_forward_hook
は名前の通りforwardのときに機能する -
backwardのときに機能するものもあるにはあるのだが、ちょっと問題があるらしく、
nn.Module
のregister_backward_hook
は使わずに、register_forward_hook
だけで勾配の情報を取ることが推奨されている
(morphousさん記事からかなり抜粋)
MNISTで試す
勾配情報の可視化の部分以外はほとんどmorphousさんの記事から抜粋させて頂いております。
合ってるかわからないコメントを大量に追加しています。
CNN中間層の取得
- MNISTデータセットに対するシンプルなCNNでの分類問題で実装
MNISTデータ読み込み
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
np.random.seed(1234)
torch.manual_seed(1234)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#root = os.path.join('data', 'mnist')
transform = transforms.Compose([transforms.Resize(28),
transforms.ToTensor(),
])
mnist_train = \
torchvision.datasets.MNIST(root='./',
download=True,
train=True,
transform=transform)
mnist_test = \
torchvision.datasets.MNIST(root='./',
download=True,
train=False,
transform=transform)
train_dataloader = DataLoader(mnist_train,
batch_size=100,
shuffle=True)
test_dataloader = DataLoader(mnist_test,
batch_size=1,
shuffle=False)
モデル定義
class Net(nn.Module):
def __init__(self):
super().__init__()
self.nelements = 7*7*32 # 全結合層前の要素数
self.device = device
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.l1 = nn.Linear(self.nelements, 10)
def forward(self, x):
h = self.conv1(x)
h = torch.relu(h)
h = F.max_pool2d(h, 2)
h = self.conv2(h)
h = torch.relu(h)
h = F.max_pool2d(h, 2)
h = h.view(-1, self.nelements)
y = self.l1(h)
return y
# モデルの設定
model = Net().to(device)
from torchsummary import summary
summary(model, input_size=[[1, 28, 28]]) # input_size=(チャンネル数, 高さ, 幅)
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 28, 28] 320
Conv2d-2 [-1, 32, 14, 14] 9,248
Linear-3 [-1, 10] 15,690
================================================================
Total params: 25,258
Trainable params: 25,258
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.24
Params size (MB): 0.10
Estimated Total Size (MB): 0.34
----------------------------------------------------------------
学習
# モデルの設定
model = Net().to(device)
# 損失関数の設定
criterion = nn.CrossEntropyLoss()
# 最適化関数の設定
optimizer = optimizers.Adam(model.parameters())
# エポック数
epochs = 10
for epoch in range(epochs):
train_loss = 0.
for (x, label) in train_dataloader:
x = x.to(device)
label = label.to(device)
model.train()
preds = model(x)
loss = criterion(preds, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_dataloader)
print(f"Epoch: {epoch+1}, Loss: {train_loss:.3f}")
Epoch: 1, Loss: 0.270
Epoch: 2, Loss: 0.076
Epoch: 3, Loss: 0.057
Epoch: 4, Loss: 0.046
Epoch: 5, Loss: 0.039
Epoch: 6, Loss: 0.033
Epoch: 7, Loss: 0.028
Epoch: 8, Loss: 0.025
Epoch: 9, Loss: 0.022
Epoch: 10, Loss: 0.019
CPU times: user 1min 52s, sys: 520 ms, total: 1min 53s
Wall time: 1min 54s
評価
from tqdm import tqdm
model.eval()
with torch.no_grad():
total = 0.0
correct = 0.0
for (x, label) in tqdm(test_dataloader):
x = x.to(device)
label = label.to(device)
outputs = model(x)
_, predicted = torch.max(outputs, dim=1)
correct += int((predicted==label).sum())
total += label.shape[0]
print("Validation Acc", correct/total)
精度はエポック10で98-99%の正解率
100%|██████████| 10000/10000 [00:08<00:00, 1171.97it/s]
Validation Acc 0.9898
CPU times: user 8.4 s, sys: 34.7 ms, total: 8.43 s
Wall time: 8.54 s
# モデル保存
torch.save(model.state_dict(), "model_MNIST_CNN.pth")
中間層へのアクセス
- 畳み込み2層目へのアクセスをする
# モデルの指定されたレイヤーの出力と勾配を保存するクラス
class SaveOutput:
def __init__(self, model, target_layer): # 引数:モデル, 対象のレイヤー
self.model = model
self.layer_output = []
self.layer_grad = []
# 特徴マップを取るためのregister_forward_hookを設定
self.feature_handle = target_layer.register_forward_hook(self.feature)
# 勾配を取るためのregister_forward_hookを設定
self.grad_handle = target_layer.register_forward_hook(self.gradient)
# self.feature_handleの定義時に呼び出されるメソッド
## モデルの指定されたレイヤーの出力(特徴マップ)を保存する
def feature(self, model, input, output):
activation = output
self.layer_output.append(activation.to("cpu").detach())
# self.grad_handleの定義時に呼び出されるメソッド
## モデルの指定されたレイヤーの勾配を保存する
## 勾配が存在しない場合や勾配が必要ない場合は処理をスキップ
def gradient(self, model, input, output):
# 勾配が無いとき
if not hasattr(output, "requires_grad") or not output.requires_grad:
return # ここでメソッド終了
# 勾配を取得
def _hook(grad):
# gradが定義されていないが、勾配が計算されると各テンソルのgrad属性に保存されるっぽい(詳細未確認)
self.layer_grad.append(grad.to("cpu").detach())
# PyTorchのregister_hookメソッド(https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html)
output.register_hook(_hook)
# メモリの解放を行うメソッド、フックを解除してメモリを解放する
def release(self):
self.feature_handle.remove()
self.grad_handle.remove()
# モデルとアクセスしたい層の名称を渡す
save=SaveOutput(model, model.conv2)
# モデルにデータを流し込んで、特徴マップを取得する
img, label=next(iter(test_dataloader))
print("img:", img.shape)
print("label:", label.shape)
# 特徴マップ
output = model(img.to(device))
print("output:", output.shape)
img: torch.Size([1, 1, 28, 28])
label: torch.Size([1])
output: torch.Size([1, 10])
特徴マップを可視化
# 特徴マップをチャネル毎に可視化
intermediate = save.layer_output[0].squeeze(0).numpy()
fig=plt.figure(figsize=(40,20))
for i,im in enumerate(intermediate):
ax1=fig.add_subplot(4,8,i+1)
ax1.imshow(im,'gray')
勾配情報を可視化(1枚)
# 予測クラスのインデックスを取得
## 分類結果をargmaxで求めて、対応する全結合層の出力要素に対して、backwardをかける
## (backwardはスカラーにしか使えない。分類問題なので、最大の値を返したものを該当クラスとみなして、その勾配を取りに行く)
idx=torch.argmax(output)
# 勾配を計算
## backward()メソッドの呼び出しにより勾配の計算が開始
## SaveOutputクラスのgradientメソッドが呼び出され、output.register_hook(_hook)により、_hook関数は、gradを受け取ってself.layer_gradに追加
## gradはbackward()を呼び出した時点で自動的に計算されるっぽい
### (backward()メソッドを呼び出すと、計算グラフを遡りながら各テンソルのgrad_fnに対応する勾配計算メソッドが実行され、各テンソルのgrad属性に保存される)
output[0,idx].backward()
# 2回やるとエラーになる?(詳細未確認)
これも可視化
# 勾配情報をチャネル毎に可視化
grad = save.layer_grad[0].squeeze(0).numpy()
fig=plt.figure(figsize=(40,20))
for i,im in enumerate(grad):
ax1=fig.add_subplot(4,8,i+1)
ax1.imshow(im,'gray')
- 勾配情報をチャネル方向に加算して、元画像と重ねて表示してみる
# 勾配情報をチャネル方向に加算
grad_re = np.sum(grad, axis=0)
grad_re.shape
(14, 14)
# 拡大後の画像サイズ
new_size = (28, 28)
# バイリニア補完による画像の拡大(元画像と重ねる為、同じサイズにする)
grad_re = zoom(grad_re, new_size / np.array(grad_re.shape), order=1)
grad_re.shape
(28, 28)
# 勾配情報のみ表示
fig=plt.figure(figsize=(4,4))
ax1=fig.add_subplot(1,1,1)
ax1.imshow(grad_re)
<matplotlib.image.AxesImage at 0x7f30d60d7a00>
元画像は「7」なので、それっぽくはなっているかも(黄色が勾配情報 大、青が小)
# 元画像
fig=plt.figure(figsize=(4,4))
ax1=fig.add_subplot(1,1,1)
ax1.imshow(img.squeeze().numpy(), "gray")
<matplotlib.image.AxesImage at 0x7f30d62b82b0>
# 勾配情報と元画像を重ねて表示
cam = grad_re + img.squeeze().numpy()
cam = cam / np.max(cam)
fig=plt.figure(figsize=(4,4))
ax1=fig.add_subplot(1,1,1)
ax1.imshow(cam)
<matplotlib.image.AxesImage at 0x7f304f151060>
勾配情報を可視化(複数枚)
# 32枚可視化する
num_imgs = 32
tmp_dataloader = DataLoader(mnist_test,
batch_size=num_imgs,
shuffle=True)
# batch_size分データ取り出し
imgs, labels = next(iter(tmp_dataloader))
fig=plt.figure(figsize=(40,20))
for i, (img, label) in enumerate(zip(imgs, labels)):
# 特徴量
output = model(img.to(device))
# 勾配を計算
idx = torch.argmax(output)
output[0,idx].backward()
grad = save.layer_grad[0].squeeze(0).numpy()
# 勾配をチャンネル方向に加算
grad_re = np.sum(grad, axis=0)
# 拡大後の画像サイズ
new_size = (28, 28)
# バイリニア補完による画像の拡大
grad_re = zoom(grad_re, new_size / np.array(grad_re.shape), order=1)
# 元画像と重ねる
cam = grad_re + img.squeeze().numpy()
cam = cam / np.max(cam)
# 表示
ax1=fig.add_subplot(4,8,i+1)
ax1.imshow(cam)
使い方合ってるかわからないけど、数字の部分の勾配情報が大きくなっているので良いのかな
以上