1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

🔰PyTorchでニューラルネットワーク基礎 #07【中間層特徴量の抽出 2】

Last updated at Posted at 2025-06-04

概要

個人的な備忘録を兼ねたPyTorchの基本的な解説とまとめです。第7回は学習済みモデルでの中間層の特徴量を抽出する 方法となります。画像分類でおなじみのVGG16モデルとPyTorchのregister_forward_hook を使って特徴量抽出と可視化を試みてみます。第6回の自作ネットワークでの中間層の特徴量の抽出も参考にしてもらえると助かります。

演習用のファイル

1. 今回の目的

第6回では 「自作のネットワークモデルに対して中間層の出力値を取得する」 ことに焦点をあてました。自作モデルの場合、ネットワークの各層に付ける名前は自由に決められるため、register_forward_hookで中間層を登録する際にも、どの層を指定すればよいか迷うケースは比較的少ないでしょう。このregister_forward_hookを使った手法は、自作モデルだけでなく学習済みモデルに対しても同様に適用できます。というか学習済みモデルの内部動作を理解したい局面でこそ、この技が威力を発揮する?

今回は学習済みモデルに対して中間層の出力値をどのように抽出し、可視化するかについて見ていきます。具体例として、画像分類タスクで広く使われているVGG16を取り上げ、CNNの畳み込み層やプーリング層から得られる特徴マップの可視化を試みてみます:smile:

2. コードと解説

2.1 VGG16モデルのダウンロード

今回利用するライブラリを一度に読み込んでしまいましょう。PyTorchでVGGなどの学習済みのモデルを読み込むにはtorchvisionを使います。利用できるモデルについては、公式ドキュメントに詳しく書かれています。

ライブラリの読み込み
import torch
import torchvision.transforms as transforms
from torchvision.models import vgg16, VGG16_Weights
from PIL import Image
import json
import matplotlib.pyplot as plt
import japanize_matplotlib

これまでよりも若干多い気がしますが、モデルの読み込み、画像の読み込み、分類クラスを記述したJSONファイルの読み込み、可視化のためのmatplotlibという流れです。

モデルの読み込みは実質1行model=vgg16(weights=VGG16_Weights.IMAGENET1K_V1)です。

モデルの読み込み
weights = VGG16_Weights.IMAGENET1K_V1
model = vgg16(weights=weights)          # ImageNetで学習済みの重みを持つVGG16モデル取得
preprocess = weights.transforms()       # 前処理も取得

サンプル画像で画像分類を行う時の画像の前処理ですが、weights.transforms()でOKです。VGG16モデルを実際に試すときには、

  1. 画像を読み込んで
  2. preprocessで前処理を行い
  3. model(加工後の画像)

とすることで、1000分類の特徴量(1000次元のベクトル)が表示されます。画像の前処理についてですが、作業工程を詳しく記述すると次のようになります。

weights.transforms()の詳細

preprocess = transforms.Compose([
    transforms.Resize(256),            # 画像を256x256にリサイズ
    transforms.CenterCrop(224),        # 中央を224x224に切り抜く
    transforms.ToTensor(),             # PyTorchテンソルに変換 (0-255 -> 0-1)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #正規化
])

詳しい内容は PyTorchドキュメント VGG16 にあります。前処理は重要とはわかりつつも地味に面倒!ありがたい機能です:smile:

2.2 VGG16動作確認

本題からそれますが、サンプル画像:lion_face:を利用してモデルの動作を確認してみます。

サンプル画像での分類
sample_image = Image.open("./data/lion.jpg").convert("RGB")  # 画像のロード
input_tensor = preprocess(sample_image)  # 画像の前処理 -> shape: (3, 224, 224)
input_batch = input_tensor.unsqueeze(0)  # shape: (1, 3, 224, 224)

# 推論
model.eval()
# GPUが利用可能であればGPUへ
if torch.cuda.is_available():
    input_batch = input_batch.to("cuda")
    model.to("cuda")

with torch.inference_mode():
    output = model(input_batch)

# 確率に変換
probabilities = torch.nn.functional.softmax(output[0], dim=0)

# 最も確率の高い上位3つのクラスを表示
top3_prob, top3_id = torch.topk(probabilities, 3)
#
# top3_prob: [9.9957e-01, 3.7367e-04, 2.3764e-05]
# top3_id: [291, 260, 292]

291番の確率が99%なのですが、291番?なのでラベルと対応することにしました。imagenet.jsonが0〜999番までのIDと名前の対応表になります。ちょっと強引ですが、json形式で読み込んでIDをキーとしてラベル名を表示させてみました。1つだけなので、imagenet.jsonのID部分を見るほうが早いと思います:sweat_smile:

idをラベルに
# imagenet.jsonを使ってIDからラベルへ
with open("./data/imagenet.json", "rt") as file:
        labels = json.load(file)
for num, id in enumerate(top3_id):
    print(f"{num+1}位: {labels[str(id.cpu().numpy())]}")

# 1位: lion
# 2位: chow
# 3位: tiger

ライオン:lion_face:と正しく分類されています。

2.3 VGG16のモデル構造

VGG16のモデルの構造を簡単に図で示しておきます。画像:lion_face:が入力されて、CNNやプーリング層を複数経由して最終的に1000種類の画像分類になります。VGG16についてはたくさんのサイトで解説されています。例えば、ネットワーク図と解説がVGG16とは何ですか?—VGG16の概要に掲載されています。

vgg16_model.jpg

上記の図がVGG16の主要構造となります。実際は活性化関数やドロップアウトなども入り込んでいるので更に細かいです。modelをprintしたものもすべて記載しておきます。よく解説されているように3x3のカーネルが活躍しています。

VGG16の構造
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

featuresが図のconv 1-1からPooling_5、classifierが図のLinearの3層に対応しています。

2.4 register_forward_hook

register_forward_hookの簡単な使い方は、第6回を参考にしてください。基本的には、抽出するネットワーク層の名前をregister_forward_hook() を使って順番に登録していく形でした。抽出するネットワーク層の名前を知ることがポイントになります。

VGG16で使用されているネットワークの名前は、前述の「VGG16の構造」から簡単に調べることができます。features部分には0〜30までの番号が振られており、この番号とmodel(VGG16のモデルをmodelとして定義:model = vgg16(weights=weights))を使って、model.features[番号]という形でfeatures部分の各ネットワーク層の名称を取得できます。

具体例

ネットワーク名称 解説 具体的な中身
model.features[0] 最初のCNNでconv1-1 Conv2d(3, 64,...)
model.features[1] 最初のReLU関数 ReLU(inplace=True)
model.features[2] 2番目のCNNでconv1-2 Conv2d(64, 64,...)

深い意味はありませんが、試しにconv1_1、conv5_3、maxpool_5の3種類の特徴量を抽出してみます。

サンプルコード
feature_maps = {}
def register_activation_hook(name):
    def hook(module, input, output):
        feature_maps[name] = output.detach()
    return hook

# 抽出するネットワーク名を登録する
hooks = []
hooks.append(model.features[0].register_forward_hook(register_activation_hook("conv1_1")))
hooks.append(model.features[28].register_forward_hook(register_activation_hook("conv5_3")))
hooks.append(model.features[30].register_forward_hook(register_activation_hook("maxpool_5")))

feature_mapsがネットワーク名(キー)とその出力値を対応させた辞書、hooksが登録したネットワーク名に関連するリストとなります。

サンプルコードの簡単な解説

  • feature_maps = {} :中間層の出力を保存するための辞書

  • register_forward_hook(関数) :引数を関数とするメソッドです。その関数はmodule、input、outputの3種類を引数とすると決められています。表現は不正確ですがfunction(module, input, output)をregister_forward_hookの引数にしてあげればOKということです。1

  • サンプルコード上ではmodule、input、outputを引数に持つ関数をhookと定義して、register_activatioin_hookの戻り値に指定しています。その際、ネットワーク名(name)とその出力値(output)が対応するようにしています。feature_maps[name] = output.detach()の部分です。register_activation_hook("conv1_1") とすることでconv1_1に対応するhook関数が作られます。

  • register_activation_hook(name) :引数を抽出したいネットワークの呼び名(name)、戻り値を前述のhook関数とする関数です。

  • hooksリストに抽出するネットワーク名を追加していきます

ポイント(分けて記述してみた)
1行目: conv_hook = register_activation_hook("conv1_1")
2行目: model.features[0].register_forward_hook(conv_hook)

1行目:conv1_1用のhook関数の作成
2行目 :conv1_1の出力の登録

2.5 出力値を可視化してみる

中間層の特徴量を出力するために、サンプルが画像を準備します。再び、ライオンさんの写真を使います:lion_face:

画像の準備
sample_image = Image.open("./data/lion.jpg").convert("RGB")  # 画像のロード
input_tensor = preprocess(sample_image)  # 画像の前処理 -> shape: (3, 224, 224)
input_batch = input_tensor.unsqueeze(0)  # shape: (1, 3, 224, 224)

# 推論
model.eval()
# GPUが利用可能であればGPUへ
if torch.cuda.is_available():
    input_batch = input_batch.to("cuda")
    model.to("cuda")

with torch.inference_mode():
    output = model(input_batch)

conv1_1の出力

feature_map = feature_maps["conv1_1"].squeeze().cpu().numpy()のように形状を整えてmatplotlibで画像を表示します。15枚の画像を表示してみましょう。あとは好みの方法で画像を表示すればOKです:smile:

conv 1_1の出力
network_name = "conv1_1"

feature_map = feature_maps[network_name].squeeze().cpu().numpy()

fig, axes = plt.subplots(3, 5, figsize=(10, 6))
plt.suptitle(network_name+"の特徴マップ")
for i, ax in enumerate(axes.flat):
    ax.imshow(feature_map[i], cmap="gray")
    ax.axis('off')
plt.tight_layout()
plt.show()

conv1_1.jpg

まだ、第1層目なのでライオンの面影がしっかり残っています。カーネルも表示すればどのように特徴を捉えているのかもわかる可能性があります。

conv5_3の出力

最後の畳み込み層の出力値も見てみましょう。feature_maps["conv5_1"]で値を抽出して
.squeeze().cpu().numpy()のように形状を適切に整えてmatplotlibで画像を表示します。

conv 5_3の出力
network_name = "conv5_3"

feature_map = feature_maps[network_name].squeeze().cpu().numpy()

fig, axes = plt.subplots(3, 5, figsize=(10, 6))
plt.suptitle(network_name+"の特徴マップ")
for i, ax in enumerate(axes.flat):
    ax.imshow(feature_map[i], cmap="gray")
    ax.axis('off')
plt.tight_layout()
plt.show()

conv5_3.jpg

さすが最後の畳み込み層の出力結果、もやはライオンかどうかは不明な状況です。画像の白っぽい場所が注目している部分と考えられます。

maxpool_5の出力

プーリング層の出力値も同様に出力することができます。feature_maps["maxpool_5"]で値を抽出して、形状を適切に整えてmatplotlibで画像を表示します。

maxpool_5の出力
network_name = "maxpool_5"

feature_map = feature_maps[network_name].squeeze().cpu().numpy()

fig, axes = plt.subplots(3, 5, figsize=(10, 6))
plt.suptitle(network_name+"の特徴マップ")
for i, ax in enumerate(axes.flat):
    ax.imshow(feature_map[i], cmap="gray")
    ax.axis('off')
plt.tight_layout()
plt.show()

maxpool_5.jpg

先程同様もやはライオンかどうかは不明な状況です。一応、画像の白っぽい場所が注目している部分と考えられます。

3 次回

:cactus: 中間層の特徴量の抽出方法については一旦終了
:cactus: 次回は再帰ネットワークの予定

目次ページ

  1. (module, input, output)の3種類ですが、これは(x, y, z)みたいな文字でももちろんOKです。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?