LoginSignup
1
1

More than 1 year has passed since last update.

diffusers用のstable diffusionモデルをmy driveに保存する方法(google colab)。

Last updated at Posted at 2023-01-22

WHY

google colabを使う場合、diffusersで毎回モデルをhugging faceからダウンロードするのに時間がかかるので、よく使うモデルはmy driveに保存しておきたい(次回以降はフォルダパスを指定すればよい)。

以下のようにすれば一応保存できるが、不必要なファイルまで一緒にダウンロードされてしまう。場合によっては数十GBになるのでこれも避けたい。

!git lfs install
!git clone repositoryURL

HOW

方針:通常のやり方でモデルをhugging faceからダウンロードするとキャッシュフォルダに保存されるので、それを保存しなおす。

まず、デフォルトのキャッシュフォルダは以下になっている。

cache_dir = "/root/.cache/huggingface"

この状態でモデルを使おうとすると先ほど設定したキャッシュフォルダに自動でダウンロードされる。

!pip install  --upgrade torch diffusers accelerate -q

# 普通にpip install transformersするとバージョンが低くてエラーになる。現時点では4.26.0.dev0
!pip install git+https://github.com/huggingface/transformers

import torch
from diffusers import StableDiffusionDepth2ImgPipeline

pipe = StableDiffusionPipeline.from_pretrained(
   "stabilityai/stable-diffusion-2-1",
   torch_dtype=torch.float16,
)

この時どのように保存されるかというと、こんな感じにsnapshotフォルダにシンボリックリンクが置かれて、本体はblobフォルダに保存される。

/root/.cache/huggingface
├── diffusers
│   └── models--stabilityai--stable-diffusion-2-1
│       ├── blobs
│       │   ├── 1238522277c48923ff2751e238f2742c562e45643f3d50cc93d163cb30638b0c
│       │   ├── 33898360ff5abe5e3c51667a5a7d2fd41238de486cd797d2f70fbe45e1a1c310
│       │   ├── 469be27c5c010538f845f518c4f5e8574c78f7c8
│       │   ├── 5294955ff7801083f720b34b55d0f1f51313c5c5
│       │   ├── 536b82b4e3c62c4898b4ac8725bc514f2a98f5de
│       │   ├── 76e821f1b6f0a9709293c3b6b51ed90980b3166b
│       │   ├── 9b1458658e8651398962171a8c5c56c5c0bd5aea
│       │   ├── 9c60528fdcb99a7caf834426a94ea13c56cf422b
│       │   ├── a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
│       │   ├── a4302e1efa25f3a47ceb7536bc335715ad9d1f203e90c2d25507600d74006e89
│       │   ├── ae0c5be6f35217e51c4c000fd325d8de0294e99c
│       │   ├── cce6febb0b6d876ee5eb24af35e27e764eb4f9b1d0b7c026c8c3333d4cfc916c
│       │   ├── cefd6c7732bbda467ed1ec3e33cfa332a5fe32cf
│       │   ├── e9c787e9388134c1a25dc69934a51a32a2683b38b8a9b017e1f3a692b8ed6b98
│       │   ├── f4fe219b936c0e171504b4bba0c33c7bef6ea211
│       │   └── f97af6a6a8235236b1346312f328569ce2d70f81
│       ├── refs
│       │   └── main
│       └── snapshots
│           └── 36a01dc742066de2e8c91e7cf0b8f6b53ef53da1
│               ├── feature_extractor
│               │   └── preprocessor_config.json -> ../../../blobs/5294955ff7801083f720b34b55d0f1f51313c5c5
│               ├── model_index.json -> ../../blobs/cefd6c7732bbda467ed1ec3e33cfa332a5fe32cf
│               ├── scheduler
│               │   └── scheduler_config.json -> ../../../blobs/536b82b4e3c62c4898b4ac8725bc514f2a98f5de
│               ├── text_encoder
│               │   ├── config.json -> ../../../blobs/9c60528fdcb99a7caf834426a94ea13c56cf422b
│               │   ├── model.safetensors -> ../../../blobs/cce6febb0b6d876ee5eb24af35e27e764eb4f9b1d0b7c026c8c3333d4cfc916c
│               │   └── pytorch_model.bin -> ../../../blobs/e9c787e9388134c1a25dc69934a51a32a2683b38b8a9b017e1f3a692b8ed6b98
│               ├── tokenizer
│               │   ├── merges.txt -> ../../../blobs/76e821f1b6f0a9709293c3b6b51ed90980b3166b
│               │   ├── special_tokens_map.json -> ../../../blobs/ae0c5be6f35217e51c4c000fd325d8de0294e99c
│               │   ├── tokenizer_config.json -> ../../../blobs/f4fe219b936c0e171504b4bba0c33c7bef6ea211
│               │   └── vocab.json -> ../../../blobs/469be27c5c010538f845f518c4f5e8574c78f7c8
│               ├── unet
│               │   ├── config.json -> ../../../blobs/9b1458658e8651398962171a8c5c56c5c0bd5aea
│               │   ├── diffusion_pytorch_model.bin -> ../../../blobs/33898360ff5abe5e3c51667a5a7d2fd41238de486cd797d2f70fbe45e1a1c310
│               │   └── diffusion_pytorch_model.safetensors -> ../../../blobs/1238522277c48923ff2751e238f2742c562e45643f3d50cc93d163cb30638b0c
│               └── vae
│                   ├── config.json -> ../../../blobs/f97af6a6a8235236b1346312f328569ce2d70f81
│                   ├── diffusion_pytorch_model.bin -> ../../../blobs/a4302e1efa25f3a47ceb7536bc335715ad9d1f203e90c2d25507600d74006e89
│                   └── diffusion_pytorch_model.safetensors -> ../../../blobs/a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
└── hub
    └── version.txt

13 directories, 34 files

したがって、リンクから本体を特定してリネームしてやればよい。

from glob import glob
import subprocess
import os

model_dir = "saved_model"

for fl in glob(os.path.join(cache_dir, "diffusers","**", "*.*"), recursive=True):
  if "/snapshots/" in fl:
    src = subprocess.check_output(["readlink", "-f", fl]).decode().strip()
    dst = os.path.join(model_dir, *fl.split("/snapshots/")[1].split("/")[1:])
    cmd = f"mkdir -p {os.path.dirname(dst)}"
    subprocess.run(cmd.split(" ")) #; print(cmd)
    cmd = f"mv {src} {dst}" #; print(cmd)
    subprocess.run(cmd.split(" "))
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1