Pytorchメモ→マルチヘッドアテンション(Multi-head Attention)の二つの作り方を紹介させていただきます.
メソッド1
この⽅法で⾏う⾏列の形状変換のは、並列性があり、計算効率が⾼いというメリットがあります。また、こういう書き方はネット上で広く使われているコードですが、初学者にはわかりにくいです。
method 1
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, head_num: int):
super(MultiHeadAttention, self).__init__()
assert d_model % head_num == 0
self.d_model = d_model
self.head_num = head_num
self.head_dim = d_model // head_num
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
q = self.wq(x)
k = self.wk(x)
v = self.wv(x)
q = q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.head_dim ** 0.5)
attn_weights = F.softmax(attn_scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) # [batch_size, seq_len, d_model]
output = self.fc(attn_output)
return output
d_model = 20
head_num = 5
model = MultiHeadAttention(d_model, head_num)
x = torch.randn(4, 10, d_model)
print(model(x).shape)
メソッド2
この方法で各のヘッドがそれぞれ計算されており、並列処理を⾏うのは割と弱いです。計算効率は低下、一般的に使わなくても、初心者向けわかりやすいと思います。
method 2
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(self, d_model: int = 4):
super(AttentionOld, self).__init__()
self.d_model = d_model
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
def forward(self, x):
q = self.wq(x)
k = self.wk(x).transpose(-1, -2)
v = self.wv(x)
attn_weights = F.softmax(torch.matmul(q, k) / (self.d_model ** 0.5), dim=-1)
return torch.matmul(attn_weights, v)
class AttentionMulti(nn.Module):
def __init__(self, dim: int, d_model: int, head_num: int):
super(AttentionMulti, self).__init__()
self.head_num = head_num
self.heads = nn.ModuleList([Attention(dim) for _ in range(head_num)])
self.final_fc = nn.Linear(dim * head_num, d_model)
def forward(self, x):
heads_output = [head(x) for head in self.heads]
concat = torch.cat(heads_output, dim=-1)
output = self.final_fc(concat)
return output
d_model = 4
head_num = 5
dim = 4
model = AttentionMulti(dim, d_model, head_num)
x = torch.randn(4, 4)
for name, param in model.named_parameters():
print(f"パラメータのネーム: {name}")
print(f"パラメータ: {param.data}\n")
更新中です!