モデル
学習モデルは、MATLABのDeep Learning Toolboxのアドオンから入手できるものを用いる。
事前学習済みの深層ニューラル ネットワーク
今回は、VGG16とResNet50を例に説明する。
VGG16
VGG16.m
% --- 学習データと検証データの読み込み
% フォルダ名がクラスになる。
imds = imageDatastore('./pictures', 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
% 学習用と検証用にデータを分割
p = 0.6; % 学習用に使うデータの割合
[trainData, testData] = splitEachLabel(imds, p, 'randomized');
% ---- 事前学習済みのネットワークの読み込み。
% アドオンになければ、ダウンロードリンクが表示される。
net = vgg16;
% このままでは編集できないので、コピーのような操作を行う。
lgraph = layerGraph(net);
% ---出力数が合わないので、最後の出力層を変更する。
% 最後から二番目 (fullyConnectedLayer)
LayerName = lgraph.Layers(end-2).Name;
newLayer = fullyConnectedLayer(出力数, 'Name', LayerName);
lgraph = replaceLayer(lgraph, LayerName, newLayer);
% 最後 (classificationLayer)
LayerName = lgraph.Layers(end).Name;
newLayer = classificationLayer('Name', LayerName);
lgraph = replaceLayer(lgraph, LayerName, newLayer);
% ---ネットワークのインプットサイズに合わせて画像をリサイズ
% trainDataとtestDataの読み込みは割愛。
inputSize = lgraph.Layers(1).InputSize(1:2);
augTrainData = augmentedImageDatastore(inputSize, trainData);
augTestData = augmentedImageDatastore(inputSize, testData);
% ---学習オプション
InitialLearnRate = 1e-4;
MaxEpochs = 6;
MinBatchSize = 10;
ValidFreq = floor(length(augimdsTrain.Files) / MinBatchSize); % 1エポックに1回の検証
options = trainingOptions('sgdm', ...
'MiniBatchSize', MinBatchSize, ...
'MaxEpochs', MaxEpochs, ...
'InitialLearnRate', InitialLearnRate, ...
'Shuffle', 'every-epoch', ...
'ValidationData', augTrainData, ...
'ValidationFrequency', ValidFreq, ...
'Verbose', false, ...
'Plots', 'training-progress');
% ---学習
% training-progressウィンドが開く。時間がかかるので待つ。
netTransfer = trainNetwork(augimdsTrain, lgraph, options);
% 検証データで確認
[predLabels, scores] = classify(netTransfer, augTestData);
% ---モデル精度の検証
testDataLabels = testData.Labels;
DataCount = numel(testData.Labels);
CorrectCount = nnz(testDataLabels == predLabels);
ErrorCount = DataCount - CorrectCount;
Accuracy = CorrectCount / DataCount;
ErrorRatio = ErrorCount / DataCount;
ResNet50
VGG16とほぼ同じだが、ResNetはConnections情報が必要になる。
VGG16のlgraphと比較すると、構造がひとつ増えているのがわかるので、確認すること。
参考先だとtrainNetworkにlayerのみ入れているが、ResNetだとlayerのみでは動かないので、layerGraph構造で入れる必要がある。
ResNet50.m
% VGG16と異なる箇所のみ記述
% ---- 事前学習済みのネットワークの読み込み。
% アドオンになければ、ダウンロードリンクが表示される。
net = resnet50;
% このままでは編集できないので、コピーのような操作を行う。
lgraph = layerGraph(net);
% VGG16のlgraphと比較すると、構造がひとつ増えているのがわかるので、確認すること。
trainingOption
opt.m
options = trainingOptions("adam", ....
"InitialLearnRate", 0.0001, ....
"Plots", "training-progress", ....
"ValidationData", val)
% 最急降下法の種類はadamが一般的。
% InitialLearnRateの初期値は0.001。これだと大きいことが多いので、1/10や1/100にしてみる。
% 乱数制御を制御したい場合は、rng()を実行する。