1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

はじめての記事投稿
Qiita Engineer Festa20242024年7月17日まで開催中!

torch.nn.MultiheadAttentionのkey_padding_mask/attn_maskとは何か

Posted at

まえがき

業務でtorch.nn.MultiheadAttentionを使用する機会があったが、key_padding_mask/attn_maskの扱いが分からず苦労したので、何のために使うのか、どう使うのかをまとめる。

本家のドキュメントは以下

結論

  • key_padding_mask
    • AttentionのKey側にPaddingが入っている場合、そのPaddingの箇所を計算に含めないように制御するためのマスク
  • attn_mask
    • Self-Attentionで未来の情報を参照しないように制御するためのマスク

key_padding_mask

なんのために使うのか

ドキュメントでは以下のように説明されている(ChatGPTさんによる翻訳)。

指定された場合、形状が (N, S) のマスクは、アテンションの目的で key 内のどの要素を無視するか
(つまり「パディング」として扱うか)を示します。バッチ処理されていないクエリの場合、形状は (S) 
である必要があります。バイナリマスクとフロートマスクがサポートされています。バイナリマスクでは、
True の値は対応する key の値がアテンションの目的で無視されることを示します。フロートマスクの場合、
それは対応する key の値に直接加算されます。

大抵の場合はattention weightの計算時にPaddingされた箇所を計算から除外するために使用する。

時系列データを扱う場合時系列データの長さが違うとミニバッチ処理ができない。
それを解決するためにPaddingによってダミーデータを入れることでデータの長さを揃えている。

無題のプレゼンテーション.png

無題のプレゼンテーション (1).png

ただし、このままattention weightを計算しようとするとダミーデータが計算に含まれてしまい、意図した計算にならない。
別に無視をすればいいじゃないかと思われるが、attention weightの計算ではQueryとKeyの内積後に、行方向にSoftmaxの計算が入っているのでそうもいかない。

そこでkey_padding_maskの登場である。
key_padding_maskによりPADの部分を-Infに置き換えてしまう。
これにより、PADの部分だけSoftmaxの結果が小さくなり影響を軽減できる。

どう使うのか

PADの位置がTrueになる(バッチサイズ, データ長)のマスクを生成しkey_padding_maskに設定すればよい。

PAD = 1000
data = torch.Tensor([
    [1,2,3,4,PAD],
    [5,6,7,PAD,PAD],
    [8,9,10,11,12]
    ]).unsqueeze(-1)

print(data.shape)
# torch.Size([3, 5, 1])

mask = data==PAD
mask = mask[:,:,0]
print(mask.shape)
# torch.Size([3, 5])

print(mask)
#tensor([[False, False, False, False,  True],
#       [False, False, False,  True,  True],
#       [False, False, False, False, False]])

実験

key_paddin_maskを設定しない場合結果を確認する。

attention = nn.MultiheadAttention(embed_dim=1, num_heads=1, batch_first=True)
_, attn_weight = attention(data, data, data)

print(attn_weight)
# tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00]],
#
#        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01]],
#
#        [[1.1341e-04, 1.0686e-03, 1.0069e-02, 9.4870e-02, 8.9388e-01],
#         [3.8020e-05, 4.7417e-04, 5.9137e-03, 7.3753e-02, 9.1982e-01],
#         [1.2650e-05, 2.0883e-04, 3.4473e-03, 5.6908e-02, 9.3942e-01],
#         [4.1861e-06, 9.1469e-05, 1.9986e-03, 4.3671e-02, 9.5423e-01],
#         [1.3797e-06, 3.9905e-05, 1.1541e-03, 3.3380e-02, 9.6542e-01]]],
#       grad_fn=<MeanBackward1>)

PADの値が大きいためその影響を受けていることが分かる。

key_padding_maskを設定すると以下のようになる。

_, attn_weight = attention(data, data, data, key_padding_mask=mask)

print(attn_weight)
# tensor([[[1.5638e-01, 2.0699e-01, 2.7398e-01, 3.6265e-01, 0.0000e+00],
#         [8.9290e-02, 1.5644e-01, 2.7408e-01, 4.8019e-01, 0.0000e+00],
#         [4.7240e-02, 1.0955e-01, 2.5405e-01, 5.8916e-01, 0.0000e+00],
#         [2.3577e-02, 7.2372e-02, 2.2215e-01, 6.8190e-01, 0.0000e+00],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00]],
#
#        [[4.6359e-02, 1.8836e-01, 7.6529e-01, 0.0000e+00, 0.0000e+00],
#         [2.8329e-02, 1.5235e-01, 8.1932e-01, 0.0000e+00, 0.0000e+00],
#         [1.7010e-02, 1.2108e-01, 8.6191e-01, 0.0000e+00, 0.0000e+00],
#         [0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
#         [0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00]],
#
#        [[1.1341e-04, 1.0686e-03, 1.0069e-02, 9.4870e-02, 8.9388e-01],
#         [3.8020e-05, 4.7417e-04, 5.9137e-03, 7.3753e-02, 9.1982e-01],
#         [1.2650e-05, 2.0883e-04, 3.4473e-03, 5.6908e-02, 9.3942e-01],
#         [4.1861e-06, 9.1469e-05, 1.9986e-03, 4.3671e-02, 9.5423e-01],
#         [1.3797e-06, 3.9905e-05, 1.1541e-03, 3.3380e-02, 9.6542e-01]]],
#       grad_fn=<MeanBackward1>)

maskされた部分は計算から除外されている(-Infで置き換えられている)ため0になっている。

attn_mask

なんのために使うのか

ドキュメントでは以下のように説明されている(ChatGPTさんによる翻訳)。

指定された場合、特定の位置へのアテンションを防ぐための2Dまたは3Dのマスクを使用します。
形状は (L, S) または (N⋅num_heads, L, S) でなければなりません。ここで、N はバッチサイズ、
L はターゲットシーケンスの長さ、S はソースシーケンスの長さです。2Dマスクはバッチ全体に対し
てブロードキャストされますが、3Dマスクはバッチ内の各エントリごとに異なるマスクを許可します。
バイナリマスクとフロートマスクがサポートされています。バイナリマスクでは、True の値は対応
する位置がアテンションされないことを示します。フロートマスクの場合、
マスクの値はアテンションウェイトに加算されます。attn_mask と key_padding_mask の両方が
提供される場合、それらの型は一致している必要があります。

大抵の場合、学習時に未来の情報を参照しないようにマスクする際に使用する。

言語モデルの学習の場合、文章を単語に分割してからまとめてTransformerに入力する。その後Self-Attentionで単語間の関係を学習するのだが、そのまま入力すると本来知りえない未来の単語との関係性も学習してしまう。これを防ぐために未来の情報をマスクする必要がある。

無題のプレゼンテーション (3).png

どう使うのか

torch.tril()を使うと簡単にマスクを作成できる。
このマスクをattn_maskに設定すればよい。

PAD = 1000
data = torch.Tensor([
    [1,1,1,1,1],
    ]).unsqueeze(-1)
data.shape
# torch.Size([1, 5, 1])

attn_mask = torch.logical_not(torch.tril(torch.ones(data.shape[1], data.shape[1])))
print(attn_mask)
# tensor([[False,  True,  True,  True,  True],
#        [False, False,  True,  True,  True],
#        [False, False, False,  True,  True],
#        [False, False, False, False,  True],
#        [False, False, False, False, False]])

実験

attn_maskをしない場合の結果を確認する。

attention = nn.MultiheadAttention(embed_dim=1, num_heads=1, batch_first=True)
_, attn_weight = attention(data, data, data)
print(attn_weight)
# tensor([[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
#         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
#         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
#         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
#         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]], grad_fn=<MeanBackward1>)

すべて1のデータを使用しており、未来の情報も見ているのでattn_weightの値はすべて同じ値になっている。

attn_maskを設定すると以下のようになる。

_, attn_weight = attention(data, data, data, attn_mask=attn_mask)
print(attn_weight)
# tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
#         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
#         [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
#         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]], grad_fn=<MeanBackward1>)

未来の情報はマスクされているので0になっており計算に含まれていない。

まとめ

  • key_padding_mask
    • AttentionのKey側にPaddingが入っている場合、そのPaddingの箇所を計算に含めないように制御するためのマスク
  • attn_mask
    • Self-Attentionで未来の情報を参照しないように制御するためのマスク

二つのマスクを適切に設定することで、効率よく学習することが出来る。

参考文献

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?