4
6

More than 1 year has passed since last update.

Google ColabでLoRAを実行する

Last updated at Posted at 2023-02-23

GPUのスペック

スクリーンショット 2023-02-23 20.23.06.png

メモリは16GBですね。

作成するファイル

今回はGoogle Colab上で

output_lora_weight.ipynb

output_lora_image.ipynb

の二種類のファイルを作成します。

注意点

2023年2月23日時点で動くコードです。

また、本来作成するファイルの拡張子は.ipynbですが、Qiitaのマークダウン記述では.pyとすると、色がつくので、ソースコードの部分では.ipynbが.pyになっています。

output_lora_weight.pyはoutput_lora_weight.ipynbで、
output_lora_image.pyはoutput_lora_image.ipynbです。

ソースコード(output_lora_weight.ipynb)

セル1

output_lora_weight.py
!sudo apt-get install python3.10
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.7 1
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 2
!sudo update-alternatives --config python3
!curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
!python3 get-pip.py --force-reinstall

python3.10をColabにインストールします。

実行したら数字を入力する場面が出てきます。

好きなバージョンに該当する数字を入れてください。

セル2

output_lora_weight.py
from google.colab import drive
drive.mount('/content/drive')

ドライブをマウント

セル3

output_lora_weight.py
import os
os.makedirs("/content/drive/My Drive/work", exist_ok=True)
%cd '/content/drive/My Drive/work'

作業フォルダへ移動します

セル4

output_lora_weight.py
!git clone https://github.com/cloneofsimo/lora.git
!pip install accelerate bitsandbytes
%cd lora
!pip install .
os.makedirs("./instance_data", exist_ok=True)
os.makedirs("./output", exist_ok=True)

ここでGoogle Driveのwork/lora内にinstance_dataというフォルダができるので、Google Driveのコンソールから、学習させたいデータを入れてください。私は5枚ほどにしました。

セル5

output_lora_weight.py
!accelerate launch   \
  --num_processes=1  \
  --num_machines=1  \
  --mixed_precision="fp16"  \
  --dynamo_backend="no"  \
  training_scripts/train_lora_dreambooth.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base"  \
  --instance_prompt="sks" \
  --instance_data_dir="./instance_data" \
  --output_dir="./output" \
  --resolution=512 \
  --train_batch_size=1 \
  --color_jitter \
  --learning_rate=1e-4 \
  --learning_rate_text=5e-5 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --gradient_accumulation_steps=1 \
  --max_train_steps=1 \
  --train_text_encoder

ファインチューニングの実行を行います。

--train_batch_sizeを2以上にすると、
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
というエラーが発生します。

--max_train_stepsを2以上にすると
torch.cuda.OutOfMemoryError: CUDA out of memory.
というエラーが発生します。

しかし
それぞれの値が1では思ったような画像が生成されない可能性が高いです。

実際、私は思ったような画像が出力できませんでした。

このセルでは、Lora TRAINING DONE!という表示が出れば成功です。

ソースコード(output_lora_image.ipynb)

続いて、output_lora_image.ipynbを作成していきます。

セル1

output_lora_image.py
from google.colab import drive
drive.mount('/content/drive')

セル2

output_lora_image.py
import os
os.makedirs("/content/drive/My Drive/work", exist_ok=True)
%cd '/content/drive/My Drive/work'

セル3

output_lora_image.py
!git clone https://github.com/cloneofsimo/lora.git
!pip install accelerate bitsandbytes
%cd lora
!pip install .

ここまでは一つ目のファイルと同じです。強いて言えば、以下のコードを抜かしています。

もう既にあるディレクトリは作成する必要がないからです。

.py
os.makedirs("./instance_data", exist_ok=True)
os.makedirs("./output", exist_ok=True)

セル4

output_lora_image.py
import torch
from lora_diffusion import monkeypatch_or_replace_lora, tune_lora_scale
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
model_id = "stabilityai/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(
    model_id, 
    scheduler=EulerDiscreteScheduler.from_pretrained(
        model_id, 
        subfolder="scheduler"
    ),     
    torch_dtype=torch.float16
).to("cuda")

monkeypatch_or_replace_lora(pipe.unet, 
    torch.load(os.path.join("./output", "lora_weight.pt")))
monkeypatch_or_replace_lora(pipe.text_encoder, 
    torch.load(os.path.join("./output", "lora_weight.text_encoder.pt")), 
    target_replace_module=["CLIPAttention"])

ここでStable Diffusionのパイプラインを設定しています。

セル5

output_lora_image.py
tune_lora_scale(pipe.unet, 0.8)
tune_lora_scale(pipe.text_encoder, 0.8)

image = pipe(
    "sks on the moon", 
    num_inference_steps=50, 
    guidance_scale=7
).images[0]
image

こちらでLoRAの重みを調整して、推論を実行しています。

このセルを実行することで画像が生成されます。

参考:
https://note.com/npaka/n/ndb287a48b682

4
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6