0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

BERTのAttention重みの形状を確認する

Last updated at Posted at 2025-01-05

BERT LARGEについて、W_Q, W_K, W_Vの形状を確認します。

Attention重みの形状を確認

from transformers import BertModel

# 事前学習済みBERTモデルのロード
model_name = "bert-large-uncased"
model = BertModel.from_pretrained(model_name)

# モデル構造を確認(オプション)
# print(model)

# エンコーダのAttention関連のパラメータを確認
# BERTはトランスフォーマーのエンコーダのみを使用しているため、
# エンコーダ部分にMulti-Head Attentionが含まれます。
for i, layer in enumerate(model.encoder.layer):
    print(f"Layer {i+1}:")
    
    # Query, Key, Valueの重みとバイアス
    query_weight = layer.attention.self.query.weight
    query_bias = layer.attention.self.query.bias
    key_weight = layer.attention.self.key.weight
    key_bias = layer.attention.self.key.bias
    value_weight = layer.attention.self.value.weight
    value_bias = layer.attention.self.value.bias

    # Outputの重みとバイアス
    output_weight = layer.attention.output.dense.weight
    output_bias = layer.attention.output.dense.bias

    # 表示
    print("  Query Weight:", query_weight.shape)
    print("  Query Bias:", query_bias.shape)
    print("  Key Weight:", key_weight.shape)
    print("  Key Bias:", key_bias.shape)
    print("  Value Weight:", value_weight.shape)
    print("  Value Bias:", value_bias.shape)
    print("  Output Dense Weight:", output_weight.shape)
    print("  Output Dense Bias:", output_bias.shape)
    print()

Layer 1:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 2:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 3:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 4:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 5:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 6:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 7:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 8:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 9:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 10:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 11:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 12:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 13:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 14:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 15:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 16:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 17:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 18:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 19:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 20:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 21:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 22:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 23:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

Layer 24:
  Query Weight: torch.Size([1024, 1024])
  Query Bias: torch.Size([1024])
  Key Weight: torch.Size([1024, 1024])
  Key Bias: torch.Size([1024])
  Value Weight: torch.Size([1024, 1024])
  Value Bias: torch.Size([1024])
  Output Dense Weight: torch.Size([1024, 1024])
  Output Dense Bias: torch.Size([1024])

W_Q, W_K, W_Vは1024*1024の行列になっていることがわかりました。

また隠れ層の層数(Transformerブロック数)が24となっていることも確認できました。

self-attentionのヘッド数(=16, 後述)はどこかに現れないのか?と思われるかもしれませんが、1024/16の重み行列が16のヘッドそれぞれに対して紐づいています。(つまりヘッド数には依存しません)

multi-head-attentionの重みの考え方は以下のサイトの画像を見るとイメージが掴みやすいと思います。
https://developers.agirobots.com/jp/multi-head-attention/

その他の設定値についても確認

論文には以下のようにあります。

In this work, we denote the number of layers (i.e., Transformer blocks) as L, the hidden size as H, and the number of self-attention heads as A. We primarily report results on two model sizes: BERTBASE (L=12, H=768, A=12, Total Parameters=110M) and BERTLARGE (L=24, H=1024, A=16, Total Parameters=340M).

これはBertConfigで確認することができます。

from transformers import BertConfig
config = BertConfig.from_pretrained(model_name)
print(config)
BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.47.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

ここでも、隠れ層の層数(Transformerブロック数)が24であることを確認できました。
また、self-attentionのヘッド数が16であることも確認できました。

総パラメータ数を計算

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")
Total Parameters: 335,141,888
Trainable Parameters: 335,141,888
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?