1
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

RNNの基本

Last updated at Posted at 2023-01-28

はじめに

今回は時系列データを用いる際に使用する,Recurent Neural Network(RNN)についてまとめておきます.
最もシンプルなRNN,進化系のLSTM,LSTMを軽くしたGRUについて紹介します.

また,RNNへの入力形式についてもお話しできればと思います.
時系列データは画像等とは少し違ったデータ形式にして,入力しないといけません.
(はじめて実装するときに悩んだので,同じ悩みを抱える人の助けになればいいな:relieved:

因みに,モデルの実装はPytorchだとnn.RNNと書くだけなので楽勝です.
(勿論中身を理解した方がいいですが)

目次

1 RNN
2 LSTM
3 GRU
4 RNNへの入力形式

1 RNN

RNN(Recurent Neural Network)とは,時系列データを扱うDeep Neural Networkです.
RNNでは現在に対する過去の影響を考慮します.

時刻 (t) における状態を考えるとき,時刻 (t) の入力である入力 x(t) に加え,時刻 (t-1) の状態を表す h(t-1) を保持し,時刻 (t) に伝えます.
t-1 には t-2 の情報が含まれているため, t にも t-2 の情報が含まれていることになります.このように再帰的に考えて,過去の情報が反映されているということで,Recurent な NN と呼ばれています.
RNN.png

2 LSTM

しかし,シンプルなRNNには欠点があります.それは,長期的な依存関係を捉えられないことです.この問題を解決しているのがLSTMです.LSTMのモチベーションは以下の通りです.

  • 過去の情報をネットワーク内に保持しておきたい
  • 過去の情報を必要なタイミングで取得したい

LSTMはメモリーセル(ct)と呼ばれる状態を持ち,これに忘却ゲート (ft),入力ゲート (it),出力ゲート (ot) と呼ばれる情報の取捨選択機構を持っています.

  • 忘却ゲート:メモリーセルに保持している情報をリセットするかを選択する.
  • 入力ゲート:入力を取り込むか否かを選択し,セルの状態を更新する役.
  • 出力ゲート:次の時刻にどの程度情報を伝えるかを選択する.

各ゲートでの情報の取捨選択は,シグモイド関数 (σ) によって行われます.シグモイド関数は [0,1] の値を出力し,0 であればゲートを通さず,1 であれば全て通します.
LSTM.png

3 GRU

GRUはLSTMの精度はそのままで高速化したモデルです.といいつつもLSTMとGRUどちらの方が高い精度を出せるかはタスクやデータによってまちまちとのことなので,どちらも試して良い方のモデルを採用するのが良いでしょう.

因みに私がしていたタスクだと,パラメータ数が RNN:LSTM:GRU=1:4:3 になっていました.精度はほぼ同じでしたね.

4 RNNへの入力形式

RNNへの入力には,何時刻前までを予測に使いたいのかを決める必要があります.
私のしていた心電図タスクで考えてみたいと思います.(医療データは外部持ち出し:no_good_tone1:なので数値はでたらめです.)
RNN1.png
Xをデータ,yをGroundTruth(正解ラベル)とすると,以下のようになると思います.
RNN2.png
RNNへの入力形式は以下のようにします.3時刻を見るとすると,対象となる時刻,その1つ前,その2つ前を1セットにします.それに対して,対象時刻のラベルを予測するという形式になります.
RNN3.png

5 系列データのバリエーション

系列データには,入出力のパターンが複数あり,それぞれ”入力to出力”の形で呼ばれます.

  • One to One
    • 入力も出力も固定サイズのベクトル.普通のNN.
  • One to Many
    • 入力は固定サイズのベクトル,出力は系列.
  • Many to One
    • 入力は系列,出力は固定サイズのベクトル.系列データのクラス分類など.
  • Many to Many
    • 入力も出力も系列.言語翻訳など.

diags.jpeg
参照:http://karpathy.github.io/2015/05/21/rnn-effectiveness/

6 高速化

6.1 処理の高速化(numba)

6.2 ネットワークの高速化(toch.jit,Julia)

さいごに

随時追記予定です.
RNNはなんだか考えていると頭がこんがらがってきますね💦 いきなり実装じゃなくて一度データ形式等を考えてからコーディングをするとうまくいくと思います👍
私も学んでいる身なので,間違えている部分があれば教えてください🙇

1
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?