3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

MATLABでCNN

Last updated at Posted at 2022-06-14

モデル

学習モデルは、MATLABのDeep Learning Toolboxのアドオンから入手できるものを用いる。
事前学習済みの深層ニューラル ネットワーク
image.png

今回は、VGG16とResNet50を例に説明する。

VGG16

VGG16.m
% --- 学習データと検証データの読み込み
% フォルダ名がクラスになる。
imds = imageDatastore('./pictures', 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
% 学習用と検証用にデータを分割
p = 0.6;    % 学習用に使うデータの割合
[trainData, testData] = splitEachLabel(imds, p, 'randomized');

% ---- 事前学習済みのネットワークの読み込み。
% アドオンになければ、ダウンロードリンクが表示される。
net = vgg16;

% このままでは編集できないので、コピーのような操作を行う。
lgraph = layerGraph(net);

% ---出力数が合わないので、最後の出力層を変更する。
% 最後から二番目 (fullyConnectedLayer)
LayerName = lgraph.Layers(end-2).Name;
newLayer = fullyConnectedLayer(出力数, 'Name', LayerName);
lgraph = replaceLayer(lgraph, LayerName, newLayer);
% 最後 (classificationLayer)
LayerName = lgraph.Layers(end).Name;
newLayer = classificationLayer('Name', LayerName);
lgraph = replaceLayer(lgraph, LayerName, newLayer);

% ---ネットワークのインプットサイズに合わせて画像をリサイズ
% trainDataとtestDataの読み込みは割愛。
inputSize = lgraph.Layers(1).InputSize(1:2);
augTrainData = augmentedImageDatastore(inputSize, trainData);
augTestData = augmentedImageDatastore(inputSize, testData);

% ---学習オプション
InitialLearnRate = 1e-4;
MaxEpochs = 6;
MinBatchSize = 10;
ValidFreq = floor(length(augimdsTrain.Files) / MinBatchSize);    % 1エポックに1回の検証
options = trainingOptions('sgdm', ...
    'MiniBatchSize', MinBatchSize, ...
    'MaxEpochs', MaxEpochs, ...
    'InitialLearnRate', InitialLearnRate, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', augTrainData, ...
    'ValidationFrequency', ValidFreq, ...
    'Verbose', false, ...
    'Plots', 'training-progress');

% ---学習
% training-progressウィンドが開く。時間がかかるので待つ。
netTransfer = trainNetwork(augimdsTrain, lgraph, options);
% 検証データで確認
[predLabels, scores] = classify(netTransfer, augTestData);

% ---モデル精度の検証
testDataLabels = testData.Labels;
DataCount = numel(testData.Labels);
CorrectCount = nnz(testDataLabels == predLabels);
ErrorCount = DataCount - CorrectCount;
Accuracy = CorrectCount / DataCount;
ErrorRatio = ErrorCount / DataCount;

ResNet50

VGG16とほぼ同じだが、ResNetはConnections情報が必要になる。
VGG16のlgraphと比較すると、構造がひとつ増えているのがわかるので、確認すること。
参考先だとtrainNetworkにlayerのみ入れているが、ResNetだとlayerのみでは動かないので、layerGraph構造で入れる必要がある。

ResNet50.m
% VGG16と異なる箇所のみ記述
% ---- 事前学習済みのネットワークの読み込み。
% アドオンになければ、ダウンロードリンクが表示される。
net = resnet50;

% このままでは編集できないので、コピーのような操作を行う。
lgraph = layerGraph(net);

% VGG16のlgraphと比較すると、構造がひとつ増えているのがわかるので、確認すること。

trainingOption

opt.m
options = trainingOptions("adam", ....
    "InitialLearnRate", 0.0001, ....
    "Plots", "training-progress", ....
    "ValidationData", val)
% 最急降下法の種類はadamが一般的。
% InitialLearnRateの初期値は0.001。これだと大きいことが多いので、1/10や1/100にしてみる。
% 乱数制御を制御したい場合は、rng()を実行する。

参考

matlabで始めるディープラーニング 画像分類(1)
matlabで始めるディープラーニング 画像分類(2)

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?