概要
個人的な備忘録を兼ねたPyTorchの基本的な解説とまとめです。第6回は中間層の特徴量を抽出する 方法になります。
画像分類の回でCNNやプーリング層の特徴量を可視化する内容について触れてきました。
今回はregister_forward_hookを利用した方法を紹介していきます。
演習用のファイル
1. 直接的な可視化の方法の確認と今回の目的
これまで紹介した可視化の方法は、モデルを構成するforwardの部分に直接書き込むスタイルでした。下記のコードのh1やh2のように、直接出力を指定することで中間層の値を抽出していました。この方法は直観的で良いのですが、コード変更時のミスも発生しやすい上、学習後に確認したい出力の値を追加するのも非常に面倒
class DNN(nn.Module):
def __init__(self):
super(DNN, self).__init__()
self.cnn1 = nn.Conv2d(in_channels=1, out_channels=5 ,kernel_size=5)
self.act1 = nn.LeakyReLU(negative_slope=0.01) # Leaky ReLU導入
self.pool1= nn.MaxPool2d(kernel_size=2, stride=2)
self.flat = nn.Flatten()
self.fc1 = nn.Linear(in_features=5*12*12, out_features=100)
self.act2 = nn.ReLU()
self.fc2 = nn.Linear(in_features=100, out_features=5)
def forward(self, x):
h1 = self.cnn1(x) # h1:CNNの出力
h = self.act1(h1)
h2 = self.pool1(h) # h2:プーリング層の出力
h = self.flat(h2)
h = self.act2(self.fc1(h))
y = self.fc2(h)
return y, h1, h2
model = DNN()
学習後に出力値を確認できる方法の確認が今回の目的となります。数行コード部分が増えますが、register_forward_hookと呼ばれるメソッドを利用することで簡単に実装することができるみたいです。
2. コードと解説
2.1 ネットワークの書き方
コード全体はこちらに記載しておきます。#05のネットワークモデルからh1やh2を除いた形でforward部分を記述します。すでに学習は終了した状態としましょう。
class DNN(nn.Module):
def __init__(self):
super(DNN, self).__init__()
self.cnn1 = nn.Conv2d(in_channels=1, out_channels=5 ,kernel_size=5)
self.act1 = nn.LeakyReLU(negative_slope=0.01) # Leaky ReLU導入
self.pool1= nn.MaxPool2d(kernel_size=2, stride=2)
self.flat = nn.Flatten()
self.fc1 = nn.Linear(in_features=5*12*12, out_features=100)
self.act2 = nn.ReLU()
self.fc2 = nn.Linear(in_features=100, out_features=5)
def forward(self, x):
h = self.cnn1(x)
h = self.act1(h)
h = self.pool1(h)
h = self.flat(h)
h = self.fc1(h)
h = self.act2(h)
y = self.fc2(h)
return y
model = DNN()
繰り返しになりますが、forwardの戻り値を指定する部分がreturn y
だけです。途中の出力結果を抽出するためのh1やh2という変数はありません。
2.2 register_forward_hook
register_forward_hookの簡単な使い方は、PyTorchのModule部分に記載があります。
基本的には、出力するネットワーク層の名前をregister_forward_hook() を使って順番に登録していく形となります。コードの解説は後回しで内容に進みます。
# 登録する関数
feature_maps = {}
def register_activation_hook(name):
def hook(module, input, output):
feature_maps[name] = output.detach()
return hook
# register_forward_hookで登録
hooks = []
hooks.append(model.cnn1.register_forward_hook(register_activation_hook("cnn1")))
hooks.append(model.act1.register_forward_hook(register_activation_hook("act1")))
hooks.append(model.pool1.register_forward_hook(register_activation_hook("pool1")))
feature_mapsがネットワーク名とその出力値を対応させた辞書、hooksが登録したネットワーク名に関連するリストになります。
今回利用しているモデルは第5回の画像分類で利用したネットワーク構造です。第5回で利用した「あ・い・う・え・お」の画像データから適当に一つ選択してmodelに入力して、feature_maps辞書に登録したネットワークの出力値の形状を確認してみましょう。
model.eval()
# テスト用の一つの画像を選択 あ:300, い:500 う:200 え:600 お:100
sample_image = x[500].unsqueeze(0) # バッチ次元を追加 (1, 1, 28, 28)
# sample_imageの出力結果をみる場合
with torch.no_grad():
y = model(sample_image)
# feature_maps辞書へ登録したネットワークの出力の形状を確認
for layer_name, feature_map in feature_maps.items():
print(f"{layer_name}: {feature_map.shape}")
# printの結果
# cnn1: torch.Size([1, 5, 24, 24])
# act1: torch.Size([1, 5, 24, 24])
# pool1: torch.Size([1, 5, 12, 12])
feature_mapsは辞書形式なので、キーを指定すれば出力値を取得することができます。サンプルコードでは、register_activation_hook("cnn1")
のように、ネットワークの名前をそのままキーとしています。
ネットワーク名 | 出力値 |
---|---|
cnn1 | feature_maps["cnn1"] |
act1 | feature_maps["act1"] |
pool1 | feature_maps["pool1"] |
サンプルコードの簡単な解説
- feature_maps = {} :中間層の出力を保存するための辞書
- register_forward_hook(関数) :引数を関数とするメソッドです。その関数はmodule、input、outputの3種類を引数とすると決められています。表現は不正確ですがfunction(module, input, output)をregister_forward_hookの引数にしてあげればOKということです。1
- サンプルコード上ではmodule、input、outputを引数に持つ関数をhookと定義して、register_activatioin_hookの戻り値に指定しています。その際、ネットワーク名(name)とその出力値(output)が対応するようにしています。
feature_maps[name] = output.detach()
の部分です。register_activation_hook("cnn1") とすることでcnn1に対応するhook関数が作られることになります。
引数 | 簡単な解説 | 例 |
---|---|---|
module | ネットワーク層 | nn.Linear,nn.Conv2d |
input | moduleに入力される値 | 画像だと(b, c, h, w) |
output | moduleの出力値でリスト | 画像だとoutput[0]がCNN出力の(b, c, h, w) |
-
register_activation_hook(name) :引数を抽出したいネットワークの呼び名(name)、戻り値を前述のhook関数とする関数です。もうちょっと他のネーミングが良かった気もします
- hooksリストに抽出するネットワーク名を追加していきます
だいたいこんな感じ(分けて記述してみた)
- cnn1_hook = register_activation_hook("cnn1") :cnn1用のhook関数の作成
- model.cnn1.register_forward_hook(cnn1_hook) :cnn1の出力の登録
2.3 出力値を可視化してみる
CNNの出力値
cnn_feature_map = feature_maps["cnn1"].squeeze().cpu().numpy()
のように形状を整えてmatplotlibで画像を表示します。チャンネル数が5だったので5枚の画像を表示してみましょう。cnn_feature_map
として出力値が取得できるので、あとは好みの方法で画像を表示すればOKです
# 評価モードに変更
model.eval()
# テスト用の一つの画像を選択 あ:300, い:500 う:200 え:600 お:100
sample_image = x[500].unsqueeze(0) # バッチ次元を追加 (1, 1, 28, 28)
# sample_imageの出力結果をみる場合
with torch.no_grad():
y = model(sample_image)
# cnn1の出力結果をみる
cnn_feature_map = feature_maps["cnn1"].squeeze().cpu().numpy()
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
plt.suptitle('CNN出力 (特徴マップ)')
for i, ax in enumerate(axes.flat):
ax.imshow(cnn_feature_map[i], cmap="gray")
ax.axis('off')
plt.tight_layout()
plt.show()
最初に入力する画像データを1つ決めます。sample_imageとしてみました。model(sample_image)
とすれば、sample_imageに対する登録したネットワークの出力値を取得できます。2
活性化関数後の出力値
活性化関数の出力結果も同様に出力することが可能です。act_feature_map
が活性化関数の出力結果になります。
act_feature_map = feature_maps["act1"].squeeze().cpu().numpy()
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
plt.suptitle('act層 (特徴マップ)')
for i, ax in enumerate(axes.flat):
ax.imshow(act_feature_map[i], cmap="gray")
ax.axis('off')
plt.tight_layout()
plt.show()
プーリング層の出力値
完全に繰り返しですが、プーリング層の出力も同様の形となります。
pool_feature_map = feature_maps["pool1"].squeeze().cpu().numpy()
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
plt.suptitle('POOL層 (特徴マップ)')
for i, ax in enumerate(axes.flat):
ax.imshow(pool_feature_map[i], cmap="gray")
ax.axis('off')
plt.tight_layout()
plt.show()
3 その他と次回
register_forward_hookの方法を使えば、公開されているモデルの中間層の特徴量も抽出することができます。画像分類で言えば、VGGやResNetでも中間層の特徴量を抽出・可視化することも可能になります。次回は公開されているモデルでの中間層の特徴量の抽出方法に触れてみたい
目次ページ