まえがき
業務でtorch.nn.MultiheadAttentionを使用する機会があったが、key_padding_mask
/attn_mask
の扱いが分からず苦労したので、何のために使うのか、どう使うのかをまとめる。
本家のドキュメントは以下
結論
- key_padding_mask
- AttentionのKey側にPaddingが入っている場合、そのPaddingの箇所を計算に含めないように制御するためのマスク
- attn_mask
- Self-Attentionで未来の情報を参照しないように制御するためのマスク
key_padding_mask
なんのために使うのか
ドキュメントでは以下のように説明されている(ChatGPTさんによる翻訳)。
指定された場合、形状が (N, S) のマスクは、アテンションの目的で key 内のどの要素を無視するか
(つまり「パディング」として扱うか)を示します。バッチ処理されていないクエリの場合、形状は (S)
である必要があります。バイナリマスクとフロートマスクがサポートされています。バイナリマスクでは、
True の値は対応する key の値がアテンションの目的で無視されることを示します。フロートマスクの場合、
それは対応する key の値に直接加算されます。
大抵の場合はattention weightの計算時にPaddingされた箇所を計算から除外するために使用する。
時系列データを扱う場合時系列データの長さが違うとミニバッチ処理ができない。
それを解決するためにPaddingによってダミーデータを入れることでデータの長さを揃えている。
ただし、このままattention weightを計算しようとするとダミーデータが計算に含まれてしまい、意図した計算にならない。
別に無視をすればいいじゃないかと思われるが、attention weightの計算ではQueryとKeyの内積後に、行方向にSoftmaxの計算が入っているのでそうもいかない。
そこでkey_padding_maskの登場である。
key_padding_maskによりPADの部分を-Infに置き換えてしまう。
これにより、PADの部分だけSoftmaxの結果が小さくなり影響を軽減できる。
どう使うのか
PADの位置がTrueになる(バッチサイズ, データ長)のマスクを生成しkey_padding_maskに設定すればよい。
PAD = 1000
data = torch.Tensor([
[1,2,3,4,PAD],
[5,6,7,PAD,PAD],
[8,9,10,11,12]
]).unsqueeze(-1)
print(data.shape)
# torch.Size([3, 5, 1])
mask = data==PAD
mask = mask[:,:,0]
print(mask.shape)
# torch.Size([3, 5])
print(mask)
#tensor([[False, False, False, False, True],
# [False, False, False, True, True],
# [False, False, False, False, False]])
実験
key_paddin_maskを設定しない場合結果を確認する。
attention = nn.MultiheadAttention(embed_dim=1, num_heads=1, batch_first=True)
_, attn_weight = attention(data, data, data)
print(attn_weight)
# tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00]],
#
# [[0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0000e-01, 5.0000e-01]],
#
# [[1.1341e-04, 1.0686e-03, 1.0069e-02, 9.4870e-02, 8.9388e-01],
# [3.8020e-05, 4.7417e-04, 5.9137e-03, 7.3753e-02, 9.1982e-01],
# [1.2650e-05, 2.0883e-04, 3.4473e-03, 5.6908e-02, 9.3942e-01],
# [4.1861e-06, 9.1469e-05, 1.9986e-03, 4.3671e-02, 9.5423e-01],
# [1.3797e-06, 3.9905e-05, 1.1541e-03, 3.3380e-02, 9.6542e-01]]],
# grad_fn=<MeanBackward1>)
PADの値が大きいためその影響を受けていることが分かる。
key_padding_maskを設定すると以下のようになる。
_, attn_weight = attention(data, data, data, key_padding_mask=mask)
print(attn_weight)
# tensor([[[1.5638e-01, 2.0699e-01, 2.7398e-01, 3.6265e-01, 0.0000e+00],
# [8.9290e-02, 1.5644e-01, 2.7408e-01, 4.8019e-01, 0.0000e+00],
# [4.7240e-02, 1.0955e-01, 2.5405e-01, 5.8916e-01, 0.0000e+00],
# [2.3577e-02, 7.2372e-02, 2.2215e-01, 6.8190e-01, 0.0000e+00],
# [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00]],
#
# [[4.6359e-02, 1.8836e-01, 7.6529e-01, 0.0000e+00, 0.0000e+00],
# [2.8329e-02, 1.5235e-01, 8.1932e-01, 0.0000e+00, 0.0000e+00],
# [1.7010e-02, 1.2108e-01, 8.6191e-01, 0.0000e+00, 0.0000e+00],
# [0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
# [0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00]],
#
# [[1.1341e-04, 1.0686e-03, 1.0069e-02, 9.4870e-02, 8.9388e-01],
# [3.8020e-05, 4.7417e-04, 5.9137e-03, 7.3753e-02, 9.1982e-01],
# [1.2650e-05, 2.0883e-04, 3.4473e-03, 5.6908e-02, 9.3942e-01],
# [4.1861e-06, 9.1469e-05, 1.9986e-03, 4.3671e-02, 9.5423e-01],
# [1.3797e-06, 3.9905e-05, 1.1541e-03, 3.3380e-02, 9.6542e-01]]],
# grad_fn=<MeanBackward1>)
maskされた部分は計算から除外されている(-Infで置き換えられている)ため0になっている。
attn_mask
なんのために使うのか
ドキュメントでは以下のように説明されている(ChatGPTさんによる翻訳)。
指定された場合、特定の位置へのアテンションを防ぐための2Dまたは3Dのマスクを使用します。
形状は (L, S) または (N⋅num_heads, L, S) でなければなりません。ここで、N はバッチサイズ、
L はターゲットシーケンスの長さ、S はソースシーケンスの長さです。2Dマスクはバッチ全体に対し
てブロードキャストされますが、3Dマスクはバッチ内の各エントリごとに異なるマスクを許可します。
バイナリマスクとフロートマスクがサポートされています。バイナリマスクでは、True の値は対応
する位置がアテンションされないことを示します。フロートマスクの場合、
マスクの値はアテンションウェイトに加算されます。attn_mask と key_padding_mask の両方が
提供される場合、それらの型は一致している必要があります。
大抵の場合、学習時に未来の情報を参照しないようにマスクする際に使用する。
言語モデルの学習の場合、文章を単語に分割してからまとめてTransformerに入力する。その後Self-Attentionで単語間の関係を学習するのだが、そのまま入力すると本来知りえない未来の単語との関係性も学習してしまう。これを防ぐために未来の情報をマスクする必要がある。
どう使うのか
torch.tril()を使うと簡単にマスクを作成できる。
このマスクをattn_maskに設定すればよい。
PAD = 1000
data = torch.Tensor([
[1,1,1,1,1],
]).unsqueeze(-1)
data.shape
# torch.Size([1, 5, 1])
attn_mask = torch.logical_not(torch.tril(torch.ones(data.shape[1], data.shape[1])))
print(attn_mask)
# tensor([[False, True, True, True, True],
# [False, False, True, True, True],
# [False, False, False, True, True],
# [False, False, False, False, True],
# [False, False, False, False, False]])
実験
attn_maskをしない場合の結果を確認する。
attention = nn.MultiheadAttention(embed_dim=1, num_heads=1, batch_first=True)
_, attn_weight = attention(data, data, data)
print(attn_weight)
# tensor([[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
# [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
# [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
# [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
# [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]], grad_fn=<MeanBackward1>)
すべて1のデータを使用しており、未来の情報も見ているのでattn_weightの値はすべて同じ値になっている。
attn_maskを設定すると以下のようになる。
_, attn_weight = attention(data, data, data, attn_mask=attn_mask)
print(attn_weight)
# tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
# [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
# [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
# [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
# [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]], grad_fn=<MeanBackward1>)
未来の情報はマスクされているので0になっており計算に含まれていない。
まとめ
- key_padding_mask
- AttentionのKey側にPaddingが入っている場合、そのPaddingの箇所を計算に含めないように制御するためのマスク
- attn_mask
- Self-Attentionで未来の情報を参照しないように制御するためのマスク
二つのマスクを適切に設定することで、効率よく学習することが出来る。