1
0

Colab Enterpriseを使ってLoRA学習

Last updated at Posted at 2023-12-21

Colab Enterpriseを使ってStable DiffusionのLoRA学習を行なってみたいと思います。

Colab Enterpriseは、ひとことで言うと「GCP上のColabratory」に相当するサービスです。ColabratoryやJupyter Notebookと同じように、ブラウザ上でPythonのコードを実行できます。課金は、使用量に応じて課金される従量課金です。有料のColab環境という点ではColabratory Proに似ていますが、個人のクレジットカード登録などが必要ないため、企業内で使用しやすいサービスになっています。

今回はこちらをAI画像生成の実験環境として使用します。Colab Enterpriseを使うメリットは、開発者だけではなく、デザイナーやイラストレーターの方々にも気軽に試してもらいやすいことです。

ただし、今回は画像の生成ではなく、気軽にGPU環境を用意できるクラウドのメリットを活かし、LoRAの学習について紹介してみたいと思います。

LoRAは少ない計算コストでモデルの追加学習を行なう手法です。画像生成の場合、絵柄やキャラクターを学習するために使用されることが多いようです。LoRA学習用途でよく使用されるKohya's GUIをColab Enterprise上で動かしてみます。

1. 環境準備

Colab Enterpriseでは、ランタイムテンプレートを設定することで、ノートブックの実行環境をきめ細かく設定することができます。今回は、次のようなランタイムテンプレートを作成しました。

  • インスタンスタイプ: n1-standard-4
  • GPU: T4
  • ディスクサイズ: 100GB
  • アイドル状態でのシャットダウン: 20分

今回実行したノートブックはこちらに公開しました。
https://gist.github.com/takada-at/545214ed2ddb68b2c526e25bb0f22e0c

基本的にはこちらを順番に実行していくだけです。コードの中身はリポジトリをチェックアウトし、依存ライブラリをインストールしているだけのシンプルなものです。サービスはgradioを通じて公開されるので、パスワード認証を設定しています。

2. 実行

ノートブックを順番に実行していくと、Kohya's GUIが起動します。以下のようにURLが表示されるのでこちらをブラウザで開きます。

kohya.png

表示されたURLにアクセスし、パスワードを入力すると以下のような画面が表示され、LoRAの学習を実行することができます。

スクリーンショット 2023-12-20 12.50.27.png

3. SDXLの学習

Stable Diffusion XL1.0(SDXL1.0)はStable Diffusionの最新版のモデルです。モデル自体が巨大なため、LoRAの学習にもかなりの性能が要求され、実行のハードルが高いです。今回は、Colab Enterprise上でもSDXLでの実行が確認できたので、実行時の設定を共有したいと思います。

ランタイムテンプレート

  • インスタンスタイプ: n1-highmem-4
  • GPU: T4
  • ディスクサイズ: 100GB
  • アイドル状態でのシャットダウン: 20分

メモリ16GBではSDXLの学習に不足するようだったので、メモリ32GBのn1-highmem-4インスタンスを使用します。

学習時の設定

以下のKohya's GUIの公式リポジトリに従って、SDXL用の設定を適用していきます。わかりにくいところのみ以下に解説します。

LR Scheduler / Optimizerの設定

  • LR Scheduler: constant_with_warmupが推奨されているようです。
  • Optimizer: Adafactorに設定し、Optimizer extra argumentに"scale_parameter=False", "relative_step=False", "warmup_init=False"を設定(下記スクリーンショットを参照)。

スクリーンショット 2023-12-20 12.52.19.png

  • Cache text encoder outputs: 「Parameters」「Basic」の「Cache text encoder outputs」をオンにする。

  • UNETのみ学習: 「Parameters」「Advanced」「Additional parameters」に"--network_train_unet_only"を設定し、UNETのみ学習するように設定します。

スクリーンショット 2023-12-20 12.51.21.png

  • Gradient checkpointing: 「Parameters」「Advanced」の「Gradient checkpointing」をオンにする。

以上ひととおりの設定をすれば、T4環境でもLoRA学習が実行できました。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0