傷に注目してNGになっている等がわかる。
判定結果は特に赤色部分に起因している。
OKは全体に分布しているのに対し、NGは傷があるところに集中している。
■はじめに
ディープラーニング(CNN)は画像分類タスクの自動化に不可欠となっています。
一方中身がブラックボックス(情報は取り出せるが数値の意味が人間が把握できない)で、
「精度が上がらないけどどこ見て判断してんの?」
「評価データでは精度高いけど、今後未知のデータでも大丈夫?」
という点が気になることもしばしば。
そこで判断根拠のヒントを得ようとする取り組みも多数あるようです。
ディープラーニングの判断根拠を理解する手法
■やったこと
今回は金属表面のキズ判定のために学習させたCNNネットワークにClass Activation Mapping(CAM)を適用して、
判定結果がどの領域の特徴と相関が強くあるのかを可視化してみました。
検証用動画、学習済みモデルを含むMATLABサンプルコード
Class Activation MappingはMITのComputer Science and Artificial Intelligence Laboratoryのチームが発表されたもののようです。(※1)
利用するネットワーク
転移学習でキズ有無によるOK/NGを判定できるようにしたGoogleNetとSqueezeNetで今回試してみました。
CAMを適用する点での違いは
演算で用いる場所のActivation Mapのサイズが異なる点です。
GoogleNet → 7x7
SqueezeNet → 14x14
ですので、SqueezeNetのほうが高分解能なMapを取り出せます。
CAMの計算
SqueezeNetの場合を例に示すと、演算に使った層は'relu_conv10'と'fc'です。
'fc'層から判定されたクラスのWeight(長さ1000のベクトル)を取り出し、'relu_conv10'層でのActivationの出力(imageActivations(14x14x1000))の各位置の出力ベクトルとWeightを要素ごとの掛け算後すべてを足し合わせることでサイズ14x14のClass Activation Mapができます。
% 全結合層から判定されたカテゴリのWeightを取り出す。
weightVector = myNet.Layers(67).Weights(classIndex,:);
% Classification Activation Mapを計算する
weightVectorSize = size(weightVector);
weightVector = reshape(weightVector,[1 weightVectorSize]);
dotProduct = bsxfun(@times,imageActivations,weightVector);
classActivationMap = sum(dotProduct,3);
その出力結果を
値大 → 赤
値小 → 青
のヒートマップとして表示しています。
■ポイント
・Class Activation Mappingはどこの特徴量が判定に寄与しているかを可視化できる
・分類する全結合層の重みとその前の層のActivation Mapとの演算で求めるため、全結合層が複数層続く構造のネットワークには向いていない。
・解像度は演算で用いる層のActivation Mapのサイズに依存
■参照
(※1) http://cnnlocalization.csail.mit.edu/Zhou_Learning_Deep_Features_CVPR_2016_paper.pdf