stable diffusion XLをfp8で学習させたい
先日、stable diffusionのfp8(8bit浮動小数点数)でのトレーニングがdevブランチで可能になりました。
追記:現在はmainブランチにもマージされているようです。
fp8で学習させる事により、消費VRAMがかなり抑えられます。
本記事ではsdxlのトレーニングをしたことがある人向けに手順を書いていきます。
設定にもよりますが、vram6GB
でのトレーニングも可能なようです。
トレーニングは5%ほど遅くなりますが、3GBほどvramを節約できるようです。
cuda11.8の場合(cuda12.1の場合は適宜改変してください)
新しく適当なフォルダを作り、power shellを管理者権限
で開いて、以下の内容をコピペします。
自分のディレクトリに合うように最初の一行目(cdのところ)を改変してください
fp8でのトレーニングにはpytorchの2.1以上が必要なので必ず2.1以上のバージョンとそれに対応するxformersをインストールする必要があります。
cuda11.8以外の環境の場合は自分に合ったバージョンを下のリンクから選択してください。
2024/1/21日現在mainブランチにはマージされていないようなので今回はdevブランチをクローンします。
1/24追記 23日にmainブランチにfp8でのトレーニングがマージされました。
本記事ではdevブランチをcloneしていますが、mainブランチでも同様の手順で行けると思います。
cd D:\hoge
git clone -b dev https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu118
accelerate config
その後いくつか選択肢があるので次のように選択してください。
- This machine
- No distributed training
- NO
- NO
- NO
- all
- fp16
次にbitsandbytesをインストールします。
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
続いてlycorisをインストールします。
python -m pip install lycoris_lora
おまけでlionオプティマイザーを使う人向け
python -m pip install lion_pytorch
ここまでで環境作成は終了です。
fp8でのトレーニング方法
普段使っているsdxl用のbatファイルに--fp8_baseを追加してください。
--fp8_base
今回の学習で私が使用したbatファイルです。
vram24GBでの設定です。OOMした場合はbatch数を下げて下さい。
rem sd-scriptsのパス 自身の環境に合わせて変更してください
set sd_path="D:\sd-scripts"
rem 学習に使うモデルのパス 自身の環境に合わせて変更してください
set ckpt_file="D:\ComfyUI\models\checkpoints\animagine-xl-3.0.safetensors"
rem 学習データセットのパス 自身の環境に合わせて変更してください
set image_path="D:\chara"
set output_path="D:\sd-scripts\output_xl"
rem output name
set file_prefix="chara_animaginev3_locon"
set train_mode=--train_data_dir=%image_path%
set learning_rate=4e-7
rem default 4e-7
set train_batch_size=10
set num_epochs=8
set save_every_n_epochs=2
cd /d %sd_path%
call venv\Scripts\activate.bat & call :main & pause & exit
:main
accelerate launch --num_cpu_threads_per_process 4 sdxl_train_network.py ^
--network_module=lycoris.kohya ^
--pretrained_model_name_or_path=%ckpt_file% ^
%train_mode% ^
--output_dir=%output_path% ^
--caption_extension=".txt" ^
--shuffle_caption ^
--prior_loss_weight=1 ^
--resolution=1024,1024 ^
--enable_bucket ^
--min_bucket_reso=128 ^
--max_bucket_reso=1980 ^
--train_batch_size=%train_batch_size% ^
--learning_rate=%learning_rate% ^
--lr_warmup_steps=0 ^
--network_train_unet_only ^
--max_train_epochs=%num_epochs% ^
--save_every_n_epochs=%save_every_n_epochs% ^
--mixed_precision="fp16" ^
--save_precision="fp16" ^
--xformers ^
--max_data_loader_n_workers=10 ^
--save_model_as=safetensors ^
--output_name=%file_prefix% ^
--seed=42 ^
--network_dim=64 ^
--network_alpha=1 ^
--network_args "algo=locon" "conv_dim=64" "conv_alpha=1" "dropout=0.05" ^
--max_token_length=225 ^
--lr_scheduler=cosine ^
--lr_scheduler_num_cycles=7 ^
--optimizer_type="lion" ^
--gradient_checkpointing ^
--cache_latents ^
--fp8_base ^
--persistent_data_loader_workers
exit /b
素のanimaginev3とfp8でのloraの比較
素のanimagine xl v3 (toki bunny blue archive)
fp8で学習させたtoki (toki bunny blue archive)
最後に、私のx(twitter)アカウントです。(宣伝)