BERT LARGEについて、W_Q, W_K, W_Vの形状を確認します。
Attention重みの形状を確認
from transformers import BertModel
# 事前学習済みBERTモデルのロード
model_name = "bert-large-uncased"
model = BertModel.from_pretrained(model_name)
# モデル構造を確認(オプション)
# print(model)
# エンコーダのAttention関連のパラメータを確認
# BERTはトランスフォーマーのエンコーダのみを使用しているため、
# エンコーダ部分にMulti-Head Attentionが含まれます。
for i, layer in enumerate(model.encoder.layer):
print(f"Layer {i+1}:")
# Query, Key, Valueの重みとバイアス
query_weight = layer.attention.self.query.weight
query_bias = layer.attention.self.query.bias
key_weight = layer.attention.self.key.weight
key_bias = layer.attention.self.key.bias
value_weight = layer.attention.self.value.weight
value_bias = layer.attention.self.value.bias
# Outputの重みとバイアス
output_weight = layer.attention.output.dense.weight
output_bias = layer.attention.output.dense.bias
# 表示
print(" Query Weight:", query_weight.shape)
print(" Query Bias:", query_bias.shape)
print(" Key Weight:", key_weight.shape)
print(" Key Bias:", key_bias.shape)
print(" Value Weight:", value_weight.shape)
print(" Value Bias:", value_bias.shape)
print(" Output Dense Weight:", output_weight.shape)
print(" Output Dense Bias:", output_bias.shape)
print()
Layer 1:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 2:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 3:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 4:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 5:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 6:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 7:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 8:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 9:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 10:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 11:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 12:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 13:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 14:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 15:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 16:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 17:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 18:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 19:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 20:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 21:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 22:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 23:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
Layer 24:
Query Weight: torch.Size([1024, 1024])
Query Bias: torch.Size([1024])
Key Weight: torch.Size([1024, 1024])
Key Bias: torch.Size([1024])
Value Weight: torch.Size([1024, 1024])
Value Bias: torch.Size([1024])
Output Dense Weight: torch.Size([1024, 1024])
Output Dense Bias: torch.Size([1024])
W_Q, W_K, W_Vは1024*1024の行列になっていることがわかりました。
また隠れ層の層数(Transformerブロック数)が24となっていることも確認できました。
self-attentionのヘッド数(=16, 後述)はどこかに現れないのか?と思われるかもしれませんが、1024/16の重み行列が16のヘッドそれぞれに対して紐づいています。(つまりヘッド数には依存しません)
multi-head-attentionの重みの考え方は以下のサイトの画像を見るとイメージが掴みやすいと思います。
https://developers.agirobots.com/jp/multi-head-attention/
その他の設定値についても確認
論文には以下のようにあります。
In this work, we denote the number of layers (i.e., Transformer blocks) as L, the hidden size as H, and the number of self-attention heads as A. We primarily report results on two model sizes: BERTBASE (L=12, H=768, A=12, Total Parameters=110M) and BERTLARGE (L=24, H=1024, A=16, Total Parameters=340M).
これはBertConfigで確認することができます。
from transformers import BertConfig
config = BertConfig.from_pretrained(model_name)
print(config)
BertConfig {
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.47.1",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
ここでも、隠れ層の層数(Transformerブロック数)が24であることを確認できました。
また、self-attentionのヘッド数が16であることも確認できました。
総パラメータ数を計算
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")
Total Parameters: 335,141,888
Trainable Parameters: 335,141,888