LoginSignup
22
15

More than 1 year has passed since last update.

PyTorchで深層学習モデルの中間層の特徴ベクトルを取り出す方法のまとめ

Posted at

はじめに

学習済みの深層学習モデル(特に画像認識のCNN)は特徴抽出器としても有用で、予測するだけでなく中間層の特徴ベクトルを取り出して別のことに使いたいこともよくあります。

PyTorchで中間層の特徴ベクトルを取り出すには、モデルの必要ない層を恒等関数に置き換えたり、必要な層だけを取り出して新たなモデルを作成したりするなど、何通りか方法があります。
特に、Torchvisionのv0.11から新たに feature_extraction というモジュールが追加されたことで、簡単に任意のモデルから中間層の特徴ベクトルを取り出せるようになっています1

本記事では新しい feature_extraction を使う方法と、それ以前の方法をまとめます。

feature_extraction モジュールを使う方法

torchvision.models.feature_extraction.create_feature_extractor を使用すると任意のモデルの任意の中間層の特徴ベクトルを取り出すモデルを作成してくれます。
以下のコードはTorchvisionで提供されているResNetから、最後の畳み込み後の特徴ベクトルを取り出す例です。

import torch
from torchvision.models import resnet18
from torchvision.models.feature_extraction import create_feature_extractor

net = resnet18()
feature_extractor = create_feature_extractor(net, {"avgpool": "feature"})

x = torch.rand((1, 3, 224, 224)).float()
feature_dict = feature_extractor(x)
print(feature_dict["feature"].shape)

create_feature_extractor の第一引数にモデルを、第二引数に {"層の名前": "参照したい名前"} の形式の辞書を渡します。第二引数では取り出したい中間層を同時に2つ以上指定することもできます。
返却される feature_extractor は普通のモデルと同じように使えます。
出力として辞書が返ってくるので、先程指定した名前をキーとして中間層出力を参照できます。

feature_extraction 以前の方法

上記以外の方法で学習済みモデルから中間層の特徴ベクトルを取り出したい場合、主に3通りの方法があります。

  • 不要な層を恒等関数に置き換える
  • 必要な層のみを取り出して新たなモデルを作成
  • forward_hook の利用

不要な層を恒等関数に置き換える

取り出したい層以降の層を nn.Identity() のような恒等関数に置き換えます。

import torch
from torchvision.models import resnet18
from torch import nn

net = resnet18()
net.fc = nn.Identity()

x = torch.rand((1, 3, 224, 224)).float()
feature = net(x)
print(feature.shape)

最終層の分類をするための全結合層 net.fcnn.Identity() で上書きしています。
取り出したい層以降はすべて上書きする必要があるため、層の多いモデルで真ん中や最初の方を取り出したい場合は不便です。

必要な層のみを取り出して新たなモデルを作成

入力層から取り出したい層までを nn.Sequential に渡したり自作モデルに組み込んだりして新たに特徴抽出用のモデルを作成します。

import torch
from torchvision.models import resnet18
from torch import nn

net = resnet18()
layers = list(net.children())[:-1]
feature_extractor = nn.Sequential(*layers)

x = torch.rand((1, 3, 224, 224)).float()
feature = feature_extractor(x)
print(feature.shape)

net.children()resnet18 が持つ層を列挙し、最終層以外を nn.Sequential に渡してモデルを作成しています。
最初や最後の方の層以外を取り出したい場合は [:-1] のようにインデックスで指定するとミスが起きやすいです。
nn.Sequential(net.layer1, net.layer2, ...) のように1つずつ名前で指定するのも層数が多いと大変です。

forward_hook の利用

PyTorchのモデルなら必ず継承しているはずの nn.Module クラスには register_forward_hook というメソッドがあります。
これを使うことで、そのモデル(層)の forward 実行後に呼び出す関数を登録できるので、呼び出し先で出力を記録します。

import torch
from torchvision.models import resnet18

def store_feature(module, input, output):
    global feature
    feature = output

net = resnet18()
net.avgpool.register_forward_hook(store_feature)

x = torch.rand((1, 3, 224, 224)).float()
y = net(x)
print(feature.shape)

store_feature に渡された outputnet.avgpool の出力になっており、グローバル変数 feature に記録しています。
必要な層だけを直接指定できるので上2つの方法より汎用性が高いですが、取り出したい層以降の計算も実行されてしまうので無駄があります。

おわりに

v0.11以降のTorchvisionを使っているなら feature_extraction が一番おすすめです。
古い3つの方法は一長一短ありましたが、 feature_extraction ではいずれの短所も解消されています。

22
15
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
15