Help us understand the problem. What is going on with this article?

ChainerでLSTM/GRU/RNNの出力の最後のベクトルを取り出す

More than 3 years have passed since last update.

自然言語処理において文(sentence)を単方向LSTM/GRUに入力して最後のベクトルを取り出してsentence vectorを文書分類などに使いたい時があります。

そこでNStepGRUに入力して, 最後のベクトルを得るコードを書いてみましょう。

2017.7.12 追記
※個人的に任意のベクトルを取ってくるときに下記のようにF.embed_idを使うのが便利だと思うのですが、LSTM, GRUの場合は返り値として最後のベクトルは取れるようになっています。
詳しくは
http://qiita.com/aonotas/items/8e38693fb517e4e90535#hy
を御覧ください。失礼しました。

コード

class SentenceEncoderGRU(chainer.Chain):

    def __init__(self, n_vocab, emb_dim, hidden_dim, use_dropout):
        super(SentenceEncoderGRU, self).__init__(
            word_embed=L.EmbedID(n_vocab, emb_dim, ignore_label=-1),
            gru=L.NStepGRU(n_layers=1, in_size=emb_dim,
                           out_size=hidden_dim, dropout=use_dropout)
        )
        self.use_dropout = use_dropout

    def __call__(self, x_data):
        batchsize = len(x_data)
        xp = self.xp
        hx = None
        xs = []
        lengths = []
        for i, x in enumerate(x_data):
            x = Variable(x)
            x = self.word_embed(x)
            x = F.dropout(x, ratio=self.use_dropout)
            xs.append(x)
            lengths.append(len(x))
        # GRU
        _hy, ys = self.gru(hx=hx, xs=xs)

        #######################################################
        ## Extract Last Vector 最後のベクトルを取り出す
        ## 例: 
        ##    x_data = [[0, 1, 2, 3], [4, 5, 6], [7, 8]]の時
        ##    lengthsは[4, 3, 2]というリストになる
        ##    xp.cumsum(lengths)は [4, 7, 9] というArrayを返す
        ##    last_idxは [3, 6, 8] というインデックスになる
        #######################################################
        last_idx = xp.cumsum(lengths).astype(xp.int32) - 1 

        #######################################################
        ##    ysは [(4, 400), (3, 400), (2, 400)] のshapeのListのVariable
        ##    F.concat(ys, axis=0) は   (9, 400) のshapeに変換する
        ##    F.embed_idで最後のインデックスを取り出す
        #######################################################
        last_vecs = F.embed_id(last_idx, F.concat(ys, axis=0))
        last_vecs = F.dropout(last_vecs, ratio=self.use_dropout)
        return last_vecs

実行

enc = SentenceEncoderGRU(5000, 100, 100, 0.33)
x_data = [[0, 1, 2, 3], [4, 5, 6], [7, 8]] # 可変長データ (4, 3, 2)の長さのデータとする
x_data = [np.array(x, dtype=np.int32) for x in x_data] # numpyに変換する

vec = enc(x_data)

print(vec.shape) # (3, 100)のベクトルが返ってくる

解説

  • 入力データの長さをlengthsとして計算しておく
  • xp.cumsum(lengths)で累積和を計算しておく。これがベクトルの最後のインデックスを示す。
  • F.embed_idで行列中の該当するインデックスのベクトルを取り出す

質問や間違い訂正などは @aonotas までお願いします。

aonotas
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away