LoginSignup
8
6

More than 1 year has passed since last update.

Mask R-CNNの層を深くしたり広くしたりした

Last updated at Posted at 2021-06-25

0.はじめに

この記事では、単にMaskRCNNのモデルのいじり方を紹介するものではありません。自分がどのように調べてモデルを変更するまでにたどり着いたのかを紹介する記事です。したがって、解説記事のように推敲してまとめたような記事ではありません。ご了承ください。まとまった記事はそのうち書きます。

私も初心者なので、「このコードを見るとここのネットワークがこうなっているから、ここをこういじればいいじゃん!」ということは全くできませんでした。最初はみんなそうだと思います。

結局どこをいじればいいんだって思うのであれば最後の部分を見てください。

1.前提

使ったフレームワーク
PyTorch 1.7
torchvision 0.8

2.実際にやったこと

2.1はじめに

PyTorchの公式文書に従ってMask R-CNNを作成したものの、精度がいまいち出ていない。「モデルのネットワーク構造を変更して精度を上げられないだろうか」と考えた。データを学習しやすい形になるように調整するのに飽きてしまったわけではない。
そこで、まずは他の人が解説記事を出していないか調べたが、欲しい点をカバーしている記事を見つけることはできなかった。そこで、公式が公にしているコードを読んで何とかしよう、と考えるに至った。

番外編1 解説記事の発見

そんな折、素晴らしい解説記事を見つけた。Mask R-CNNのネットワーク構造、学習の進め方(今回は書いていないけど)などはほとんどこの記事から勉強させていただいた。本当にありがとうございます。

2.2モデルのネットワーク構造の確認

本当はどんなネットワーク構造をしているのか事細かに理解してモデルを変更する、なんてことをやりたいのだが、私も初心者であるし、そんな難しいことができない。そこで、「どのようにプログラムをするとネットワーク構造が作られていくか」という点に着目してネットワーク構造の解析(というより自分なりの解釈)を試みた。

1 - Finetuning from a pretrained model
Let’s suppose that you want to start from a model pre-trained on COCO and want to finetune it for your particular classes. Here is a possible way of doing it:

    import torchvision
    from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

    # load a model pre-trained pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    # replace the classifier with a new one, that has
    # num_classes which is user-defined
    num_classes = 2  # 1 class (person) + background
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

これを見ると、初めに
"model=torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)"
の部分でモデルを作成し、その後、
"model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)"
の部分でネットワーク構造を何かいじっているのだろうか?ということを考えた。

2.3公式コードを読んでいく

"model=torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)"を実行することで一体どんなことが行われているのかを理解するために、公式のgithubを読んでいくことにした。
"torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)"を見るに、「torchvisionの中のmodelの中のdetectionの中のfastrcnn_resnet50_fpnを見てみよう」と考えた。
それがここである。

ここを読むと無駄な部分が多いがコメントアウトによってたくさん説明されているが、コメントアウト部分以外の部分を見ると、

def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
                          num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):

trainable_backbone_layers = _validate_trainable_layers(
        pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)

if pretrained:
    pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained:
    state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
                                              progress=progress)
    model.load_state_dict(state_dict)
    overwrite_eps(model, 0.0)
return model

となっている。これを読むと"trainable_backbone_layers"を指定した後、"backbone"を指定し、上の方で定義されているMaskRCNNのclassを作っているのだろう、ということが読めた。
つまり(現状の)流れとしては、
1."trainable_backbone_layers"を指定
2."backbone"の指定
3."MaskRCNN"のclassを作成
4."model.roi_heads.box_predictor"を作成
という流れになるだろう、と感じた。

余談1

image.png
こうしてみると、

from torchvision.models.detection.mask_rcnn import ~~~

とあらわされるように、基本的には"."によって下の階層(フォルダ、または.pyファイル)に進んでいくんだな、ということが分かった。

2.3.1 trainable_backbone_layers

この定義を確認してみると"_validate_trainable_layers"によって定義されていることがわかる。これは、mask_rcnn.pyの最初の方にimportしているように、同じ階層にある"backbone_utils.py"に格納されているらしい。

image.png
確認してみると、何らかの数字をreturn しているらしい。名前的に、学習させる層の数だろうか。

image.png
もう一歩踏み込んで、"resnet_fpn_backbone"の関数の中でどのようにふるまっているのかを確認した。
今回の"resnet_fpn_backbone"の中では、この値は0~5の範囲で選ぶらしい。
image.png

2.3.2 backbone

2.3.1と同じmaskrcnn_resnet50_fpnの中で、backboneを定義している部分を探してみる。そうすると、以下のように定義している部分があることがわかる。
image.png

そこで、resnet_fpn_backboneが定義されている箇所を確認する。

image.png

すると、引数に"backbone_name"というものがあることに気が付いた。
コメントアウトの部分を見てみると
image.png

と書いてあった。この部分を変更することで様々なモデルのネットワークを使えるようだ。無駄な部分が多いとか言ってしまって申し訳ない

2.3.3 MaskRCNNのclassを作成

ここはfastrcnn_resnet50_fpnの中身にもあるように

model = MaskRCNN(backbone, num_classes, **kwargs)

のようなコードを使えば十分だろうと思う。
しかし、MaskRCNNクラスの中身、特に"init"の中身をよく見ると

image.png
このように"mask_roi_pool", "mask_head", "mask_predictor"の定義が書いてあり、それぞれで"MultiScaleRoIAlign", "MaskRCNNHeads", "MaskRCNNPredictor"によって定義されていることがわかる。デフォルトの設定でもfastrcnn_resnet50_fpnの中身を見る限り十分動作するプログラムであることは読み取れる。
これは2.3.4の後に中身を確認したい。

2.3.4 model.roi_heads.box_predictor

この部分のclassであるFastRCNNPredictorを見てみると以下のようになっていた。
image.png
これを見ると、特に大きな変更ができそうな余地はないように見える(関数を変更したりはできるけれども)ので、このまま使えそうだということが分かった。

2.3.5 mask_roi_pool

MultiScaleRoIAlignのclassを見てみたが、正直全然中身を読み取れなかった。これは後々、勉強していきたい。

2.3.6 mask_head

MaskRCNN class内の該当箇所を見てみると
image.png
となっている。これを見ると、MaskRCNNHeadsのclassに対して"out_channels","mask_layers","mask_dilation"を決めればモデルのネットワークを変更できそうだ、ということがわかる。MaskRCNNHeadsのclassを見ても、これは単にネットワークを作っているだけのように見受けられるので、この理解であっていそうな気がする。
例えばこの部分でmask_layersを

mask_layers=(256,256,512,512,1024)

なんて形へ変更できそうだ。

ちなみに、ここの"mask_dilation"の役割については説明が難しかった。
image.png
機能としては図のように、nn.Conv2dの関数の中に定義されているらしい。デフォルトでいいかな、と判断。
これらをまとめて、この部分を変更するにはどうすればよいのかを考えたときに、

model.roi_heads.box_predictor=FastRCNNPredictor(in_channels, num_classes)

によって"model"内の"box_predictor"を変更していたことを参考にしてみた。

モデルは

model=MaskRCNN(backbone,num_classes)

で定義しているため、MaskRCNNのclassの中の"init"によって"mask_head"が定義されていると考えられる。よく見ると、out_channelsの値を用いてこの部分を作っているため、

image.png

この部分の定義の仕方にのっとって、

out_channels = model.backbone.out_channels
mask_layers = (256,256,512,512,1024)
mask_dilation = 1
mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)

と記述することでこの部分を変更できそうだとわかった。

2.3.7 mask_predictor

MaskRCNNPredictorのclassも基本的には引数を元にネットワークを作るクラスだということがわかる。
image.png
ここも数字を適当に繕って作成すればよさそうだが、

if mask_predictor is None:
    mask_predictor_in_channels = 256  # == mask_layers[-1]
    mask_dim_reduced = 256
    mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels,
                                       mask_dim_reduced, num_classes)

コメントアウトに"==mask_layers[-1]"と記載されている。これは"mask_layers"の最後の値と一致させることを要求されていると思われる。今回のような場合

mask_layers=(256,256,512,512,1024)

であれば、

mask_predictor_in_channels = 1024

とすればよいことがわかる。

これを利用してこの部分のネットワークを書き換えるなら、"model"内の"mask_predictor"を変更すればよいので、

mask_predictor_in_channels=mask_layers[-1]
mask_dim_reduced=256 #デフォルト値だが変更可能
model.mask_predictor=MaskRCNNPredictor(mask_predictor_in_channels,
                                       mask_dim_reduced, num_classes)

とすれば変更できると思われる。

2.4 まとめ

ざっくりとまとめると、モデルのネットワークの変更は次のようになる。
ここでは、

  • resnet101の学習済みモデルを使う
  • classの数は20
  • trainable_layersは4にする
  • mask_layersを(256,256,512,512,1024)にする
  • mask_dim_reducedを256にする

この条件のもと、MaskRCNNのモデルを作成すると

from torchvision.models.detection.backbone_utils import resnet_fpn_backbone #resnetの学習済みモデルを使うのに必要
from torchvision.models.detection.mask_rcnn import MaskRCNN, MaskRCNNHeads, MaskRCNNPredictor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
num_classes=20
backbone=resnet_fpn_backbone('resnet101', pretrained=True,
trainable_layers=4)
model=MaskRCNN(backbone, num_classes)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
#ここまで元となるモデル作り
#公式のチュートリアルとだいたい同じことをしている
#ここから下は自分で新しくネットワークを追加していく作業になる
out_channels = model.backbone.out_channels 
mask_layers = (256,256,512,512,1024)
mask_dilation = 1
mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
mask_predictor_in_channels=mask_layers[-1]
mask_dim_reduced=256 #mask_reducedの値
model.mask_predictor=MaskRCNNPredictor(mask_predictor_in_channels,
                                   mask_dim_reduced, num_classes)

となる。これで、Mask R-CNNを改造するための第一歩になったのではないかと思う。動作は確認しました。

3 最後に

お読みいただきありがとうございました。
自分が苦労して学んだことをこういう形で残せれば、そして同じ苦労をする人が減らせればと思います。もう一つ、「こうすれば自力でモデルを変更できるようになった」という一つの参考になればと思います。
ちなみにもうすでにもっとわかりやすい解説記事があれば教えてください。私はそっちを読みますし、この記事にもそのリンクを貼り付けます。

8
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
6