はじめに
LSTMの自習として脳波生波形の回帰を試してみたので、備忘として載せておきます。
ちなみに脳波は正規化してありますので、生波形と言っていいかは微妙なところですが。
やったこと
・脳波の生波形を0.1秒窓で分ける。
・今の窓の生波形を入力に、次の窓の生波形を出力にして、学習をさせる。
・上で学習させたネットワークに脳波生波形を食べさせて、0.1秒後の脳波を予測させる。
使ったデータ
以前の記事でもご紹介した、Brain Computer Interface research at NUST Pakistan様のSubject1_2D.matを拝借します。
再掲ですが、データの詳細は下記の通り。
・被験者:21歳、男性、右利き、健常者
・EEGチャンネル: 全19チャンネル(FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ CZ PZ)
・データ: Neurofax EEG システムを用いて取得ののち、Eemagine EEGを用いてエクスポート。サンプリングレートは500 Hz。
環境
MATLAB R2019b
Deep Learning Toolbox
実際の処理
学習データと評価データの準備
画像赤線で示した、実際の手の運動のデータ4種類ともを、学習データ・評価データとして使ってしまいます。それぞれの運動で3トライアル分の脳波があるので、最初と2つ目の2トライアル分の生波形を学習データとして使い、3トライアル目は評価データとして残しておくことにします。
コードを見て頂くとわかる通り、トライアルの種類に関係なくごちゃまぜで、ある0.1秒窓分のデータを入力データXTrain・XTestに、その次の窓分のデータを出力データYTrain・YTestに、それぞれ設定しています。
データの準備の部分は、冗長なコードが続くので(こういう部分かっこよく書けない・・)折りたたみます。
%% Loading the data
load Subject1_2D.mat
%% Normalize
LeftBackward1 = normalize(LeftBackward1);
LeftBackward2 = normalize(LeftBackward2);
LeftBackward3 = normalize(LeftBackward3);
LeftForward1 = normalize(LeftForward1);
LeftForward2 = normalize(LeftForward2);
LeftForward3 = normalize(LeftForward3);
RightBackward1 = normalize(RightBackward1);
RightBackward2 = normalize(RightBackward2);
RightBackward3 = normalize(RightBackward3);
RightForward1 = normalize(RightForward1);
RightForward2 = normalize(RightForward2);
RightForward3 = normalize(RightForward3);
%% Prepare the training data
for ii = 1:59
XTrain{ii} = LeftBackward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTrain{ii+59} = LeftBackward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTrain{ii+118} = LeftForward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTrain{ii+177} = LeftForward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTrain{ii+236} = RightBackward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTrain{ii+295} = RightBackward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTrain{ii+354} = RightForward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTrain{ii+413} = RightForward2(50*(ii-1)+1:50*ii,:)';
end
XTrain = XTrain';
for ii = 2:60
YTrain{ii-1} = LeftBackward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTrain{ii+58} = LeftBackward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTrain{ii+117} = LeftForward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTrain{ii+176} = LeftForward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTrain{ii+235} = RightBackward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTrain{ii+294} = RightBackward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTrain{ii+353} = RightForward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTrain{ii+412} = RightForward2(50*(ii-1)+1:50*ii,:)';
end
YTrain = YTrain';
%% Prepare the test data
for ii = 1:59
XTest{ii} = LeftBackward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTest{ii+59} = LeftForward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTest{ii+118} = RightBackward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
XTest{ii+177} = RightForward3(50*(ii-1)+1:50*ii,:)';
end
XTest = XTest';
for ii = 2:60
YTest{ii-1} = LeftBackward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTest{ii+58} = LeftForward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTest{ii+117} = RightBackward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
YTest{ii+176} = RightForward3(50*(ii-1)+1:50*ii,:)';
end
YTest = YTest';
そうすると、19(脳波チャンネル下図)×50(0.1秒分のサンプル数)の行列が入ったセルが出来上がります。
LSTMネットワークの設定
今回LSTMを使うということで、MATLAB公式のこちらのページを参考にしながら、ネットワークを作成していきます。
%% Layer settings
numResponses = size(YTrain{1},1);
featureDimension = size(XTrain{1},1);
numHiddenUnits = 200;
layers = [ ...
sequenceInputLayer(featureDimension)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(50)
dropoutLayer(0.5)
fullyConnectedLayer(numResponses)
regressionLayer];
%% Option settings
maxEpochs = 100;
miniBatchSize = 32;
options = trainingOptions('adam', ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'InitialLearnRate',0.01, ...
'GradientThreshold',1, ...
'Shuffle','never', ...
'Plots','training-progress',...
'Verbose',0);
いざ学習
%% Train the data
net = trainNetwork(XTrain,YTrain,layers, options);
上のコードを実行すると、下記の画面が立ち上がり、学習が進行していきます。
RMSEは二乗平均平方根誤差で、誤差が大きいほど値が大きくなるので、RMSEは小さい方が適していると言えます。
学習したネットワークで次の0.1秒の脳波を予測
%% Validate the data
YPredict = predict(net, XTest);
結果の可視化
さて、どんな結果になったでしょうか。試しに、予測結果の1窓目と本来の波形の1窓目を重ね書きして可視化するために、下記のコードを実行します。
%% Visualize the accuracy
figure;
plot(YTest{1}(1,:), '-o');
hold on
plot(YPredict{1}(1,:), '-*');
legend('Original', 'Predicted');
完全に一致とは言わないまでも、なんとなくトレンドは押さえられています。
3窓目を試しに可視化してみても、似たような感じです。
おわりに
predict関数で予測された脳波を使って、また次の0.1秒分を予測すると、永遠に未来の脳波を予測できそうですね。