6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

機械学習Advent Calendar 2023

Day 16

3Dモデルから自動的に生成した画像とアノテーションのデータセットで学習するセマンティックセグメンテーション

Last updated at Posted at 2024-05-25

はじめに

セマンティックセグメンテーションを行う時に普段はアノテーションを作成するのは大変な作業でしょう。もしそれが自動的に作れるのならどれくらい楽になるでしょうね。

私は「自動的に生成された画像データセットで学習して本物に適用する」ということはよくやっています。普通の分類モデルでも教師データを準備することは大変なことだから、自動生成のデータが代わりに使えたら楽ですね。

そしてその生成データはセマンティックセグメンテーションにも使えるようにすることもできます。自動的に生成したデータなので、アノテーションも当然同時に作成することができます。しかもこれは手作業より正確で完璧なアノテーションになるでしょう。

「学習データがないので自分で生成する」という話はよくあることで新しいことではないのですが、これをセマンティックセグメンテーションに使う例はあまり聞いたことないの意外でした。だから私は自分で試してみました。

この記事ではランダムで3Dモデルを作成してレンダリングしてそれを使ってセマンティックセグメンテーションのニューラルネットワークに学習させる実装の例を説明します。

今回のコードはMATLABで書いています。理由はMATLABの3Dレンダリング機能で簡単に綺麗な3D画像が作れるからです。Pythonにはmatplotlibやmayaviがあるけど、MATLABの方がよくできています。PythonはMayaやBlenderなど3D専用のソフトウェアと連携できますが、そうしない限りPythonだけでMATLABほど綺麗な3D画像を作るのは難しいでしょう。

というわけで、今回の実装は全部MATLABを使うことにしましたが、自動的に学習データとアノテーションを作成するという考え方としてPythonや他の言語も通用するでしょう。

MATLABで深層学習の実装方法及びpytorchの書き方の比較は前回の記事に書いてあります。まずその記事を読んだら今回の記事がわかりやすくなると思います。

なお、セマンティックセグメンテーションの基本については沢山記事があるので割愛します。この記事の最後に纏めておいたリンクを参考に。

セマンティックセグメンテーションがかなり人気な手法で、このqiitaにも実装する例が沢山ありますが、全部Python(pytorch, keras, tensorflow, chainer)で、MATLABによる実装の例はなくて恐らくこの記事で初めてでしょうね。だから自分で書き方をアレンジするしかないです。

3Dモデルの作成

今回使うのは自分で作成する海星ヒトデのデータセットです。

まずこのコードを実行したら綺麗な海星ヒトデが出てきます。

sx = 360; % 円周におけるポリゴンの分割
sy = 100; % 暑さにおけるポリゴンの分割
atsusa = 1; % 暑さ
hankei1 = 1.5; % 中の半径
hankei2 = 4; % 外の半径

% 暑さと円周のメッシュ
[ar_z,ar_theta] = meshgrid(linspace(0,atsusa-(atsusa/sy),sy), ...
                           linspace(0,360-(360/sx),sx));
ar_a = ((sin((1-ar_z/atsusa).^0.5*pi/2)).^2)*hankei2;
ar_r = ar_a - (hankei2-hankei1)./hankei2.*ar_a.*abs(cos(deg2rad((ar_theta)*5/2)));
ar_r = ar_r + randn(sx,sy)*0.04; % ランダムで表面を凸凹にする

% 少しねじる
ar_theta = ar_theta+15.*sin(deg2rad(ar_theta*5/2)).^2;
aaa = cumsum(randn(1,sx));
aaa = (aaa - linspace(aaa(1),aaa(end),sx))'/range(aaa)*10;
ar_theta = ar_theta+aaa;

% 直交座標系にする
ar_x = ar_r.*cos(deg2rad(ar_theta));
ar_y = ar_r.*sin(deg2rad(ar_theta));

% 頂点
ar_v1 = [ar_x(:) ar_y(:) ar_z(:)];
ar_v2 = [0 0 atsusa]; % 一番上の頂点
ar_v = [ar_v1
        ar_v2];

% 面
fv1 = (1:sx) + (0:sx:sx*(sy-1)-1)';
fv2 = [2:sx 1] + (0:sx:sx*(sy-1)-1)';
fv3 = [2:sx 1] + (sx:sx:sx*sy-1)';
fv4 = (1:sx) + (sx:sx:sx*sy-1)';
ar_f1 = cat(3,fv1.',fv2.',fv3.',fv4.');
ar_f1 = reshape(ar_f1,[],4);
ar_f2 = [(sx*(sy-1)+1:sx*sy)' ...
         [sx*(sy-1)+2:sx*sy sx*(sy-1)+1]' ...
         repmat(sx*sy+1,sx,1) ...
         NaN(sx,1)];
ar_f = [ar_f1;ar_f2];

% 海星の表面の色。HSV空間でランダムしてRGBに変換する
ar_hsv = [mod(unifrnd(0.02,0.08,[sx*sy+1 1]),1),unifrnd(0.3,1,[sx*sy+1 1]),unifrnd(0.7,1,[sx*sy+1 1])];
ar_c = hsv2rgb(ar_hsv);

close all
figure(Position=[100 100 600 600])

% 地面を入れる
[ymesh,xmesh] = meshgrid( ...
    linspace(-8,8,501), ...
    linspace(-8,8,501));
zmesh = -0.1.*rand(501,501);
mesh_jimen = mesh(xmesh,ymesh,zmesh,FaceColor=[0.79 0.63 0.16],LineStyle="none");

% 海星のポリゴンを入れる
patch( ...
    Faces=ar_f, ...
    Vertices=ar_v, ...
    FaceVertexCData=ar_c, ...
    FaceColor="interp", ...
    LineStyle="none");
material shiny % 明るい表面にする
lighting gouraud % グーローシェーディングを使う

set(gcf,color="k");
set(gca,Position=[0 0 1 1])
light(Position=[2 5 10],color=[0.9,0.9,0.9]); % 照明を入れる
camproj("perspective") % 透視図モード
view(0,50) % 眺める角度
axis equal off % 各軸の単位を揃える

hitode.jpg

今回はこのような海星ヒトデを検出するためのセマンティックセグメンテーションです。

3Dモデルを作ることは今回の記事の主題ではないので、説明は割愛しますが、MATLABによる3Dモデリングに関する基本は次の記事で説明しています。

今の簡単な海星ヒトデの3Dモデルで本物の海星ヒトデを代表できるとは言い難いのですが、もっと色々工夫したらできるでしょう。でもコードは複雑になるので、今回はただ簡単な例を作るためだからこれくらいにしましょう。

ランダムで作った3Dモデルから学習用の画像を量産する

以上のコートで一枚の海星ヒトデの画像ができますが、もし色な大きさや地面など色々ランダムしたら、沢山違う画像ができるでしょう。

今回は学習に使うためにランダムで1000枚を作ります。同時に海星ヒトデがない地面だけの空っぽ画像も作って同じように学習に使います。そして海星ヒトデのピクセルを示すアノテーションも作成します。だから生成されるのは全部3000×3枚となります。

海星ヒトデなしの画像も学習に使うのは、海星ヒトデがない画像にも適用できるようにするためです。海星ヒトデあり画像ばっかり学習すると全ての画像はどこかで海星ヒトデがあるという基準になる恐れがあるから。

それに海星ヒトデを入れるのと入れない同じ画像のペアで学習すると海星ヒトデの特徴を把握しやすいと思うから。このようなことができるのは人工的に作るデータだからこそですね。

生成コードは以下です。

sx = 360;
sy = 100;
n_img = 2000; % 生成する海星の画像の枚数
px = 256; % 保存する画像の解像度
fol_nashi = "hitodenashi"; % 海星なし画像を保存するフォルダ
fol_ari = "hitodeari"; % 海星あり画像を保存するフォルダ
fol_anno = "hitodeanno"; % アノテーション画像を保存するフォルダ

% 保存するフォルダまだ存在していなければまず作成する
if(~exist(fol_nashi,"dir"))
    mkdir(fol_nashi)
end
if(~exist(fol_ari,"dir"))
    mkdir(fol_ari)
end
if(~exist(fol_anno,"dir"))
    mkdir(fol_anno)
end

for i = 1524:n_img % ループして一枚ずつ生成する
    % 毎回色々ランダムする
    atsusa = unifrnd(0.8,3);
    hankei1 = unifrnd(0.8,2);
    hankei2 = unifrnd(3,5);
    theta0 = unifrnd(0,360);
    idou_x = unifrnd(-3,3);
    idou_y = unifrnd(-3,3);
    h = unifrnd(0.04,0.16);
    xyz_shoumei = [unifrnd(-10,10) unifrnd(-10,10) 10];
    
    [ar_z,ar_theta] = meshgrid(linspace(0,atsusa-(atsusa/sy),sy), ...
                               linspace(0,360-(360/sx),sx));
    ar_a = ((sin((1-ar_z/atsusa).^0.5*pi/2)).^2)*hankei2;
    ar_r = ar_a - (hankei2-hankei1)./hankei2.*ar_a.*abs(cos(deg2rad((ar_theta-theta0-unifrnd(0,72))*5/2)));
    ar_r = ar_r + randn(sx,sy)*0.04;
    
    ar_theta = ar_theta+unifrnd(-15,15).*sin(deg2rad((ar_theta-theta0)*5/2)).^2;
    aaa = cumsum(randn(1,sx));
    aaa = (aaa - linspace(aaa(1),aaa(end),sx))'/range(aaa)*10;
    ar_theta = ar_theta+aaa;

    ar_x = ar_r.*cos(deg2rad(ar_theta))+idou_x;
    ar_y = ar_r.*sin(deg2rad(ar_theta))+idou_y;
    
    ar_v1 = [ar_x(:) ar_y(:) ar_z(:)];
    ar_v2 = [idou_x idou_y atsusa];
    ar_v = [ar_v1;ar_v2];
    
    fv1 = (1:sx) + (0:sx:sx*(sy-1)-1)';
    fv2 = [2:sx 1] + (0:sx:sx*(sy-1)-1)';
    fv3 = [2:sx 1] + (sx:sx:sx*sy-1)';
    fv4 = (1:sx) + (sx:sx:sx*sy-1)';
    ar_f1 = cat(3,fv1.',fv2.',fv3.',fv4.');
    ar_f1 = reshape(ar_f1,[],4);
    ar_f2 = [(sx*(sy-1)+1:sx*sy)' ...
             [sx*(sy-1)+2:sx*sy sx*(sy-1)+1]' ...
             repmat(sx*sy+1,sx,1) ...
             NaN(sx,1)];
    ar_f = [ar_f1;ar_f2];
    
    ar_hsv = [mod(unifrnd(h-0.03,h+0.03,[sx*sy+1 1]),1),unifrnd(0.3,1,[sx*sy+1 1]),unifrnd(unifrnd(0.2,0.9),1,[sx*sy+1 1])];
    ar_c = hsv2rgb(ar_hsv);
    
    
    close all
    figure(visible="off",Position=[100 100 600 600])
    set(gcf,color="k");
    set(gca,Position=[0 0 1 1])

    % 地面のメッシュを入れる
    kaizoudo_jimen = randi([100,600]);
    [ymesh,xmesh] = meshgrid( ...
        linspace(-8,8,kaizoudo_jimen+1), ...
        linspace(-8,8,kaizoudo_jimen+1));
    zmesh = -0.1.*rand(kaizoudo_jimen+1,kaizoudo_jimen+1);
    mesh_jimen = mesh(xmesh,ymesh,zmesh,...
             FaceColor=hsv2rgb([unifrnd(0.06,0.32),unifrnd(0.1,0.6),unifrnd(0.3,0.9)]),LineStyle="none");
    material shiny
    lighting gouraud

    camproj("perspective")
    light(Position=xyz_shoumei,color=[unifrnd(0.8,1),unifrnd(0.8,1),0.9]);
    axis equal off
    view(0,90)
    set(gca,CameraPosition=[0 0 20])
    set(gca,CameraTarget=[0 0 0])
    set(gca,CameraViewAngle=42)
    
    % 人手なしの画像を保存する
    img1 = imresize(getframe(gcf).cdata,[px px]);
    imwrite(img1,sprintf(fullfile(fol_nashi,"%04d.jpg"),i));
    
    % 海星を入れる
    patch( ...
        Faces=ar_f, ...
        Vertices=ar_v, ...
        FaceVertexCData=ar_c, ...
        FaceColor="interp", ...
        LineStyle="none");
    material shiny
    lighting gouraud
    
    % 海星ありの画像を保存する
    img2 = imresize(getframe(gcf).cdata,[px px]);
    imwrite(img2,sprintf(fullfile(fol_ari,"%04d.jpg"),i));

    delete(mesh_jimen) % 地面のメッシュを消して、海星だけ残す
    % アノテーションの画像を作る
    img3 = all(imresize(getframe(gcf).cdata,[px px])>0,3);
    imwrite(img3,sprintf(fullfile(fol_anno,"%04d.png"),i),BitDepth=1);
end

これを実行したら生成が始まります。生成速度はかなり速いです。パソコンの性能にもよりますが、私のパソコンでは1000個で10分くらいかかります。

このようにいっぱい海星ヒトデができました。
截屏2024-05-26-03.38.50.jpg

ここで海星ヒトデなしともアノテーションとも並べて一部表示してみます。

px = 256;
n = 6;
imds_nashi = imageDatastore("hitodenashi");
imds_ari = imageDatastore("hitodeari");
imds_anno = imageDatastore("hitodeanno");

arimg = zeros(px,px*3+2,3,n,"uint8")+255;
for i = 1:n
    arimg(:,1:px,:,i) = imds_nashi.read;
    arimg(:,2+px:1+px*2,:,i) = imds_ari.read;
    arimg(:,3+px*2:2+px*3,:,i) = repmat(imds_anno.read*255,1,1,3,1);
end

imwrite(imtile(arimg,GridSize=[n 1],BorderSize=1,BackgroundColor="w"),"hitode3x6.jpg");

hitode3x6.jpg

このようにアノテーションもパッケージで作られました。ちゃんと海星ヒトデの輪郭になっています。これで学習用のデータが整いました。

使うニューラルネットワークのモデル

セマンティックセグメンテーションに使うニューラルネットワークのモデルは色々あって、今一番評判がいいのはDeepLab v3+でしょう。MATLABでもDeepLab v3+が準備されておいて簡単に使えるのですが、今回は小さい画像でも軽く実装できるU-netを使います。MATLABではU-netが準備してあるので簡単に使えます。

U-netは私が以前から使っていて馴染んでいるモデルです。記事にも書きました。

以前使ったのはノイズ除去のためですが、本来U-netはセマンティックセグメンテーションのために作られたモデルですね。

実際にノイズ除去に使われたU-netはセマンティックセグメンテーションに使われるモデルとは少し違うが、構造は殆ど同じなのでどっちもU-netと呼びますね。

違いは例えば最後の層はセマンティックセグメンテーションの場合はソフトマックスですが、ノイズ除去ではLReLUに入れ替えられます。

adobe illustratorのextendscriptで今回使うU-netモデルの構造を描いてここに載せておきます。

unet.png

ここでmは画像のサイズで、nは入力チャネル数です。今回は色画像なのでn=3となりますが、n=1のグレースケールも使えます。convは畳み込み層でconvTは逆畳み込み層であり、後ろにある数字は出力チャネル数×カーネルサイズ、s=ストライド、p=パディング。最後の層でのkは分割する種類の数です。今回は海星ヒトデと地面でk=2だけとなります。

基本的にmaxpool層で半分ずつ小さくして、その後逆畳み込み層(convT)で2倍ずつ大きくする、というオートエンコーダーの構造ですが、スキップ接続があるのは特徴ですね。特徴マップのサイズは一番小さないところ(最後のmaxpoolの後)では入力サイズの1/16まで縮められます。

セマンティックセグメンテーションの実装

データが準備できて、使うニューラルネットワークのモデルも決まったら、次はこれを使ってセマンティックセグメンテーションの学習をします。

画像は256ピクセルで生成されたのですが、このサイズの画像を扱うとなるとGPUがないと無理です。今回はCPUでも気軽に動ける程度の例をしたいので、64ピクセルに縮小した画像を使います。

セマンティックセグメンテーションでは評価の基準は正確度よりもダイス係数(F1値)やIoUTversky損失が使われることが多いのです。今回は一番わかりやすいダイス係数を使います。再現率(recall)と適合率(precision)も一緒に学習曲線グラフを描きます。10エポック経ってもダイス係数が上がらなければ学習はそれで終了となって、ダイス係数が一番高い時のモデルを採用します。

実装のコードは以下です。

px = 64; % 入力画像のピクセル
n_epoch = 200; % 最大のエポック数
mouii = 10; % 正確度が何エポック上がらなければ学習が止まる
batch_size = 32; % 各ミニバッチの数
test_size = 0.2; % テストに使うデータの割合
fol_nashi = "hitodenashi"; % 海星なしデータが保存されてあるフォルダ
fol_ari = "hitodeari"; % 海星ありデータが保存されてあるフォルダ
fol_anno = "hitodeanno"; % アノテーションデータが保存されてあるフォルダ

% モデルはU-netを使う
dlnet = unet([px px 3],2);

% 学習ループを進める関数
function [dlnet,param_ag,param_asg] = fcn_update(dlnet,X,T,param_ag,param_asg,i_iter)
    if(canUseGPU) % GPUが使える場合
        X = gpuArray(X);
        T = gpuArray(T);
    end
    [Y,state] = dlnet.forward(X);
    loss = crossentropy(Y,T);
    param_grad = dlgradient(loss,dlnet.Learnables);
    dlnet.State = state;
    [dlnet,param_ag,param_asg] = adamupdate(dlnet,param_grad,param_ag,param_asg,i_iter);
end

% 海星ありと海星なしのデータを連結して準備する関数
function [X,T] = fcn_xt(ar_nashi,ar_ari,ar_anno)
    X = dlarray(cat(4,ar_nashi,ar_ari),"SSCB");
    T = dlarray( ...
        cat(4, ...
            cat(3,zeros(size(ar_anno)),ones(size(ar_anno))), ...
            cat(3,ar_anno,1-ar_anno)), ...
        "SSCB");
end

imds_nashi = imageDatastore(fol_nashi, ...
    ReadFcn=@(fn)(single(imresize(imread(fn),[px px])))/255);
imds_ari = imageDatastore(fol_ari, ...
    ReadFcn=@(fn)(single(imresize(imread(fn),[px px])))/255);
imds_anno = imageDatastore(fol_anno, ...
    ReadFcn=@(fn)single(imresize(imread(fn),[px px])>0));

% 全部の画像データを読み込んでおく
ar_nashi = cat(4,imds_nashi.readall{:});
ar_ari = cat(4,imds_ari.readall{:});
ar_anno = cat(4,imds_anno.readall{:});

n_data = numel(imds_ari.Files); % 全部のデータの数
n_kenshou = round(n_data*test_size); % 検証データの数
n_kunren = n_data-n_kenshou; % 訓練データの数
n_batch = ceil(n_kunren/batch_size); % 学習の時にミニバッチに分ける回数

% ランダムで訓練と検証データに分ける
idx_rand = randperm(n_data);
idx_kunren = idx_rand(1:n_kunren);
idx_kenshou = idx_rand(n_kunren+1:end);

param_ag = [];
param_asg = [];
lis_dice_kunren = []; % 各エポックの訓練データのダイス係数を収める配列
lis_dice_kenshou = []; % 各エポックの検証データの分割間違い率を収める配列
lis_saigen_kunren = []; % 各エポックの訓練データの再現率を収める配列
lis_saigen_kenshou = []; % 各エポックの検証データの分再現率を収める配列
lis_tekigou_kunren = []; % 各エポックの訓練データの適合率を収める配列
lis_tekigou_kenshou = []; % 各エポックの検証データの分適合率を収める配列
lis_sonshitsu_kunren = []; % 各エポックの訓練データの損失を収める配列
lis_sonshitsu_kenshou = []; % 各エポックの検証データの損失を収める配列
mouiika = 0; % 上がらない回数を数える
i_iter = 1; % 繰り返しの回数を数える

figure(Position=[100 100 400 600]); % 学習曲線のグラフを描く準備
fprintf("訓練データ数:%d\n検証データ数:%d\n学習開始\n",n_kunren,n_kenshou)
t0 = datetime("now");
% 学習ループ開始
for i_epoch = 1:n_epoch
    idx_perm = randperm(n_kunren);
    for i_batch = 1:n_batch
        idx_batch = idx_kunren(idx_perm((i_batch-1)*batch_size+1:min(i_batch*batch_size,n_kunren)));
        [X,T] = fcn_xt(...
            ar_nashi(:,:,:,idx_batch), ...
            ar_ari(:,:,:,idx_batch), ...
            ar_anno(:,:,:,idx_batch));
        % パラメータの更新
        [dlnet,param_ag,param_asg] = dlfeval(@fcn_update,dlnet,X,T,param_ag,param_asg,i_iter);
        i_iter = i_iter+1;
    end

    % 各エポックの学習が終わった後、学習データと検証データで検証を行う
    [X,T] = fcn_xt(ar_nashi(:,:,:,idx_kunren(1:n_kenshou)),ar_ari(:,:,:,idx_kunren(1:n_kenshou)),ar_anno(:,:,:,idx_kunren(1:n_kenshou)));
    Y = minibatchpredict(dlnet,X,MiniBatchSize=batch_size);
    lis_sonshitsu_kunren = cat(1,lis_sonshitsu_kunren,crossentropy(Y,T));

    seikaika = (Y(:,:,1,:)>0.5)==T(:,:,1,:);
    saigen = mean(seikaika(T(:,:,1,:)==1),"all");
    tekigou = mean(seikaika(Y(:,:,1,:)>0.5),"all");
    lis_dice_kunren = cat(1,lis_dice_kunren,2*(saigen*tekigou)/(saigen+tekigou));
    lis_saigen_kunren = cat(1,lis_saigen_kunren,saigen*100);
    lis_tekigou_kunren = cat(1,lis_tekigou_kunren,tekigou*100);

    [X,T] = fcn_xt(ar_nashi(:,:,:,idx_kenshou),ar_ari(:,:,:,idx_kenshou),ar_anno(:,:,:,idx_kenshou));
    Y = minibatchpredict(dlnet,X,MiniBatchSize=batch_size);
    lis_sonshitsu_kenshou = cat(1,lis_sonshitsu_kenshou,crossentropy(Y,T));

    seikaika = (Y(:,:,1,:)>0.5)==T(:,:,1,:);
    saigen = mean(seikaika(T(:,:,1,:)==1),"all");
    tekigou = mean(seikaika(Y(:,:,1,:)>0.5),"all");
    lis_dice_kenshou = cat(1,lis_dice_kenshou,2*(saigen*tekigou)/(saigen+tekigou));
    lis_saigen_kenshou = cat(1,lis_saigen_kenshou,saigen*100);
    lis_tekigou_kenshou = cat(1,lis_tekigou_kenshou,tekigou*100);
    
    % 学習曲線を描く
    tiledlayout(4,1,Padding="none",TileSpacing="tight");
    nexttile
    hold on
    [max_kunren,idxmax_kunren] = max(lis_dice_kunren);
    [max_kenshou,idxmax_kenshou] = max(lis_dice_kenshou);
    p1 = plot(lis_dice_kunren,"r",LineWidth=2,DisplayName=sprintf("訓練データ(最大 %.3g)",max_kunren));
    p2 = plot(lis_dice_kenshou,"g",LineWidth=2,DisplayName=sprintf("検証データ(最大 %.3g)",max_kenshou));
    plot(idxmax_kunren,max_kunren," or",MarkerFaceColor="r")
    plot(idxmax_kenshou,max_kenshou," og",MarkerFaceColor="g")
    legend([p1,p2],Location="best",FontSize=12)
    xlim([0.9 numel(lis_dice_kenshou)+0.1])
    xticklabels([])
    ylabel("ダイス係数",FontSize=12)
    grid

    nexttile
    hold on
    plot(clip(lis_saigen_kunren,1e-6,100),"r",LineWidth=2)
    plot(clip(lis_saigen_kenshou,1e-6,100),"g",LineWidth=2)
    xlim([0.9 numel(lis_saigen_kenshou)+0.1])
    xticklabels([])
    ylabel("再現率(%)",FontSize=12)
    grid

    nexttile
    hold on
    plot(clip(lis_tekigou_kunren,1e-6,100),"r",LineWidth=2)
    plot(clip(lis_tekigou_kenshou,1e-6,100),"g",LineWidth=2)
    xlim([0.9 numel(lis_tekigou_kenshou)+0.1])
    xticklabels([])
    ylabel("適合率(%)",FontSize=12)
    grid

    nexttile
    hold on
    plot(lis_sonshitsu_kunren,"r",LineWidth=2)
    plot(lis_sonshitsu_kenshou,"g",LineWidth=2)
    set(gca,YScale="log")
    xlim([0.9 numel(lis_sonshitsu_kunren)+0.1])
    xlabel("エポック",FontSize=12)
    ylabel("交差エントロピー",FontSize=12)
    grid
    saveas(gca,"lc_hitode.png")

    fprintf("%.2f分経って、%dエポックが終わった。最大ダイス係数は%.3f\n",minutes(datetime("now")-t0),i_epoch,max_kenshou);

    % 学習の各段階での予測結果の例も描いておく
    X = ar_ari(:,:,:,idx_kenshou(1:36));
    Y = minibatchpredict(dlnet,dlarray(X,"SSCB"),MiniBatchSize=batch_size).extractdata;
    imwrite(imtile([X;repmat(Y(:,:,1,:),1,1,3,1)]),sprintf("imtile_hitode%02d.jpg",i_epoch))

    % 検証データのダイス係数が最大より更に上がったかどうか
    if(lis_dice_kenshou(end) <= max(lis_dice_kenshou(1:end-1)))
        % 上がらなかった場合
        mouiika = mouiika+1;
        if(mouiika >= mouii)
            break % 数回も上がっていない場合すぐ学習を終了させる
        end
    else
        % 上がったらもう一度最初から数える
        mouiika = 0;
        % 今後使うために、一番いい結果を出す時の学習済みモデルを保存しておく
        save("hitodenet.mat","dlnet")
    end
end

GPUのない環境でこれを実行したら数時間がかかるでしょう。

結果

学習が終了した後、保存した学習曲線グラフはこうなりました。

lc_hitode.png

訓練データも検証データもあまり違いがなくて、ダイス係数が1の近くまで達したので、ちゃんと学習できたとわかります。

今回は80エポックまで行ったら検証データのダイス係数が10エポック上がっていないという条件が満たされてここで終了して、一番いい結果を出した70エポック目のモデルを採用することになりました。訓練データの方は学習が続いたらもっといい結果になりそうですが、それは過学習になるだけですね。

そして各学習ループが終わった後予測の例を描いたのでここで一部載せておきます。

まずは1エポック目。学習し始めたばかりでまだあまり形になっていない状態ですね。上述のグラフでもわかる通り再現率は低くて10%くらいしかない。
imtile_hitode01.jpg

2エポック目になると大体いい形になってきましたね。再現率も70%になっているから。ただし白っぽいところがいっぱいあるってのは海星ヒトデではない部分まで海星ヒトデと判断しすぎて、だから適合率が低いです。再現率が高くても適合率が低いというのも、違う意味で駄目です。ダイス係数は再現率も適合率も同時に考慮するバランスのいいスコアだからこれを基準とするのに意味がありますね。
imtile_hitode02.jpg

3エポック目になると再現率が更に高くなって、余計な白い部分も大分なくなって適合率も回復しています。
imtile_hitode03.jpg

それが続いて海星ヒトデの形が整っていきます。
5エポック目。
imtile_hitode05.jpg

17エポック目。大分きれいになってきました。
imtile_hitode17.jpg

19エポック目では突然崩れました。グラフでもはっきり見えるほど大きな急変。これはやばいですね。
imtile_hitode19.jpg

その後すぐ回復してよかったです。学習の時こういう突然崩壊することもよくあるのですね。回復しないままずっと壊れていく場合もあるから、いつもいい結果が出せる時のパラメータを保存しておくのは大事です。せっかく時間をかけて学習したのにいきなり台無しになるのは嫌ですよね。

次の段階で形は既に安定して特に変化がなくて、学習は一番いい70エポック目に辿り着きました。まあ、大体30エポックから数字も予測の形も殆ど変わらないので30エポックで終了してもいい気がしますけどね。

imtile_hitode70.jpg

学習が終了の後大部分はちゃんと海星ヒトデの形になっていますが、まだ境界線がおかしくて上手くいかないものがありますね。特に海星ヒトデと地面が似ている例。これはやはり難易度高くなりますね。

実際にもっと学習データを沢山生成したり、毎エポックで水増ししたりすることで画像のバリエーションを増やしたら更に結果はよくなりますけど、データが増えると学習にかかる時間も長くなるので今回はこれくらいにします。これでも1エポックに2~3分かかって、80エポックは3時間以上かかりました。因みに使ったパソコンはMacBook Air M3。

また、今回は海星ヒトデみたいな簡単な形をしているものだからこんな64ピクセルの小さい画像でも簡単に区別できるが、もっと複雑なものだとこんな小さな画像では難しいはずです。十分大きい入力サイズも大事ですが、その分計算の負担は莫大になってGPUが欠かせないものになります。

終わりに

以上生成データの中でセマンティックセグメンテーションが上手くいったとわかりました。でもこのモデルをこのまま実の写真に使うことは難しいでしょう。ただ簡単に数十行のコードで作成したものだから。まだ本物の海星ヒトデとはあまり似ていないから。実際の自然で撮った写真はもっと複雑だから学習データもそれなりに頑張らないといけないのですね。

もし3Dモデルからレンダリングした画像で学習したセマンティックセグメンテーションが実際に上手く本物の写真に使えるようになったらそれは凄いことになると思います。アノテーションの手間は全然かからないからです。

でもどうやって本物と同じ3Dモデルを作るのか、これは課題となりますね。AI技術とは別として、3D関連の知識は勿論必要です。
私が目指しているのもそのような、3DモデリングもできるAIプログラマーです。

参考&もっと読む

この記事を書く前に私いままで沢山セマンティックセグメンテーション関連の記事を読んできました。これら参考になった記事のリンクをここに貼っておきます。

基本

アノテーションの話

各モデル

SegNet

U-Net

PSPNet

DeepLab

各分野の応用

その他

6
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?