15
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ChainerでLSTM + Attentionを計算する

Posted at

ChainerでLSTMを使いたい方向けに簡単に使い方を書きます。

前準備

入力形式

import numpy as np
from chainer import Variable
import chainer.functions as F
import chainer.links as L

## 入力データの準備
x_list = [[0, 1, 2, 3], [4, 5, 6], [7, 8]] # 可変長データ (4, 3, 2)の長さのデータとする
x_list = [np.array(x, dtype=np.int32) for x in x_list] # numpyに変換する
batchsize = len(x_list) # 3

上記のデータはバッチサイズ3, 長さがそれぞれ4,3,2のデータです。

Word Embeddings

n_vocab = 500
emb_dim = 100
word_embed=L.EmbedID(n_vocab, emb_dim, ignore_label=-1)

NStepLSTM, NStepBiLSTM

use_dropout = 0.25
in_size = 100
hidden_size = 200
n_layers = 1

bi_lstm=L.NStepBiLSTM(n_layers=n_layers, in_size=in_size,
                      out_size=hidden_size, dropout=use_dropout)

LSTMに入力データを渡す

# Noneを渡すとゼロベクトルを用意してくれます. Encoder-DecoderのDecoderの時は初期ベクトルhxを渡すことが多いです.
hx = None 
cx = None 

xs_f = []
for i, x in enumerate(x_list):
    x = word_embed(Variable(x)) # Word IndexからWord Embeddingに変換
    x = F.dropout(x, ratio=use_dropout) 
    xs_f.append(x)

# xs_fのサイズは
# [(4, 100), (3, 100), (2, 100)]というVariableのリストになっている

hy, cy, ys = bi_lstm(hx=hx, cx=cx, xs=xs_f)

ysがNStepBiLSTMの最終層のベクトルの各ステップのベクトルが返ってきます。

Attentionを計算する

concat_ys = F.concat(ys, axis=0) # (9, 400)
linear = L.Linear(400, 1)  # Attention用に 400 -> 1 のLink (モデルによってはDecoderのhidden vector (400次元))
attn = linear(concat_ys) # (9, 1)
split_attention = F.split_axis(attn, np.cumsum([len(x) for x in x_list])[:-1], axis=0) # [_.shape for _ in split_attention]は [(4, 1), (3, 1), (2, 1)]
split_attention_pad = F.pad_sequence(split_attention, padding=-1024.0) # -1024でパディングする
# ここでpadding=-1024.0でパディングすることで, softmaxで計算した時のexp(-1024.0) => 0.0となる
attn_softmax = F.softmax(split_attention_pad, axis=1) # Softmaxを計算する

### パディングした部分が0.0になってくれる
print(attn_softmax)
# variable([[[ 0.26594046],
#           [ 0.2388579 ],
#           [ 0.2923921 ],
#           [ 0.20280953]],
#
#          [[ 0.30548543],
#           [ 0.39249265],
#           [ 0.30202195],
#           [ 0.        ]],
#
#          [[ 0.46441966],
#           [ 0.53558034],
#           [ 0.        ],
#           [ 0.        ]]])

ys_pad = F.pad_sequence(ys, length=None, padding=0.0) # hidden vectorの部分を0.0でpaddingする 
# ys_pad.shape => (3, 4, 400)
ys_pad_reshape = F.reshape(ys_pad, (-1, ys_pad.shape[-1]))
# ys_pad_reshape.shape => (12, 400)

attn_softmax_reshape = F.broadcast_to(F.reshape(attn_softmax, (-1, attn_softmax.shape[-1])), ys_pad_reshape.shape)
# attn_softmax_reshape.shape => (12, 400)

attention_hidden = ys_pad_reshape * attn_softmax_reshape 
# attention_hidden.shape => (12, 400)

## F.sum
attention_hidden_reshape = F.reshape(attention_hidden, (batchsize, -1, attention_hidden.shape[-1])) # (3, 4, 400)
result = F.sum(attention_hidden_reshape, axis=1) # (3, 400)

15
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?