0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

SFTTrainerのLiger-Kernel+勾配チェックポイントでメモリ効率約8倍!(多分)

Last updated at Posted at 2025-01-09

背景

8GBメモリのGPUでも9Bモデルのインストラクションチューニングをしてみたいと思ったのでメモリ削減について調べていたらメモリ使用量は減らないが訓練速度が上がるpackingを見つけました。
で、packingのドキュメント読んでたらuse_ligerが目に入ったのでそっちも調べました(use_ligerのDocStringは意味不明)。

結論から

pip install liger-kernelしてSFTConfigでgradient_checkpointing=Trueuse_liger=Trueに設定するだけ。

計測結果

計測結果といっても完全に記憶頼りです。大体あってるはずです。

Liger-Kernel 勾配チェックポイント バッチサイズ シーケンス長 GPUメモリ
False False 2 96 OOM
False True 2 96 約7.8GB
True False 2 96 約7.8GB
True True 2 96 約6.8GB
True True 4 96 約7.2GB
True True 4 384 約7.8GB

Liger-Kernelによってバッチサイズ2倍にしたらGPUの計算効率が上がって訓練速度がかなり上がりましたし、長いシーケンス長でも訓練できるようになりました。
バッチサイズを1にすれば一応1200トークン位まで行けそうです。

多分現時点最高のファインチューニング用Config

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=R,
    lora_alpha=A,
    target_modules=["k_proj", "q_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.1,
    bias="none",
    use_rslora=True,
)
quantization_config = {
    "load_in_4bit": True,
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_use_double_quant": True,
    "bnb_4bit_compute_dtype": "bfloat16",
}
args = SFTConfig(
    bf16=True,
    dataloader_pin_memory=True,
    gradient_checkpointing=True,
    optim="adamw_bnb_8bit",
    packing=True,
    chars_per_token=chars_per_tokens,
    max_seq_length=MAX_SEQ_LENGTH,
    num_of_sequences=num_of_sequences,
    gradient_checkpointing_kwargs={"use_reentrant": True},
    model_init_kwargs={
        "quantization_config": quantization_config,
        "device_map": "auto",
    },
    use_liger=True,
)
trainer = SFTTrainer(
    モデル名,
    args,
    train_dataset=train_dataset,
    peft_config=lora_config,
    formatting_func=tokenizer.apply_chat_templateを内部で使う関数など,
)

formatting_funcとかはSFTTrainerのpackingについて調べれば出てきます。

Unslothについて

Unslothが最近Qiitaでも見かけられるようになりました。
Liger-Kernelより歴史が長いようです。
というかLiger-KernelはUnslothを参考にして作られたんだとか?(ソース忘れました。)
gradient_checkpointing="unsloth"というオプションが使えるらしいのでUnslothのほうがメモリ効率や訓練速度が優れていそうですが、比較記事は特に見つからなかったので簡単に使えるLiger-Kernelを使うのが良さそうです。※私個人の意見です。

追記

https://github.com/linkedin/Liger-Kernel/issues/57によるとLiger-KernelとUnslothはメモリ効率が同等で、速度についてはLiger-Kernelのほうが優れているようです。

余談

Liger-Kernelが無効で勾配チェックポイントが有効の時と、その逆の時とでほぼGPUメモリが同じくらいでした。
しかもLiger-Kernelが有効の時のほうが速度が遅かったです。
なので、実は一度「Liger-Kernel、役に立たないな…」と勘違いしてました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?