オブジェクトdlnetwork
を利用して,ニューラルネットワークの回忌曲線の学習をする.
このサンプルに検証データあり過学習を防ぐサンプルに変更する.
sample_dlnet_dim1_valid.m
clc
close all
clearvars
rng("default");
%% 入力層の次元
dim_in=1;
%% 出力層の次元
dim_out=1;
%% 各層の大きさ
layer_sizes=[dim_in 10 20 10 dim_out];
%% 学習データの個数
num_train=500;
%% 検証データの個数
num_valid=floor(num_train*0.3);
%% ミニバッチ環境:auto / cpu / gpu
mbq_env='auto';
%% ミニバッチサイズ:
% 小規模モデル(MNISTなど):32, 64, 128
% 中規模CNN(ResNet-18など):64, 128, 256
% 大規模モデル(ResNet-50、ViTなど):256, 512, 1024(GPUに応じて)
% NLP(BERTなど):8, 16, 32(大きい入力のため小さめ)
batch_size=64;
%% 最大エポック数100〜500など
epoch_max=500;
%% 学習率の初期値
learn_rate_init=0.01;
%% Warm-UP
% 小規模モデル:1〜3エポック
% 中〜大規模モデル(ResNet, Transformer等):5〜10エポック
% BERT・ViTなどの大規模事前学習モデル:10エポック
learn_warmup=5;
%% 早期停止における我慢の回数
% 3〜5エポック:小規模モデル(小さなCNNなど)
% 5〜10エポック:中〜大規模モデル(ResNet等)
% 10〜20エポック:ノイズの多いタスク(RNN・時系列・強化学習など)
patience=20;
%% 学習の終了条件:loss_valid/loss_train>=loss_ratio_max (1.5〜3..0)
loss_ratio_max=3.0;
%% 学習データの生成
x_train=4*rand(num_train,dim_in)-1;
y_train=func_dim1(x_train);
%% 検証データの生成
x_valid=4*rand(num_valid,1)-1;
y_valid=func_dim1(x_valid);
%% データの表示
fig1=figure('Position',[0 0 1000 700]);
func_fig_data2(fig1,x_train,y_train,x_valid,y_valid,'学習データと検証データ:y=f(x)','学習データ','検証データ');
%% ネットワークの定義
layers=[ featureInputLayer(layer_sizes(1),'Normalization','zscore','Name','input') ];
for i=2:(length(layer_sizes)-1)
layers(end+1)=fullyConnectedLayer(layer_sizes(i),'Name',"fc_"+(i-1));
layers(end+1)=reluLayer('Name',"relu_"+(i-1));
end
layers(end+1)=fullyConnectedLayer(layer_sizes(end),'Name',"fc_out");
net=dlnetwork(layers);
net=initialize(net);
fig_net=figure('Position',[1200 0 1000 700]);
func_fig_net(fig_net,net,'ネットワークの定義');
%% 学習データ用のミニバッチキュー
ds_x_train=arrayDatastore(x_train,'IterationDimension',1);
ds_y_train=arrayDatastore(y_train,'IterationDimension',1);
ds_train=combine(ds_x_train,ds_y_train);
mbq_train=minibatchqueue(ds_train, ...
'MiniBatchSize',batch_size, ...
'MiniBatchFormat',{'BC','BC'}, ...
'OutputAsDlarray',true, ...
'OutputEnvironment',mbq_env);
%% 検証データ用のミニバッチキュー
ds_x_valid=arrayDatastore(x_valid,'IterationDimension',1);
ds_y_valid=arrayDatastore(y_valid,'IterationDimension',1);
ds_valid=combine(ds_x_valid,ds_y_valid);
mbq_valid=minibatchqueue(ds_valid, ...
'MiniBatchSize',batch_size, ...
'MiniBatchFormat',{'BC','BC'}, ...
'OutputAsDlarray',true, ...
'OutputEnvironment',mbq_env);
%% Adamで学習
tic
disp('training...')
epoch=1;
done=false;
count=0;
grad_ave_train=[];
grad_ave_valid=[];
gra_sqrt_ave_train=[];
gra_sqrt_ave_valid=[];
loss_train=zeros(epoch_max,1);
loss_valid=zeros(epoch_max,1);
loss_valid_min=zeros(epoch_max,1);
figLoss=figure('Position',[0 800 1000 700]);
func_fig_loss(figLoss,epoch_max,loss_train,loss_valid_min,loss_valid);
while ~done && epoch<=epoch_max
%% 学習率スケジューラ:エポックごとに減衰させる
% Cosine Annealing:cosineで急現象
% Warmup:最初ゆっくり上げてから減衰.⼤規模モデルで主流(BERT, ViT で常⽤)
if epoch<learn_warmup
learn_rate=learn_rate_init*epoch/learn_warmup;
else
learn_rate=learn_rate_init*0.5*(1+cos(pi*(epoch-learn_warmup)/(epoch_max-learn_warmup)));
end
%% 学習データ
reset(mbq_train);
shuffle(mbq_train);
%% 学習データを全て用いて1エポックの計算
itr=0;
loss_train(epoch)=0;
while hasdata(mbq_train)
[batch_x,batch_y]=next(mbq_train);
[loss,grad]=dlfeval(@loss_mse_grad,net,batch_x,batch_y);
[net,grad_ave_train,gra_sqrt_ave_train]=adamupdate(net,grad,grad_ave_train,gra_sqrt_ave_train,epoch,learn_rate);
loss_train(epoch)=loss_train(epoch)+double(gather(extractdata(loss)));
itr=itr+1;
end
%% 1反復当たりの学習損失
loss_train(epoch)=loss_train(epoch)/itr;
%% 検証データ
reset(mbq_valid);
shuffle(mbq_valid);
%% 検証データを全て用いて1エポックの計算
itr=0;
loss_valid(epoch)=0;
while hasdata(mbq_valid)
[batch_x,batch_y]=next(mbq_valid);
[loss,grad]=dlfeval(@dlnet_loss_mse_grad,net,batch_x,batch_y);
[net,grad_ave_valid,gra_sqrt_ave_valid]=adamupdate(net,grad,grad_ave_valid,gra_sqrt_ave_valid,epoch,learn_rate);
loss_valid(epoch)=loss_valid(epoch)+double(gather(extractdata(loss)));
itr=itr+1;
end
%% 1反復当たりの検証損失
loss_valid(epoch)=loss_valid(epoch)/itr;
%% 検証損失の最小値を前のエポックの値を引き継ぐ
if epoch==1
loss_valid_min(epoch)=+inf;
else
loss_valid_min(epoch)=loss_valid_min(epoch-1);
end
%% 検証損失が最小より悪化し続けたら終了
if loss_valid(epoch)<loss_valid_min(epoch)
loss_valid_min(epoch)=loss_valid(epoch);
else
count=count+1;
if count>=patience
done=true;
end
end
%% 学習損失に帯する検証損失の比
loss_ratio_valid_train=loss_valid(epoch)/loss_train(epoch);
%% 検証損失が学習損失に比べて大きすぎたら終了
if loss_ratio_valid_train>loss_ratio_max
done=true;
end
%% 損失をグラフに表示
if mod(epoch-1,10)==0
func_fig_loss(figLoss,epoch_max,loss_train,loss_valid_min,loss_valid);
end
%% 表示
if done || epoch==1 || epoch==epoch_max || mod(epoch,10)==0
fprintf('[%04d/%04d] learn_rate=%.03e, loss_train=%.1e, loss_valid=%.1e, loss_valid_min=%.1e, loss_valid/loss_train=%.03f%s%.01f, count=%d%s%d\n', ...
epoch,epoch_max,learn_rate,loss_train(epoch),loss_valid(epoch),loss_valid_min(epoch), ...
loss_ratio_valid_train,tfv(loss_ratio_valid_train<loss_ratio_max,"<",">="),loss_ratio_max, ...
count,tfv(count<patience,"<",">="),patience);
end
%% 次のエポック
epoch=epoch+1;
end
epoch=epoch-1;
loss_train=loss_train(1:epoch);
loss_valid=loss_valid(1:epoch);
disp('done')
t=toc;
%% テストデータを予測で生成
x_test=(-1:0.01:3)';
y_pred=predict(net,x_test);
%% 予測値の表示
fig3=figure('Position',[1200 800 1000 700]);
func_fig_data2(fig3,x_train,y_train,x_test,y_pred,'学習データと予測データの比較','学習データ','予測データ');
%% 真値との誤差
YTrue=func_dim1(x_test);
Error=sum((y_pred-YTrue).^2)/num_train;
fprintf('epochs=%d, time=%.1f, loss=%.1e, error=%.1e\n',epoch,t,loss_train(epoch),Error);
%%========ローカル関数===========
%% 学習データの生成用
function y=func_dim1(x)
y=x.*(x-1).*(x-2);
end
%% モデル損失関数:平均2乗誤差MSE
function [loss,grad]=loss_mse_grad(net,x,y)
y_pred=forward(net,x);
loss=mean((y_pred-y).^2,'all');
grad=dlgradient(loss,net.Learnables);
end
function func_fig_net(fig,net,name)
figure(fig);
plot(net);
title(name);
drawnow;
end
function func_fig_data2(fig,x1,y1,x2,y2,name,l1,l2)
figure(fig);
plot(x1,y1,'o',x2,y2,'.');
grid on;
xlabel('x');
ylabel('y');
legend(l1,l2);
title(name);
drawnow;
end
function func_fig_loss(fig,epoch_max,loss_train,loss_valid_min,loss_valid)
figure(fig);
semilogy(1:epoch_max,loss_train,'-',1:epoch_max,loss_valid_min,'-r',1:epoch_max,loss_valid,'+k');
xlabel('Epoch');
ylabel('Loss');
title('学習中の損失MSEの推移');
legend('学習損失','検証損失の最小値','検証損失');
grid on;
drawnow;
end