はじめに
Transformer のことを勉強していたら、Scaled dot product attention の q と k の内積は、q と k の類似度を見ているという説明がありました。類似度なら、内積以外にも計算方法があるぞと思い、コサイン類似度とユークリッド距離を試してみました。コサイン類似度は距離の情報がなくなるということで学習には使えないようです。ユークリッド距離の逆数を使うと学習することが分かったので、ご報告させていただきます。
Scaled Euclid Distance Attention クラス
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, dim_hidden: int, n_head: int, dropout: float=0.1, qkv_bias: bool=False):
super().__init__()
self.n_head = n_head
self.proj_q = nn.Linear( dim_hidden, dim_hidden, bias=qkv_bias)
self.proj_k = nn.Linear( dim_hidden, dim_hidden, bias=qkv_bias)
self.proj_v = nn.Linear( dim_hidden, dim_hidden, bias=qkv_bias)
self.proj_out = nn.Linear(dim_hidden, dim_hidden)
self.qkv_attention = ScaledEuclidDistanceAttention( dim_hidden, n_head, dropout )
def forward( self,query,key,value, attn_mask = None ):
q = self.proj_q( query )
k = self.proj_k( key )
v = self.proj_v( value )
output = self.qkv_attention(q, k, v, attn_mask)
output = self.proj_out( output )
return output
class ScaledEuclidDistanceAttention(nn.Module):
def __init__(self, dim_hidden: int, num_heads: int, dropout: float=0.1):
super().__init__()
assert dim_hidden % num_heads == 0
self.num_heads = num_heads
dim_head = dim_hidden // num_heads
# Scale Value of Softmax
self.scale = dim_head ** -0.5
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None ):
q = q.view( q.size(0), q.size(1), self.num_heads, -1 ).permute( 0,2,1,3 )
k = k.view( k.size(0), k.size(1), self.num_heads, -1 ).permute( 0,2,1,3 )
v = v.view( v.size(0), v.size(1), self.num_heads, -1 ).permute( 0,2,1,3 )
qk_dis = torch.cdist( q, k, p = 2 ) * self.scale # cdist is distance function of pytorch
attn = 1 / ( qk_dis + 1e-9 )
# Check of learning was carried out with mask = None
if mask is not None:
attn = attn + torch.unsqueeze( torch.unsqueeze( mask, dim = 0 ), dim = 0 ).to(torch.float16) * -1e9
attn = ( attn ).softmax(dim=-1)
attn = self.dropout( attn )
x = attn.matmul(v)
x = x.permute(0, 2, 1, 3).flatten(2)
return x
if __name__ == "__main__":
num_batch = 8
q_seq = 300
k_seq = 100
dim_hidden = 512
num_heads = 8
func2 = MultiHeadAttention( dim_hidden, num_heads )
q = torch.randn( ( num_batch, q_seq, dim_hidden))
k = torch.randn( ( num_batch, k_seq, dim_hidden))
v = k
mask = torch.randint(low=0, high=2, size=(q_seq, k_seq)).to( torch.bool )
x = func2( q,k,v,mask)
print( x.size() )
Vision Transformer での確認。
こちらでの確認を報告します。画像分類のプログラムのTransformer Encoder の Self Attention にユークリッド距離の attention を使って試したのですが、学習がすすみました。内積だと1エポック12分程度(cpu)で、ユークリッド距離だと14分程度(cpu)です。この確認のための学習だけでモデルの良さの結論は出せません。「Python で学ぶ画像認識」という本のプログラムで、q と k の類似度に通常の内積を使った場合、ある画像分類の問題のテストデータの正解率が 63.8% であったと本に書いてあります。本のプログラムで内積の代わりにユークリッド距離の逆数を使った確認のための学習では、同じ問題について、正解率が 66.7 %になりました。
画像分類プログラムで学習確認したときの Loss の推移
Accuracy の推移
機械翻訳プログラムでの確認
機械翻訳の Transformer ですが、プログラム的にはTransformer Decoder の Self Attention でも Cross Attention でも通ります。また、機械翻訳のプログラムで、Transformer Encoder と Transformer Decoder にユークリッド距離による attention を使って学習が進むことを確認しました。
Github
この案件は、英語でも報告したほうが良さそうなので、Github でつたない英語ですが報告しておきます。
2024年3月8日追記、image captioning の Transforemer Decdoer で動作確認
時間がとれたので、既出の本「Python で学ぶ画像認識」の p. 331 で述べられている Transformer を使った image captioning のプログラムの MultiHeadAttention モジュールをこのページで提案しているモジュールに修正して学習を行いました。キャプショニングの結果を示します。
<start> a dog is playing with a frisbee in its mouth <end>
<start> a herd of giraffes are standing in a field <end>
<start> a plane flying over a plane flying over a plane <end>
<start> a car is parked on a car on the street <end>
<start> a man is surfing on a wave on a wave <end>
<start> a group of people riding on a beach <end>
<start> a cat is sitting on a colorful rug <end>
<start> a bowl of fruit and a bowl of fruit <end>
<start> a woman is riding a beach with a beach <end>
<start> a man is playing with a frisbee in the water <end>
<start> a baseball player is swinging a baseball game <end>
<start> a man riding skis down a snow covered slope <end>
<start> three men are standing on a beach with surfboards <end>
<start> a large clock tower towering over a city street <end>
<start> a man riding a skateboard on a skateboard <end>
<start> a bird perched on a branch in a tree branch <end>
<start> a man in a man in a bar <end>