MultiheadAttentionの出力の形状を確認してみます。
適当な設定で宣言をし、乱数をquery, key, valueとして入力します。
乱数が入力されているかつ学習もされていないので、出力の値には意味がありません。あくまでも、形状のチェックが目的です。
import torch
import torch.nn as nn
# 設定
batch_size = 4 # バッチサイズ
seq_length = 10 # シーケンスの長さ
embedding_size = 32 # 埋め込みサイズ
num_heads = 2 # Attentionのヘッド数
# ランダムデータの生成
# バッチサイズ x シーケンス長さ x 埋め込みサイズ のランダムデータ
query = torch.randn(seq_length, batch_size, embedding_size) # クエリ
key = torch.randn(seq_length, batch_size, embedding_size) # キー
value = torch.randn(seq_length, batch_size, embedding_size) # バリュー
# MultiheadAttentionの設定
multihead_attention = nn.MultiheadAttention(
embed_dim=embedding_size, # 埋め込みサイズ
num_heads=num_heads, # Attentionのヘッド数
dropout=0.1 # ドロップアウト率
)
# アテンションの計算
# query, key, value はすべて (seq_length, batch_size, embedding_size) という形状です
output, attention_weights = multihead_attention(query, key, value)
# 出力のサイズを確認
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(output)
Output shape: torch.Size([10, 4, 32])
Attention weights shape: torch.Size([4, 10, 10])
tensor([[[ 0.1713, -0.2676, -0.0591, ..., -0.0712, -0.0633, 0.1592],
[-0.0424, -0.1004, 0.1188, ..., 0.0526, 0.0797, -0.1483],
[-0.0829, -0.1702, 0.0319, ..., 0.1743, 0.0711, 0.1157],
[-0.0258, -0.0024, 0.0192, ..., -0.0217, 0.1270, -0.1470]],
[[ 0.0273, -0.1164, 0.1897, ..., 0.1240, -0.1581, 0.0513],
[-0.0379, 0.0415, 0.1967, ..., 0.0629, 0.0220, -0.1570],
[-0.1271, -0.2191, 0.0715, ..., -0.0848, -0.0655, 0.0260],
[ 0.0506, -0.0920, -0.0293, ..., 0.0104, 0.2500, 0.1290]],
[[ 0.0452, -0.0528, -0.1012, ..., -0.0429, 0.1069, 0.0166],
[ 0.0841, -0.0791, 0.2010, ..., -0.1204, 0.0417, -0.1295],
[ 0.0679, -0.2161, 0.0100, ..., -0.0142, -0.0047, 0.0991],
[ 0.0333, -0.1052, -0.1009, ..., -0.1023, 0.1071, -0.0949]],
...,
[[ 0.1572, -0.2214, 0.0624, ..., -0.0173, -0.0019, 0.1955],
[ 0.0669, -0.1636, 0.2195, ..., 0.0009, -0.0018, -0.1998],
[-0.3079, -0.1937, -0.0058, ..., -0.1234, 0.1939, 0.0098],
[ 0.0121, -0.0515, -0.0999, ..., -0.0573, 0.1333, -0.1043]],
[[ 0.1728, -0.4110, -0.0294, ..., -0.0683, 0.0476, 0.1301],
[ 0.0846, -0.0632, 0.0253, ..., -0.1755, 0.0539, -0.0639],
[-0.1748, -0.1014, 0.0090, ..., -0.0198, -0.0062, -0.0573],
[-0.0305, 0.1544, -0.0525, ..., -0.1008, 0.0284, -0.0273]],
[[ 0.0601, -0.1530, 0.1416, ..., 0.1298, -0.1676, 0.1102],
[ 0.0972, -0.0521, 0.2555, ..., -0.1308, -0.0005, -0.2223],
[-0.0917, -0.3137, 0.1493, ..., -0.0399, 0.0784, 0.0322],
[ 0.0555, -0.0470, -0.0173, ..., 0.0620, 0.0864, -0.0523]]],
grad_fn=<ViewBackward0>)
より理解を進めるためには重みの形状だけではなく、こちらの記事なども参考にMultiheadAttentionの計算の中身を確認してみると良さそうです。
https://blog.amedama.jp/entry/pytorch-multi-head-attention-verify