LoginSignup
22
14

More than 1 year has passed since last update.

事前学習済みモデルを使った学習ゼロの異常検知【MATLAB実装】

Posted at

はじめに

この記事では、少しタイムリーではないですが、異常検知界隈で盛り上がっていた下記論文を紹介し、MATLABによる実装について書いていきたいと思います。

こちらの論文に書かれている手法に通り名がないっぽいので、この記事の中では簡単のためMahalanobisADと呼ぶことにしました。手法でやっていることからするとあいまいな呼び名になりますが、著者の公式GitHubリポジトリからいただきました。

Modeling the Distribution of Normal Data in Pre-Trained Deep Features for Anomaly Detection

また、自分で実験する過程で下記論文にも自動的に巡り合いました。発表された日付(1か月違い)からするとPaDiMがMahalanobisADの後のようですね。ですがまぁほぼ一緒だと思います。個人的にはPaDiMの方が好きです。可視化ができるし細かく制御できる。

PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization

そもそもなぜこの2論文なのか、ですが、製造業向けのセミナーのために新しい手法を紹介できれば、と思ってWebで探していたら見つけた、という次第です。

ちなみにこれらの手法の後継としてPatchCoreというものもあるようですね。高速化のための工夫が入っているようですが、いろいろやりだすといつまでたってもアップできないので、今日は上の2つです。

実装環境

  • MATLAB R2021a
  • Statistics and Machine Learning Toolbox
  • Deep Learning Toolbox
  • Image Processing Toolbox

手法の特徴

どんな手法なのかざっと理解できるようまずはこの2つの手法の特徴を列挙していきます。

  • ニューラルネットワーク自体の学習はゼロ!みなさんもすぐに使える
  • 正常な画像だけあればOK
  • それぞれ発表当時はMVTecADデータセットでSOTAを達成。

ということでめちゃくちゃ簡単なのに使える、というやらなきゃソンソンな方法です。

中身の説明をしようかと思ったのですが、他の良い記事が有り余るくらい有りますので、私は実装をご覧いただきながら最低限説明していくスタンスとします。

アルゴリズムの流れ

実装に入る前に全体の流れを確認します。下記のとおりです。

  1. 正常データを準備する(ついでにテストデータも)
  2. 事前学習済みネットワーク(今回はMahalanobisADに倣ってEfficientNet-b0)を用意する
  3. 1を2に通し、中間層からの出力を抽出
  4. 3から共分散行列を計算
  5. テストデータを3と同様に中間層に通し出力を抽出
  6. 5と6からマハラノビス距離を計算
  7. 7が一定以上なら異常

実装

それではMATLABによる実装をしていきます。

正常データを準備する(ついでにテストデータも)

正常な画像を持ってきてデータストアを作成します。テスト用の正常画像、そしてテスト用の異常画像のデータストアもここで作ってしまいましょう。

データストアは、こちらのものを使っています。皆さんもダウンロード可能です。

学習用画像のデータストア

Code
imds_train = imageDatastore('data\trainingimage\',"LabelSource","foldernames");

検証用画像のデータストア

Code
imds_val = imageDatastore('data\testimage\',"LabelSource","foldernames","IncludeSubfolders",true);

可視化して確認

Code
img1 = read(imds_train);
img2 = read(imds_val);
imshow([img1 img2]);
title(['(左)学習用の正常画像   (中央)検証用の画像'])

figure_0_png.jpg

データ拡張

画像に余白部分が多いので、ナットを画像のセンターに持ってきて、ネットワークの入力画像サイズにリサイズしてしまいます。この処理は本質的には無くても良いのですが、MahalanobisADの方は、中間層からの出力に対して空間次元で平均を取る使用になっています。つまり背景部分も異常度の計算に含まれてしまうことになるので、その影響ができるだけ少なくなるような工夫は有った方がいいとおもいます。MahalanobisADは、といいましたが、PaDiMの方もパッチごとの共分散行列計算時に、できるだけ少ない枚数で正確な分布を得るためには画像上でのオブジェクトの位置がバラけないようにする方がいいと思います。というようなことが論文中にも書いてありましたしね。

Code
auds_train = transform(imds_train,@augmentationFcn); % 学習用
auds_val = transform(imds_val,@preprocessFcn); % 検証用

augmentationFcnは末尾示しますが、画処理で対象領域をざっくり求めてモルフォロジー処理で綺麗にした後その領域を切り出してきてリサイズしてランダム回転・XYランダム反転しています。

一方でpreprocessFcnの方は回転させる意味が無いので、切り出しただけで終了です。

確認。

Code
img1 = read(auds_train);
img2 = read(auds_val);
imshow([img1 img2]);
title(['(左)学習用の正常画像   (中央)検証用の画像'])

figure_1_png.jpg

事前学習済みネットワークを用意する

今回はMahalanobisADに倣ってEfficientNet-b0を使っていきます。MATLABではEfficientNet系はb0しかpre-builtの関数はないのが痛いところです。他の番号を使いたい方は、インポートして使って下さい。

idxで指定した中間出力層の番号は各ステージの最後の層にしています、たしか。間違っていたらすみません。以下のスクリプトでは7番目からの出力を使っています。このあたりはそれぞれ確認しましたが、ステージが下ってくるにつれて、細かな違いではなくなんというか大筋の違いを捉えるようになってくる様は圧巻でした。ここでは省略。

Code
net = efficientnetb0;
% 中間層インデックス(efficientnetの各レベルの最終層を中間出力とする
idx = [2 18 53 88 141 194 265 282 286];
testIdx = idx(7);
layerName = net.Layers(testIdx).Name;

中間層からの出力を抽出

いよいよ正常画像を使って特徴を抽出していきます。ダウンロードした画像は100枚しかなく、分布を得るには不安だったので、、回転、反転を加えることで、1000サンプルにかさ増ししています。
下記の実装例はより手間がかかるPaDiMの方です。パッチごとに特徴ベクトルを取り出して保存しています。

Code
% 繰り返し数を計算
sampleNum = length(imds_train.Labels);
leastSampleNum = 1000;
repNum = ceil(leastSampleNum / sampleNum);
for repi = 1:repNum
    fMap = activations(net, auds_train, layerName);
    if repi == 1 % 一回目のループだけ、全ループ分の配列を確保
        tr_normal_fMap = fMap;
        tr_normal_fMap = repmat(tr_normal_fMap,[1,1,1,repNum]);
    else
        tr_normal_fMap(:,:,:,(repi-1)*sampleNum+1:repi*sampleNum) = fMap;
    end
    reset(auds_train);
end
size(tr_normal_fMap)
Output
ans = 1x4    
           7           7         192        1000

共分散行列の計算

テスト時、正常データの分布を繰り返し計算しないで済むように、正常データの分布の共分散行列事前を計算しておきます。計算はパッチごとです。

逆行列計算が不安定になる場合にも備えて縮退した共分散行列としておきます。rhoは適当に小さい値にしています。本当は、データから決めた方がいいみたいです。そのあたりは論文にも記載がありますが、ちょっと逃げました。次の処理の準備のためにパッチごとに特徴ベクトルのバッチ平均(パッチとバッチが紛らわしい)もついでに計算しています。

Code
cov_normal = cell(size(tr_normal_fMap,1:2)); % サイズは既知なのでぴったりの配列を作っておいてもいい
mean_normal = cell(size(tr_normal_fMap,1:2)); % サイズは既知なのでぴったりの配列を作っておいてもいい
rho = 1e-4;
[h, w] = size(tr_normal_fMap,1:2);
for idx_x = 1:w
    for idx_y = 1:h
        cov_normal_raw = cov(squeeze(tr_normal_fMap(idx_y,idx_x,:,:))');
        cov_normal_shrink = (1-rho)*cov_normal_raw + rho*trace(cov_normal_raw)/size(cov_normal_raw,1)*eye(size(cov_normal_raw));
        cov_normal{idx_y,idx_x} = cov_normal_shrink;

        mean_normal{idx_y,idx_x} = mean(squeeze(tr_normal_fMap(idx_y,idx_x,:,:))');
    end
end

テスト画像から特徴ベクトルを抽出

事前準備は終わったことになります。ではテスト画像に対して異常度を計算していきます。
まず、分布計算でしたのと同様にテスト画像の特徴量を抽出していきましょう。

Code
val_fMap = activations(net, auds_val, layerName);

正常データ

Code
val_normal_fMap = val_fMap(:,:,:,imds_val.Labels=='normal');

異常データ

Code
val_abnormal_fMap = val_fMap(:,:,:,imds_val.Labels=='anomaly');

マハラノビス距離による異常度計算

ここは注意が必要です。普通に[MATLAB マハラノビス距離]でググるとmahalという関数がヒットしますが、この関数では、事前計算した共分散行列を受け付けてくれません。いちいち元の特徴量を全て入れる必要があるため、実用しようと思うと計算効率が高くありません。まぁそれでもDNNがボトルネックなので大した影響はないのですが。代わりにpdist2という関数があります。こちらは、2つのベクトルの距離を計算しますが、プロパティで距離のタイプを選ぶことができてその中にマハラノビス距離があります。このプロパティを使うと、事前計算した共分散行列を受け付けてくれます。

Code
d_normal = zeros(h,w,size(val_normal_fMap,4));
d_abnormal = zeros(h,w,size(val_abnormal_fMap,4));
% パッチごとにマハラノビス距離を計算していく
for idx_x = 1:w
    for idx_y = 1:h
        temp_tr_normal_fMap = squeeze(tr_normal_fMap(idx_y,idx_x,:,:))'; % 学習用
        % 正常テスト画像
        temp_val_normal_fMap = squeeze(val_normal_fMap(idx_y,idx_x,:,:))'; % テスト正常
        d_normal(idx_y,idx_x,:) = pdist2(mean_normal{idx_y,idx_x},temp_val_normal_fMap,'mahalanobis',cov_normal{idx_y,idx_x}).^2;
        % 異常テスト画像
        temp_val_abnormal_fMap = squeeze(val_abnormal_fMap(idx_y,idx_x,:,:))'; % テスト異常
        d_abnormal(idx_y,idx_x,:) = pdist2(mean_normal{idx_y,idx_x},temp_val_abnormal_fMap,'mahalanobis',cov_normal{idx_y,idx_x}).^2;
    end
end

結果の可視化

では結果を可視化していきましょう。ダントツで異常画像の方が少ないので正常テスト画像の方は異常画像と同じだけランダム抽出します。

Code
val_images = readall(auds_val);
val_images = mat2cell(val_images,ones(1,size(val_images,1)/224)*224,224,3);

val_normal_images = val_images(imds_val.Labels == 'normal');
val_abnormal_images = val_images(imds_val.Labels == 'anomaly');

% 異常度表示のスケールを決める
maxth = stretchlim(rescale(d_abnormal(:)))*max(d_abnormal(:));

% 異常画像の異常可視化
figure;
tiledlayout(2,2,'TileSpacing',"compact","Padding",'compact');
for i = 1:length(val_abnormal_images)
    abnormal_image = val_abnormal_images{i};
    d = d_abnormal(:,:,i);

    nexttile;
    imshow(abnormal_image);
    hold on;
    imagesc(rescale(imresize(d,[224 224],'bilinear'),'InputMax',maxth(2)),'AlphaData',0.5);
end

figure_2_png.jpg

Code
% 正常画像の異常可視化
figure;
tiledlayout(2,2,'TileSpacing',"compact","Padding",'compact');
normal_idx = randperm(length(val_normal_images),length(val_abnormal_images));
for i = normal_idx
    normal_image = val_normal_images{i};
    d = d_normal(:,:,i);

    nexttile;
    imshow(normal_image);
    hold on;
    imagesc(rescale(imresize(d,[224 224],'bilinear'),'InputMax',maxth(2)),'AlphaData',0.5);
end

figure_3_png.jpg

いかがでしょう。

・・・

・・・

微妙ですね(笑)

考察

今回は異常のモードとして2通りあります。傷とテクスチャの違い(表と裏?)です。

傷の方は割とわかりやすいので、きちんとヒートマップで出ていますが、テクスチャの方は全体に微妙な違いがまばらにある感じでうっすらと反応があるかなぁという感じです。

一方で正常画像を見ても、異常度が高くなっている場所が見られますね…。

正常画像を見ても明らかなとおり、ざらざらした質感の部分はまばらに見られます。ただ、全体としては少ないので、分布としては外れ値になりそう。つまりこういう場所はマハラノビス距離で言えば値が大きくなりそうですね。なので裏のざらざらテクスチャの異常検知はできない、ということになります。

一方で傷の方は、正常の方には見られないテクスチャなので、これは反応が綺麗に出るんでしょうね。

手法の得手不得手が出てる感じがしますが、なんかアドホックに上手いこと対応できそうです。

例えば、異常度とその面積の関係も見てみる、とか?

この手法の思想は、正常分布から外れた特徴量を持つ画像はそりゃ異常だろ、ってことですが、異常のモードがある程度わかっているのであれば、外れ方も見てあげることで、誤検知は確実に減らせると思います。

が、ちょっと今日はそこまではしんどいのでここまでとします。

ちなみに今回はランダム回転かけてオーグメンテーションしていましたが、正常データの分布のばらつきを抑えるためには、姿勢が揃えられるのであれば揃えてしまった方がいいですね。

定量的には・・・

具体的な異常度の計算ですが、PaDiMの場合は、パッチ間で最大の異常度を出してきて、あとはしきい値で二値化して終わりです。

気がつけばPaDiMしかやっていない…

MaharanobisADの方は、そもそもパッチごとの異常度計算ではなく、抽出してきた中間出力マップに対してmean poolingをかけることで空間情報をスカラ化してしまいます。そのあとで共分散行列を計算します。テスト画像に対しても同じように特徴ベクトルを求めて、それに対してマハラノビス距離を計算します。個人的には、異常は局所で発生しているケースが多いと思うので、mean poolingでその情報をなましてしまうのはいかがなものかと思うのですが、max poolingだとだめなんですかね。誰か試してみてください。このデータセット自体が余りこの手法に適していないことが分かったので、私は試しませんでした。

オーグメンテーションと前処理

興味ある方はこちら

学習用画像のオーグメンテーション

Code
function pImg = augmentationFcn(img)
    tform1 = randomAffine2d('Rotation',[-180 180],...
                            'XReflection',true, 'YReflection',true);
    J = imwarp(img,tform1,'nearest');
    bw = createMask(J);
    bw2 = imclose(bw,strel('disk',5));
    bw3 = bwareafilt(bw2,1);
    stats = regionprops(bw3,'BoundingBox');
    bbox = stats.BoundingBox;
    imc = imcrop(J,bbox);
    pImg = imresize(imc, [224 224]);
end

function [BW,maskedRGBImage] = createMask(RGB)
%createMask  Threshold RGB image using auto-generated code from colorThresholder app.
%  [BW,MASKEDRGBIMAGE] = createMask(RGB) thresholds image RGB using
%  auto-generated code from the colorThresholder app. The colorspace and
%  range for each channel of the colorspace were set within the app. The
%  segmentation mask is returned in BW, and a composite of the mask and
%  original RGB images is returned in maskedRGBImage.

% Auto-generated by colorThresholder app on 08-Sep-2021
%------------------------------------------------------

% Convert RGB image to chosen color space
I = rgb2hsv(RGB);

% Define thresholds for channel 1 based on histogram settings
channel1Min = 0.000;
channel1Max = 1.000;

% Define thresholds for channel 2 based on histogram settings
channel2Min = 0.000;
channel2Max = 1.000;

% Define thresholds for channel 3 based on histogram settings
channel3Min = 0.000;
channel3Max = 1.000;

% Create mask based on chosen histogram thresholds
sliderBW = (I(:,:,1) >= channel1Min ) & (I(:,:,1) <= channel1Max) & ...
    (I(:,:,2) >= channel2Min ) & (I(:,:,2) <= channel2Max) & ...
    (I(:,:,3) >= channel3Min ) & (I(:,:,3) <= channel3Max);

% Create mask based on selected regions of interest on point cloud projection
I = double(I);
[m,n,~] = size(I);
polyBW = false([m,n]);
I = reshape(I,[m*n 3]);

% Convert HSV color space to canonical coordinates
Xcoord = I(:,2).*I(:,3).*cos(2*pi*I(:,1));
Ycoord = I(:,2).*I(:,3).*sin(2*pi*I(:,1));
I(:,1) = Xcoord;
I(:,2) = Ycoord;
clear Xcoord Ycoord

% Project 3D data into 2D projected view from current camera view point within app
J = rotateColorSpace(I);

% Apply polygons drawn on point cloud in app
polyBW = applyPolygons(J,polyBW);

% Combine both masks
BW = sliderBW & polyBW;

% Initialize output masked image based on input image.
maskedRGBImage = RGB;

% Set background pixels where BW is false to zero.
maskedRGBImage(repmat(~BW,[1 1 3])) = 0;

end

function J = rotateColorSpace(I)

% Translate the data to the mean of the current image within app
shiftVec = [-0.192719 -0.133423 0.458825];
I = I - shiftVec;
I = [I ones(size(I,1),1)]';

% Apply transformation matrix
tMat = [-2.090720 -2.178754 -0.000000 0.706857;
    0.180389 -0.209076 0.920083 -0.496216;
    1.974238 -2.288199 -0.084069 8.724459;
    0.000000 0.000000 0.000000 1.000000];

J = (tMat*I)';
end

function polyBW = applyPolygons(J,polyBW)

% Define each manually generated ROI
hPoints(1).data = [0.091177 -0.298153;
    0.610040 -0.282819;
    0.818714 -0.175481;
    0.954069 -0.070700;
    0.931510 0.036638;
    0.192694 0.044305;
    0.000000 0.044305];

% Iteratively apply each ROI
for ii = 1:length(hPoints)
    if size(hPoints(ii).data,1) > 2
        in = inpolygon(J(:,1),J(:,2),hPoints(ii).data(:,1),hPoints(ii).data(:,2));
        in = reshape(in,size(polyBW));
        polyBW = polyBW | in;
    end
end

end

テスト用画像の前処理

Code
function pImg = preprocessFcn(img)
    bw = createMask(img);
    bw2 = imclose(bw,strel('disk',5));
    bw3 = bwareafilt(bw2,1);
    stats = regionprops(bw3,'BoundingBox');
    bbox = stats.BoundingBox;
    imc = imcrop(img,bbox);
    pImg = imresize(imc, [224 224]);
end

function [BW,maskedRGBImage] = createMask(RGB)
%createMask  Threshold RGB image using auto-generated code from colorThresholder app.
%  [BW,MASKEDRGBIMAGE] = createMask(RGB) thresholds image RGB using
%  auto-generated code from the colorThresholder app. The colorspace and
%  range for each channel of the colorspace were set within the app. The
%  segmentation mask is returned in BW, and a composite of the mask and
%  original RGB images is returned in maskedRGBImage.

% Auto-generated by colorThresholder app on 08-Sep-2021
%------------------------------------------------------

% Convert RGB image to chosen color space
I = rgb2hsv(RGB);

% Define thresholds for channel 1 based on histogram settings
channel1Min = 0.000;
channel1Max = 1.000;

% Define thresholds for channel 2 based on histogram settings
channel2Min = 0.000;
channel2Max = 1.000;

% Define thresholds for channel 3 based on histogram settings
channel3Min = 0.000;
channel3Max = 1.000;

% Create mask based on chosen histogram thresholds
sliderBW = (I(:,:,1) >= channel1Min ) & (I(:,:,1) <= channel1Max) & ...
    (I(:,:,2) >= channel2Min ) & (I(:,:,2) <= channel2Max) & ...
    (I(:,:,3) >= channel3Min ) & (I(:,:,3) <= channel3Max);

% Create mask based on selected regions of interest on point cloud projection
I = double(I);
[m,n,~] = size(I);
polyBW = false([m,n]);
I = reshape(I,[m*n 3]);

% Convert HSV color space to canonical coordinates
Xcoord = I(:,2).*I(:,3).*cos(2*pi*I(:,1));
Ycoord = I(:,2).*I(:,3).*sin(2*pi*I(:,1));
I(:,1) = Xcoord;
I(:,2) = Ycoord;
clear Xcoord Ycoord

% Project 3D data into 2D projected view from current camera view point within app
J = rotateColorSpace(I);

% Apply polygons drawn on point cloud in app
polyBW = applyPolygons(J,polyBW);

% Combine both masks
BW = sliderBW & polyBW;

% Initialize output masked image based on input image.
maskedRGBImage = RGB;

% Set background pixels where BW is false to zero.
maskedRGBImage(repmat(~BW,[1 1 3])) = 0;

end

function J = rotateColorSpace(I)

% Translate the data to the mean of the current image within app
shiftVec = [-0.192719 -0.133423 0.458825];
I = I - shiftVec;
I = [I ones(size(I,1),1)]';

% Apply transformation matrix
tMat = [-2.090720 -2.178754 -0.000000 0.706857;
    0.180389 -0.209076 0.920083 -0.496216;
    1.974238 -2.288199 -0.084069 8.724459;
    0.000000 0.000000 0.000000 1.000000];

J = (tMat*I)';
end

function polyBW = applyPolygons(J,polyBW)

% Define each manually generated ROI
hPoints(1).data = [0.091177 -0.298153;
    0.610040 -0.282819;
    0.818714 -0.175481;
    0.954069 -0.070700;
    0.931510 0.036638;
    0.192694 0.044305;
    0.000000 0.044305];

% Iteratively apply each ROI
for ii = 1:length(hPoints)
    if size(hPoints(ii).data,1) > 2
        in = inpolygon(J(:,1),J(:,2),hPoints(ii).data(:,1),hPoints(ii).data(:,2));
        in = reshape(in,size(polyBW));
        polyBW = polyBW | in;
    end
end

end

謝辞

本記事は @eigs さんのlivescript2markdownを使わせていただいてます。

22
14
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
14