LoginSignup
10
7

More than 1 year has passed since last update.

【Stable Diffusion】DiffusersでWebUIのHires.fix(latent)を再現してみる(Jupyter/Colab/ipythonなどで実行可能)

Last updated at Posted at 2023-06-19

はじめに

Stable Diffusionの優れた拡張機能が使えるアプリとして有名なAUTOMATIC1111さんのstable-diffusion-webuiがありますが、中でも高画質化に貢献するHires.fixをDiffusersだけで再現できないかと思い、実装してみました。

実装のリポジトリ

こちらにアップロードしました。あくまで私が解釈した処理なので、本家の処理方法と少し違うかもしれませんがご容赦ください。

本記事で行える機能

  • :white_check_mark: HuggingFaceやcivitaiにあるsafetensorsファイルを用意するだけでOK
  • :white_check_mark: Hires.fixをDiffusersで再現 WebUI:AUTOMATIC1111/stable-diffusion-webui
    • :white_check_mark: 潜在変数モード(Latent)
    • :heavy_multiplication_x: GANモデル
  • :white_check_mark: 多段階のアップスケーリング(Hires.fixの拡張)
  • :white_check_mark: LoRAの読み込み(Safetensorsファイル)
  • :heavy_multiplication_x: Controlnet
  • :heavy_multiplication_x: マルチバッチ生成

Hires.fix(Latentモード)の仕組み

本家の実装の中身と、Web UIで変更できる各パラメータの挙動を観察したところ、Hires.fixのLatentモードは、以下のステップから構成されていると推測しました。

  1. 通常のtxt2imgで画像生成
  2. 1.で得られたLatent(潜在変数)をアップサンプリング
  3. アップサンプリング後のLatentをVAEでデコードし画像を取得
  4. 3.で取得した画像を初期画像としてimg2imgで画像生成

これをDiffusersだけで再現してみます。

実装(抜粋)

Jupyterやipythonでも気軽に生成できるように、画像ジェネレータをクラスとして定義しました。
txt2imgとimg2imgを両方扱うので、単純にpipeを2個用意します。

class StableDiffusionImageGenerator:
    def __init__(
            self,
            sd_safetensor_path: str,
            device: str="cuda",
            dtype: torch.dtype=torch.float16,
            ):
        self.device = torch.device(device)
        self.pipe = StableDiffusionPipeline.from_ckpt(
            sd_safetensor_path,
            torch_dtype=dtype,
        ).to(device)
        self.pipe_i2i = StableDiffusionImg2ImgPipeline.from_ckpt(
            sd_safetensor_path,
            torch_dtype=dtype,
        ).to(device)

safetensors形式のLoRAを読み込む機能を追加します。

class StableDiffusionImageGenerator:
    ...
    def load_lora(self, safetensor_path, alpha=0.75):
        self.pipe = load_safetensors_lora(self.pipe, safetensor_path, alpha=alpha, device=self.device)
        self.pipe_i2i = load_safetensors_lora(self.pipe_i2i, safetensor_path, alpha=alpha, device=self.device)

txt2imgを行う関数を定義します。

class StableDiffusionImageGenerator:
    ...
    def diffusion_from_noise(
        self,
        prompt,
        negative_prompt,
        scheduler_name="dpm++_2m_karras",
        num_inference_steps=20, 
        guidance_scale=9.5,
        width=512,
        height=512,
        output_type="pil",
        decode_factor=0.18215,
        seed=1234,
        save_path=None
        ):
        
        self.pipe.scheduler = SCHEDULERS[scheduler_name].from_config(self.pipe.scheduler.config)
        self.pipe.scheduler.set_timesteps(num_inference_steps, self.device)
        seed = random.randint(1, 1000000000) if seed == -1 else seed
        
        with torch.no_grad():
            latents = self.pipe(
                prompt=prompt, 
                negative_prompt=negative_prompt,
                num_inference_steps=num_inference_steps, 
                generator=torch.manual_seed(seed),
                guidance_scale=guidance_scale,
                width=width,
                height=height,
                output_type="latent"
            ).images # 1x4x(W/8)x(H/8)
        
            if save_path is not None:
                pil_image = self.decode_latents_to_PIL_image(latents, decode_factor)
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                pil_image.save(save_path, quality=95)
        
            if output_type == "latent":
                return latents
            elif output_type == "pil":
                return self.decode_latents_to_PIL_image(latents, decode_factor)
            else:
                raise NotImplementedError()

img2imgを行う関数を定義します。

class StableDiffusionImageGenerator:
    ...
    def diffusion_from_image(
            self,
            prompt,
            negative_prompt,
            image,
            scheduler_name="dpm++_2m_karras",
            num_inference_steps=20,
            denoising_strength=0.58,
            guidance_scale=10,
            output_type="pil",
            decode_factor=0.18215,
            seed=1234,
            save_path=None
            ):

        self.pipe_i2i.scheduler = SCHEDULERS[scheduler_name].from_config(self.pipe_i2i.scheduler.config)
        self.pipe_i2i.scheduler.set_timesteps(num_inference_steps, self.device)
        seed = random.randint(1, 1000000000) if seed == -1 else seed

        with torch.no_grad():
            latents = self.pipe_i2i(
                prompt=prompt, 
                negative_prompt=negative_prompt,
                image=image,
                num_inference_steps=num_inference_steps, 
                strength=denoising_strength,
                generator=torch.manual_seed(seed),
                guidance_scale=guidance_scale,
                output_type="latent"
            ).images # 1x4x(W/8)x(H/8)

            if save_path is not None:
                pil_image = self.decode_latents_to_PIL_image(latents, decode_factor)
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                pil_image.save(save_path, quality=95)

            if output_type == "latent":
                return latents
            elif output_type == "pil":
                return self.decode_latents_to_PIL_image(latents, decode_factor)
            else:
                raise NotImplementedError()

さいごに、txt2img、img2imgを組み合わせて多段階のスケーリングを行う関数を定義します。

class StableDiffusionImageGenerator:
    ...
    def diffusion_enhance(
            self,
            prompt,
            negative_prompt,
            scheduler_name="dpm++_2m_karras",
            num_inference_steps=20,
            num_inference_steps_enhance=20,
            guidance_scale=10,
            width=512,
            height=512,
            seed=1234,
            upscale_target="latent", # "latent" or "pil"
            interpolate_mode="nearest",
            antialias = True,
            upscale_by=1.8,
            enhance_steps=2, # 2=Hires.fix
            denoising_strength=0.58,
            output_type="pil",
            decode_factor=0.15,
            decode_factor_final=0.18215,
            save_dir="output"
            ):
        
        with torch.no_grad():
            w_init = calc_pix_8(width)
            h_init = calc_pix_8(height)
            w_final = calc_pix_8(w_init * upscale_by)
            h_final = calc_pix_8(h_init * upscale_by)
            resolution_pairs = [(calc_pix_8(x), calc_pix_8(y)) for x, y 
                    in zip(np.linspace(w_init, w_final, enhance_steps),
                            np.linspace(h_init, h_final, enhance_steps))
                    ]
            image = None
            now_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

            if enhance_steps == 1: # Single generation
                image = self.diffusion_from_noise(
                        prompt,
                        negative_prompt,
                        scheduler_name=scheduler_name,
                        num_inference_steps=num_inference_steps, 
                        guidance_scale=guidance_scale,
                        width=w_final,
                        height=h_final,
                        output_type=output_type,
                        decode_factor=decode_factor_final,
                        seed=seed,
                        save_path=os.path.join(save_dir, f"{now_str}.jpg")
                    )
                return image

            
            for i, (w, h) in enumerate(resolution_pairs):

                if image is None: # Step 1: Generate low-quality image
                    image = self.diffusion_from_noise(
                        prompt,
                        negative_prompt,
                        scheduler_name=scheduler_name,
                        num_inference_steps=num_inference_steps, 
                        guidance_scale=guidance_scale,
                        width=w,
                        height=h,
                        output_type=upscale_target,
                        decode_factor=decode_factor,
                        seed=seed,
                        save_path=os.path.join(save_dir, f"{now_str}_{i}.jpg")
                    )
                    continue

                # Step 2: Interpolate latent or image -> PIL image
                if upscale_target == "latent":
                    image = torch.nn.functional.interpolate(
                            image,
                            (h // 8, w // 8),
                            mode=interpolate_mode,
                            antialias=True if antialias and interpolate_mode != "nearest" else False,
                        )
                    image = self.decode_latents_to_PIL_image(image, decode_factor)
                else:
                    image = image.resize((w, h), Image.Resampling.LANCZOS)

                # Step 3: Generate image (i2i) 
                if i < len(resolution_pairs) - 1:
                    image = self.diffusion_from_image(
                        prompt,
                        negative_prompt,
                        image,
                        scheduler_name=scheduler_name,
                        num_inference_steps=int(num_inference_steps_enhance / denoising_strength) + 1, 
                        denoising_strength=denoising_strength,
                        guidance_scale=guidance_scale,
                        output_type=upscale_target,
                        decode_factor=decode_factor,
                        seed=seed,
                        save_path=os.path.join(save_dir, f"{now_str}_{i}.jpg")
                    )

                else: # Final enhance
                    image = self.diffusion_from_image(
                        prompt,
                        negative_prompt,
                        image,
                        scheduler_name=scheduler_name,
                        num_inference_steps=int(num_inference_steps_enhance / denoising_strength) + 1, 
                        denoising_strength=denoising_strength,
                        guidance_scale=guidance_scale,
                        output_type=output_type,
                        decode_factor=decode_factor_final,
                        seed=seed,
                        save_path=os.path.join(save_dir, f"{now_str}_{i}.jpg")
                    )
                    return image

Jupyter/Colab/ipython等での使い方

  • メインクラスをインポートします。
from s2d2 import StableDiffusionImageGenerator
  • 好きなSDモデル、LoRAを読み込みます(適用する強さ(アルファ値)を指定)。
generator = StableDiffusionImageGenerator(
    "braBeautifulRealistic_brav5.safetensors",
    device="cuda",
)
generator.load_lora("hogeLoRA.safetensors", alpha=0.2)
generator.load_lora("fugaLoRA.safetensors", alpha=0.15)
  • 各パラメータを指定して、画像生成を実行するだけです。
image = generator.diffusion_enhance(
          prompt,
          negative_prompt,
          scheduler_name="dpm++_2m_karras", # [1]
          num_inference_steps=20, # [2]
          num_inference_steps_enhance=20, # [3]
          guidance_scale=10,  # [4]
          width=700, # [5]
          height=500, # [6]
          seed=-1, # [7]
          upscale_target="latent", # [8] "latent" or "pil". pil mode is temporary implemented.
          interpolate_mode="bicubic", # [9]
          antialias=True, # [10]
          upscale_by=1.8, # [11]
          enhance_steps=2, # [12] 2=Hires.fix
          denoising_strength=0.60, # [13]
          output_type="pil", # [14] "latent" or "pil"
          decode_factor=0.15, # [15] Denominator when decoding latents. Used to adjust the saturation of the image during decoding.
          decode_factor_final=0.18215, # [16] Denominator when decoding final latents.
          )
image.save("generated_image.jpg") # or just "image" to display image in jupyter

Web UIのパラメータとの対応
image

画像生成サンプル

  • 利用したモデル: Counterfeit-V30.safetensors
  • 初期解像度: 696x496
  • アップスケーリング倍数値: 1.8
  • 最終目的解像度: 696x496(x1.8, 最も近い8の倍数) = 1248x888

2段階生成(Hires.fix)

image

N段階生成(例:4段階)

初期解像度と目的解像度の間を等間隔に分け、少しずつ解像度を上げます。
image

Latentアップスケーリングの有無による生成画像の比較

  • アップスケーリングなし:696x496での一回の生成のみ

  • アップスケーリングあり: 2段階生成(Hires.fix)、 696x496から1248x888へアップスケーリング

  • プロンプト: "1girl, solo, full body, blue eyes, looking at viewer, hairband, bangs, brown hair, long hair, smile, blue eyes, wine-red dress, outdoor, night, moonlight, castle, flowers, garden"

  • ネガティブプロンプト: "EasyNegative, extra fingers, fewer fingers, bad hands"

”wine-red dress”があまり考慮されておらず、青っぽい衣装になってしまいました。。
”wine red dress”とすべきだったかもしれません(ハイフンでつなげるとトークナイザが分解できないため)。

image
image
image
image
image
image
image

単体で生成した画像よりも、ディテールが細かく描かれた画像が出力されていることがわかります。

最初から高解像度で生成することもできますが、いきなり高解像度の潜在変数を拡散するのは得意としないため、高解像度の画像生成ができる「Hires.fix」では①まずは得意な低解像度で生成、②①で生成された画像の潜在変数を使うことで、構図を保って高解像度化する、という流れができたようです。

さいごに

思い付きで実装したものであるため正常に生成できないケースがあることや、できることの制限が多いのはご容赦いただければ幸いです。

次はControlNetを追加してみようかなと思います。

参考文献

実装するにあたり、以下のリポジトリを参考にさせていただきました。

10
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
7