はじめに
学習済みの深層学習モデル(特に画像認識のCNN)は特徴抽出器としても有用で、予測するだけでなく中間層の特徴ベクトルを取り出して別のことに使いたいこともよくあります。
PyTorchで中間層の特徴ベクトルを取り出すには、モデルの必要ない層を恒等関数に置き換えたり、必要な層だけを取り出して新たなモデルを作成したりするなど、何通りか方法があります。
特に、Torchvisionのv0.11から新たに feature_extraction
というモジュールが追加されたことで、簡単に任意のモデルから中間層の特徴ベクトルを取り出せるようになっています1。
本記事では新しい feature_extraction
を使う方法と、それ以前の方法をまとめます。
feature_extraction
モジュールを使う方法
torchvision.models.feature_extraction.create_feature_extractor
を使用すると任意のモデルの任意の中間層の特徴ベクトルを取り出すモデルを作成してくれます。
以下のコードはTorchvisionで提供されているResNetから、最後の畳み込み後の特徴ベクトルを取り出す例です。
import torch
from torchvision.models import resnet18
from torchvision.models.feature_extraction import create_feature_extractor
net = resnet18()
feature_extractor = create_feature_extractor(net, {"avgpool": "feature"})
x = torch.rand((1, 3, 224, 224)).float()
feature_dict = feature_extractor(x)
print(feature_dict["feature"].shape)
create_feature_extractor
の第一引数にモデルを、第二引数に {"層の名前": "参照したい名前"}
の形式の辞書を渡します。第二引数では取り出したい中間層を同時に2つ以上指定することもできます。
返却される feature_extractor
は普通のモデルと同じように使えます。
出力として辞書が返ってくるので、先程指定した名前をキーとして中間層出力を参照できます。
feature_extraction
以前の方法
上記以外の方法で学習済みモデルから中間層の特徴ベクトルを取り出したい場合、主に3通りの方法があります。
- 不要な層を恒等関数に置き換える
- 必要な層のみを取り出して新たなモデルを作成
-
forward_hook
の利用
不要な層を恒等関数に置き換える
取り出したい層以降の層を nn.Identity()
のような恒等関数に置き換えます。
import torch
from torchvision.models import resnet18
from torch import nn
net = resnet18()
net.fc = nn.Identity()
x = torch.rand((1, 3, 224, 224)).float()
feature = net(x)
print(feature.shape)
最終層の分類をするための全結合層 net.fc
を nn.Identity()
で上書きしています。
取り出したい層以降はすべて上書きする必要があるため、層の多いモデルで真ん中や最初の方を取り出したい場合は不便です。
必要な層のみを取り出して新たなモデルを作成
入力層から取り出したい層までを nn.Sequential
に渡したり自作モデルに組み込んだりして新たに特徴抽出用のモデルを作成します。
import torch
from torchvision.models import resnet18
from torch import nn
net = resnet18()
layers = list(net.children())[:-1]
feature_extractor = nn.Sequential(*layers)
x = torch.rand((1, 3, 224, 224)).float()
feature = feature_extractor(x)
print(feature.shape)
net.children()
でresnet18
が持つ層を列挙し、最終層以外を nn.Sequential
に渡してモデルを作成しています。
最初や最後の方の層以外を取り出したい場合は [:-1]
のようにインデックスで指定するとミスが起きやすいです。
nn.Sequential(net.layer1, net.layer2, ...)
のように1つずつ名前で指定するのも層数が多いと大変です。
forward_hook
の利用
PyTorchのモデルなら必ず継承しているはずの nn.Module
クラスには register_forward_hook
というメソッドがあります。
これを使うことで、そのモデル(層)の forward
実行後に呼び出す関数を登録できるので、呼び出し先で出力を記録します。
import torch
from torchvision.models import resnet18
def store_feature(module, input, output):
global feature
feature = output
net = resnet18()
net.avgpool.register_forward_hook(store_feature)
x = torch.rand((1, 3, 224, 224)).float()
y = net(x)
print(feature.shape)
store_feature
に渡された output
が net.avgpool
の出力になっており、グローバル変数 feature
に記録しています。
必要な層だけを直接指定できるので上2つの方法より汎用性が高いですが、取り出したい層以降の計算も実行されてしまうので無駄があります。
おわりに
v0.11以降のTorchvisionを使っているなら feature_extraction
が一番おすすめです。
古い3つの方法は一長一短ありましたが、 feature_extraction
ではいずれの短所も解消されています。