0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Multi-Head Attentionの仕組みについての覚書

Posted at

はじめに

どうも、機械学習を用いて動画認識を行っている大学院生です。
最近、動画認識を行う際にTransformerを使用しました。
そこで、Transformerの中のMulti-Head Attentionの仕組みを覚書として残します。

Transformer

  • Transformerとは、Attentionを用いて一連データを一度に処理できるニューラルネットワーク
  • Transoformerのモデル図は以下の通り(原論文より引用)

スクリーンショット 2024-06-02 9.54.46.png

  • Multi-Head Attentionの説明をする前にInput EmbeddingとPostional Encondingを軽く説明します

1. Input Embedding

Input Embeddingとは、Inputからきた単語や文章などのデータをベクトル空間に変換を行っています。
例ですが、「This is a pen.」というInputが来た場合
Thisのベクトル [0.9, 0.1,...0.4], isのベクトル [0.1, 0.01,...0.9], aのベクトル[0.3, 0.2,...0.01], penのベクトル [0.5, 0.03,...0.11] .のベクトル [0.3, 0.1,...0.9]という形に変換します
詳しくは、embeddingとは?手を動かしながら学ぶを見て貰えばわかると思います。

2. Postional Encoding

Postional Encodingとは、ベクトルを足して位置情報を表現しています。
次に入力するAttentionでは、位置情報を加味する仕組みがないためPostional Encodingを行っています。
詳しくは、Positional Encodingを理解したいを見て貰えばわかると思います。

Multi-Head Attention

Multi-Head Attentionは、図(原論文より引用)に示すようにSingle-Head Attentionの複数の集まりです。
そして、Single-Head Attentionは、Q, K, Vに対してLinear層で変換をし、Scaled Dot-Product Attentionを行い、類似度の計算をするものです。従って、Q, K, Vについて、Scaled Dot-Product Attentionについて説明していきます。

スクリーンショット 2024-06-04 11.54.55.png

Q, K, Vについて

Q, K, Vは、Query, Key, Valueと言います。Qは、入力ベクトルであり、Kは、入力ベクトルに対してどれくらい関連性があるかを計算のために使用されます。そしてVは、QとKで計算された関連性を強調するものです。
Q, K, Vの元は、上記で説明したPostional Encodingの付与を行ったベクトルです。このベクトルの値に対して3つ異なる線形変換を行いQ, K, Vを生成しています。

Scaled Dot-Prodcutについて

Scaled Dot-Prodcutの流れは、QとKの内積を計算し、スケーリングを行い、softmaxを適用しVとかけて算出する。
式は以下の通り

   \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Scaled Dot-Productは、上記図に示すように複数個存在します。この理由は、さまざまな表現を学習可能にするため。
複数個のScaled Dot-Productを繋げて(concat)、線形変換を行い、出力をしています。

プログラム

これまでのことをコードで書いたのでよかったら

import torch
import torch.nn as nn
import torch.nn.functional as F

# サンプルデータ
sentence = "This is a pen ."
vocab = {"This": 0, "is": 1, "a": 2, "pen": 3, ".": 4}
input_ids = torch.tensor([vocab[word] for word in sentence.split()])
print(input_ids)
# パラメータ
d_model = 512  # エンベディング次元
max_len = 10   # 最大文長
vocab_size = len(vocab)
num_heads = 8  # ヘッドの数
batch_size = 1  # サンプルのためバッチサイズは1とする

# 入力エンベディング
embedding = nn.Embedding(vocab_size, d_model)
x = embedding(input_ids)  # (seq_len, d_model)

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

pos_encoder = PositionalEncoding(d_model, max_len)
x = pos_encoder(x)  # (seq_len, d_model)
x = x.unsqueeze(0)  # (1, seq_len, d_model)

# MultiHeadAttentionのクラス
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)

    def forward(self, v, k, q):
        batch_size = q.size(0)

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len, depth)

        return q, k, v

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(ScaledDotProductAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

    def forward(self, q, k, v, mask=None):
        matmul_qk = torch.matmul(q, k.transpose(-2, -1))  # (batch_size, num_heads, seq_len_q, seq_len_k)
        dk = torch.tensor(k.shape[-1], dtype=torch.float32)
        scaled_attention_logits = matmul_qk / torch.sqrt(dk)  # スケーリング
        print( torch.sqrt(dk), (k[0]))
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)  # マスクの適用

        attention_weights = F.softmax(scaled_attention_logits, dim=-1)  # ソフトマックス
        output = torch.matmul(attention_weights, v)  # (batch_size, num_heads, seq_len_q, depth)

        return output, attention_weights

# インスタンス生成
mha = MultiHeadAttention(d_model, num_heads)
attention = ScaledDotProductAttention(d_model, num_heads)

def multi_head_attention_forward(mha, x):
    q, k, v = mha(x, x, x)
    attention_output, attention_weights = attention(q, k, v)
    
    # ヘッドを結合する
    attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
    attention_output = attention_output.view(batch_size, -1, d_model)  # (batch_size, seq_len, d_model)
    
    # 最終線形層
    output = mha.dense(attention_output)
    
    return output, attention_weights

# 出力の計算
output, attention_weights = multi_head_attention_forward(mha, x)
print(f"Output shape: {output.shape}, Attention Weights shape: {attention_weights.shape}")

おわりに

もし間違った表現などをしていたら教えて頂けると嬉しいです。
また、初めてQiitaに記事を書いたので文章が伝わりづらいかもしれませんが大目に見てください

参考文献

Attention Is All You Need
embeddingとは?手を動かしながら学ぶ
Positional Encodingを理解したい
Multi-Head Attentionの仕組み

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?