PyTorchで深層学習モデルの**中間層の出力(特徴量)**を取得したいとき、モデル構造を壊さずにアクセスできるのが「Hook(フック)」です。
この記事では、WideResNet50 を例にとって、forward_hook
を使って中間特徴マップを抽出する方法を、実用コードとともに解説します。
0.Hookとは?なぜ使うのか?
PyTorchの Hook(フック) とは、モデルの 特定の層の順伝播や逆伝播のタイミングで、入力や出力にアクセスできる仕組み です。とくに forward_hook
を使うと、「順伝播(forward)」の 出力 をキャッチできます。
フックが活躍する場面:
- 中間層の特徴マップ(feature map)を抽出したいとき
- 活性化値を可視化したいとき(例:Grad-CAM)
- 複数スケールの特徴を使った異常検知(例:PaDiM、PatchCore)
- モデル内部のデバッグ
1. モデルと環境の準備
import torch
import torchvision.models as models
import os
# Intel MKLの重複読み込みを防止(環境依存の対策)
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
# GPUが利用可能ならGPU、なければCPUを使用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Current Device is {device}")
# ImageNetで事前学習されたWideResNet50を読み込み
model = models.wide_resnet50_2(weights=models.Wide_ResNet50_2_Weights.IMAGENET1K_V1)
model.eval().to(device)
2. Forward Hookの定義と登録
出力保存用のリストを準備
outputs = []
フック関数の定義
def hook(module, input, output):
outputs.append(output.clone().detach().cpu().numpy())
この関数は、対象の層が順伝播した直後に自動的に呼び出され、その出力テンソルをNumPy形式で outputs
に保存します。
フックを登録する
model.layer1[-1].register_forward_hook(hook)
model.layer2[-1].register_forward_hook(hook)
model.layer3[-1].register_forward_hook(hook)
ここで layer1[-1]
は、ResNetのlayer1ブロックの最後のBottleneck層を指しています。
3. ダミー画像で動作を確認
dummy_input_tensor = torch.randn(1, 3, 512, 512).to(device)
with torch.no_grad():
_ = model(dummy_input_tensor)
4. フックの結果を確認する
for i, layer_output_np in enumerate(outputs):
print(f"Output from Hook {i+1}")
print(f" Shape: {layer_output_np.shape}")
print(f" First values: {layer_output_np.flatten()[:5]}")
例:
Output from Hook 1
Shape: (1, 256, 128, 128)
First values: [0.003, 0.014, 0.018, ...]
5. 活用例:異常検知や可視化に
取得した中間特徴は、以下のような応用が可能です:
活用法 | 説明 |
---|---|
PaDiM・PatchCore | 異常検知用に特徴を保存し、類似度や距離を計算 |
Grad-CAM | 勾配と組み合わせて注目領域を可視化 |
転移学習の前段処理 | 中間出力を別モデルへ渡して学習に利用 |
6. 注意点と補足
- フックは forward(順伝播)時のみ 呼ばれます(
backward_hook
は逆伝播用) - 登録されたフックは解除しないと残り続けます。後述の
handle.remove()
を使いましょう - 多くの層にHookを登録するとメモリを圧迫します
補足①:ResNetのブロック構造と layer1[-1]
ブロック名 | Bottleneck数 | 出力チャンネル | 出力サイズ(入力224x224時) |
---|---|---|---|
layer1 |
3 | 256 | 56×56 |
layer2 |
4 | 512 | 28×28 |
layer3 |
6 | 1024 | 14×14 |
layer4 |
3 | 2048 | 7×7 |
layer4
は空間分解能が小さすぎるため、局所的な異常検知には不向きなことがあります。したがって、通常は layer1
〜layer3
を使用します。
補足②:register_forward_hook()
の構造と文法
以下のような関数を登録することで、モデルの各層の「出力」を取得できます:
def hook(module, input, output):
print(output.shape)
handle = model.layer2[-1].register_forward_hook(hook)
引数の意味:
引数 | 内容 |
---|---|
module |
呼び出されたモジュール(例:Conv2d) |
input |
入力(タプルで渡される) |
output |
出力テンソル(これを保存したり可視化) |
フックの解除方法:
handle.remove()
Hookの一連の流れ(図解)
まとめ
- PyTorchの
forward_hook
を使うと、中間層の出力を簡単に取得できる - ResNetの
layer1[-1]
のような書き方で、任意の層の出力を抽出可能 - 特徴抽出・異常検知・可視化など、多くの応用に使える
- Hookは登録→使用→解除の流れで安全に活用