2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

mmsegmentationのモデルをダウンロードして中間層の特徴量を取り出してみる

Posted at

はじめに

  • mmsegmentationとはsemantic segmentationを実現するためのpytorchベースのToolboxです。
  • 多数のアーキテクチャのモデルが提供されており、config一つで学習・推論できます。
  • この記事ではmmsegmentationのモデルで中間層の特徴量を取り出すコードを紹介します。

環境準備

ライブラリのインストール

git clone https://github.com/open-mmlab/mmsegmentation.git
cd mmsegmentation
pip install -v -e .
pip install mmcv-full
  • 動作確認環境
    • python 3.6.8
    • mmsegmentation(0.29.1)
      • configを読み込む場合、githubからcloneした方が良いかもしれません。
      • wgetだとconfigがうまく開けませんでした。
    • mmcv-full(1.7.0)

モデルのダウンロード

checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes/deeplabv3plus_r50-d8_512x1024_40k_cityscapes_20200605_094610-d222ffcd.pth'

!pip install wget
# google colabで実行する場合は”wget $checkpoint_file”でOK
!python -m wget $checkpoint_file

中間層の出力を得る方法

pytorchのmodelはnn.Moduleクラスを継承しています。
nn.Moduleクラスのregister_forward_hookメソッドにforward後に呼び出す関数を引数指定することで、中間層の出力を得ることが出来ます。
他にも中間層の出力を得る方法はあるようですが、今回はregister_forward_hookを使用していきます。

コード

指定した中間層の出力を得る関数

def middle_layer_output(target, inputs):
    """
    中間層の特徴量を得る関数
    """
    feature = None

    # forward後に呼び出される関数
    def forward_hook(module, inputs, outputs):
        global features
        # detach()で計算グラフから切り離し、clone()で値の共有をしないようにする
        features = outputs.detach().clone()

    handle = target.register_forward_hook(forward_hook)
    model.eval()
    # NOTE: mmsegmentationでない場合はmodel(inputs)で良い
    inference_segmentor(model, inputs)
    handle.remove()
    return features

こちらの関数を使用します。

コード全体

from pathlib import Path
import mmcv
from mmseg.apis import init_segmentor, inference_segmentor

mmseg_path = Path('./mmsegmentation') # path to your mmsegmentation
config_file_path = mmseg_path / 'configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py'
checkpoint_file_path = './deeplabv3plus_r50-d8_512x1024_40k_cityscapes_20200605_094610-d222ffcd.pth'

# configの読み込み
cfg = mmcv.Config.fromfile(config_file_path)

# 複数GPUでBNを行うSyncBNが定義されており、single CPU/GPUで実行するとエラーになってしまうため、configのSyncBNをBNに書き換える
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg

# モデルの読み込み
model = init_segmentor(cfg, checkpoint_file_path, device='cpu')

def middle_layer_output(target, inputs):
    """
    中間層の特徴量を得る関数
    """
    feature = None

    # forward後に呼び出される関数
    def forward_hook(module, inputs, outputs):
        global features
        # detach()で計算グラフから切り離し、clone()で値の共有をしないようにする
        features = outputs.detach().clone()

    handle = target.register_forward_hook(forward_hook)
    model.eval()
    # NOTE: mmsegmentationでない場合はmodel(inputs)で良い
    inference_segmentor(model, inputs)
    handle.remove()
    return features

# 特徴量を得たい中間層を指定
target = model.backbone.layer4[-1].conv3
img = 'demo.jpg'
middle_layer_output(target, img)

References

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?