1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

[MATLAB]ニューラル ネットワーク回帰モデルの学習(dlnetwork)検証データあり

Last updated at Posted at 2025-04-18

オブジェクトdlnetworkを利用して,ニューラルネットワークの回忌曲線の学習をする.

このサンプルに検証データあり過学習を防ぐサンプルに変更する.

sample_dlnet_dim1_valid.m

clc
close all
clearvars
rng("default");

%% 入力層の次元
dim_in=1;

%% 出力層の次元
dim_out=1;

%% 各層の大きさ
layer_sizes=[dim_in 10 20 10 dim_out];

%% 学習データの個数
num_train=500;

%% 検証データの個数
num_valid=floor(num_train*0.3);

%% ミニバッチ環境:auto / cpu / gpu
mbq_env='auto';

%% ミニバッチサイズ:
% 小規模モデル(MNISTなど):32, 64, 128
% 中規模CNN(ResNet-18など):64, 128, 256
% 大規模モデル(ResNet-50、ViTなど):256, 512, 1024(GPUに応じて)
% NLP(BERTなど):8, 16, 32(大きい入力のため小さめ)
batch_size=64;

%% 最大エポック数100〜500など
epoch_max=500;

%% 学習率の初期値
learn_rate_init=0.01;

%% Warm-UP
% 小規模モデル:1〜3エポック
% 中〜大規模モデル(ResNet, Transformer等):5〜10エポック
% BERT・ViTなどの大規模事前学習モデル:10エポック
learn_warmup=5;

%% 早期停止における我慢の回数
% 3〜5エポック:小規模モデル(小さなCNNなど)
% 5〜10エポック:中〜大規模モデル(ResNet等)
% 10〜20エポック:ノイズの多いタスク(RNN・時系列・強化学習など)
patience=20;

%% 学習の終了条件:loss_valid/loss_train>=loss_ratio_max (1.5〜3..0)
loss_ratio_max=3.0;        

%% 学習データの生成
x_train=4*rand(num_train,dim_in)-1;
y_train=func_dim1(x_train);

%% 検証データの生成
x_valid=4*rand(num_valid,1)-1;
y_valid=func_dim1(x_valid);

%% データの表示
fig1=figure('Position',[0 0 1000 700]);
func_fig_data2(fig1,x_train,y_train,x_valid,y_valid,'学習データと検証データ:y=f(x)','学習データ','検証データ');

%% ネットワークの定義
layers=[ featureInputLayer(layer_sizes(1),'Normalization','zscore','Name','input') ];
for i=2:(length(layer_sizes)-1)
    layers(end+1)=fullyConnectedLayer(layer_sizes(i),'Name',"fc_"+(i-1));
    layers(end+1)=reluLayer('Name',"relu_"+(i-1));
end
layers(end+1)=fullyConnectedLayer(layer_sizes(end),'Name',"fc_out");
net=dlnetwork(layers);
net=initialize(net);
fig_net=figure('Position',[1200 0 1000 700]);
func_fig_net(fig_net,net,'ネットワークの定義');

%% 学習データ用のミニバッチキュー
ds_x_train=arrayDatastore(x_train,'IterationDimension',1);
ds_y_train=arrayDatastore(y_train,'IterationDimension',1);
ds_train=combine(ds_x_train,ds_y_train);
mbq_train=minibatchqueue(ds_train, ...
    'MiniBatchSize',batch_size, ...
    'MiniBatchFormat',{'BC','BC'}, ...
    'OutputAsDlarray',true, ...
    'OutputEnvironment',mbq_env);

%% 検証データ用のミニバッチキュー
ds_x_valid=arrayDatastore(x_valid,'IterationDimension',1);
ds_y_valid=arrayDatastore(y_valid,'IterationDimension',1);
ds_valid=combine(ds_x_valid,ds_y_valid);
mbq_valid=minibatchqueue(ds_valid, ...
    'MiniBatchSize',batch_size, ...
    'MiniBatchFormat',{'BC','BC'}, ...
    'OutputAsDlarray',true, ...
    'OutputEnvironment',mbq_env);

%% Adamで学習
tic
disp('training...')
epoch=1;
done=false;
count=0;
grad_ave_train=[];
grad_ave_valid=[];
gra_sqrt_ave_train=[];
gra_sqrt_ave_valid=[];
loss_train=zeros(epoch_max,1);
loss_valid=zeros(epoch_max,1);
loss_valid_min=zeros(epoch_max,1);
figLoss=figure('Position',[0 800 1000 700]);
func_fig_loss(figLoss,epoch_max,loss_train,loss_valid_min,loss_valid);
while ~done && epoch<=epoch_max
    %% 学習率スケジューラ:エポックごとに減衰させる
    % Cosine Annealing:cosineで急現象
    % Warmup:最初ゆっくり上げてから減衰.⼤規模モデルで主流(BERT, ViT で常⽤)
    if epoch<learn_warmup
        learn_rate=learn_rate_init*epoch/learn_warmup;
    else
        learn_rate=learn_rate_init*0.5*(1+cos(pi*(epoch-learn_warmup)/(epoch_max-learn_warmup)));
    end
    %% 学習データ
    reset(mbq_train);
    shuffle(mbq_train);
    %% 学習データを全て用いて1エポックの計算
    itr=0;
    loss_train(epoch)=0;
    while hasdata(mbq_train)
        [batch_x,batch_y]=next(mbq_train);
        [loss,grad]=dlfeval(@loss_mse_grad,net,batch_x,batch_y);
        [net,grad_ave_train,gra_sqrt_ave_train]=adamupdate(net,grad,grad_ave_train,gra_sqrt_ave_train,epoch,learn_rate);
        loss_train(epoch)=loss_train(epoch)+double(gather(extractdata(loss)));
        itr=itr+1;
    end
    %% 1反復当たりの学習損失
    loss_train(epoch)=loss_train(epoch)/itr;
    %% 検証データ
    reset(mbq_valid);
    shuffle(mbq_valid);
    %% 検証データを全て用いて1エポックの計算
    itr=0;
    loss_valid(epoch)=0;
    while hasdata(mbq_valid)
        [batch_x,batch_y]=next(mbq_valid);
        [loss,grad]=dlfeval(@dlnet_loss_mse_grad,net,batch_x,batch_y);
        [net,grad_ave_valid,gra_sqrt_ave_valid]=adamupdate(net,grad,grad_ave_valid,gra_sqrt_ave_valid,epoch,learn_rate);
        loss_valid(epoch)=loss_valid(epoch)+double(gather(extractdata(loss)));
        itr=itr+1;
    end
    %% 1反復当たりの検証損失
    loss_valid(epoch)=loss_valid(epoch)/itr;
    %% 検証損失の最小値を前のエポックの値を引き継ぐ
    if epoch==1
        loss_valid_min(epoch)=+inf;
    else
        loss_valid_min(epoch)=loss_valid_min(epoch-1);
    end
    %% 検証損失が最小より悪化し続けたら終了
    if loss_valid(epoch)<loss_valid_min(epoch)
        loss_valid_min(epoch)=loss_valid(epoch);
    else
        count=count+1;
        if count>=patience
            done=true;
        end
    end
    %% 学習損失に帯する検証損失の比
    loss_ratio_valid_train=loss_valid(epoch)/loss_train(epoch);
    %% 検証損失が学習損失に比べて大きすぎたら終了
    if loss_ratio_valid_train>loss_ratio_max
        done=true;
    end
    %% 損失をグラフに表示
    if mod(epoch-1,10)==0
        func_fig_loss(figLoss,epoch_max,loss_train,loss_valid_min,loss_valid);
    end
    %% 表示
    if done || epoch==1 || epoch==epoch_max || mod(epoch,10)==0
        fprintf('[%04d/%04d] learn_rate=%.03e, loss_train=%.1e, loss_valid=%.1e, loss_valid_min=%.1e, loss_valid/loss_train=%.03f%s%.01f, count=%d%s%d\n', ...
            epoch,epoch_max,learn_rate,loss_train(epoch),loss_valid(epoch),loss_valid_min(epoch), ...
            loss_ratio_valid_train,tfv(loss_ratio_valid_train<loss_ratio_max,"<",">="),loss_ratio_max, ...
            count,tfv(count<patience,"<",">="),patience);
    end
    %% 次のエポック
    epoch=epoch+1;
end
epoch=epoch-1;
loss_train=loss_train(1:epoch);
loss_valid=loss_valid(1:epoch);
disp('done')
t=toc;

%% テストデータを予測で生成
x_test=(-1:0.01:3)';
y_pred=predict(net,x_test);

%% 予測値の表示
fig3=figure('Position',[1200 800 1000 700]);
func_fig_data2(fig3,x_train,y_train,x_test,y_pred,'学習データと予測データの比較','学習データ','予測データ');

%% 真値との誤差
YTrue=func_dim1(x_test);
Error=sum((y_pred-YTrue).^2)/num_train;
fprintf('epochs=%d, time=%.1f, loss=%.1e, error=%.1e\n',epoch,t,loss_train(epoch),Error);

%%========ローカル関数===========

%% 学習データの生成用
function y=func_dim1(x)
y=x.*(x-1).*(x-2);
end

%% モデル損失関数:平均2乗誤差MSE
function [loss,grad]=loss_mse_grad(net,x,y)
y_pred=forward(net,x);
loss=mean((y_pred-y).^2,'all');
grad=dlgradient(loss,net.Learnables);
end

function func_fig_net(fig,net,name)
figure(fig);
plot(net);
title(name);
drawnow;
end

function func_fig_data2(fig,x1,y1,x2,y2,name,l1,l2)
figure(fig);
plot(x1,y1,'o',x2,y2,'.');
grid on;
xlabel('x');
ylabel('y');
legend(l1,l2);
title(name);
drawnow;
end

function func_fig_loss(fig,epoch_max,loss_train,loss_valid_min,loss_valid)
figure(fig);
semilogy(1:epoch_max,loss_train,'-',1:epoch_max,loss_valid_min,'-r',1:epoch_max,loss_valid,'+k');
xlabel('Epoch');
ylabel('Loss');
title('学習中の損失MSEの推移');
legend('学習損失','検証損失の最小値','検証損失');
grid on;
drawnow;
end


1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?