LLM の学習において、メモリ消費量が一番の敵です。
今回は LLM 学習に使えそうな optimizer をまとめます。
8bit-AdamW
- AdamW の 8bit 版です
- ちなみに AdamW は Adam の問題点を改良したものです
- 量子化しているので精度は落ちますが、ほぼ影響はないです
- 特に素の Adam/AdamW はメモリ消費量がやばいので、普通にやってみて動かせなかった場合は 8bit 版を使うとよさそうです
- huggingface なら
adamw_bnb_8bit
を指定するだけで使えます
training_args = TrainingArguments(per_device_train_batch_size=4, optim="adamw_bnb_8bit", **default_args)
Adafactor
- Adam のメモリ消費量を抑えた改良版
- 素の Adam よりも多少性能面で落ちているようですが、計算コストやメモリの面で大幅に改善されている
- Adafactor も huggingface なら
adafactor
を指定するだけで使えます
training_args = TrainingArguments(per_device_train_batch_size=4, optim="adafactor", **default_args)
Galore
- 勾配を低ランク行列に落とし込んでメモリ消費量を削減した optimizer
- Galore 自体は勾配を低ランク行列に落とし込む手法っぽいので、Adam (Galore 版)、Adafactor (Galore 版) とかが存在するようです
- 8bit Adam と比較しても性能低下はそれほど見られないらしい
- 一方でメモリ消費量は 8bit Adam 以下です
- Galore 論文ではメモリ消費量をわかりやすく画像でまとめてくれていますね
- galore も optim を指定するだけ
# “galore_adamw”, “galore_adamw_8bit”, “galore_adafactor” とかがある
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)