0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【PyTorch】モデルの中間層の出力を取得する「Hook」の使い方:ResNetでの具体例付き

Posted at

PyTorchで深層学習モデルの**中間層の出力(特徴量)**を取得したいとき、モデル構造を壊さずにアクセスできるのが「Hook(フック)」です。

この記事では、WideResNet50 を例にとって、forward_hook を使って中間特徴マップを抽出する方法を、実用コードとともに解説します。

0.Hookとは?なぜ使うのか?

PyTorchの Hook(フック) とは、モデルの 特定の層の順伝播や逆伝播のタイミングで、入力や出力にアクセスできる仕組み です。とくに forward_hook を使うと、「順伝播(forward)」の 出力 をキャッチできます。

フックが活躍する場面:

  • 中間層の特徴マップ(feature map)を抽出したいとき
  • 活性化値を可視化したいとき(例:Grad-CAM)
  • 複数スケールの特徴を使った異常検知(例:PaDiM、PatchCore)
  • モデル内部のデバッグ

1. モデルと環境の準備

import torch
import torchvision.models as models
import os

# Intel MKLの重複読み込みを防止(環境依存の対策)
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

# GPUが利用可能ならGPU、なければCPUを使用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Current Device is {device}")

# ImageNetで事前学習されたWideResNet50を読み込み
model = models.wide_resnet50_2(weights=models.Wide_ResNet50_2_Weights.IMAGENET1K_V1)
model.eval().to(device)

2. Forward Hookの定義と登録

出力保存用のリストを準備

outputs = []

フック関数の定義

def hook(module, input, output):
    outputs.append(output.clone().detach().cpu().numpy())

この関数は、対象の層が順伝播した直後に自動的に呼び出され、その出力テンソルをNumPy形式で outputs に保存します。

フックを登録する

model.layer1[-1].register_forward_hook(hook)
model.layer2[-1].register_forward_hook(hook)
model.layer3[-1].register_forward_hook(hook)

ここで layer1[-1] は、ResNetのlayer1ブロックの最後のBottleneck層を指しています。

3. ダミー画像で動作を確認

dummy_input_tensor = torch.randn(1, 3, 512, 512).to(device)

with torch.no_grad():
    _ = model(dummy_input_tensor)

4. フックの結果を確認する

for i, layer_output_np in enumerate(outputs):
    print(f"Output from Hook {i+1}")
    print(f"  Shape: {layer_output_np.shape}")
    print(f"  First values: {layer_output_np.flatten()[:5]}")

例:

Output from Hook 1
  Shape: (1, 256, 128, 128)
  First values: [0.003, 0.014, 0.018, ...]

5. 活用例:異常検知や可視化に

取得した中間特徴は、以下のような応用が可能です:

活用法 説明
PaDiM・PatchCore 異常検知用に特徴を保存し、類似度や距離を計算
Grad-CAM 勾配と組み合わせて注目領域を可視化
転移学習の前段処理 中間出力を別モデルへ渡して学習に利用

6. 注意点と補足

  • フックは forward(順伝播)時のみ 呼ばれます(backward_hook は逆伝播用)
  • 登録されたフックは解除しないと残り続けます。後述の handle.remove() を使いましょう
  • 多くの層にHookを登録するとメモリを圧迫します

補足①:ResNetのブロック構造と layer1[-1]

ブロック名 Bottleneck数 出力チャンネル 出力サイズ(入力224x224時)
layer1 3 256 56×56
layer2 4 512 28×28
layer3 6 1024 14×14
layer4 3 2048 7×7

layer4 は空間分解能が小さすぎるため、局所的な異常検知には不向きなことがあります。したがって、通常は layer1layer3 を使用します。


補足②:register_forward_hook() の構造と文法

以下のような関数を登録することで、モデルの各層の「出力」を取得できます:

def hook(module, input, output):
    print(output.shape)

handle = model.layer2[-1].register_forward_hook(hook)

引数の意味:

引数 内容
module 呼び出されたモジュール(例:Conv2d)
input 入力(タプルで渡される)
output 出力テンソル(これを保存したり可視化)

フックの解除方法:

handle.remove()

Hookの一連の流れ(図解)

hook.png

まとめ

  • PyTorchのforward_hookを使うと、中間層の出力を簡単に取得できる
  • ResNetの layer1[-1] のような書き方で、任意の層の出力を抽出可能
  • 特徴抽出・異常検知・可視化など、多くの応用に使える
  • Hookは登録→使用→解除の流れで安全に活用
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?