#経緯
大学院の講義の課題で映像や画像処理の論文を読んで実装するというものが出たので、セマンティックセグメンテーションに関する実装をやってみようと思ったのですが、うまく実装できませんでした、というお話です。もはや単位は怪しいですが「こんな記事を残したままにはできないだろう!」という未来の自分へ課題を与えるために初Qiita記事を書き残しておきます。
##セマンティックセグメンテーションとは
画像内の物体をその輪郭まで正確に推定して、ラベル付けやカテゴリ分けを行うことをセマンティックセグメンテーションと言います。
[参考] セマンティックセグメンテーション - MATLAB & Simulink
単純な画像であれば既存のモデルでも十分な結果を出せますが、画像の情報量が多く複雑になると境界の検出やラベル付けも難しくなります。課題に取り組もうとした当初はGoogleによるセマンティックセグメンテーションである「DeepLab v3+」について書かれた2018年の論文「Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation」に目を通していました。高いパフォーマンスであるセマンティックセグメンテーションモデルなのですが、構造はシンプルで既存の手法であった空間ピラミッドプーリングとエンコーダ-デコーダ(とAtrous畳み込み)を組み合わせた比較的わかりやすい形になっています。
さて、機械学習歴(≒プログラマー歴)半年程度の私は、機械学習のコードを書くときはtensorflow(というよりKeras)派なのですが、DeepLab v3+はGoogleのモデルですので公開されているコードは言わずもがなtensorflowで書かれています。そのままtensorflowで再実装するのも味気ないので、今まで触れてこなかったpytorchを勉強する目的も兼ねてDeepLab v3+をpytorchで実装しようと思いました。
しかし、論文だけでは実装の細かい部分がわからず、ソースコードも公開はされているのですが、tensorflowの2系から入った私にとって1系で書かれたコードはなかなか読めず……。このままでは課題が終わらないと思い、DeepLab v3+の実装は一旦諦め、DeepLab v3+を構成する要素である空間ピラミッドプーリングを用いたモデルである「PSPNet」の実装をまずやることにしました。
※なお、2020/8/5現在、「PSPNet」の実装はまだできていません。
#PSPNet
さて、ようやくここから本題です。
PSPNetは論文「Pyramid Scene Parsing Network」にて2017年に提唱された手法で、ピラミッド構造のプーリングを挿入しています。これにより、局所的情報だけでなく、大域的な情報を獲得することに成功し、それ以前の手法より高いパフォーマンスを発揮しました。論文の結果(以下の図)を見ると、確かに細かい部分まで大域的な情報を失わずにセグメンテーションができていることがわかります。
- 入力をCNNに通し特徴マップを得る(今回はResNet50と呼ばれるモデルを用いて転移学習をする)
- 特徴マップをPPM(Pyramid Pooling Module)に通す
- 畳み込み計算を行い、出力を得る(この部分に関しては論文中に詳しく言及されていないっぽい?)
従って、この1~3を実装できれば、それらを組み合わせてPSPNetが構築できることになります。
###CNN部分
上記で述べた通り、ここではResNet50のモデルを用います。そして嬉しいことに、pytorchではtorchvision.modelsを使うことでResNetが簡単に定義でき、学習済みモデルも使用することができます。
[参考] 実践Pytorch
ただ、このモデルをそのまま使うことはできないので、多少書き換える必要があります。torchvision.models.resnetのドキュメントを見ると、畳み込み層の後ろにプールする層と出力する層がありますので、今回は最後の2つの層を削ります。すると、CNN部分は次のように書けます。
import torch
from torch import nn
import torchvision.models as models
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
resnet = models.resnet50(pretrained=True)
self.resnet = nn.Sequential(*list(resnet.children())[:-2])
def forward(self, x):
x = self.resnet(x)
return x
###PPM部分
本手法の核となる部分です。ここで行っている手順は次の通りです。
- CNNを通して得られた特徴マップについて、サイズが1×1、2×2、3×3、6×6の4種類でプーリングを行う
- プーリングしたそれぞれに1×1の畳み込みを行い、チャネル数を1/N(Nはピラミッドの個数、つまりN=4)にする
- それぞれにバイリニア補完で元の特徴マップと同じサイズにアップサンプリングし、元の特徴マップと合わせて結合する
この大きさの異なる複数種類のプーリングを結合することが、局所的情報と大域的情報を獲得することに寄与していると考えられます。ちなみに先に述べたDeepLab v3+ではプーリング部分をAtrous畳み込みという手法にすることで、極力データを圧縮しないで情報を保存して、より高精度なパフォーマンスを出しているようです。
さて、この部分のコードは次のように記述しました。
class PPM(nn.Module):
def __init__(self, in_channels, out_channels, bins):
super(PPM, self).__init__()
self.mods = []
for bin in bins:
self.mods.append(nn.Sequential(
nn.AdaptiveAvgPool2d(bin),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
self.mods = nn.ModuleList(self.mods)
def forward(self, x):
size = x.size()
concat = [x]
for module in self.mods:
concat.append(nn.Upsample(size=size[2:], mode='bilinear', align_corners=True)(module(x)))
return torch.cat(concat, 1)
bins = [1,2,3,6]
とすれば1×1、2×2、3×3、6×6の4種類の処理がキレイに書けます。先月の私だったら一個一個ハードコーディングしてしまいそうですが、pytorchだとこのようにまとめて記述しやすいのが良いですね。なお、2つ目のチャネル数を1/Nにする処理は、out_channels = in_channels / N
とすれば良いので、こちらはインスタンスを作るときに処理します。
また、ReLUのinplace=True
やConv2dのbias=False
といったオプションの部分は先に述べたResNetの中身を見て、それを真似して書いています。なぜ、このようにしないといけないのかという理由に関しては正直よく理解できていないです。勉強しないと……。
###PSPNet実装
最後の畳み込みの処理の部分が残っていますが、今まで書いたCNNとPPMをまとめてPSPNetモデルを作成します。
class PSPNet(nn.Module):
def __init__(self, bins=[1, 2, 3, 6], dropout=0.1, classes=21):
super(PSPNet, self).__init__()
self.cnn = CNN()
feature_channels = 2048
self.ppm = PPM(feature_channels, int(feature_channels/len(bins)), bins)
self.conv = nn.Sequential(
nn.Conv2d(feature_channels*2, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout),
nn.Conv2d(512, classes, kernel_size=1)
)
def forward(self, x):
size = x.size()
x = self.cnn(x)
x = self.ppm(x)
x = self.conv(x)
x = nn.Upsample(size=size[2:], mode='bilinear', align_corners=True)(x)
return x
classesはセグメンテーションで何種類のオブジェクトを分類するかによって決まる数値です。今回はPASCAL VOC 2012のデータセットを用いようと思ったので21種類が対象になります。なお、畳み込みの層の部分だけは論文だけでは処理がわからなかったので元のソースコードを参考して書きました。畳み込み層のインプットのチャネルが特徴マップのチャネルの2倍になっているのは、PPMで記述したように、元の特徴マップと、チャネル数を1/Nしたピラミッドの層がN個(つまり1+1/N*N
)がインプットになるからです。
これにより、PSPNetモデルの定義ができたことになります。
##データ準備
あとはデータを流し込めば訓練ができるわけですが、PASCAL VOC 2012の前処理でかなりハマりました。というか未だによくわかっていないです。思考を放棄してこちらの記事を参考にさせていただきました。論文中にはData Augmentation(データ水増し)の仕方についても言及されているのですが、とりあえず先ほど作ったモデルがちゃんと動作するか確認したかったので、Data Augmentationは行いませんでした。
結局、それ以前の問題があった(後述)ですが、ここらへんの処理をちゃんと理解したら、こちらに追記したいと思います。
##学習
水増しはできていないので過学習の恐れはありますが、とりあえずモデルの学習が進むかを確認します。
しかし、ここでトラブルが。
###GPUが割り当てられない
私は自宅のPCでVRとかも動かすタイプの人なのでGTX 1080Tiというそれなりに良いグラボを積んでいるのですが、いざ学習を始めようとすると、
RuntimeError: CUDA out of memory. Tried to allocate 2.42 GiB (GPU 0; 11.00 GiB total capacity; 6.37 GiB already allocated; 2.35 GiB free; 6.38 GiB reserved in total by PyTorch)
というエラーが。えっ、1080Tiでダメですか?確かにPASCAL VOC 2012自体が2GB近くあるデータセットなのでなかなか処理が大変ではあるだろうけど、Data Augmentationなしでこれなのか……。
初めてpytorchで機械学習を行うのでもしかしたら環境構築あたりで失敗している可能性もありますが……。
###CPUに切り替えるとメモリが足りない
仕方なくGPUで計算するのを諦めて、CPUの計算にします。CPUで計算するとなると非常に遅くなるイメージですが、私が使っているCPUはRyzen 9 3950Xですので他のCPUと比べればだいぶ早い方だから何とかなるだろうと信じて学習を開始します。
RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 162267136 bytes. Buy new RAM!
えー、メモリ128GB積んでいるんですけど……。
ちなみにjupyter notebookで作業しているんですが、このメモリエラーに関しては、カーネルを再起動すればギリギリ何とかなることがわかりました。
###損失関数に交差エントロピーを指定するとエラーが出る
分類問題なので損失関数は交差エントロピーを選択しました。が、損失計算するところでエラーが。こちらは調べれば解決しそうで現在調査中なんですが、もしかしたらモデルの出力の形やデータセットの整形の仕方を変える必要があるかもしれません。
一旦、平均二乗誤差に損失関数を入れ替えたところ、学習は進んだので、モデル自体は問題なさそうです。
#反省と課題
そんなわけでモデルの評価までたどり着きませんでした。初めての画像処理、初めてのpytorchということもあって色んな部分で躓いて時間が足りなくなってしまいました。こればかりはたくさんコード書いたり読んだりして精進するしかありませんね。
とは言え、モデルの評価までやるとなると、Data Augmentationあたりはしっかりやる必要が出てくると思うので、おそらくローカル環境だと実行が厳しいかもしれません。前処理等がきちんとできればうまくいくかもしれませんが……。
本当はgithubでコードを公開するところまでが課題になるのですが、今回上記のモデル以外はエラーを吐いている学習部分のコードしかないので、少なくとも交差エントロピーの部分のエラーを解消して、精度悪くてもちゃんとした学習を終えてから全体のコードを公開しようと思います(さらば単位)。プログラマー歴少なすぎるのでたぶんそれが初コミットになりそう。
こちらのページにどんどん追記していくか、もしくは新しい記事を書き直すかもしれませんが、実際にセグメンテーションを行うところまでは実装したいと思います。忙しいけど8月中には何とかしたいですね。