0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Hugging Faceからダウンロードしたモデルの詳細を確認する

0
Posted at

1. はじめに

Hugging Faceのモデルをロードした後にパラメータ数、メモリ使用量、モデルの構造等を確認する方法を記載します。

確認に使用したモデル
https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2

2. 確認方法

2-1. 必要なもの

pip install torch
pip install transformers

2-2. モデルのロード

python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
# 初回ロード時のみ15GBくらいのダウンロードが走ります

2-3. 各種情報確認

config確認

print(model.config)
# 以下出力
MistralConfig {
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "dtype": "bfloat16",
  "eos_token_id": 2,
  "head_dim": 128,
# ~長いため省略~
# config.jsonとほぼ同じ内容

print(model.config.num_hidden_layers)
# 32

モデルのフォーマット

print(model.dtype)
# torch.bfloat16 ※BF16

デバイス確認

print(model.device)
# cpu
model.cuda()
# GPUに移動 ※CPUのメモリには残り続けるので注意
print(model.device)
# cuda:0

全パラメータ数

sum(p.numel() for p in model.parameters())
# 7241732096
sum(p.numel() for p in model.parameters() if p.requires_grad)
# 7241732096 ※学習可能なパラメータのみ

メモリ使用量

total_params = sum(p.numel() for p in model.parameters())
# 7241732096 ※全パラメータ数
mem_params = sum(p.nelement() * p.element_size() for p in model.parameters())
# 14483464192 ※メモリ使用量
print(mem_params / total_params)
# 2.0 ※パラメータ1つあたりのメモリ使用量 BF16と一致

2-4. モデルの形状を表示

有名なコマンドかと思いますが、以下でモデルの構造とパラメータ数を確認できます。
右に記載の表の通り、単純に要素同士を掛けて合計すれば2-3で表示させた総数と一致します。
今回使用したモデルは32層なので、DecoderLayer内のパラメータは32個分になります。
(LlamaやMistralなどの標準的なモデルでいくつか試した限り一致していましたが、複雑なモデルだと一致しないかもしれません)
※表を右に挿入する方法が分からなかったので画像にしています

print(model)

image.png

以下の場合、上記の計算が一致しないと思われます
・以下がTrueになる場合
上記での(embed_tokens)と(lm_head)で同じものを流用(重み共有)するためユニークなパラメータの数は減ります

print(model.config.tie_word_embeddings)

・print(model)内にbias=Trueのパラメータがある場合
bias分の数が増えます

2-5. パラメータの中身を確認

各パラメータの中身を見ていきます。

sd = model.state_dict()
sd_name = [i for i,j in sd.items()]
# 全パラメータの名前を取得

print(sd_name[0])
# 'model.embed_tokens.weight'

print(sd[sd_name[0]])
# tensor([[-2.0336e-36,  3.3208e-37, -1.5517e-35,  ..., -4.9371e-36,
#          -7.9934e-36, -5.9480e-36],
# 長いため省略

sd_nameのままだと32層分のレイヤがすべて表示され見づらいので、layer.nに置き換えます

import re
sd_name_layer_n = [re.sub("layers.[0-9]+","layers.n", i) for i,j in sd.items()]
sd_name_layer_n = list(dict.fromkeys(sd_name_layer_n))
print(sd_name_layer_n)
# ['model.embed_tokens.weight', 'model.layers.n.self_attn.q_proj.weight', 'model.layers.n.self_attn.k_proj.weight', 'model.layers.n.self_attn.v_proj.weight', 'model.layers.n.self_attn.o_proj.weight', 'model.layers.n.mlp.gate_proj.weight', 'model.layers.n.mlp.up_proj.weight', 'model.layers.n.mlp.down_proj.weight', 'model.layers.n.input_layernorm.weight', 'model.layers.n.post_attention_layernorm.weight', 'model.norm.weight', 'lm_head.weight']
# レイヤをnに書き換えた重みの内訳

レイヤ1層分にしたパラメータの形状は以下でまとめて確認できます。
print(model)の表示と比較してみると、Linearの部分はPyTorchのnn.Linearの実装のためtranspose(転置)状態で配置されているようです。

for i in sd_name_layer_n:
    i = i.replace(".n.", ".0.")
    print("Name:", i, "Shape:", sd[i].shape)
# 以下出力
Name: model.embed_tokens.weight Shape: torch.Size([32000, 4096])
Name: model.layers.0.self_attn.q_proj.weight Shape: torch.Size([4096, 4096])
Name: model.layers.0.self_attn.k_proj.weight Shape: torch.Size([1024, 4096])
Name: model.layers.0.self_attn.v_proj.weight Shape: torch.Size([1024, 4096])
Name: model.layers.0.self_attn.o_proj.weight Shape: torch.Size([4096, 4096])
Name: model.layers.0.mlp.gate_proj.weight Shape: torch.Size([14336, 4096])
Name: model.layers.0.mlp.up_proj.weight Shape: torch.Size([14336, 4096])
Name: model.layers.0.mlp.down_proj.weight Shape: torch.Size([4096, 14336])
Name: model.layers.0.input_layernorm.weight Shape: torch.Size([4096])
Name: model.layers.0.post_attention_layernorm.weight Shape: torch.Size([4096])
Name: model.norm.weight Shape: torch.Size([4096])
Name: lm_head.weight Shape: torch.Size([32000, 4096])

念のための確認ですが、上記を32層分にすると2-3で表示させた総数と一致します。

params = 0
for i in sd_name_layer_n:
    i = i.replace(".n.", ".0.")
    if ".0." in i:
        params += sd[i].numel()*32
    else:
        params += sd[i].numel()

print(params)
# 7241732096

3. まとめ

Transformersでロード可能なモデルであれば、上記の通り簡単にモデルの詳細情報を確認することができます。

実行環境

Python 3.12.3
torch 2.12.0
transformers 5.8.1

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?