はじめに
セマンティックセグメンテーションを行う時に普段はアノテーションを作成するのは大変な作業でしょう。もしそれが自動的に作れるのならどれくらい楽になるでしょうね。
私は「自動的に生成された画像データセットで学習して本物に適用する」ということはよくやっています。普通の分類モデルでも教師データを準備することは大変なことだから、自動生成のデータが代わりに使えたら楽ですね。
そしてその生成データはセマンティックセグメンテーションにも使えるようにすることもできます。自動的に生成したデータなので、アノテーションも当然同時に作成することができます。しかもこれは手作業より正確で完璧なアノテーションになるでしょう。
「学習データがないので自分で生成する」という話はよくあることで新しいことではないのですが、これをセマンティックセグメンテーションに使う例はあまり聞いたことないの意外でした。だから私は自分で試してみました。
この記事ではランダムで3Dモデルを作成してレンダリングしてそれを使ってセマンティックセグメンテーションのニューラルネットワークに学習させる実装の例を説明します。
今回のコードはMATLABで書いています。理由はMATLABの3Dレンダリング機能で簡単に綺麗な3D画像が作れるからです。Pythonにはmatplotlibやmayaviがあるけど、MATLABの方がよくできています。PythonはMayaやBlenderなど3D専用のソフトウェアと連携できますが、そうしない限りPythonだけでMATLABほど綺麗な3D画像を作るのは難しいでしょう。
というわけで、今回の実装は全部MATLABを使うことにしましたが、自動的に学習データとアノテーションを作成するという考え方としてPythonや他の言語も通用するでしょう。
MATLABで深層学習の実装方法及びpytorchの書き方の比較は前回の記事に書いてあります。まずその記事を読んだら今回の記事がわかりやすくなると思います。
なお、セマンティックセグメンテーションの基本については沢山記事があるので割愛します。この記事の最後に纏めておいたリンクを参考に。
セマンティックセグメンテーションがかなり人気な手法で、このqiitaにも実装する例が沢山ありますが、全部Python(pytorch, keras, tensorflow, chainer)で、MATLABによる実装の例はなくて恐らくこの記事で初めてでしょうね。だから自分で書き方をアレンジするしかないです。
3Dモデルの作成
今回使うのは自分で作成する海星のデータセットです。
まずこのコードを実行したら綺麗な海星が出てきます。
sx = 360; % 円周におけるポリゴンの分割
sy = 100; % 暑さにおけるポリゴンの分割
atsusa = 1; % 暑さ
hankei1 = 1.5; % 中の半径
hankei2 = 4; % 外の半径
% 暑さと円周のメッシュ
[ar_z,ar_theta] = meshgrid(linspace(0,atsusa-(atsusa/sy),sy), ...
linspace(0,360-(360/sx),sx));
ar_a = ((sin((1-ar_z/atsusa).^0.5*pi/2)).^2)*hankei2;
ar_r = ar_a - (hankei2-hankei1)./hankei2.*ar_a.*abs(cos(deg2rad((ar_theta)*5/2)));
ar_r = ar_r + randn(sx,sy)*0.04; % ランダムで表面を凸凹にする
% 少しねじる
ar_theta = ar_theta+15.*sin(deg2rad(ar_theta*5/2)).^2;
aaa = cumsum(randn(1,sx));
aaa = (aaa - linspace(aaa(1),aaa(end),sx))'/range(aaa)*10;
ar_theta = ar_theta+aaa;
% 直交座標系にする
ar_x = ar_r.*cos(deg2rad(ar_theta));
ar_y = ar_r.*sin(deg2rad(ar_theta));
% 頂点
ar_v1 = [ar_x(:) ar_y(:) ar_z(:)];
ar_v2 = [0 0 atsusa]; % 一番上の頂点
ar_v = [ar_v1
ar_v2];
% 面
fv1 = (1:sx) + (0:sx:sx*(sy-1)-1)';
fv2 = [2:sx 1] + (0:sx:sx*(sy-1)-1)';
fv3 = [2:sx 1] + (sx:sx:sx*sy-1)';
fv4 = (1:sx) + (sx:sx:sx*sy-1)';
ar_f1 = cat(3,fv1.',fv2.',fv3.',fv4.');
ar_f1 = reshape(ar_f1,[],4);
ar_f2 = [(sx*(sy-1)+1:sx*sy)' ...
[sx*(sy-1)+2:sx*sy sx*(sy-1)+1]' ...
repmat(sx*sy+1,sx,1) ...
NaN(sx,1)];
ar_f = [ar_f1;ar_f2];
% 海星の表面の色。HSV空間でランダムしてRGBに変換する
ar_hsv = [mod(unifrnd(0.02,0.08,[sx*sy+1 1]),1),unifrnd(0.3,1,[sx*sy+1 1]),unifrnd(0.7,1,[sx*sy+1 1])];
ar_c = hsv2rgb(ar_hsv);
close all
figure(Position=[100 100 600 600])
% 地面を入れる
[ymesh,xmesh] = meshgrid( ...
linspace(-8,8,501), ...
linspace(-8,8,501));
zmesh = -0.1.*rand(501,501);
mesh_jimen = mesh(xmesh,ymesh,zmesh,FaceColor=[0.79 0.63 0.16],LineStyle="none");
% 海星のポリゴンを入れる
patch( ...
Faces=ar_f, ...
Vertices=ar_v, ...
FaceVertexCData=ar_c, ...
FaceColor="interp", ...
LineStyle="none");
material shiny % 明るい表面にする
lighting gouraud % グーローシェーディングを使う
set(gcf,color="k");
set(gca,Position=[0 0 1 1])
light(Position=[2 5 10],color=[0.9,0.9,0.9]); % 照明を入れる
camproj("perspective") % 透視図モード
view(0,50) % 眺める角度
axis equal off % 各軸の単位を揃える
今回はこのような海星を検出するためのセマンティックセグメンテーションです。
3Dモデルを作ることは今回の記事の主題ではないので、説明は割愛しますが、MATLABによる3Dモデリングに関する基本は次の記事で説明しています。
今の簡単な海星の3Dモデルで本物の海星を代表できるとは言い難いのですが、もっと色々工夫したらできるでしょう。でもコードは複雑になるので、今回はただ簡単な例を作るためだからこれくらいにしましょう。
ランダムで作った3Dモデルから学習用の画像を量産する
以上のコートで一枚の海星の画像ができますが、もし色な大きさや地面など色々ランダムしたら、沢山違う画像ができるでしょう。
今回は学習に使うためにランダムで1000枚を作ります。同時に海星がない地面だけの空っぽ画像も作って同じように学習に使います。そして海星のピクセルを示すアノテーションも作成します。だから生成されるのは全部3000×3枚となります。
海星なしの画像も学習に使うのは、海星がない画像にも適用できるようにするためです。海星あり画像ばっかり学習すると全ての画像はどこかで海星があるという基準になる恐れがあるから。
それに海星を入れるのと入れない同じ画像のペアで学習すると海星の特徴を把握しやすいと思うから。このようなことができるのは人工的に作るデータだからこそですね。
生成コードは以下です。
sx = 360;
sy = 100;
n_img = 2000; % 生成する海星の画像の枚数
px = 256; % 保存する画像の解像度
fol_nashi = "hitodenashi"; % 海星なし画像を保存するフォルダ
fol_ari = "hitodeari"; % 海星あり画像を保存するフォルダ
fol_anno = "hitodeanno"; % アノテーション画像を保存するフォルダ
% 保存するフォルダまだ存在していなければまず作成する
if(~exist(fol_nashi,"dir"))
mkdir(fol_nashi)
end
if(~exist(fol_ari,"dir"))
mkdir(fol_ari)
end
if(~exist(fol_anno,"dir"))
mkdir(fol_anno)
end
for i = 1524:n_img % ループして一枚ずつ生成する
% 毎回色々ランダムする
atsusa = unifrnd(0.8,3);
hankei1 = unifrnd(0.8,2);
hankei2 = unifrnd(3,5);
theta0 = unifrnd(0,360);
idou_x = unifrnd(-3,3);
idou_y = unifrnd(-3,3);
h = unifrnd(0.04,0.16);
xyz_shoumei = [unifrnd(-10,10) unifrnd(-10,10) 10];
[ar_z,ar_theta] = meshgrid(linspace(0,atsusa-(atsusa/sy),sy), ...
linspace(0,360-(360/sx),sx));
ar_a = ((sin((1-ar_z/atsusa).^0.5*pi/2)).^2)*hankei2;
ar_r = ar_a - (hankei2-hankei1)./hankei2.*ar_a.*abs(cos(deg2rad((ar_theta-theta0-unifrnd(0,72))*5/2)));
ar_r = ar_r + randn(sx,sy)*0.04;
ar_theta = ar_theta+unifrnd(-15,15).*sin(deg2rad((ar_theta-theta0)*5/2)).^2;
aaa = cumsum(randn(1,sx));
aaa = (aaa - linspace(aaa(1),aaa(end),sx))'/range(aaa)*10;
ar_theta = ar_theta+aaa;
ar_x = ar_r.*cos(deg2rad(ar_theta))+idou_x;
ar_y = ar_r.*sin(deg2rad(ar_theta))+idou_y;
ar_v1 = [ar_x(:) ar_y(:) ar_z(:)];
ar_v2 = [idou_x idou_y atsusa];
ar_v = [ar_v1;ar_v2];
fv1 = (1:sx) + (0:sx:sx*(sy-1)-1)';
fv2 = [2:sx 1] + (0:sx:sx*(sy-1)-1)';
fv3 = [2:sx 1] + (sx:sx:sx*sy-1)';
fv4 = (1:sx) + (sx:sx:sx*sy-1)';
ar_f1 = cat(3,fv1.',fv2.',fv3.',fv4.');
ar_f1 = reshape(ar_f1,[],4);
ar_f2 = [(sx*(sy-1)+1:sx*sy)' ...
[sx*(sy-1)+2:sx*sy sx*(sy-1)+1]' ...
repmat(sx*sy+1,sx,1) ...
NaN(sx,1)];
ar_f = [ar_f1;ar_f2];
ar_hsv = [mod(unifrnd(h-0.03,h+0.03,[sx*sy+1 1]),1),unifrnd(0.3,1,[sx*sy+1 1]),unifrnd(unifrnd(0.2,0.9),1,[sx*sy+1 1])];
ar_c = hsv2rgb(ar_hsv);
close all
figure(visible="off",Position=[100 100 600 600])
set(gcf,color="k");
set(gca,Position=[0 0 1 1])
% 地面のメッシュを入れる
kaizoudo_jimen = randi([100,600]);
[ymesh,xmesh] = meshgrid( ...
linspace(-8,8,kaizoudo_jimen+1), ...
linspace(-8,8,kaizoudo_jimen+1));
zmesh = -0.1.*rand(kaizoudo_jimen+1,kaizoudo_jimen+1);
mesh_jimen = mesh(xmesh,ymesh,zmesh,...
FaceColor=hsv2rgb([unifrnd(0.06,0.32),unifrnd(0.1,0.6),unifrnd(0.3,0.9)]),LineStyle="none");
material shiny
lighting gouraud
camproj("perspective")
light(Position=xyz_shoumei,color=[unifrnd(0.8,1),unifrnd(0.8,1),0.9]);
axis equal off
view(0,90)
set(gca,CameraPosition=[0 0 20])
set(gca,CameraTarget=[0 0 0])
set(gca,CameraViewAngle=42)
% 人手なしの画像を保存する
img1 = imresize(getframe(gcf).cdata,[px px]);
imwrite(img1,sprintf(fullfile(fol_nashi,"%04d.jpg"),i));
% 海星を入れる
patch( ...
Faces=ar_f, ...
Vertices=ar_v, ...
FaceVertexCData=ar_c, ...
FaceColor="interp", ...
LineStyle="none");
material shiny
lighting gouraud
% 海星ありの画像を保存する
img2 = imresize(getframe(gcf).cdata,[px px]);
imwrite(img2,sprintf(fullfile(fol_ari,"%04d.jpg"),i));
delete(mesh_jimen) % 地面のメッシュを消して、海星だけ残す
% アノテーションの画像を作る
img3 = all(imresize(getframe(gcf).cdata,[px px])>0,3);
imwrite(img3,sprintf(fullfile(fol_anno,"%04d.png"),i),BitDepth=1);
end
これを実行したら生成が始まります。生成速度はかなり速いです。パソコンの性能にもよりますが、私のパソコンでは1000個で10分くらいかかります。
ここで海星なしともアノテーションとも並べて一部表示してみます。
px = 256;
n = 6;
imds_nashi = imageDatastore("hitodenashi");
imds_ari = imageDatastore("hitodeari");
imds_anno = imageDatastore("hitodeanno");
arimg = zeros(px,px*3+2,3,n,"uint8")+255;
for i = 1:n
arimg(:,1:px,:,i) = imds_nashi.read;
arimg(:,2+px:1+px*2,:,i) = imds_ari.read;
arimg(:,3+px*2:2+px*3,:,i) = repmat(imds_anno.read*255,1,1,3,1);
end
imwrite(imtile(arimg,GridSize=[n 1],BorderSize=1,BackgroundColor="w"),"hitode3x6.jpg");
このようにアノテーションもパッケージで作られました。ちゃんと海星の輪郭になっています。これで学習用のデータが整いました。
使うニューラルネットワークのモデル
セマンティックセグメンテーションに使うニューラルネットワークのモデルは色々あって、今一番評判がいいのはDeepLab v3+でしょう。MATLABでもDeepLab v3+が準備されておいて簡単に使えるのですが、今回は小さい画像でも軽く実装できるU-netを使います。MATLABではU-netが準備してあるので簡単に使えます。
U-netは私が以前から使っていて馴染んでいるモデルです。記事にも書きました。
以前使ったのはノイズ除去のためですが、本来U-netはセマンティックセグメンテーションのために作られたモデルですね。
実際にノイズ除去に使われたU-netはセマンティックセグメンテーションに使われるモデルとは少し違うが、構造は殆ど同じなのでどっちもU-netと呼びますね。
違いは例えば最後の層はセマンティックセグメンテーションの場合はソフトマックスですが、ノイズ除去ではLReLUに入れ替えられます。
adobe illustratorのextendscriptで今回使うU-netモデルの構造を描いてここに載せておきます。
ここでmは画像のサイズで、nは入力チャネル数です。今回は色画像なのでn=3となりますが、n=1のグレースケールも使えます。convは畳み込み層でconvTは逆畳み込み層であり、後ろにある数字は出力チャネル数×カーネルサイズ、s=ストライド、p=パディング。最後の層でのkは分割する種類の数です。今回は海星と地面でk=2だけとなります。
基本的にmaxpool層で半分ずつ小さくして、その後逆畳み込み層(convT)で2倍ずつ大きくする、というオートエンコーダーの構造ですが、スキップ接続があるのは特徴ですね。特徴マップのサイズは一番小さないところ(最後のmaxpoolの後)では入力サイズの1/16まで縮められます。
セマンティックセグメンテーションの実装
データが準備できて、使うニューラルネットワークのモデルも決まったら、次はこれを使ってセマンティックセグメンテーションの学習をします。
画像は256ピクセルで生成されたのですが、このサイズの画像を扱うとなるとGPUがないと無理です。今回はCPUでも気軽に動ける程度の例をしたいので、64ピクセルに縮小した画像を使います。
セマンティックセグメンテーションでは評価の基準は正確度よりもダイス係数(F1値)やIoUやTversky損失が使われることが多いのです。今回は一番わかりやすいダイス係数を使います。再現率(recall)と適合率(precision)も一緒に学習曲線グラフを描きます。10エポック経ってもダイス係数が上がらなければ学習はそれで終了となって、ダイス係数が一番高い時のモデルを採用します。
実装のコードは以下です。
px = 64; % 入力画像のピクセル
n_epoch = 200; % 最大のエポック数
mouii = 10; % 正確度が何エポック上がらなければ学習が止まる
batch_size = 32; % 各ミニバッチの数
test_size = 0.2; % テストに使うデータの割合
fol_nashi = "hitodenashi"; % 海星なしデータが保存されてあるフォルダ
fol_ari = "hitodeari"; % 海星ありデータが保存されてあるフォルダ
fol_anno = "hitodeanno"; % アノテーションデータが保存されてあるフォルダ
% モデルはU-netを使う
dlnet = unet([px px 3],2);
% 学習ループを進める関数
function [dlnet,param_ag,param_asg] = fcn_update(dlnet,X,T,param_ag,param_asg,i_iter)
if(canUseGPU) % GPUが使える場合
X = gpuArray(X);
T = gpuArray(T);
end
[Y,state] = dlnet.forward(X);
loss = crossentropy(Y,T);
param_grad = dlgradient(loss,dlnet.Learnables);
dlnet.State = state;
[dlnet,param_ag,param_asg] = adamupdate(dlnet,param_grad,param_ag,param_asg,i_iter);
end
% 海星ありと海星なしのデータを連結して準備する関数
function [X,T] = fcn_xt(ar_nashi,ar_ari,ar_anno)
X = dlarray(cat(4,ar_nashi,ar_ari),"SSCB");
T = dlarray( ...
cat(4, ...
cat(3,zeros(size(ar_anno)),ones(size(ar_anno))), ...
cat(3,ar_anno,1-ar_anno)), ...
"SSCB");
end
imds_nashi = imageDatastore(fol_nashi, ...
ReadFcn=@(fn)(single(imresize(imread(fn),[px px])))/255);
imds_ari = imageDatastore(fol_ari, ...
ReadFcn=@(fn)(single(imresize(imread(fn),[px px])))/255);
imds_anno = imageDatastore(fol_anno, ...
ReadFcn=@(fn)single(imresize(imread(fn),[px px])>0));
% 全部の画像データを読み込んでおく
ar_nashi = cat(4,imds_nashi.readall{:});
ar_ari = cat(4,imds_ari.readall{:});
ar_anno = cat(4,imds_anno.readall{:});
n_data = numel(imds_ari.Files); % 全部のデータの数
n_kenshou = round(n_data*test_size); % 検証データの数
n_kunren = n_data-n_kenshou; % 訓練データの数
n_batch = ceil(n_kunren/batch_size); % 学習の時にミニバッチに分ける回数
% ランダムで訓練と検証データに分ける
idx_rand = randperm(n_data);
idx_kunren = idx_rand(1:n_kunren);
idx_kenshou = idx_rand(n_kunren+1:end);
param_ag = [];
param_asg = [];
lis_dice_kunren = []; % 各エポックの訓練データのダイス係数を収める配列
lis_dice_kenshou = []; % 各エポックの検証データの分割間違い率を収める配列
lis_saigen_kunren = []; % 各エポックの訓練データの再現率を収める配列
lis_saigen_kenshou = []; % 各エポックの検証データの分再現率を収める配列
lis_tekigou_kunren = []; % 各エポックの訓練データの適合率を収める配列
lis_tekigou_kenshou = []; % 各エポックの検証データの分適合率を収める配列
lis_sonshitsu_kunren = []; % 各エポックの訓練データの損失を収める配列
lis_sonshitsu_kenshou = []; % 各エポックの検証データの損失を収める配列
mouiika = 0; % 上がらない回数を数える
i_iter = 1; % 繰り返しの回数を数える
figure(Position=[100 100 400 600]); % 学習曲線のグラフを描く準備
fprintf("訓練データ数:%d\n検証データ数:%d\n学習開始\n",n_kunren,n_kenshou)
t0 = datetime("now");
% 学習ループ開始
for i_epoch = 1:n_epoch
idx_perm = randperm(n_kunren);
for i_batch = 1:n_batch
idx_batch = idx_kunren(idx_perm((i_batch-1)*batch_size+1:min(i_batch*batch_size,n_kunren)));
[X,T] = fcn_xt(...
ar_nashi(:,:,:,idx_batch), ...
ar_ari(:,:,:,idx_batch), ...
ar_anno(:,:,:,idx_batch));
% パラメータの更新
[dlnet,param_ag,param_asg] = dlfeval(@fcn_update,dlnet,X,T,param_ag,param_asg,i_iter);
i_iter = i_iter+1;
end
% 各エポックの学習が終わった後、学習データと検証データで検証を行う
[X,T] = fcn_xt(ar_nashi(:,:,:,idx_kunren(1:n_kenshou)),ar_ari(:,:,:,idx_kunren(1:n_kenshou)),ar_anno(:,:,:,idx_kunren(1:n_kenshou)));
Y = minibatchpredict(dlnet,X,MiniBatchSize=batch_size);
lis_sonshitsu_kunren = cat(1,lis_sonshitsu_kunren,crossentropy(Y,T));
seikaika = (Y(:,:,1,:)>0.5)==T(:,:,1,:);
saigen = mean(seikaika(T(:,:,1,:)==1),"all");
tekigou = mean(seikaika(Y(:,:,1,:)>0.5),"all");
lis_dice_kunren = cat(1,lis_dice_kunren,2*(saigen*tekigou)/(saigen+tekigou));
lis_saigen_kunren = cat(1,lis_saigen_kunren,saigen*100);
lis_tekigou_kunren = cat(1,lis_tekigou_kunren,tekigou*100);
[X,T] = fcn_xt(ar_nashi(:,:,:,idx_kenshou),ar_ari(:,:,:,idx_kenshou),ar_anno(:,:,:,idx_kenshou));
Y = minibatchpredict(dlnet,X,MiniBatchSize=batch_size);
lis_sonshitsu_kenshou = cat(1,lis_sonshitsu_kenshou,crossentropy(Y,T));
seikaika = (Y(:,:,1,:)>0.5)==T(:,:,1,:);
saigen = mean(seikaika(T(:,:,1,:)==1),"all");
tekigou = mean(seikaika(Y(:,:,1,:)>0.5),"all");
lis_dice_kenshou = cat(1,lis_dice_kenshou,2*(saigen*tekigou)/(saigen+tekigou));
lis_saigen_kenshou = cat(1,lis_saigen_kenshou,saigen*100);
lis_tekigou_kenshou = cat(1,lis_tekigou_kenshou,tekigou*100);
% 学習曲線を描く
tiledlayout(4,1,Padding="none",TileSpacing="tight");
nexttile
hold on
[max_kunren,idxmax_kunren] = max(lis_dice_kunren);
[max_kenshou,idxmax_kenshou] = max(lis_dice_kenshou);
p1 = plot(lis_dice_kunren,"r",LineWidth=2,DisplayName=sprintf("訓練データ(最大 %.3g)",max_kunren));
p2 = plot(lis_dice_kenshou,"g",LineWidth=2,DisplayName=sprintf("検証データ(最大 %.3g)",max_kenshou));
plot(idxmax_kunren,max_kunren," or",MarkerFaceColor="r")
plot(idxmax_kenshou,max_kenshou," og",MarkerFaceColor="g")
legend([p1,p2],Location="best",FontSize=12)
xlim([0.9 numel(lis_dice_kenshou)+0.1])
xticklabels([])
ylabel("ダイス係数",FontSize=12)
grid
nexttile
hold on
plot(clip(lis_saigen_kunren,1e-6,100),"r",LineWidth=2)
plot(clip(lis_saigen_kenshou,1e-6,100),"g",LineWidth=2)
xlim([0.9 numel(lis_saigen_kenshou)+0.1])
xticklabels([])
ylabel("再現率(%)",FontSize=12)
grid
nexttile
hold on
plot(clip(lis_tekigou_kunren,1e-6,100),"r",LineWidth=2)
plot(clip(lis_tekigou_kenshou,1e-6,100),"g",LineWidth=2)
xlim([0.9 numel(lis_tekigou_kenshou)+0.1])
xticklabels([])
ylabel("適合率(%)",FontSize=12)
grid
nexttile
hold on
plot(lis_sonshitsu_kunren,"r",LineWidth=2)
plot(lis_sonshitsu_kenshou,"g",LineWidth=2)
set(gca,YScale="log")
xlim([0.9 numel(lis_sonshitsu_kunren)+0.1])
xlabel("エポック",FontSize=12)
ylabel("交差エントロピー",FontSize=12)
grid
saveas(gca,"lc_hitode.png")
fprintf("%.2f分経って、%dエポックが終わった。最大ダイス係数は%.3f\n",minutes(datetime("now")-t0),i_epoch,max_kenshou);
% 学習の各段階での予測結果の例も描いておく
X = ar_ari(:,:,:,idx_kenshou(1:36));
Y = minibatchpredict(dlnet,dlarray(X,"SSCB"),MiniBatchSize=batch_size).extractdata;
imwrite(imtile([X;repmat(Y(:,:,1,:),1,1,3,1)]),sprintf("imtile_hitode%02d.jpg",i_epoch))
% 検証データのダイス係数が最大より更に上がったかどうか
if(lis_dice_kenshou(end) <= max(lis_dice_kenshou(1:end-1)))
% 上がらなかった場合
mouiika = mouiika+1;
if(mouiika >= mouii)
break % 数回も上がっていない場合すぐ学習を終了させる
end
else
% 上がったらもう一度最初から数える
mouiika = 0;
% 今後使うために、一番いい結果を出す時の学習済みモデルを保存しておく
save("hitodenet.mat","dlnet")
end
end
GPUのない環境でこれを実行したら数時間がかかるでしょう。
結果
学習が終了した後、保存した学習曲線グラフはこうなりました。
訓練データも検証データもあまり違いがなくて、ダイス係数が1の近くまで達したので、ちゃんと学習できたとわかります。
今回は80エポックまで行ったら検証データのダイス係数が10エポック上がっていないという条件が満たされてここで終了して、一番いい結果を出した70エポック目のモデルを採用することになりました。訓練データの方は学習が続いたらもっといい結果になりそうですが、それは過学習になるだけですね。
そして各学習ループが終わった後予測の例を描いたのでここで一部載せておきます。
まずは1エポック目。学習し始めたばかりでまだあまり形になっていない状態ですね。上述のグラフでもわかる通り再現率は低くて10%くらいしかない。
2エポック目になると大体いい形になってきましたね。再現率も70%になっているから。ただし白っぽいところがいっぱいあるってのは海星ではない部分まで海星と判断しすぎて、だから適合率が低いです。再現率が高くても適合率が低いというのも、違う意味で駄目です。ダイス係数は再現率も適合率も同時に考慮するバランスのいいスコアだからこれを基準とするのに意味がありますね。
3エポック目になると再現率が更に高くなって、余計な白い部分も大分なくなって適合率も回復しています。
19エポック目では突然崩れました。グラフでもはっきり見えるほど大きな急変。これはやばいですね。
その後すぐ回復してよかったです。学習の時こういう突然崩壊することもよくあるのですね。回復しないままずっと壊れていく場合もあるから、いつもいい結果が出せる時のパラメータを保存しておくのは大事です。せっかく時間をかけて学習したのにいきなり台無しになるのは嫌ですよね。
次の段階で形は既に安定して特に変化がなくて、学習は一番いい70エポック目に辿り着きました。まあ、大体30エポックから数字も予測の形も殆ど変わらないので30エポックで終了してもいい気がしますけどね。
学習が終了の後大部分はちゃんと海星の形になっていますが、まだ境界線がおかしくて上手くいかないものがありますね。特に海星と地面が似ている例。これはやはり難易度高くなりますね。
実際にもっと学習データを沢山生成したり、毎エポックで水増ししたりすることで画像のバリエーションを増やしたら更に結果はよくなりますけど、データが増えると学習にかかる時間も長くなるので今回はこれくらいにします。これでも1エポックに2~3分かかって、80エポックは3時間以上かかりました。因みに使ったパソコンはMacBook Air M3。
また、今回は海星みたいな簡単な形をしているものだからこんな64ピクセルの小さい画像でも簡単に区別できるが、もっと複雑なものだとこんな小さな画像では難しいはずです。十分大きい入力サイズも大事ですが、その分計算の負担は莫大になってGPUが欠かせないものになります。
終わりに
以上生成データの中でセマンティックセグメンテーションが上手くいったとわかりました。でもこのモデルをこのまま実の写真に使うことは難しいでしょう。ただ簡単に数十行のコードで作成したものだから。まだ本物の海星とはあまり似ていないから。実際の自然で撮った写真はもっと複雑だから学習データもそれなりに頑張らないといけないのですね。
もし3Dモデルからレンダリングした画像で学習したセマンティックセグメンテーションが実際に上手く本物の写真に使えるようになったらそれは凄いことになると思います。アノテーションの手間は全然かからないからです。
でもどうやって本物と同じ3Dモデルを作るのか、これは課題となりますね。AI技術とは別として、3D関連の知識は勿論必要です。
私が目指しているのもそのような、3DモデリングもできるAIプログラマーです。
参考&もっと読む
この記事を書く前に私いままで沢山セマンティックセグメンテーション関連の記事を読んできました。これら参考になった記事のリンクをここに貼っておきます。
基本
- 画像セグメンテーションについて調査中
- 深層学習を用いたセグメンテーションの紹介 セグメンテーションシリーズ①
- セマンティックセグメンテーションをざっくり学ぶ
- セグメンテーションとマルチスケールアテンション
- 国土地理院の空中写真データセットでセマンティック・セグメンテーションしてみた
アノテーションの話
- 画像の差分からセグメンテーションのアノテーションを作りたい!
- アノテーション仕様の曖昧さ (アスパラ収穫環境のSemantic Segmentationを例に)
- [論文紹介]アノテーションを自動作成してくれるアルゴリズムが凄い!
- LabelMEの使い方(セマンティックセグメンテーション用アノテーションツール)
- たった数クリックでセグメンテーションのマスクを作る!? Edge Flowの紹介
各モデル
- 2022年時点リアルタイムセマンティックセグメンテーションモデルのまとめ
- segmentation_models_pytorchの使い方と実装例
- セマンティックセグメンテーションについてザックリ解説【E資格対策】
SegNet
- MathWorks機械学習・深層学習セミナーの感想
- ChainerでSegNetとU-Net
- 実装で学ぶ深層学習(segmentation編) ~SegNet の実装~
- 【semantic segmentation】SegNet : Pooling Indicesでメモリの効率化
- SegNet: 画像セグメンテーションニューラルネットワーク
- CaDIS: a Cataract Datasetで画像セグメンテーション
U-Net
- U-Net紹介
- U-Netとは セグメンテーションシリーズ②
- 【Pytorch】UNetを実装する
- ボルダリングジムのホールドセグメンテーション
- 畳み込みオートエンコーダによる画像の再現、ノイズ除去、セグメンテーション
- セマンティックセグメンテーションをやってみた
- U-net構造で、画像セグメンテーションしてみた。(1)
- U-net構造で、画像セグメンテーションしてみた。(2)
- UNet(系)で、マルチクラスセグメンテーション(1)
- U-NetでPascal VOC 2012の画像をSemantic Segmentationする (TensorFlow)
- UNetで脳腫瘍を検出してみた(colab環境)
PSPNet
- 自作データを使ったPyTorchによるPSPNetの実装
- PSPNetをpytorchで実装しようとしたけどできなかった
- セマンティックセグメンテーション PSPNet
- 1日目 目標は衛星画像のSemantic Segmentation。
- 【semantic segmentation】PSPNet : Pyramid Pooling最強説
- PSPNetで脳腫瘍を検出してみた(colab環境)
DeepLab
- 【Semantic Segmentation】DeepLab (V1) : Fully Connected CRFでセグメンテーションを向上
- 【Semantic Segmentation】DeepLab(v2) : DeepLab(v1)との違いは?
- 【Semantic Segmentation】DeepLab(v3) : DeepLab(v2)との違いは?
- 【Semantic Segmentation】DeepLab(v3+) : DeepLab(v3)との違いは?
- DeepLab で Sematic Segmentation する(デモを動かす)
- DeepLabのSemantic Segmentationやーる(Windows10、Python3.6)
- DeepLab v3+(意訳)
- DeepLab v3+でオリジナルデータを学習してセグメンテーションできるようにする
- DeepLab v3+でオリジナルデータを学習してセマンティックセグメンテーションする
- DeepLabv3+を自前のデータセットで学習させる
- 誰でも出来る!DeepLab v3+でGPUを使って自作データセットで学習・推論する
- DeepLab学習用画像の要件
- 最強のSemantic Segmentation「Deep lab v3 plus」を用いて自前データセットを学習させる
各分野の応用
- 農業領域でのSemantic Segmentationの難しさ(作物と環境の多様性)と対応方法の例
- Pytorchによる航空画像の建物セグメンテーションの作成方法.
- セマンティックセグメンテーションによる建造物識別
- アニメイラストのセグメンテーションがやりたかった話
- 肺のX線画像のセグメンテーション(U-Net)をやってみた
- U-NetからDeepLab V3に変更したときの肺X線画像セグメンテーションの精度改善
- PSPNet vs UNet 脳腫瘍データセット
その他
- [PyTorch]セグメンテーションのためのDataAugmentation
- kerasを用いたセマンティックセグメンテーション
- セマンティックセグメンテーションを試してみる(Pytorch)
- 猫画像から猫部分のみを抽出する(matting/semantig segmentation)
- U2Netを独自データでトレーニングする
- Semantic Segmentationの実装
- segmentation_models_pytorchの使い方と実装例
- Image Pyramid(画像ピラミッド)と深層学習中の応用
- 学習最適化のための損失関数とOptimizer & MRI画像を使った比較
- 纏を習得する