33
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Pytorchのnn.Transformerのforward関数を理解する

Last updated at Posted at 2021-12-22

この記事について

Transformerは自然言語処理や様々な分野で使われており、いろいろな方がわかりやすく解説されています。私もこれらの記事に触発され、使ってみようとしましたが、実際に使おうとするとPytorchのnn.Transformerforward関数の引数が多く混乱しました:eyes:。特にmaskの引数の意味がドキュメントを見るだけではわかりませんでした。そこでこれらの引数について調べた結果をまとめました。またPytorchのチュートリアルであるTransformerを使った言語翻訳(ドイツ語から英語)を具体例にしました。記事の作成に当たって、主に以下の情報を参考にして執筆しました。

nn.Transformer.forwardの引数

nn.Transformer(リンク)のforward関数は次の8個の引数を持っています。

  1. src: ソースのシークエンス
  2. tgt: ターゲットのシークエンス
  3. src_mask: ソースの加法性マスク
  4. src_key_padding_mask: ソースのキーのマスク
  5. tgt_mask: ターゲットの加法性マスク
  6. tgt_key_padding_mask: ターゲットのキーのマスク
  7. memory_mask: メモリの加法性マスク
  8. memory_key_padding_mask: メモリのキーのマスク

そしてforward関数は次の値を返します。

  1. output: Transformerが計算したシークエンス

結論から述べますと、これらの引数と「Attention is All You Need」のモデルの対応関係は次のようになっています。nn.Transformerはデコーダの最初のAttentionだけでなく、すべてのAttentionに2つずつマスクが設定できる設計となっていることがわかります。

parameter_and_arch.png

引数 対応するnn.MultiHeadAttention
src_maskおよび src_key_padding_mask TransformerのエンコーダのSelf Attention
tgt_maskおよびtgt_key_padding_mask TransformerのデコーダのSelf Attention
memory_maskおよび memory_key_padding_mask TransformerのデコーダのSource-Target Attention

以降は、各引数の説明をしていきます。

src, tgt, outputについて

srcはTransformerのエンコーダで処理するソースのシークエンスです。シークエンス長S、バッチサイズN、デプスEのテンソルを渡します。tgtはデコーダで処理するターゲットのシークエンスです。シークエンス長T、バッチサイズN、デプスEのテンソルを渡します。outputtgtと同じサイズのテンソルで、Transformerが計算した結果です。

言語翻訳の例だとsrcはドイツ語の文章、tgtは英語の文章です。src, tgtともにPositional Encodingを加えたEmbeddingで、複数の文章をバッチするためにシークエンス中にpaddingを含みます。outputはLinear層とsoftmax層で翻訳した結果を計算するために使われます。ソースのシークエンス長を5、ターゲットのシークエンス長を6、バッチ数を2とするとsrc, tgtは次のような感じです。

src_tgt.png

[src/tgt/memory]_maskについて

[src/tgt/memory]_maskは加法性マスク(additive mask)です。加法性マスクは次のコードのようにクエリとキーから計算されるAttention Weightに加算してAttention weightを変更するために使われます。

# _scaled_dot_product_attention関数はクエリとキーから計算されるAttention Weightと加法性マスクの和を計算し、
# その和の重みに従ってバリューから情報を取り出します。
# attn_maskが加法性マスクです。
B, Nt, E = q.shape
q = q / math.sqrt(E)
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
attn = torch.bmm(q, k.transpose(-2, -1))
if attn_mask is not None:
    attn += attn_mask
attn = softmax(attn, dim=-1)
if dropout_p > 0.0:
   attn = dropout(attn, p=dropout_p)
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
output = torch.bmm(attn, v)

したがって、[src/tgt/memory]_maskにAttention Weightと同じサイズ(クエリのシークエンス長、キーのシークエンス長)のテンソル、もしくは(バッチ数×ヘッド数、クエリのシークエンス長、キーのシークエンス長)のテンソルを渡します。(バッチ数×ヘッド数、クエリのシークエンス長、キーのシークエンス長)のテンソルを渡した場合は、各バッチや各ヘッドに違うマスクを使うことができます。また[src/tgt/memory]_maskのデフォルト値はNoneです。Noneの場合は上記のコードのとおり、Attention weightを変更しません。

src_maskはクエリとキーがソースのSelf Attentionに使われます。このAttention weightのサイズは(S,S)であるので、src_maskのサイズは(S,S)です。同様にtgt_maskはクエリとキーがターゲットのSelf Attentionに使われるので、そのサイズは(T, T)です。memory_maskはクエリがターゲット、キーがソースのSource-Target Attentionで使われます。したがってmemory_maskのサイズは(T, S)です。

[src/tgt/memory]_maskdtypeはクエリと同じdtype、もしくはboolです。クエリと同じdtypeの場合、そのままAttention Weightに加算されます。boolの場合、bool型のマスクは次のコードのようにクエリと同じdtypeのマスクに変換されてから加算されます。

# multi_head_attention_forward関数の加法性マスクのdtypeを変更する箇所
# Trueがfloat("-inf")に、Falseが0.0に置き換えられる
if attn_mask is not None and attn_mask.dtype == torch.bool:
    new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
    new_attn_mask.masked_fill_(attn_mask, float("-inf"))
    attn_mask = new_attn_mask

加法性マスクはTransformerに予測する単語を使わないようにするためのマスクです。訓練時は入力するシークエンスは正解のデータを持っているので、正解を隠すためでもあります。マスクの仕方はAttention weightのマスクする場所にfloat("-inf")またはTrue、マスクしない場所に0.0またはFalseを指定します。Attention weightとの加算後にsoftmaxを適用するので、大きな負の値をもつ要素は注目を受けなくなります。

言語翻訳の例ではソースであるドイツ語の文章は翻訳する文章なのでシークエンスのすべての位置を参照できます。したがってsrc_maskに要素がすべてFalseであるテンソルを使います。

>>> S = 5
>>> src_mask = torch.zeros((S, S), dtype=bool)

一方で、ターゲットである英語の文章はモデルが予測する文章です。訓練時は正解である文章をターゲットにいれて、tgtに含まれている各単語の次の単語を予測します。したがって各単語からそれ以降の単語を使わないようにtgt_maskに次のようなマスクを設定して英語の文章の単語間のSelf Attentionに制限をかけます。

>>> T = 6
>>> tgt_mask = torch.triu(torch.full((T, T), float('-inf')), diagonal=1)
tensor([[0., -inf, -inf, -inf, -inf, -inf],
        [0.,   0., -inf, -inf, -inf, -inf],
        [0.,   0.,   0., -inf, -inf, -inf],
        [0.,   0.,   0.,   0., -inf, -inf],
        [0.,   0.,   0.,   0.,   0., -inf],
        [0.,   0.,   0.,   0.,   0.,   0.]])

推論時は、Greedyアルゴリズムの場合、再帰的に最初の単語から最後の単語まで一つづつ予測していきます(コード)。具体的には最初の推論でBOS(文章の始まり)のみを含む文章をターゲットとして、2個目の単語を予測します。そして予測した単語を含めて再度モデルを実行し、3個目の単語を予測します。このときモデルはソースであるドイツ語の文章と1〜2個目の英単語から3個目を予測することになります。これを繰り返し、モデルがEOS(文章の終わり)を出力するまで繰り返します。何番目の単語を予測するかで文章の長さが変わるのでtgt_maskは文章の長さを同じになるように値を設定します。

# 2個目の予測
>>> tgt_idx = torch.ones(1, 1).fill_(BOS).type(torch.long)
tensor([[2]])
>>> tgt_mask = torch.triu(torch.full((1, 1), float('-inf')), diagonal=1).type(torch.bool)
tensor([[False]])
#3個目の予測(2個目の結果が7とします)
>>> tgt_idx = torch.cat([tgt_idx, torch.ones(1, 1).type_as(tgt_idx.data).fill_(7)], dim=0)
tensor([[ 2],
        [ 7]])
>>> tgt_mask = torch.triu(torch.full((2, 2), float('-inf')), diagonal=1).type(torch.bool)
tensor([[False,  True],
        [False, False]])

最後にmemory_maskですが、要素がすべてFalseであるテンソルを使います。これはキーやバリューであるドイツ語の文章をエンコーダによって処理したシークエンスの全体が利用できるためです。例えば日本語と英語は動詞の場所が異なります。ドイツ語から英語の翻訳においても、冒頭の単語を予測するとき、翻訳元の最後の単語を参照することは意味があると思います:thinking:(私はドイツ語についての知識はありません)。

[src/tgt/memory]_key_padding_maskについて

[src/tgt/memory]_key_padding_maskはAttentionに入力するキーのシークエンスのpaddingの位置を示すkey_paddingマスクです。key_paddingマスクは次のコードのように変形され、加法性マスクにマージされます。

# multi_head_attention_forward関数のkey_paddingマスクを変形する箇所
key_padding_mask = (
    key_padding_mask.view(N, 1, 1, key_seq_length)
    .expand(-1, num_heads, -1, -1)
    .reshape(N * num_heads, 1, key_seq_length)
)
# 加法性マスクとkey_paddingマスクをマージする部分
# attn_maskは加法性マスクです。Noneでなければ、
# サイズは(1, query_seq_length, key_seq_length)
# もしくは(N * num_heads, query_seq_length, key_seq_length)
if attn_mask is None:
    attn_mask = key_padding_mask
elif attn_mask.dtype == torch.bool:
    attn_mask = attn_mask.logical_or(key_padding_mask)
else:
    attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))

この記事を作るために参考にさせていただいた「Transformerのデータの流れを追ってみる」ではこのマージ操作のわかりやすい図がありますので、ぜひ見てください。

key_paddingマスクのサイズはバッチ内の各シークエンスに対してpaddingの位置を示すように(バッチ数、シークエンス長)です。具体的には[src/tgt/memory]_key_padding_maskのサイズはそれぞれ(N, S)(N, T)(N, S)です。またdtypeboolです。paddingであればTrue, そうでなければFalseの値をマスクに指定します。

key_paddingマスクは加法性マスクと同様にAttention weightを変更するマスクです。しかし加法性マスクと違い、バリューからpaddingの情報を抽出しないようにするために使います。同じシークエンスでもpaddingの有り無しでその結果が変わらないように、Transformerにシークエンス中のどこの要素を使っていいか伝えます。

言語翻訳の例ではTransformerのエンコーダのSelf AttentionとデコーダのSource-Target Attentionのキーはドイツ語の文章です。したがってsrc_key_padding_maskmemory_key_padding_maskにドイツ語のpaddingの位置を示したテンソルを指定します。

>>> src_key_padding_mask = memory_key_padding_mask = (src == PAD_IDX).transpose(0, 1)
tensor([[False, False, False, False,  True],
        [False, False, False, False,  True]])

一方でTransformerのデコーダのSelf Attentionのキーは英語の文章なので、tgt_key_padding_maskに英語のパッディングの位置を示したテンソルを指定します。

>>> tgt_key_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
tensor([[False, False, False, False,  True,  True],
        [False, False, False, False, False,  True]])

最後に

読んでいただきありがとうございます。言語翻訳はチュートリアルを動かしただけなので、間違いや補足、怪しいところがあればぜひ指摘してください:pray:。また文体や構成など意味をなさないところが有りましたら、ぜひそれも指摘していただけると助かります。

33
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
33
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?