自然言語処理において文(sentence)を単方向LSTM/GRUに入力して最後のベクトルを取り出してsentence vectorを文書分類などに使いたい時があります。
そこでNStepGRUに入力して, 最後のベクトルを得るコードを書いてみましょう。
2017.7.12 追記
※個人的に任意のベクトルを取ってくるときに下記のようにF.embed_idを使うのが便利だと思うのですが、LSTM, GRUの場合は返り値として最後のベクトルは取れるようになっています。
詳しくは
http://qiita.com/aonotas/items/8e38693fb517e4e90535#hy
を御覧ください。失礼しました。
コード
class SentenceEncoderGRU(chainer.Chain):
def __init__(self, n_vocab, emb_dim, hidden_dim, use_dropout):
super(SentenceEncoderGRU, self).__init__(
word_embed=L.EmbedID(n_vocab, emb_dim, ignore_label=-1),
gru=L.NStepGRU(n_layers=1, in_size=emb_dim,
out_size=hidden_dim, dropout=use_dropout)
)
self.use_dropout = use_dropout
def __call__(self, x_data):
batchsize = len(x_data)
xp = self.xp
hx = None
xs = []
lengths = []
for i, x in enumerate(x_data):
x = Variable(x)
x = self.word_embed(x)
x = F.dropout(x, ratio=self.use_dropout)
xs.append(x)
lengths.append(len(x))
# GRU
_hy, ys = self.gru(hx=hx, xs=xs)
#######################################################
## Extract Last Vector 最後のベクトルを取り出す
## 例:
## x_data = [[0, 1, 2, 3], [4, 5, 6], [7, 8]]の時
## lengthsは[4, 3, 2]というリストになる
## xp.cumsum(lengths)は [4, 7, 9] というArrayを返す
## last_idxは [3, 6, 8] というインデックスになる
#######################################################
last_idx = xp.cumsum(lengths).astype(xp.int32) - 1
#######################################################
## ysは [(4, 400), (3, 400), (2, 400)] のshapeのListのVariable
## F.concat(ys, axis=0) は (9, 400) のshapeに変換する
## F.embed_idで最後のインデックスを取り出す
#######################################################
last_vecs = F.embed_id(last_idx, F.concat(ys, axis=0))
last_vecs = F.dropout(last_vecs, ratio=self.use_dropout)
return last_vecs
実行
enc = SentenceEncoderGRU(5000, 100, 100, 0.33)
x_data = [[0, 1, 2, 3], [4, 5, 6], [7, 8]] # 可変長データ (4, 3, 2)の長さのデータとする
x_data = [np.array(x, dtype=np.int32) for x in x_data] # numpyに変換する
vec = enc(x_data)
print(vec.shape) # (3, 100)のベクトルが返ってくる
解説
- 入力データの長さをlengthsとして計算しておく
- xp.cumsum(lengths)で累積和を計算しておく。これがベクトルの最後のインデックスを示す。
- F.embed_idで行列中の該当するインデックスのベクトルを取り出す
質問や間違い訂正などは @aonotas までお願いします。