LoginSignup
13
5

More than 3 years have passed since last update.

LSTMで脳波を学習させて0.1秒後の生波形を予測

Last updated at Posted at 2019-11-30

はじめに

LSTMの自習として脳波生波形の回帰を試してみたので、備忘として載せておきます。
ちなみに脳波は正規化してありますので、生波形と言っていいかは微妙なところですが。

やったこと

・脳波の生波形を0.1秒窓で分ける。
・今の窓の生波形を入力に、次の窓の生波形を出力にして、学習をさせる。
・上で学習させたネットワークに脳波生波形を食べさせて、0.1秒後の脳波を予測させる。

使ったデータ

以前の記事でもご紹介した、Brain Computer Interface research at NUST Pakistan様のSubject1_2D.matを拝借します。

再掲ですが、データの詳細は下記の通り。
・被験者:21歳、男性、右利き、健常者
・EEGチャンネル: 全19チャンネル(FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ CZ PZ)
・データ: Neurofax EEG システムを用いて取得ののち、Eemagine EEGを用いてエクスポート。サンプリングレートは500 Hz。

環境

MATLAB R2019b
Deep Learning Toolbox

実際の処理

学習データと評価データの準備

画像赤線で示した、実際の手の運動のデータ4種類ともを、学習データ・評価データとして使ってしまいます。それぞれの運動で3トライアル分の脳波があるので、最初と2つ目の2トライアル分の生波形を学習データとして使い、3トライアル目は評価データとして残しておくことにします。

image.png

コードを見て頂くとわかる通り、トライアルの種類に関係なくごちゃまぜで、ある0.1秒窓分のデータを入力データXTrain・XTestに、その次の窓分のデータを出力データYTrain・YTestに、それぞれ設定しています。

データの準備の部分は、冗長なコードが続くので(こういう部分かっこよく書けない・・)折りたたみます。
%% Loading the data
load Subject1_2D.mat

%% Normalize
LeftBackward1 = normalize(LeftBackward1);
LeftBackward2 = normalize(LeftBackward2);
LeftBackward3 = normalize(LeftBackward3);
LeftForward1 = normalize(LeftForward1);
LeftForward2 = normalize(LeftForward2);
LeftForward3 = normalize(LeftForward3);
RightBackward1 = normalize(RightBackward1);
RightBackward2 = normalize(RightBackward2);
RightBackward3 = normalize(RightBackward3);
RightForward1 = normalize(RightForward1);
RightForward2 = normalize(RightForward2);
RightForward3 = normalize(RightForward3);

%% Prepare the training data
for ii = 1:59
    XTrain{ii} = LeftBackward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTrain{ii+59} = LeftBackward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTrain{ii+118} = LeftForward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTrain{ii+177} = LeftForward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTrain{ii+236} = RightBackward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTrain{ii+295} = RightBackward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTrain{ii+354} = RightForward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTrain{ii+413} = RightForward2(50*(ii-1)+1:50*ii,:)';
end
XTrain = XTrain';

for ii = 2:60
    YTrain{ii-1} = LeftBackward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTrain{ii+58} = LeftBackward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTrain{ii+117} = LeftForward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTrain{ii+176} = LeftForward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTrain{ii+235} = RightBackward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTrain{ii+294} = RightBackward2(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTrain{ii+353} = RightForward1(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTrain{ii+412} = RightForward2(50*(ii-1)+1:50*ii,:)';
end
YTrain = YTrain';

%% Prepare the test data
for ii = 1:59
    XTest{ii} = LeftBackward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTest{ii+59} = LeftForward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTest{ii+118} = RightBackward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 1:59
    XTest{ii+177} = RightForward3(50*(ii-1)+1:50*ii,:)';
end
XTest = XTest';

for ii = 2:60
    YTest{ii-1} = LeftBackward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTest{ii+58} = LeftForward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTest{ii+117} = RightBackward3(50*(ii-1)+1:50*ii,:)';
end
for ii = 2:60
    YTest{ii+176} = RightForward3(50*(ii-1)+1:50*ii,:)';
end
YTest = YTest';

そうすると、19(脳波チャンネル下図)×50(0.1秒分のサンプル数)の行列が入ったセルが出来上がります。

image.png

LSTMネットワークの設定

今回LSTMを使うということで、MATLAB公式のこちらのページを参考にしながら、ネットワークを作成していきます。

%% Layer settings
numResponses = size(YTrain{1},1);
featureDimension = size(XTrain{1},1);
numHiddenUnits = 200;

layers = [ ...
    sequenceInputLayer(featureDimension)
    lstmLayer(numHiddenUnits,'OutputMode','sequence')
    fullyConnectedLayer(50)
    dropoutLayer(0.5)
    fullyConnectedLayer(numResponses)
    regressionLayer];

%% Option settings
maxEpochs = 100;
miniBatchSize = 32;

options = trainingOptions('adam', ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress',...
    'Verbose',0);

いざ学習

%% Train the data
net = trainNetwork(XTrain,YTrain,layers, options);

上のコードを実行すると、下記の画面が立ち上がり、学習が進行していきます。

image.png

RMSEは二乗平均平方根誤差で、誤差が大きいほど値が大きくなるので、RMSEは小さい方が適していると言えます。

学習したネットワークで次の0.1秒の脳波を予測

%% Validate the data
YPredict = predict(net, XTest);

結果の可視化

さて、どんな結果になったでしょうか。試しに、予測結果の1窓目と本来の波形の1窓目を重ね書きして可視化するために、下記のコードを実行します。

%% Visualize the accuracy
figure;
plot(YTest{1}(1,:), '-o');
hold on
plot(YPredict{1}(1,:), '-*');
legend('Original', 'Predicted');

すると・・・
image.png

完全に一致とは言わないまでも、なんとなくトレンドは押さえられています。

3窓目を試しに可視化してみても、似たような感じです。

image.png

おわりに

predict関数で予測された脳波を使って、また次の0.1秒分を予測すると、永遠に未来の脳波を予測できそうですね。

13
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
5