Help us understand the problem. What is going on with this article?

Keras Recurrentレイヤーメモ:return_sequences, RepeatVector, TimeDistributed

More than 3 years have passed since last update.

動機

「詳解 ディープラーニング Tensorflfow・Kerasによる時系列データ処理」で勉強をしている中で、Kerasで足し算タスクを学習するRNN Encoder-Decoderの項目がありました。

それ以前の内容は比較的追いやすかったのですが、RNN Encoder-Decoderのスクリプトの理解は少々難しかったため、KerasのReccurentレイヤーの使い方を忘れないようにメモ。

RNN Encoder Decoderとは

原著(arXiv):https://arxiv.org/abs/1406.1078
最近はRNN encoder decoderモデルの一つであるseq2seqが翻訳やチャットボットなどで使われていますね。

図:Sequence to Sequence Learning with Neural Networksより抜粋
Screen Shot 2017-09-11 at 14.01.03.png

基本的には、Encoderネットワークが[A, B, C, (EOS)]というシーケンス入力を固定長ベクトルに変換し、その固定長ベクトルを用いてDecoderネットワークが[W, X, Y, Z, (EOS)]というシーケンスを出力するようなタスクを学習するネットワークです。

私はseq2seqを先に知りました。上記の理解はseq2seqを元にしているので、厳密なRNN Encoder-Decoderはひょっとしたらまた異なる定義なのかもしれませんが、あしからず。

KerasでRNN

Keras便利ですね。model.addで簡単にネットワークを定義できるのが素晴らしい。
KerasでのRNNの定義は以下のような方法です。
上記書籍の著者の方がGithubにサンプルコードを公開しています。これもありがたい。
(Kerasを作ったFrancois CholletさんもGithubにサンプルコードをあげています)

keras_rnn_sample.py
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.layers.recurrent import SimpleRNN
from keras.optimizers import Adam

seq_length = 7
n_in = 12
n_hidden = 128
n_out = 12

model=Sequential()
model.add(SimpleRNN(units=n_hidden, input_shape=(seq_length, n_in)))
model.add(Dense(units=n_out))
model.add(Activation('softmax'))

optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999)
model.compile(loss='categorical_cross_entropy', optimizer=optimizer)

上記スクリプトでは、各要素が12次元のベクトルである長さ7のシーケンスを入力として想定し、
128次元のRNN層を通した後、全結合層で12次元にまとめ、ソフトマックスで活性化することで、入力シーケンスを12クラスに分類するモデルです。

さて、RNN Encoder-Decoderでは、EncoderとDecoder部分が分かれていますが、

rnn_enc_dec.py
from keras.models import Sequential
from keras.layers.core import Dense, Activation, RepeatVector
from keras.layers.recurrent import LSTM
from keras.layers.wrappers import TimeDistributed
from keras.optimizers import Adam

seq_in_length = 7
n_in = 12
n_hidden = 128
n_out = 12
seq_in_length = 4

model=Sequential()
model.add(LSTM(units=n_hidden, input_shape=(seq_in_length, n_in)))

#decoder
model.add(RepeatVector(seq_out_length))
model.add(LSTM(units=n_hidden, return_sequences=True))

model.add(TimeDistributed(Dense(units=n_out)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
              optimizer=Adam(lr=0.001, beta_1=0.9, beta_2=0.999),
              metrics=['accuracy'])

とありますが、LSTM内の引数return_sequencesと、ラッパーのRepeatVector, TimeDistributedとはなんでしょうか。

引数 return_sequences

https://keras.io/layers/recurrent/
return_sequenceとは、TrueにしてRNNレイヤーの毎時刻の出力を得るか、Falseにして最後の時刻のみの出力を得るか、のフラグのようです。デフォルトではFalseになっていて、最後の時刻のみの出力を得ます。

RNNレイヤーを複数積み重ねたい時は、各時刻で層間のデータのやり取りがあるので、(少なくとも最後の層以外は)必ずTrueにしなければならないようです。

上記のサンプルスクリプトのEncoderではFalseとなっており、最後の時刻の出力を固定長ベクトルとして取得しています。一方でDecoderではTrueとなっており、毎時刻の出力を取得します。

ラッパー RepeatVector

https://keras.io/layers/core/#repeatvector

RepeatVectorは、inputとして入ってくるベクトルを、指定した数だけ繰り返すラッパーです。Encoderから得られる固定長ベクトルを出力の長さ分だけ繰り返して、毎時刻入力できるようにしています。

ラッパー TimeDistributed

https://keras.io/layers/wrappers/#timedistributed
TimeDistributedは、入力されたシーケンスの各時刻に同様のネットワーク構造を付加できるラッパーです。上記のサンプルスクリプトでは、デコーダーのLSTMからはreturn_sequence=Trueとなっていることで毎時刻の出力を取得することができ、そこから毎時刻の出力毎に12クラスの分類を行っています。

3つまとめると

以上のものをまとめて、ネットワークを図示するとこのようになります。

decoder.png

encoder.png

ちなみに、TimeDistributedを用いず、単にModel.add(Dense)をしてしまうとどうなるのでしょうか...?
パッと思いつきで、「TimeDistributedを用いないと全ての時刻にわたってFull Conncetionしてしまうのでは?」と思っていましたが、特に出力の次元が変わることもなかったので、そうではないようです。今のところ、まだ違いがわかっていないので目下調査中です。

HotAllure
普段は自動車企業の研究所でデータ分析や、自然言語処理に関わる仕事をしています。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした