8
5

More than 1 year has passed since last update.

diffusers で Lora (safetensors) を読み込んで生成する方法

Posted at

はじめに

これはhuggingface/diffusersモジュールを利用して画像生成する際に、safetensorslora を読み込む方法です。

Civitai 等で公開されている Lorasafetensors ファイルであり、pipe.unet.load_attn_procs(model_path) で読み込むことができない。そのため、別の方法で読み込む必要がある。

今回はそれを見つけたため、備忘録

コード

DiffusersScripts の中には様々なモデルの変換をコマンドラインベースで利用できるようにまとめた Python ファイルがあり、その中に LoraDiffusers モデルに変換するためのコードを見つけた。

そのコードを見る限り、読み込むだけでも利用できそうだったため、その部分を利用させていただきました。
コードの大部分はいじってませんので、アップデートが加わっていたら編集することをお勧めします。

引用元はこちら

以下のコードを別のファイル等に保存してモジュールとして利用する
今回はファイル名 loadLora.py で保存してある。

モジュール

import torch
from safetensors.torch import load_file


def load_safetensors_lora(pipeline, checkpoint_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.75):
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path)

    visited = []

    # directly update weight in diffusers model
    for key in state_dict:
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        # as we have set the alpha beforehand, so just skip
        if ".alpha" in key or key in visited:
            continue

        if "text" in key:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        pair_keys = []
        if "lora_down" in key:
            pair_keys.append(key.replace("lora_down", "lora_up"))
            pair_keys.append(key)
        else:
            pair_keys.append(key)
            pair_keys.append(key.replace("lora_up", "lora_down"))

        # update weight
        if len(state_dict[pair_keys[0]].shape) == 4:
            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_up = state_dict[pair_keys[0]].to(torch.float32)
            weight_down = state_dict[pair_keys[1]].to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)

        # update visited list
        for item in pair_keys:
            visited.append(item)

    return pipeline

利用例

pipeStableDiffusionPipeline.from_pretrained 等で作ったオブジェクトが入っている。

pipe = load_safetensors_lora(
    pipe, 
    'YOUR SAFETENSORS LORA PATH'
).to("cuda")

全体

import datetime
import time
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL
from loadLora import load_safetensors_lora

print('model load')
load_time = time.time()
pipe = StableDiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path='YOUR MODEL PATH',
    torch_dtype=torch.float16,
    vae=AutoencoderKL.from_pretrained(
        pretrained_model_name_or_path='YOUR VAE MODEL PATH', 
        torch_dtype=torch.float16
    ),
)

pipe = load_safetensors_lora(
    pipe, 
    'YOUR SAFETENSORS LORA PATH'
).to("cuda")

pipe.safety_checker = None if pipe.safety_checker is None else lambda images, **kwargs: (images, False)
pipe.enable_attention_slicing()
time_load = time.time() - load_time
print(f"Models loaded in {time_load:.2f}s")

image = pipe(
    prompt='YOUR PROMPT',
    height=768,
    width=512,
    num_inference_steps=28,
    guidance_scale=12.0,
    negative_prompt='YOUR PROMPT'
).images[0]

image.save("images/" + str(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')) + ".png")

参考

8
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
5