LoginSignup
5
2

More than 1 year has passed since last update.

DonkeyCarのディープラーニングモデルをMATLABでトレーニングしてみる

Last updated at Posted at 2022-05-02

はじめに

DonkeyCarは、オープンソースのAIカーで、ディープラーニングのフレームワークとしてPythonベースのKerasとTensorFlowが使われています。ここでは、MATLABのDeep Learning Toolboxを使ってDonkeyCarの学習用データセットを読み込み、ディープラーニングモデルのトレーニングを試みてみます。

使用したツール

MATLABライブスクリプトを下記、GitHubに公開しています。(リンク先のページにあるコマンドをMATLAB コマンドウインドウにコピー&ペーストすると、自動でファイルを展開します。)
Train_DonkeyCar_MATLAB

トレーニング用データとディープラーニングモデル

GitHubに公開されているこちらのトレーニングデータを使います。

  • 使用するデータセット circuit_launch_20210716_1826.tar.gz

ベースのデープラーニングモデルとして、以下の*.h5ファイル(Kerasモデル)を使用します。
https://github.com/autorope/donkey_datasets/tree/master/circuit_launch_20210716/models

トレーニングデータとディープラーニングモデルの準備

データセットのダウンロード

以下のMATLABスクリプトを実行し、DonkeyCarのデータセットファイル(*.gz)をGitHubからダウンロードし、カレントフォルダ下に解凍します。
※MATLAB Onlineでの利用を考慮し、No.4のデータのみ再圧縮したファイルのリンクに変更しました。:smiley:

%Download DonkeyCar dataset on GitHub
% url = 'https://github.com/autorope/donkey_datasets/raw/master/circuit_launch_20210716/';%Original URL
% tarfile = 'circuit_launch_20210716_1826.tar.gz'; %Original data
url = 'https://github.com/covao/donkey_dataset_example/raw/main/';
tarfile = 'circuit_launch_20210716_1826_no4.tar.gz'; %Recompress only No4 data

datapath = './murmurpi4_circuit_launch_20210716_1826/data/';%Data Path

if(~exist(datapath,'file')) %if not unzipped
    disp('Wait a few minutes for downloading and unziping.')
    websave('temp.gz',[url tarfile],weboptions('Timeout',Inf));
    untar('temp.gz');
end

事前学習済みディープラーニングモデルのダウンロード

DonkeyCarの事前学習済みディープラーニングモデルファイル(Keras)をダウンロードします。ここでは、pilot_21-08-12_10.h5を使います。

%Download Donkeycar deep learning mode
netfile='pilot_21-08-12_10.h5';
websave(netfile,[url 'models/' netfile]);

事前学習済みモデルの読み込み

モデルをワークスペースに読み込みます

donkeyKerasNet = importKerasNetwork(netfile); %Load pre-trained network

catalogファイルの読み込み

DonkeyCarの走行データのリストが、*.catalogファイルに記録されています。画像データファイル名とステアリング値、スロットル値を含むデータセットをMATLABに読み込みます。ここでは、catalog_4.catalogを読み込みます。

%Import catalog
catalogfile = [datapath 'catalog_4.catalog'];%Catalog file
clear dataset;
js = readlines(catalogfile);
for(i = 1:length(js)-1)
    dataset(i) = jsondecode(js(i));
end

%Set image path
imgfolder = [datapath 'images/'];
imgfiles = {dataset.cam_image_array};
imgpath = {};
for(i = 1:length(imgfiles))
    imgpath{i} = strcat(imgfolder,imgfiles{i});
end

%Create Output Data
ang = [dataset.user_angle];
th = [dataset.user_throttle];

トレーニング用データと検証用データの作成

トレーニング用データと検証用データを分割します。
ここでは、70%をトレーニング用データとします。

%Count trainining and test data
train_split = 0.7;%Split data for training and varidation 
dataN = length(ang);
trainN = round(dataN*train_split);

%Create training data
tmp_im = imageDatastore(imgpath(1:trainN)','Labels',imgfiles(1:trainN)');
tmp_ang = arrayDatastore(ang(1:trainN)');
TrainData = combine(tmp_im,tmp_ang);%Training dataset

%Create test data
tmp_im = imageDatastore(imgpath(trainN+1:dataN)','Labels',imgfiles(trainN+1:dataN)');
tmp_ang = arrayDatastore(ang(trainN+1:dataN)');
ValidationData = combine(tmp_im,tmp_ang);%Validation dataset

データセットのプレビュー

画像とステアリング、スロットルの値のサンプル表示します。

%Dataset preview
figure;
plotN = 6*6;
for(i = 1:plotN)
    no = i*10;
    subplot(sqrt(plotN),sqrt(plotN),i);
    imshow(imread(imgpath{no}));
    title(sprintf('Ang=%.2f,Th=%.2f',ang(no),th(no) ) )
end

image.png

Deep Network Designerによるモデルの編集とトレーニング

Deep Network Designerの起動とネットワーク読み込み

Deep Network Designerを起動し、ワークスペースからDonkeyCarのネットワークを読み込みます。

image.png
image.png

DeepNet.gif

ネットワークモデルの編集

Deep Network Designerでは、単一出力ネットワークのTrainingのみ対応しているため、出力層の一方を削除します。
Analyzeをクリックし、エラーがないことを確認します。
image.png

image.png

トレーニングデータの設定

Dataタブをクリックし、Import Data > Import Data Storeを選択し、TrainingDataとValidationDataを設定します。

image.png
image.png

トレーニング

デフォルトのTraining Options設定で、Trainボタンをクリックし、ネットワークの学習を行います。
image.png
本設定では、発散してしまいました。:dizzy_face:残念ながら、InitialLearnRateを0.001に変更しても、改善せず。:dizzy_face:

次に、Trainingタブをクリックし、Training OptionsでSolverをadamに変更し、学習してみます。
image.png

モデルの確認

Export > Export Trained Network and Resultsをクリックし、ネットワークをワークスペースにエクスポートします。
検証用データでを用いて、モデルを確認をしてみます。(青:学習モデルによるステアリング予測、赤:ステアリング検証用データ)

if exist('trainedNetwork_1')
    figure;
    tmp = predict(trainedNetwork_1,ValidationData);
    plot([tmp,ang(trainN+1:dataN)']);
end

image.png

まとめ

DonkeyCarの学習をMATLABでやってみました。トレーニングのオプションパラメータの調整を行うことで、もう少し精度が改善しそうです。:sunglasses:
次回は、DonkeyCarのデータセットをSimulinkでリプレイしてみます:arrow_forward:

参考

5
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2