この記事について
Transformerは自然言語処理や様々な分野で使われており、いろいろな方がわかりやすく解説されています。私もこれらの記事に触発され、使ってみようとしましたが、実際に使おうとするとPytorchのnn.Transformer
のforward
関数の引数が多く混乱しました。特にmaskの引数の意味がドキュメントを見るだけではわかりませんでした。そこでこれらの引数について調べた結果をまとめました。またPytorchのチュートリアルであるTransformerを使った言語翻訳(ドイツ語から英語)を具体例にしました。記事の作成に当たって、主に以下の情報を参考にして執筆しました。
- PytorchのGithubとAPIドキュメント (1.10)
- Attention is All You Need
- 作って理解する Transformer / Attention
- Transformerのデータの流れを追ってみる
- The Illustrated Transformer(和訳)
- A detailed guid to Pytorch's nn.Transformer() module
- transformerを理解するために実装
nn.Transformer.forwardの引数
nn.Transformer
(リンク)のforward
関数は次の8個の引数を持っています。
-
src
: ソースのシークエンス -
tgt
: ターゲットのシークエンス -
src_mask
: ソースの加法性マスク -
src_key_padding_mask
: ソースのキーのマスク -
tgt_mask
: ターゲットの加法性マスク -
tgt_key_padding_mask
: ターゲットのキーのマスク -
memory_mask
: メモリの加法性マスク -
memory_key_padding_mask
: メモリのキーのマスク
そしてforward
関数は次の値を返します。
-
output
: Transformerが計算したシークエンス
結論から述べますと、これらの引数と「Attention is All You Need」のモデルの対応関係は次のようになっています。nn.Transformer
はデコーダの最初のAttentionだけでなく、すべてのAttentionに2つずつマスクが設定できる設計となっていることがわかります。
引数 | 対応するnn.MultiHeadAttention
|
---|---|
src_mask および src_key_padding_mask
|
TransformerのエンコーダのSelf Attention |
tgt_mask およびtgt_key_padding_mask
|
TransformerのデコーダのSelf Attention |
memory_mask および memory_key_padding_mask
|
TransformerのデコーダのSource-Target Attention |
以降は、各引数の説明をしていきます。
src
, tgt
, output
について
src
はTransformerのエンコーダで処理するソースのシークエンスです。シークエンス長S、バッチサイズN、デプスEのテンソルを渡します。tgt
はデコーダで処理するターゲットのシークエンスです。シークエンス長T、バッチサイズN、デプスEのテンソルを渡します。output
はtgt
と同じサイズのテンソルで、Transformerが計算した結果です。
言語翻訳の例だとsrc
はドイツ語の文章、tgt
は英語の文章です。src
, tgt
ともにPositional Encodingを加えたEmbeddingで、複数の文章をバッチするためにシークエンス中にpaddingを含みます。output
はLinear層とsoftmax層で翻訳した結果を計算するために使われます。ソースのシークエンス長を5、ターゲットのシークエンス長を6、バッチ数を2とするとsrc
, tgt
は次のような感じです。
[src/tgt/memory]_mask
について
[src/tgt/memory]_mask
は加法性マスク(additive mask)です。加法性マスクは次のコードのようにクエリとキーから計算されるAttention Weightに加算してAttention weightを変更するために使われます。
# _scaled_dot_product_attention関数はクエリとキーから計算されるAttention Weightと加法性マスクの和を計算し、
# その和の重みに従ってバリューから情報を取り出します。
# attn_maskが加法性マスクです。
B, Nt, E = q.shape
q = q / math.sqrt(E)
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
attn = torch.bmm(q, k.transpose(-2, -1))
if attn_mask is not None:
attn += attn_mask
attn = softmax(attn, dim=-1)
if dropout_p > 0.0:
attn = dropout(attn, p=dropout_p)
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
output = torch.bmm(attn, v)
したがって、[src/tgt/memory]_mask
にAttention Weightと同じサイズ(クエリのシークエンス長、キーのシークエンス長)
のテンソル、もしくは(バッチ数×ヘッド数、クエリのシークエンス長、キーのシークエンス長)
のテンソルを渡します。(バッチ数×ヘッド数、クエリのシークエンス長、キーのシークエンス長)
のテンソルを渡した場合は、各バッチや各ヘッドに違うマスクを使うことができます。また[src/tgt/memory]_mask
のデフォルト値はNone
です。None
の場合は上記のコードのとおり、Attention weightを変更しません。
src_mask
はクエリとキーがソースのSelf Attentionに使われます。このAttention weightのサイズは(S,S)
であるので、src_mask
のサイズは(S,S)
です。同様にtgt_mask
はクエリとキーがターゲットのSelf Attentionに使われるので、そのサイズは(T, T)
です。memory_mask
はクエリがターゲット、キーがソースのSource-Target Attentionで使われます。したがってmemory_mask
のサイズは(T, S)
です。
[src/tgt/memory]_mask
のdtype
はクエリと同じdtype
、もしくはbool
です。クエリと同じdtype
の場合、そのままAttention Weightに加算されます。bool
の場合、bool
型のマスクは次のコードのようにクエリと同じdtype
のマスクに変換されてから加算されます。
# multi_head_attention_forward関数の加法性マスクのdtypeを変更する箇所
# Trueがfloat("-inf")に、Falseが0.0に置き換えられる
if attn_mask is not None and attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
加法性マスクはTransformerに予測する単語を使わないようにするためのマスクです。訓練時は入力するシークエンスは正解のデータを持っているので、正解を隠すためでもあります。マスクの仕方はAttention weightのマスクする場所にfloat("-inf")
またはTrue
、マスクしない場所に0.0
またはFalse
を指定します。Attention weightとの加算後にsoftmaxを適用するので、大きな負の値をもつ要素は注目を受けなくなります。
言語翻訳の例ではソースであるドイツ語の文章は翻訳する文章なのでシークエンスのすべての位置を参照できます。したがってsrc_mask
に要素がすべてFalseであるテンソルを使います。
>>> S = 5
>>> src_mask = torch.zeros((S, S), dtype=bool)
一方で、ターゲットである英語の文章はモデルが予測する文章です。訓練時は正解である文章をターゲットにいれて、tgt
に含まれている各単語の次の単語を予測します。したがって各単語からそれ以降の単語を使わないようにtgt_mask
に次のようなマスクを設定して英語の文章の単語間のSelf Attentionに制限をかけます。
>>> T = 6
>>> tgt_mask = torch.triu(torch.full((T, T), float('-inf')), diagonal=1)
tensor([[0., -inf, -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf],
[0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0.]])
推論時は、Greedyアルゴリズムの場合、再帰的に最初の単語から最後の単語まで一つづつ予測していきます(コード)。具体的には最初の推論でBOS
(文章の始まり)のみを含む文章をターゲットとして、2個目の単語を予測します。そして予測した単語を含めて再度モデルを実行し、3個目の単語を予測します。このときモデルはソースであるドイツ語の文章と1〜2個目の英単語から3個目を予測することになります。これを繰り返し、モデルがEOS
(文章の終わり)を出力するまで繰り返します。何番目の単語を予測するかで文章の長さが変わるのでtgt_mask
は文章の長さを同じになるように値を設定します。
# 2個目の予測
>>> tgt_idx = torch.ones(1, 1).fill_(BOS).type(torch.long)
tensor([[2]])
>>> tgt_mask = torch.triu(torch.full((1, 1), float('-inf')), diagonal=1).type(torch.bool)
tensor([[False]])
#3個目の予測(2個目の結果が7とします)
>>> tgt_idx = torch.cat([tgt_idx, torch.ones(1, 1).type_as(tgt_idx.data).fill_(7)], dim=0)
tensor([[ 2],
[ 7]])
>>> tgt_mask = torch.triu(torch.full((2, 2), float('-inf')), diagonal=1).type(torch.bool)
tensor([[False, True],
[False, False]])
最後にmemory_mask
ですが、要素がすべてFalseであるテンソルを使います。これはキーやバリューであるドイツ語の文章をエンコーダによって処理したシークエンスの全体が利用できるためです。例えば日本語と英語は動詞の場所が異なります。ドイツ語から英語の翻訳においても、冒頭の単語を予測するとき、翻訳元の最後の単語を参照することは意味があると思います(私はドイツ語についての知識はありません)。
[src/tgt/memory]_key_padding_mask
について
[src/tgt/memory]_key_padding_mask
はAttentionに入力するキーのシークエンスのpaddingの位置を示すkey_paddingマスクです。key_paddingマスクは次のコードのように変形され、加法性マスクにマージされます。
# multi_head_attention_forward関数のkey_paddingマスクを変形する箇所
key_padding_mask = (
key_padding_mask.view(N, 1, 1, key_seq_length)
.expand(-1, num_heads, -1, -1)
.reshape(N * num_heads, 1, key_seq_length)
)
# 加法性マスクとkey_paddingマスクをマージする部分
# attn_maskは加法性マスクです。Noneでなければ、
# サイズは(1, query_seq_length, key_seq_length)
# もしくは(N * num_heads, query_seq_length, key_seq_length)
if attn_mask is None:
attn_mask = key_padding_mask
elif attn_mask.dtype == torch.bool:
attn_mask = attn_mask.logical_or(key_padding_mask)
else:
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
この記事を作るために参考にさせていただいた「Transformerのデータの流れを追ってみる」ではこのマージ操作のわかりやすい図がありますので、ぜひ見てください。
key_paddingマスクのサイズはバッチ内の各シークエンスに対してpaddingの位置を示すように(バッチ数、シークエンス長)
です。具体的には[src/tgt/memory]_key_padding_mask
のサイズはそれぞれ(N, S)
、(N, T)
、(N, S)
です。またdtype
はbool
です。paddingであればTrue, そうでなければFalseの値をマスクに指定します。
key_paddingマスクは加法性マスクと同様にAttention weightを変更するマスクです。しかし加法性マスクと違い、バリューからpaddingの情報を抽出しないようにするために使います。同じシークエンスでもpaddingの有り無しでその結果が変わらないように、Transformerにシークエンス中のどこの要素を使っていいか伝えます。
言語翻訳の例ではTransformerのエンコーダのSelf AttentionとデコーダのSource-Target Attentionのキーはドイツ語の文章です。したがってsrc_key_padding_mask
とmemory_key_padding_mask
にドイツ語のpaddingの位置を示したテンソルを指定します。
>>> src_key_padding_mask = memory_key_padding_mask = (src == PAD_IDX).transpose(0, 1)
tensor([[False, False, False, False, True],
[False, False, False, False, True]])
一方でTransformerのデコーダのSelf Attentionのキーは英語の文章なので、tgt_key_padding_mask
に英語のパッディングの位置を示したテンソルを指定します。
>>> tgt_key_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
tensor([[False, False, False, False, True, True],
[False, False, False, False, False, True]])
最後に
読んでいただきありがとうございます。言語翻訳はチュートリアルを動かしただけなので、間違いや補足、怪しいところがあればぜひ指摘してください。また文体や構成など意味をなさないところが有りましたら、ぜひそれも指摘していただけると助かります。