4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

分類CNNをいじってニセCAM、ニセ検出器にする

Last updated at Posted at 2020-02-28

本記事の要約

分類用のCNNの全結合層を畳み込み層に置き換えるだけで、画像中のどこに何が映っているかがわかるようになり、ニセのCAMや、複数物体のクラス分類などが可能になる。
figure_1.png

動機

この記事は前回記事の続きです。
前回記事:MATLABで入力画像サイズを合わせずに分類CNNを使う
モチベーションとしては、分類CNNは全結合層があるため、入力画像サイズが固定となっているのを、使いやすくできないかというものでした。前回の記事では、これを(MATLABにおいて)回避する一つの小技を紹介しました。今回は別のやり方を紹介しようと思った次第です。
近年のネットワークの特徴として、YOLOv3など入力画像サイズに依存しないものが多く提案されており、とても使い勝手がいいです。こうしたネットワークの特徴は全結合層を含まずに、畳み込み層を中心に作られていることです。
そこで、既存の分類用CNNの全結合層を畳み込み層に置き換えることで、入力画像サイズに依存しないネットワークに改変してみます。これによって分類結果のみならず、その画像内での分布も可視化できますし、複数のクラスが映っている場合にも適切な結果を得ることができるようになります。

準備

前回と同じくGooglenetを使っていきます。

net = googlenet;
classes = net.Layers(end).Classes; % ラベル
img = imread('peppers.png');
imshow(img)

figure_0.png

全結合層を畳み込み化して分類「画像」を作る

ネットワークの中身を少し中身をのぞいてみます。

net.Layers(end-5:end)
ans = 
  次の層をもつ 6x1 の Layer 配列:

     1   'inception_5b-output'   深さ連結                   4 入力の深さ連結
     2   'pool5-7x7_s1'          Global Average Pooling   Global average pooling
     3   'pool5-drop_7x7_s1'     ドロップアウト              40% ドロップアウト
     4   'loss3-classifier'      全結合                    1000 全結合層
     5   'prob'                  ソフトマックス              ソフトマックス
     6   'output'                分類出力                   'tench' および 999 個のその他のクラスの crossentropyex

特徴マップから全結合の流れを畳み込みで置き換えるため、GoogleNetでは、'inception_5b-output'レイヤから全結合層の重みをもった畳み込み層に直接つないでみます。dropoutは、推論では不要なため無視しましょう。結果をsoftmaxに通して出てくる確率値(と言っていいのかわかりませんが、本記事ではこれでゴリ押します)を最終出力とします。

lastFeatureMapLayerName = 'inception_5b-output';  % 全結合手前の特徴マップレイヤ
fcLayerName = 'loss3-classifier'; % 全結合レイヤ

% DAGネットワーク情報の記録
layers = net.Layers;
layersName = {net.Layers.Name};
connections = net.Connections;

% layersから特徴マップレイヤと全結合レイヤのインデックスを取得、情報を抜き出す。
lastFeatureMapLayerInd = find(strcmp(string(layersName.'), lastFeatureMapLayerName));
lastFeatureMapLayer = layers(lastFeatureMapLayerInd);
fcLayerInd = find(strcmp(string(layersName.'), fcLayerName));
fcLayer = layers(fcLayerInd);
fcWeights = fcLayer.Weights;
fcBias = fcLayer.Bias;
% 特徴マップまでを残して学習用のネットワーク(layerGraph)を作成
connections(find(strcmp(string(connections.Source), lastFeatureMapLayerName)):end,:) = [];
lgraph = createLgraphUsingConnections(layers(1:lastFeatureMapLayerInd), connections);
% 全結合層以下を畳み込み層+softmaxレイヤに置き換える
kernelSize = [3 3];
new_layers = [ ...
    convolution2dLayer(kernelSize, numel(fcBias),"Name","conv_fc", ...
    "BiasLearnRateFactor",0,"Padding", [0 0], ...
    "Stride", [1 1], "Bias", reshape(fcBias, [1 1 numel(fcBias)]), ...
    "Weights",1./prod(kernelSize)*repmat(reshape(fcWeights.', [1 1 size(fcWeights,2,1)]), [kernelSize 1 1]));
    softmaxLayer("Name", "prob");]

付け加えるnew_layersの中身は下記のようになります。

new_layers = 
  次の層をもつ 2x1 の Layer 配列:

     1   'conv_fc'   畳み込み        ストライド [1  1] およびパディング [0  0  0  0] の 1000 3x3x1024 畳み込み
     2   'prob'      ソフトマックス   ソフトマックス

次元の順番がちょっと紛らわしいので注意してください。
全結合層のweight(fcWeights)と、作成した畳み込み層のweightの次元サイズを確認しておきます。

size(fcWeights), size(new_layers(1).Weights)
ans =
        1000        1024
ans =
           3           3        1024        1000

全結合層の重みは(変換後のチャンネル数)x(変換前のチャンネル数)になっているのに対して、
畳み込み層の重みは(行方向カーネルサイズ)x(列方向カーネルサイズ)x(変換前のチャンネル数)x(変換後のチャンネル数)
です。変換前後の順番が違うので、畳み込み層を作るときには、fcWeightsに転置(.')をした次第です。これ忘れるとエラーは出ませんが、結果がズタボロになるので要注意。

また、kernelSizeは、要はglobal average poolingの代わりの役目を果たすので、とりあえず3x3にしています。1x1でも5x5でも大丈夫。この辺は適当に決めています。

これを特徴マップまで残したネットワークの下にくっつけます。

lgraph = addLayers(lgraph, new_layers);
lgraph = connectLayers(lgraph, lastFeatureMapLayerName, "conv_fc");
% analyzeNetwork(lgraph) % ネットワークの次元数の確認のため必要があればコメント外してください

MATLABでは分類のためのレイヤや回帰のためのレイヤなど明確に「ここで終わり」と言えるレイヤ以外のレイヤで終わる場合、学習済みネットワークとして定義されません。なのでそれを入力として必要とするactivationsも使えません。中途半端なネットワークでもいろいろ操作できるdlnetworkというオブジェクトにして処理を進めます。

dlnet = dlnetwork(lgraph);

出来たネットワークを使っての分類画像の作成

クラスごとの確率値が低分解能画像として出てきます。つまり1,000チャンネルの画像ですね。この中から最大の確率値となっているチャンネルの値と、インデックスを取り出してきます。

Probs = predict(dlnet, dlarray(single(img), 'SSC'));
[maxProb, predClassMap] = max(Probs, [], 3);
% dlarray型からsingleに戻す
maxProb = extractdata(maxProb);
predClassMap = extractdata(predClassMap);
% 確度が高いクラスを抽出
th = 0.85; % しきい値
predClasses = unique(predClassMap(maxProb > th));

予測したクラスは下の通り。

ans = 
  3×1  categorical 配列
     butternut squash 
     cucumber 
     bell pepper 

CAMっぽく可視化

せっかく画像として確率やラベル情報があるのでCAMっぽく可視化してみます。

gridN = ceil(10^(log10(numel(predClasses))/2)); % グリッド上に配置するため、行列数を計算しているだけ
figure;
for classi = 1:numel(predClasses) % ラベルごとにエセCAM表示
    classmap = maxProb;
    classmap(predClassMap ~= predClasses(classi)) = 0;
    subplot(gridN, gridN, classi);
    imshow(img); hold on;
    heatMap = imresize(classmap, size(img,1,2), "Method", "bicubic");
    imagesc(heatMap, 'AlphaData', 0.5); % 確率値をヒートマップとして元画像に重ねる
    colormap(hot);
    hold off;
    title(classes(predClasses(classi)));    
end

figure_1.png
冒頭の画像です。思ったよりはるかにそれっぽくできましたね。逆伝搬必要ないので、とってもお手軽なCAMです。
butternut squashcucumberと誤分類されていますが、それぞれまぁ許容範囲ではないでしょうか。

検出タスクっぽく可視化

ちょっと遊んでみました。もちろん本物の検出器には及ぶべくもないので、だいた~いな目で見てください。

figure;
imgBox = img;
annotationColors = cool(numel(predClasses))*255; % クラスごとに異なる色の線とする。
for classi = 1:numel(predClasses)
    binaryClassMap = imresize(maxProb>0.5 & predClassMap == predClasses(classi), ...
        size(img,1,2), "Method", "nearest");
    stats = regionprops(binaryClassMap, 'BoundingBox');
    bbox = cat(1,stats.BoundingBox);
    
    imgBox = insertObjectAnnotation(imgBox, "rectangle", bbox, classes(predClasses(classi)), "Color", annotationColors(classi,:));
    imshow(imgBox);
end

figure_2.png

これをwebcamに通して、リアルタイムもできました。今回は割愛します。

終わり

ちょっと遊んだ結果をシェアしてみました。それぞれあくまで簡易的なものなのですが、いちいち検出器を用意しなくても1,000クラスの検出ができるってのもお手軽でいいです。精度は高くはありませんが…。何かコメントがあると嬉しいです。

謝辞

前回記事と今回記事ですが、eigsさんlivescript2markdown使わせていただいてます。本当に楽。神様!ありがとうございます!

4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?