LoginSignup
9

More than 5 years have passed since last update.

ChainerでLSTM/GRU/RNNの出力の最後のベクトルを取り出す

Last updated at Posted at 2017-07-11

自然言語処理において文(sentence)を単方向LSTM/GRUに入力して最後のベクトルを取り出してsentence vectorを文書分類などに使いたい時があります。

そこでNStepGRUに入力して, 最後のベクトルを得るコードを書いてみましょう。

2017.7.12 追記
※個人的に任意のベクトルを取ってくるときに下記のようにF.embed_idを使うのが便利だと思うのですが、LSTM, GRUの場合は返り値として最後のベクトルは取れるようになっています。
詳しくは
http://qiita.com/aonotas/items/8e38693fb517e4e90535#hy
を御覧ください。失礼しました。

コード

class SentenceEncoderGRU(chainer.Chain):

    def __init__(self, n_vocab, emb_dim, hidden_dim, use_dropout):
        super(SentenceEncoderGRU, self).__init__(
            word_embed=L.EmbedID(n_vocab, emb_dim, ignore_label=-1),
            gru=L.NStepGRU(n_layers=1, in_size=emb_dim,
                           out_size=hidden_dim, dropout=use_dropout)
        )
        self.use_dropout = use_dropout

    def __call__(self, x_data):
        batchsize = len(x_data)
        xp = self.xp
        hx = None
        xs = []
        lengths = []
        for i, x in enumerate(x_data):
            x = Variable(x)
            x = self.word_embed(x)
            x = F.dropout(x, ratio=self.use_dropout)
            xs.append(x)
            lengths.append(len(x))
        # GRU
        _hy, ys = self.gru(hx=hx, xs=xs)

        #######################################################
        ## Extract Last Vector 最後のベクトルを取り出す
        ## 例: 
        ##    x_data = [[0, 1, 2, 3], [4, 5, 6], [7, 8]]の時
        ##    lengthsは[4, 3, 2]というリストになる
        ##    xp.cumsum(lengths)は [4, 7, 9] というArrayを返す
        ##    last_idxは [3, 6, 8] というインデックスになる
        #######################################################
        last_idx = xp.cumsum(lengths).astype(xp.int32) - 1 

        #######################################################
        ##    ysは [(4, 400), (3, 400), (2, 400)] のshapeのListのVariable
        ##    F.concat(ys, axis=0) は   (9, 400) のshapeに変換する
        ##    F.embed_idで最後のインデックスを取り出す
        #######################################################
        last_vecs = F.embed_id(last_idx, F.concat(ys, axis=0))
        last_vecs = F.dropout(last_vecs, ratio=self.use_dropout)
        return last_vecs

実行

enc = SentenceEncoderGRU(5000, 100, 100, 0.33)
x_data = [[0, 1, 2, 3], [4, 5, 6], [7, 8]] # 可変長データ (4, 3, 2)の長さのデータとする
x_data = [np.array(x, dtype=np.int32) for x in x_data] # numpyに変換する

vec = enc(x_data)

print(vec.shape) # (3, 100)のベクトルが返ってくる

解説

  • 入力データの長さをlengthsとして計算しておく
  • xp.cumsum(lengths)で累積和を計算しておく。これがベクトルの最後のインデックスを示す。
  • F.embed_idで行列中の該当するインデックスのベクトルを取り出す

質問や間違い訂正などは @aonotas までお願いします。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9