WHY
google colabを使う場合、diffusersで毎回モデルをhugging faceからダウンロードするのに時間がかかるので、よく使うモデルはmy driveに保存しておきたい(次回以降はフォルダパスを指定すればよい)。
以下のようにすれば一応保存できるが、不必要なファイルまで一緒にダウンロードされてしまう。場合によっては数十GBになるのでこれも避けたい。
!git lfs install
!git clone repositoryURL
HOW
方針:通常のやり方でモデルをhugging faceからダウンロードするとキャッシュフォルダに保存されるので、それを保存しなおす。
まず、デフォルトのキャッシュフォルダは以下になっている。
cache_dir = "/root/.cache/huggingface"
この状態でモデルを使おうとすると先ほど設定したキャッシュフォルダに自動でダウンロードされる。
!pip install --upgrade torch diffusers accelerate -q
# 普通にpip install transformersするとバージョンが低くてエラーになる。現時点では4.26.0.dev0
!pip install git+https://github.com/huggingface/transformers
import torch
from diffusers import StableDiffusionDepth2ImgPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16,
)
この時どのように保存されるかというと、こんな感じにsnapshotフォルダにシンボリックリンクが置かれて、本体はblobフォルダに保存される。
/root/.cache/huggingface
├── diffusers
│ └── models--stabilityai--stable-diffusion-2-1
│ ├── blobs
│ │ ├── 1238522277c48923ff2751e238f2742c562e45643f3d50cc93d163cb30638b0c
│ │ ├── 33898360ff5abe5e3c51667a5a7d2fd41238de486cd797d2f70fbe45e1a1c310
│ │ ├── 469be27c5c010538f845f518c4f5e8574c78f7c8
│ │ ├── 5294955ff7801083f720b34b55d0f1f51313c5c5
│ │ ├── 536b82b4e3c62c4898b4ac8725bc514f2a98f5de
│ │ ├── 76e821f1b6f0a9709293c3b6b51ed90980b3166b
│ │ ├── 9b1458658e8651398962171a8c5c56c5c0bd5aea
│ │ ├── 9c60528fdcb99a7caf834426a94ea13c56cf422b
│ │ ├── a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
│ │ ├── a4302e1efa25f3a47ceb7536bc335715ad9d1f203e90c2d25507600d74006e89
│ │ ├── ae0c5be6f35217e51c4c000fd325d8de0294e99c
│ │ ├── cce6febb0b6d876ee5eb24af35e27e764eb4f9b1d0b7c026c8c3333d4cfc916c
│ │ ├── cefd6c7732bbda467ed1ec3e33cfa332a5fe32cf
│ │ ├── e9c787e9388134c1a25dc69934a51a32a2683b38b8a9b017e1f3a692b8ed6b98
│ │ ├── f4fe219b936c0e171504b4bba0c33c7bef6ea211
│ │ └── f97af6a6a8235236b1346312f328569ce2d70f81
│ ├── refs
│ │ └── main
│ └── snapshots
│ └── 36a01dc742066de2e8c91e7cf0b8f6b53ef53da1
│ ├── feature_extractor
│ │ └── preprocessor_config.json -> ../../../blobs/5294955ff7801083f720b34b55d0f1f51313c5c5
│ ├── model_index.json -> ../../blobs/cefd6c7732bbda467ed1ec3e33cfa332a5fe32cf
│ ├── scheduler
│ │ └── scheduler_config.json -> ../../../blobs/536b82b4e3c62c4898b4ac8725bc514f2a98f5de
│ ├── text_encoder
│ │ ├── config.json -> ../../../blobs/9c60528fdcb99a7caf834426a94ea13c56cf422b
│ │ ├── model.safetensors -> ../../../blobs/cce6febb0b6d876ee5eb24af35e27e764eb4f9b1d0b7c026c8c3333d4cfc916c
│ │ └── pytorch_model.bin -> ../../../blobs/e9c787e9388134c1a25dc69934a51a32a2683b38b8a9b017e1f3a692b8ed6b98
│ ├── tokenizer
│ │ ├── merges.txt -> ../../../blobs/76e821f1b6f0a9709293c3b6b51ed90980b3166b
│ │ ├── special_tokens_map.json -> ../../../blobs/ae0c5be6f35217e51c4c000fd325d8de0294e99c
│ │ ├── tokenizer_config.json -> ../../../blobs/f4fe219b936c0e171504b4bba0c33c7bef6ea211
│ │ └── vocab.json -> ../../../blobs/469be27c5c010538f845f518c4f5e8574c78f7c8
│ ├── unet
│ │ ├── config.json -> ../../../blobs/9b1458658e8651398962171a8c5c56c5c0bd5aea
│ │ ├── diffusion_pytorch_model.bin -> ../../../blobs/33898360ff5abe5e3c51667a5a7d2fd41238de486cd797d2f70fbe45e1a1c310
│ │ └── diffusion_pytorch_model.safetensors -> ../../../blobs/1238522277c48923ff2751e238f2742c562e45643f3d50cc93d163cb30638b0c
│ └── vae
│ ├── config.json -> ../../../blobs/f97af6a6a8235236b1346312f328569ce2d70f81
│ ├── diffusion_pytorch_model.bin -> ../../../blobs/a4302e1efa25f3a47ceb7536bc335715ad9d1f203e90c2d25507600d74006e89
│ └── diffusion_pytorch_model.safetensors -> ../../../blobs/a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
└── hub
└── version.txt
13 directories, 34 files
したがって、リンクから本体を特定してリネームしてやればよい。
from glob import glob
import subprocess
import os
model_dir = "saved_model"
for fl in glob(os.path.join(cache_dir, "diffusers","**", "*.*"), recursive=True):
if "/snapshots/" in fl:
src = subprocess.check_output(["readlink", "-f", fl]).decode().strip()
dst = os.path.join(model_dir, *fl.split("/snapshots/")[1].split("/")[1:])
cmd = f"mkdir -p {os.path.dirname(dst)}"
subprocess.run(cmd.split(" ")) #; print(cmd)
cmd = f"mv {src} {dst}" #; print(cmd)
subprocess.run(cmd.split(" "))