0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

マルチヘッドアテンションの二つの簡単な作り方

Last updated at Posted at 2024-09-30

Pytorchメモ→マルチヘッドアテンション(Multi-head Attention)の二つの作り方を紹介させていただきます.

メソッド1

この⽅法で⾏う⾏列の形状変換のは、並列性があり、計算効率が⾼いというメリットがあります。また、こういう書き方はネット上で広く使われているコードですが、初学者にはわかりにくいです。

method 1
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, head_num: int):
        super(MultiHeadAttention, self).__init__()
        assert d_model % head_num == 0
        
        self.d_model = d_model
        self.head_num = head_num
        self.head_dim = d_model // head_num

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        q = self.wq(x)
        k = self.wk(x) 
        v = self.wv(x) 

        q = q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)  
        k = k.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)  
        v = v.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)  

        attn_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.head_dim ** 0.5)  
        attn_weights = F.softmax(attn_scores, dim=-1)

        attn_output = torch.matmul(attn_weights, v) 

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)  # [batch_size, seq_len, d_model]
        output = self.fc(attn_output) 

        return output

d_model = 20   
head_num = 5  
model = MultiHeadAttention(d_model, head_num)
x = torch.randn(4, 10, d_model)  

print(model(x).shape)  

メソッド2

この方法で各のヘッドがそれぞれ計算されており、並列処理を⾏うのは割と弱いです。計算効率は低下、一般的に使わなくても、初心者向けわかりやすいと思います。

method 2
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, d_model: int = 4):
        super(AttentionOld, self).__init__()
        self.d_model = d_model
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

    def forward(self, x):
        q = self.wq(x)
        k = self.wk(x).transpose(-1, -2)
        v = self.wv(x)
        attn_weights = F.softmax(torch.matmul(q, k) / (self.d_model ** 0.5), dim=-1)
        return torch.matmul(attn_weights, v)
class AttentionMulti(nn.Module):  
    def __init__(self, dim: int, d_model: int, head_num: int):
        super(AttentionMulti, self).__init__()
        self.head_num = head_num
        self.heads = nn.ModuleList([Attention(dim) for _ in range(head_num)])
        self.final_fc = nn.Linear(dim * head_num, d_model)
        
    def forward(self, x):
        heads_output = [head(x) for head in self.heads]
        concat = torch.cat(heads_output, dim=-1)

        output = self.final_fc(concat)
        return output
        
d_model = 4  
head_num = 5  
dim = 4  
model = AttentionMulti(dim, d_model, head_num)
x = torch.randn(4, 4) 

for name, param in model.named_parameters():
    print(f"パラメータのネーム: {name}")
    print(f"パラメータ: {param.data}\n")

更新中です!

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?