0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Scaled dot product attention 計算の q と k の内積を q と k のユークリッド距離の逆数にしても学習します。

Last updated at Posted at 2024-02-19

はじめに

Transformer のことを勉強していたら、Scaled dot product attention の q と k の内積は、q と k の類似度を見ているという説明がありました。類似度なら、内積以外にも計算方法があるぞと思い、コサイン類似度とユークリッド距離を試してみました。コサイン類似度は距離の情報がなくなるということで学習には使えないようです。ユークリッド距離の逆数を使うと学習することが分かったので、ご報告させていただきます。

Scaled Euclid Distance Attention クラス

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, dim_hidden: int, n_head: int, dropout: float=0.1, qkv_bias: bool=False):
        super().__init__()
        self.n_head = n_head

        self.proj_q = nn.Linear( dim_hidden, dim_hidden, bias=qkv_bias)
        self.proj_k = nn.Linear( dim_hidden, dim_hidden, bias=qkv_bias)
        self.proj_v = nn.Linear( dim_hidden, dim_hidden, bias=qkv_bias)
        self.proj_out = nn.Linear(dim_hidden, dim_hidden)

        self.qkv_attention = ScaledEuclidDistanceAttention( dim_hidden, n_head, dropout )
        
    def forward( self,query,key,value, attn_mask = None ):

        q = self.proj_q( query )
        k = self.proj_k( key )
        v = self.proj_v( value )
        
        output = self.qkv_attention(q, k, v, attn_mask)

        output = self.proj_out( output )
        
        return output 

    
class ScaledEuclidDistanceAttention(nn.Module):
    def __init__(self, dim_hidden: int, num_heads: int, dropout: float=0.1):
        super().__init__()

        assert dim_hidden % num_heads == 0

        self.num_heads = num_heads

        dim_head = dim_hidden // num_heads

        # Scale Value of Softmax
        self.scale = dim_head ** -0.5
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None ):

        q = q.view( q.size(0), q.size(1), self.num_heads, -1 ).permute( 0,2,1,3 )
        k = k.view( k.size(0), k.size(1), self.num_heads, -1 ).permute( 0,2,1,3 )
        v = v.view( v.size(0), v.size(1), self.num_heads, -1 ).permute( 0,2,1,3 )

        qk_dis = torch.cdist( q, k, p = 2 ) * self.scale # cdist is distance function of pytorch
        attn =  1 / ( qk_dis + 1e-9 )
        # Check of learning was carried out with mask = None
        if mask is not None:
            attn = attn + torch.unsqueeze( torch.unsqueeze( mask, dim = 0 ), dim = 0 ).to(torch.float16) * -1e9
        attn = ( attn ).softmax(dim=-1)  

        attn = self.dropout( attn )
        
        x = attn.matmul(v)

        x = x.permute(0, 2, 1, 3).flatten(2)

        return x    
if __name__ == "__main__":

    num_batch = 8
    q_seq = 300
    k_seq = 100
    dim_hidden = 512
    num_heads = 8

    func2 = MultiHeadAttention( dim_hidden, num_heads )

    q = torch.randn( ( num_batch, q_seq, dim_hidden))
    k = torch.randn( ( num_batch, k_seq, dim_hidden))
    v = k
    mask = torch.randint(low=0, high=2, size=(q_seq, k_seq)).to( torch.bool )

    x = func2( q,k,v,mask)

    print( x.size() )

Vision Transformer での確認。

こちらでの確認を報告します。画像分類のプログラムのTransformer Encoder の Self Attention にユークリッド距離の attention を使って試したのですが、学習がすすみました。内積だと1エポック12分程度(cpu)で、ユークリッド距離だと14分程度(cpu)です。この確認のための学習だけでモデルの良さの結論は出せません。「Python で学ぶ画像認識」という本のプログラムで、q と k の類似度に通常の内積を使った場合、ある画像分類の問題のテストデータの正解率が 63.8% であったと本に書いてあります。本のプログラムで内積の代わりにユークリッド距離の逆数を使った確認のための学習では、同じ問題について、正解率が 66.7 %になりました。

画像分類プログラムで学習確認したときの Loss の推移

Loss.png

Accuracy の推移

Accuracy.png

機械翻訳プログラムでの確認

機械翻訳の Transformer ですが、プログラム的にはTransformer Decoder の Self Attention でも Cross Attention でも通ります。また、機械翻訳のプログラムで、Transformer Encoder と Transformer Decoder にユークリッド距離による attention を使って学習が進むことを確認しました。

Github

この案件は、英語でも報告したほうが良さそうなので、Github でつたない英語ですが報告しておきます。

2024年3月8日追記、image captioning の Transforemer Decdoer で動作確認

時間がとれたので、既出の本「Python で学ぶ画像認識」の p. 331 で述べられている Transformer を使った image captioning のプログラムの MultiHeadAttention モジュールをこのページで提案しているモジュールに修正して学習を行いました。キャプショニングの結果を示します。

adorable-1849992_1920.jpg
<start> a dog is playing with a frisbee in its mouth <end>

africa-1170179_1920.jpg
<start> a herd of giraffes are standing in a field <end>

airplane-3702676_1920.jpg
<start> a plane flying over a plane flying over a plane <end>

automotive-1846910_1920.jpg
<start> a car is parked on a car on the street <end>

beach-1837030_1920.jpg
<start> a man is surfing on a wave on a wave <end>

caravan-339564_1920.jpg
<start> a group of people riding on a beach <end>

cat-4467818_1920.jpg
<start> a cat is sitting on a colorful rug <end>

cherry-1468933_1920.jpg
<start> a bowl of fruit and a bowl of fruit <end>

couple-955926_1280.jpg
<start> a woman is riding a beach with a beach <end>

dog-7367949_1920.jpg
<start> a man is playing with a frisbee in the water <end>

hit-1407826_1920.jpg
<start> a baseball player is swinging a baseball game <end>

man-498473_1920.jpg
<start> a man riding skis down a snow covered slope <end>

musician-743973_1920.jpg
<start> three men are standing on a beach with surfboards <end>

port-5788261_1920.jpg
<start> a large clock tower towering over a city street <end>

profile-7579739_1920.jpg
<start> a man riding a skateboard on a skateboard <end>

ural-owl-4808774_1920.jpg
<start> a bird perched on a branch in a tree branch <end>

wine-bar-2139973_1920.jpg
<start> a man in a man in a bar <end>

woman-3432069_1920.jpg
<start> a horse is standing in the grass field <end>

zebras-1883654_1920.jpg
<start> a zebra standing in a field with a field <end>

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?