はじめに
これはhuggingface/diffusersモジュールを利用して画像生成する際に、safetensors
の lora
を読み込む方法です。
Civitai 等で公開されている Lora
は safetensors
ファイルであり、pipe.unet.load_attn_procs(model_path)
で読み込むことができない。そのため、別の方法で読み込む必要がある。
今回はそれを見つけたため、備忘録
コード
Diffusers
の Scripts
の中には様々なモデルの変換をコマンドラインベースで利用できるようにまとめた Python
ファイルがあり、その中に Lora
を Diffusers
モデルに変換するためのコードを見つけた。
そのコードを見る限り、読み込むだけでも利用できそうだったため、その部分を利用させていただきました。
コードの大部分はいじってませんので、アップデートが加わっていたら編集することをお勧めします。
引用元はこちら
以下のコードを別のファイル等に保存してモジュールとして利用する
今回はファイル名 loadLora.py
で保存してある。
モジュール
import torch
from safetensors.torch import load_file
def load_safetensors_lora(pipeline, checkpoint_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.75):
# load LoRA weight from .safetensors
state_dict = load_file(checkpoint_path)
visited = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return pipeline
利用例
pipe
に StableDiffusionPipeline.from_pretrained
等で作ったオブジェクトが入っている。
pipe = load_safetensors_lora(
pipe,
'YOUR SAFETENSORS LORA PATH'
).to("cuda")
全体
import datetime
import time
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL
from loadLora import load_safetensors_lora
print('model load')
load_time = time.time()
pipe = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path='YOUR MODEL PATH',
torch_dtype=torch.float16,
vae=AutoencoderKL.from_pretrained(
pretrained_model_name_or_path='YOUR VAE MODEL PATH',
torch_dtype=torch.float16
),
)
pipe = load_safetensors_lora(
pipe,
'YOUR SAFETENSORS LORA PATH'
).to("cuda")
pipe.safety_checker = None if pipe.safety_checker is None else lambda images, **kwargs: (images, False)
pipe.enable_attention_slicing()
time_load = time.time() - load_time
print(f"Models loaded in {time_load:.2f}s")
image = pipe(
prompt='YOUR PROMPT',
height=768,
width=512,
num_inference_steps=28,
guidance_scale=12.0,
negative_prompt='YOUR PROMPT'
).images[0]
image.save("images/" + str(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')) + ".png")
参考