3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

🔰PyTorchでニューラルネットワーク基礎 #06【中間層特徴量の抽出】

Last updated at Posted at 2025-05-27

概要

個人的な備忘録を兼ねたPyTorchの基本的な解説とまとめです。第6回は中間層の特徴量を抽出する 方法になります。

画像分類の回でCNNやプーリング層の特徴量を可視化する内容について触れてきました。

今回はregister_forward_hookを利用した方法を紹介していきます。

演習用のファイル

1. 直接的な可視化の方法の確認と今回の目的

これまで紹介した可視化の方法は、モデルを構成するforwardの部分に直接書き込むスタイルでした。下記のコードのh1やh2のように、直接出力を指定することで中間層の値を抽出していました。この方法は直観的で良いのですが、コード変更時のミスも発生しやすい上、学習後に確認したい出力の値を追加するのも非常に面倒 :sweat_smile:

#05紹介のネットワークモデル
class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=5 ,kernel_size=5)
        self.act1 = nn.LeakyReLU(negative_slope=0.01)  # Leaky ReLU導入
        self.pool1= nn.MaxPool2d(kernel_size=2, stride=2)
        self.flat = nn.Flatten()
        self.fc1 =  nn.Linear(in_features=5*12*12, out_features=100)
        self.act2 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=100, out_features=5)
        
    def forward(self, x):
        h1 = self.cnn1(x)       # h1:CNNの出力
        h = self.act1(h1)
        h2 = self.pool1(h)      # h2:プーリング層の出力
        h = self.flat(h2)
        h = self.act2(self.fc1(h))
        y = self.fc2(h)
        return y, h1, h2
model = DNN()

学習後に出力値を確認できる方法の確認が今回の目的となります。数行コード部分が増えますが、register_forward_hookと呼ばれるメソッドを利用することで簡単に実装することができるみたいです。

2. コードと解説

2.1 ネットワークの書き方

コード全体はこちらに記載しておきます。#05のネットワークモデルからh1やh2を除いた形でforward部分を記述します。すでに学習は終了した状態としましょう。

#05のモデルを一部改造
class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=5 ,kernel_size=5)
        self.act1 = nn.LeakyReLU(negative_slope=0.01)  # Leaky ReLU導入
        self.pool1= nn.MaxPool2d(kernel_size=2, stride=2)
        self.flat = nn.Flatten()
        self.fc1 =  nn.Linear(in_features=5*12*12, out_features=100)
        self.act2 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=100, out_features=5)
        
    def forward(self, x):
        h = self.cnn1(x)
        h = self.act1(h)
        h = self.pool1(h)
        h = self.flat(h)
        h = self.fc1(h)
        h = self.act2(h)
        y = self.fc2(h)
        return y

model = DNN()

繰り返しになりますが、forwardの戻り値を指定する部分がreturn yだけです。途中の出力結果を抽出するためのh1やh2という変数はありません。

2.2 register_forward_hook

register_forward_hookの簡単な使い方は、PyTorchのModule部分に記載があります。

基本的には、出力するネットワーク層の名前をregister_forward_hook() を使って順番に登録していく形となります。コードの解説は後回しで内容に進みます。

サンプルコード
# 登録する関数 
feature_maps = {}
def register_activation_hook(name):
    def hook(module, input, output):
        feature_maps[name] = output.detach()
    return hook

# register_forward_hookで登録
hooks = []
hooks.append(model.cnn1.register_forward_hook(register_activation_hook("cnn1")))
hooks.append(model.act1.register_forward_hook(register_activation_hook("act1")))
hooks.append(model.pool1.register_forward_hook(register_activation_hook("pool1")))

feature_mapsがネットワーク名とその出力値を対応させた辞書、hooksが登録したネットワーク名に関連するリストになります。

今回利用しているモデルは第5回の画像分類で利用したネットワーク構造です。第5回で利用した「あ・い・う・え・お」の画像データから適当に一つ選択してmodelに入力して、feature_maps辞書に登録したネットワークの出力値の形状を確認してみましょう。

feature_mapsに登録したネットワークの出力値の形状
model.eval()
# テスト用の一つの画像を選択   あ:300, い:500 う:200 え:600 お:100
sample_image = x[500].unsqueeze(0)  # バッチ次元を追加 (1, 1, 28, 28)

# sample_imageの出力結果をみる場合
with torch.no_grad():
    y = model(sample_image)

# feature_maps辞書へ登録したネットワークの出力の形状を確認
for layer_name, feature_map in feature_maps.items():
    print(f"{layer_name}: {feature_map.shape}")
    
# printの結果
# cnn1: torch.Size([1, 5, 24, 24])
# act1: torch.Size([1, 5, 24, 24])
# pool1: torch.Size([1, 5, 12, 12])

feature_mapsは辞書形式なので、キーを指定すれば出力値を取得することができます。サンプルコードでは、register_activation_hook("cnn1")のように、ネットワークの名前をそのままキーとしています。

ネットワーク名 出力値
cnn1 feature_maps["cnn1"]
act1 feature_maps["act1"]
pool1 feature_maps["pool1"]

サンプルコードの簡単な解説

  • feature_maps = {} :中間層の出力を保存するための辞書
  • register_forward_hook(関数) :引数を関数とするメソッドです。その関数はmodule、input、outputの3種類を引数とすると決められています。表現は不正確ですがfunction(module, input, output)をregister_forward_hookの引数にしてあげればOKということです。1
  • サンプルコード上ではmodule、input、outputを引数に持つ関数をhookと定義して、register_activatioin_hookの戻り値に指定しています。その際、ネットワーク名(name)とその出力値(output)が対応するようにしています。feature_maps[name] = output.detach()の部分です。register_activation_hook("cnn1") とすることでcnn1に対応するhook関数が作られることになります。
引数 簡単な解説
module ネットワーク層 nn.Linear,nn.Conv2d
input moduleに入力される値 画像だと(b, c, h, w)
output moduleの出力値でリスト 画像だとoutput[0]がCNN出力の(b, c, h, w)
  • register_activation_hook(name) :引数を抽出したいネットワークの呼び名(name)、戻り値を前述のhook関数とする関数です。もうちょっと他のネーミングが良かった気もします:sweat_smile:
  • hooksリストに抽出するネットワーク名を追加していきます

だいたいこんな感じ(分けて記述してみた)

  1. cnn1_hook = register_activation_hook("cnn1") :cnn1用のhook関数の作成
  2. model.cnn1.register_forward_hook(cnn1_hook) :cnn1の出力の登録

2.3 出力値を可視化してみる

CNNの出力値

cnn_feature_map = feature_maps["cnn1"].squeeze().cpu().numpy()のように形状を整えてmatplotlibで画像を表示します。チャンネル数が5だったので5枚の画像を表示してみましょう。cnn_feature_mapとして出力値が取得できるので、あとは好みの方法で画像を表示すればOKです:smile:

# 評価モードに変更
model.eval()
# テスト用の一つの画像を選択   あ:300, い:500 う:200 え:600 お:100
sample_image = x[500].unsqueeze(0)  # バッチ次元を追加 (1, 1, 28, 28)

# sample_imageの出力結果をみる場合
with torch.no_grad():
    y = model(sample_image)

# cnn1の出力結果をみる
cnn_feature_map = feature_maps["cnn1"].squeeze().cpu().numpy()

fig, axes = plt.subplots(1, 5, figsize=(10, 3))
plt.suptitle('CNN出力 (特徴マップ)')
for i, ax in enumerate(axes.flat):
    ax.imshow(cnn_feature_map[i], cmap="gray")
    ax.axis('off')
plt.tight_layout()
plt.show()

最初に入力する画像データを1つ決めます。sample_imageとしてみました。model(sample_image)とすれば、sample_imageに対する登録したネットワークの出力値を取得できます。2

hook_cnn.jpg

活性化関数後の出力値

活性化関数の出力結果も同様に出力することが可能です。act_feature_mapが活性化関数の出力結果になります。

act_feature_map = feature_maps["act1"].squeeze().cpu().numpy()

fig, axes = plt.subplots(1, 5, figsize=(10, 3))
plt.suptitle('act層 (特徴マップ)')
for i, ax in enumerate(axes.flat):
    ax.imshow(act_feature_map[i], cmap="gray")
    ax.axis('off')
plt.tight_layout()
plt.show()

hook_act.jpg

プーリング層の出力値

完全に繰り返しですが、プーリング層の出力も同様の形となります。

pool_feature_map = feature_maps["pool1"].squeeze().cpu().numpy()

fig, axes = plt.subplots(1, 5, figsize=(10, 3))
plt.suptitle('POOL層 (特徴マップ)')
for i, ax in enumerate(axes.flat):
    ax.imshow(pool_feature_map[i], cmap="gray")
    ax.axis('off')
plt.tight_layout()
plt.show()

hook_pool.jpg

3 その他と次回

register_forward_hookの方法を使えば、公開されているモデルの中間層の特徴量も抽出することができます。画像分類で言えば、VGGやResNetでも中間層の特徴量を抽出・可視化することも可能になります。次回は公開されているモデルでの中間層の特徴量の抽出方法に触れてみたい:cactus::cactus::cactus:

目次ページ

  1. (module, input, output)の3種類ですが、これは(x, y, z)みたいな文字でももちろんOKです。

  2. with torch.no_grad()として勾配計算を無効化するように格好良く書いていますがmodel(sample_image)があれば問題なく動作します:cactus:

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?