LoginSignup
4

diffusers で Lora (safetensors) を読み込んで生成する方法

Posted at

はじめに

これはhuggingface/diffusersモジュールを利用して画像生成する際に、safetensorslora を読み込む方法です。

Civitai 等で公開されている Lorasafetensors ファイルであり、pipe.unet.load_attn_procs(model_path) で読み込むことができない。そのため、別の方法で読み込む必要がある。

今回はそれを見つけたため、備忘録

コード

DiffusersScripts の中には様々なモデルの変換をコマンドラインベースで利用できるようにまとめた Python ファイルがあり、その中に LoraDiffusers モデルに変換するためのコードを見つけた。

そのコードを見る限り、読み込むだけでも利用できそうだったため、その部分を利用させていただきました。
コードの大部分はいじってませんので、アップデートが加わっていたら編集することをお勧めします。

引用元はこちら

以下のコードを別のファイル等に保存してモジュールとして利用する
今回はファイル名 loadLora.py で保存してある。

モジュール

import torch
from safetensors.torch import load_file


def load_safetensors_lora(pipeline, checkpoint_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.75):
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path)

    visited = []

    # directly update weight in diffusers model
    for key in state_dict:
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        # as we have set the alpha beforehand, so just skip
        if ".alpha" in key or key in visited:
            continue

        if "text" in key:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        pair_keys = []
        if "lora_down" in key:
            pair_keys.append(key.replace("lora_down", "lora_up"))
            pair_keys.append(key)
        else:
            pair_keys.append(key)
            pair_keys.append(key.replace("lora_up", "lora_down"))

        # update weight
        if len(state_dict[pair_keys[0]].shape) == 4:
            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_up = state_dict[pair_keys[0]].to(torch.float32)
            weight_down = state_dict[pair_keys[1]].to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)

        # update visited list
        for item in pair_keys:
            visited.append(item)

    return pipeline

利用例

pipeStableDiffusionPipeline.from_pretrained 等で作ったオブジェクトが入っている。

pipe = load_safetensors_lora(
    pipe, 
    'YOUR SAFETENSORS LORA PATH'
).to("cuda")

全体

import datetime
import time
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL
from loadLora import load_safetensors_lora

print('model load')
load_time = time.time()
pipe = StableDiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path='YOUR MODEL PATH',
    torch_dtype=torch.float16,
    vae=AutoencoderKL.from_pretrained(
        pretrained_model_name_or_path='YOUR VAE MODEL PATH', 
        torch_dtype=torch.float16
    ),
)

pipe = load_safetensors_lora(
    pipe, 
    'YOUR SAFETENSORS LORA PATH'
).to("cuda")

pipe.safety_checker = None if pipe.safety_checker is None else lambda images, **kwargs: (images, False)
pipe.enable_attention_slicing()
time_load = time.time() - load_time
print(f"Models loaded in {time_load:.2f}s")

image = pipe(
    prompt='YOUR PROMPT',
    height=768,
    width=512,
    num_inference_steps=28,
    guidance_scale=12.0,
    negative_prompt='YOUR PROMPT'
).images[0]

image.save("images/" + str(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')) + ".png")

参考

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
4