0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

MultiheadAttentionの出力の形状を確認する

Posted at

MultiheadAttentionの出力の形状を確認してみます。
適当な設定で宣言をし、乱数をquery, key, valueとして入力します。
乱数が入力されているかつ学習もされていないので、出力の値には意味がありません。あくまでも、形状のチェックが目的です。

import torch
import torch.nn as nn

# 設定
batch_size = 4      # バッチサイズ
seq_length = 10     # シーケンスの長さ
embedding_size = 32 # 埋め込みサイズ
num_heads = 2       # Attentionのヘッド数

# ランダムデータの生成
# バッチサイズ x シーケンス長さ x 埋め込みサイズ のランダムデータ
query = torch.randn(seq_length, batch_size, embedding_size)   # クエリ
key = torch.randn(seq_length, batch_size, embedding_size)     # キー
value = torch.randn(seq_length, batch_size, embedding_size)   # バリュー

# MultiheadAttentionの設定
multihead_attention = nn.MultiheadAttention(
    embed_dim=embedding_size,  # 埋め込みサイズ
    num_heads=num_heads,       # Attentionのヘッド数
    dropout=0.1                # ドロップアウト率
)

# アテンションの計算
# query, key, value はすべて (seq_length, batch_size, embedding_size) という形状です
output, attention_weights = multihead_attention(query, key, value)

# 出力のサイズを確認
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

print(output)
Output shape: torch.Size([10, 4, 32])
Attention weights shape: torch.Size([4, 10, 10])
tensor([[[ 0.1713, -0.2676, -0.0591,  ..., -0.0712, -0.0633,  0.1592],
         [-0.0424, -0.1004,  0.1188,  ...,  0.0526,  0.0797, -0.1483],
         [-0.0829, -0.1702,  0.0319,  ...,  0.1743,  0.0711,  0.1157],
         [-0.0258, -0.0024,  0.0192,  ..., -0.0217,  0.1270, -0.1470]],

        [[ 0.0273, -0.1164,  0.1897,  ...,  0.1240, -0.1581,  0.0513],
         [-0.0379,  0.0415,  0.1967,  ...,  0.0629,  0.0220, -0.1570],
         [-0.1271, -0.2191,  0.0715,  ..., -0.0848, -0.0655,  0.0260],
         [ 0.0506, -0.0920, -0.0293,  ...,  0.0104,  0.2500,  0.1290]],

        [[ 0.0452, -0.0528, -0.1012,  ..., -0.0429,  0.1069,  0.0166],
         [ 0.0841, -0.0791,  0.2010,  ..., -0.1204,  0.0417, -0.1295],
         [ 0.0679, -0.2161,  0.0100,  ..., -0.0142, -0.0047,  0.0991],
         [ 0.0333, -0.1052, -0.1009,  ..., -0.1023,  0.1071, -0.0949]],

        ...,

        [[ 0.1572, -0.2214,  0.0624,  ..., -0.0173, -0.0019,  0.1955],
         [ 0.0669, -0.1636,  0.2195,  ...,  0.0009, -0.0018, -0.1998],
         [-0.3079, -0.1937, -0.0058,  ..., -0.1234,  0.1939,  0.0098],
         [ 0.0121, -0.0515, -0.0999,  ..., -0.0573,  0.1333, -0.1043]],

        [[ 0.1728, -0.4110, -0.0294,  ..., -0.0683,  0.0476,  0.1301],
         [ 0.0846, -0.0632,  0.0253,  ..., -0.1755,  0.0539, -0.0639],
         [-0.1748, -0.1014,  0.0090,  ..., -0.0198, -0.0062, -0.0573],
         [-0.0305,  0.1544, -0.0525,  ..., -0.1008,  0.0284, -0.0273]],

        [[ 0.0601, -0.1530,  0.1416,  ...,  0.1298, -0.1676,  0.1102],
         [ 0.0972, -0.0521,  0.2555,  ..., -0.1308, -0.0005, -0.2223],
         [-0.0917, -0.3137,  0.1493,  ..., -0.0399,  0.0784,  0.0322],
         [ 0.0555, -0.0470, -0.0173,  ...,  0.0620,  0.0864, -0.0523]]],
       grad_fn=<ViewBackward0>)

より理解を進めるためには重みの形状だけではなく、こちらの記事なども参考にMultiheadAttentionの計算の中身を確認してみると良さそうです。
https://blog.amedama.jp/entry/pytorch-multi-head-attention-verify

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?