5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Stable Diffusion 画像を作っている途中を見たい!Part 2

Last updated at Posted at 2022-10-23

お絵描きAI、Stable Diffusion 今回は img2img を扱います。
Stable Diffusion 画像を作っている途中を見たい!では txt2img のみでしたが、
同じモデルで実装できるということで、同時に使えるようにします。

使用したPCは
OS:Windows10 Home 64bit
CPU:Intel(R) Core(TM) i7-10750H
GPU:NVIDIA GeForce GTX 1650 Ti
RAM:16GB
です。

今回の目標は
● StableDiffusionPipelineを使わない。
● Jupyter Notebook で txt2img/img2img をシームレスに実行できるようにする。
● 各関数に visible フラグを追加して途中画像の表示/非表示を選択できるようにする。
● 作っている途中を GIF アニメーションに保存する。
とします。

「環境構築」

基本はStable DiffusionをノートPCで持ち歩きたい!と同じです。
以上の記事の 「Notebookを作成する」 で新たな Notebook を作成することにします。
ただし、 GIF アニメションを作成するため、 pillow ライブラリをインストールする必要があります。
インストールには Anaconda Prompt を使用します。手順は以下の通りです。
Anaconda Navigater を Windows のメニューから起動します。
Anaconda.jpg
左メニューから Environments を選択。環境 ldm の ▶ マークから Open Terminal を
クリックして Anaconda Prompt を起動します。
Anaconda_prompt.jpg

conda install pillow

と入力してインストールします。

「新たな Notebook の作成」

一旦、すべての Window を閉じます。
スタートメニューから Anacnda3(64ビット)→ Jupyter Notebook(ldm) を起動します。
jupyternotebook.jpg
ダークモードを使用しているので見かけが違いますが気にしないでください。
右上の New プルダウンから Python 3(ipykernel) を選択し、空のNotebookを作成します。
名前はUntitledになりますが、保存後に Rename できます。
以降、コードセルにコードを入力し、実行していきます。

「環境の確認」

Notebookの初めのセルで以下のコードを入力し、動作を確認します。

import sys
print("Python = "+sys.version)
import numpy as np
print("Numpy = "+np.__version__)
import torch
print("Pytorch = "+torch.__version__)
print("Pytorch GPU =",torch.cuda.is_available())
if torch.cuda.is_available():
    print("Pytorch GPU Name = "+torch.cuda.get_device_name())
    !nvcc --version
import matplotlib
print("Matplotlib = "+matplotlib.__version__)
!pwd

結果は
Env.jpg
Pytorch Cuda が動作し、 GPU が認識されていることがわかります。

「optimized_txt2img/optimized_img2imgを関数化する」

次のセルに以下を入力(コピペでOKです)。

import argparse, os, re
import torch
import numpy as np
from random import randint
from omegaconf import OmegaConf
from PIL import Image
from tqdm.auto import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
import importlib

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def instantiate_from_config(config):
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

def split_weighted_subprompts(text):
    remaining = len(text)
    prompts = []
    weights = []
    while remaining > 0:
        if ":" in text:
            idx = text.index(":") 
            prompt = text[:idx]
            remaining -= idx
            text = text[idx+1:]
            if " " in text:
                idx = text.index(" ")
            else: 
                idx = len(text)
            if idx != 0:
                try:
                    weight = float(text[:idx])
                except:
                    print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
                    weight = 1.0
            else:
                weight = 1.0
            remaining -= idx
            text = text[idx+1:]
            prompts.append(prompt)
            weights.append(weight)
        else:
            if len(text) > 0:
                prompts.append(text)
                weights.append(1.0)
            remaining = 0
    return prompts, weights

def load_model_from_config(ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    return sd

def load_img(image, h0, w0):

    w, h = image.size

    print(f"loaded input image of size ({w}, {h})")
    if h0 is not None and w0 is not None:
        h, w = h0, w0

    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 32

    print(f"New image size ({w}, {h})")
    image = image.resize((w, h), resample=Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

config = "../stable-diffusion-main/optimizedSD/v1-inference.yaml"
ckpt = "../stable-diffusion-main/models/ldm/stable-diffusion-v1/model.ckpt"

sd = load_model_from_config(f"{ckpt}")

li, lo = [], []
for key, value in sd.items():
    sp = key.split(".")
    if (sp[0]) == "model":
        if "input_blocks" in sp:
            li.append(key)
        elif "middle_block" in sp:
            li.append(key)
        elif "time_embed" in sp:
            li.append(key)
        else:
            lo.append(key)
for key in li:
    sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
    sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)

del sd

def txt2img(prompts="",H=512,W=512,C=4,f=8,ddim_steps=50,fixed_code=50,ddim_eta=0.0,
            n_rows=0,scale=7.5,device='cuda',seed=None,unet_bs=1,precision='full',sampler='plms',visible=False):
    
    tic = time.time()
    if seed == None:
        seed = randint(0, 1000000)
        seed_everything(seed)

    model.eval()
    model.unet_bs = unet_bs
    model.cdevice = device
    model.turbo = False

    modelCS.eval()
    modelCS.cond_stage_model.device = device

    modelFS.eval()

    if device != "cpu" and precision == "autocast":
        model.half()
        modelCS.half()

    start_code = None
    if fixed_code:
        start_code = torch.randn([1,C,H // f, W // f], device=device)

    n_rows = n_rows if n_rows > 0 else 1

    if precision == "autocast" and device != "cpu":
        precision_scope = autocast
    else:
        precision_scope = nullcontext

    with torch.no_grad():

        all_samples = list()
        with precision_scope("cuda"):
            modelCS.to(device)
            uc = None
            if scale != 1.0:
                uc = modelCS.get_learned_conditioning([""])
                subprompts, weights = split_weighted_subprompts(prompts)
                if len(subprompts) > 1:
                    c = torch.zeros_like(uc)
                    totalWeight = sum(weights)
                    for i in range(len(subprompts)):
                        weight = weights[i]
                        # if not skip_normalize:
                        weight = weight / totalWeight
                        c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
                else:
                    c = modelCS.get_learned_conditioning(prompts)

                shape = [1, C, H // f, W // f]

                if device != "cpu":
                    mem = torch.cuda.memory_allocated() / 1e6
                    modelCS.to("cpu")
                    while torch.cuda.memory_allocated() / 1e6 >= mem:
                         time.sleep(1)

                samples_ddim = model.sample(
                    S=ddim_steps,
                    callfunc = modelFS.decode_first_stage,#9/22追加
                    conditioning = c,
                    seed = seed,
                    shape = shape,
                    verbose = False,
                    unconditional_guidance_scale = scale,
                    unconditional_conditioning = uc,
                    eta = ddim_eta,
                    x_T = start_code,
                    sampler = sampler,
                    visible = visible
                )

                modelFS.to(device)

                x_samples_ddim = modelFS.decode_first_stage(samples_ddim[0].unsqueeze(0))
                x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
                image=x_sample.astype(np.uint8)
                
            if device != "cpu":
                mem = torch.cuda.memory_allocated() / 1e6
                modelFS.to("cpu")
                while torch.cuda.memory_allocated() / 1e6 >= mem:
                    time.sleep(1)
            del samples_ddim
            print("memory_final = ", torch.cuda.memory_allocated() / 1e6)

    toc = time.time()
    time_taken = (toc - tic) / 60.0

    print(("Samples finished in {0:.2f} minutes " + prompts + "\nSeeds used = {1:}").format(time_taken,seed))
    #return c.cpu().numpy(),image#9/29変更
    return image

print('txt2img defined !!')

def img2img(prompts="",init_img=None,ddim_steps=50,ddim_eta=0.0,n_iter=1,H=512,W=512,strength=0.75,n_samples=5,
            n_rows=0,scale=7.5,device='cuda',seed=None,unet_bs=1,precision='full',sampler='ddim',visible=False):

    tic = time.time()
    if seed == None:
        seed = randint(0, 1000000)
        seed_everything(seed)
    
    model.eval()
    model.cdevice = device
    model.unet_bs = unet_bs
    model.turbo = False
    
    modelCS.eval()
    modelCS.cond_stage_model.device = device
    
    modelFS.eval()

    init_image = load_img(init_img, H, W).to(device)

    if device != "cpu" and precision == "autocast":
        model.half()
        modelCS.half()
        modelFS.half()
        init_image = init_image.half()
    batch_size = 1
    n_rows = n_rows if n_rows > 0 else 1

    modelFS.to(device)

    init_image = repeat(init_image, "1 ... -> b ...", b=1)
    init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image))  # move to latent space

    if device != "cpu":
        mem = torch.cuda.memory_allocated() / 1e6
        modelFS.to("cpu")
        while torch.cuda.memory_allocated() / 1e6 >= mem:
            time.sleep(1)

    t_enc = int(strength*ddim_steps)
    print(f"target t_enc is {t_enc} steps")

    if precision == "autocast" and device != "cpu":
        precision_scope = autocast
    else:
        precision_scope = nullcontext
        
    with torch.no_grad():
        all_samples = list()
        for n in trange(n_iter,desc="Sampling"):
            with precision_scope("cuda"):
                modelCS.to(device)
                uc = None
                if scale != 1.0:
                    uc = modelCS.get_learned_conditioning([""])
                subprompts, weights = split_weighted_subprompts(prompts)
                if len(subprompts) > 1:
                    c = torch.zeros_like(uc)
                    totalWeight = sum(weights)
                    # normalize each "sub prompt" and add it
                    for i in range(len(subprompts)):
                        weight = weights[i]
                        # if not skip_normalize:
                        weight = weight / totalWeight
                        c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
                else:
                    c = modelCS.get_learned_conditioning(prompts)

                if device != "cpu":
                    mem = torch.cuda.memory_allocated() / 1e6
                    modelCS.to("cpu")
                    while torch.cuda.memory_allocated() / 1e6 >= mem:
                        time.sleep(1)

                # encode (scaled latent)
                z_enc = model.stochastic_encode(
                    x0 = init_latent,
                    t = torch.tensor([t_enc]).to(device),
                    ddim_eta = ddim_eta,
                    ddim_steps = ddim_steps,
                    seed = seed
                    )
                # decode it
                samples_ddim = model.sample(
                    S = t_enc,
                    conditioning = c,
                    callfunc = modelFS.decode_first_stage,#10/16追加
                    x0 = z_enc,
                    unconditional_guidance_scale = scale,
                    unconditional_conditioning = uc,
                    sampler = sampler,
                    visible = visible
                )

                modelFS.to(device)
                
                x_samples_ddim = modelFS.decode_first_stage(samples_ddim[0].unsqueeze(0))
                x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
                image=x_sample.astype(np.uint8)
                    
                if device != "cpu":
                    mem = torch.cuda.memory_allocated() / 1e6
                    modelFS.to("cpu")
                    while torch.cuda.memory_allocated() / 1e6 >= mem:
                        time.sleep(1)

                del samples_ddim
                print("memory_final = ", torch.cuda.memory_allocated() / 1e6)

    toc = time.time()
    time_taken = (toc - tic) / 60.0
    print(("Samples finished in {0:.2f} minutes " + prompts + "\nSeeds used = {1:}").format(time_taken,seed))

    return image

print('img2img defined !!')

内容は ライブラリ/モージュールのインポート、共通の関数の定義、モデルの構築、関数 txt2img/img2img の定義です。
txt2img/img2img では途中画像表示に必要な callfunc = modelFS.decode_first_stage と visibleフラグに関する部分が大きな変更点です。

config = "../stable-diffusion-main/optimizedSD/v1-inference.yaml"
ckpt = "../stable-diffusion-main/models/ldm/stable-diffusion-v1/model.ckpt"

は構成したフォルダの配置で変わってきますので、適当に修正してください。
このまま実行するとエラーとなるので、実行はお待ちください!

「ddpm.pyの変更」

フォルダ構成を確認しておきましょう。

C:\Users
 +---XXXX
   +---Documents
     +---Source
       +---Python
       +---stable-diffusion-main
         +---optimizesdSD
            +---ddpm.py ←このファイルを変更します。
            +---
            +---
         +---models
           +---ldm
             +---stable-diffusion-v1
               +----model.ckpt

ddpm.py の変更は多いので、以降の行数は手順を実行していく途中の行数を記入します。
※2022/10/23現在 ここ からダウンロードできるファイルを使用しています。

1.冒頭のライブラリ/モジュールの追加をします。(行数9行目)

import matplotlib.pyplot as plt#9/22追加
from PIL import Image#gif生成用
import os#gif生成用
import time, math
from tqdm.auto import trange, tqdm
import torch
from einops import rearrange
#from tqdm import tqdm#9/22削除
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from functools import partial
from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import exists, default, instantiate_from_config
from ldm.modules.diffusionmodules.util import make_beta_schedule
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
#from samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff#9/22削除

2.model.sample 関数への引数を追加をします。(行数470行目)

    def sample(self,
               S,
               conditioning,
               callfunc,#9/22追加
               x0=None,
               shape = None,
               seed=1234, 
               callback=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               sampler = "plms",
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None,
               visible=False#10/16追加
               ):

2.plms_sampling/ddim_sampling 関数への引数を追加をします。(行数515行目)

if sampler == "plms":
            print(f'Data shape for PLMS sampling is {shape}')
            samples = self.plms_sampling(conditioning, batch_size, x_latent,
                                        callfunc,#9/22追加
                                        callback=callback,
                                        img_callback=img_callback,
                                        quantize_denoised=quantize_x0,
                                        mask=mask, x0=x0,
                                        ddim_use_original_steps=False,
                                        noise_dropout=noise_dropout,
                                        temperature=temperature,
                                        score_corrector=score_corrector,
                                        corrector_kwargs=corrector_kwargs,
                                        log_every_t=log_every_t,
                                        unconditional_guidance_scale=unconditional_guidance_scale,
                                        unconditional_conditioning=unconditional_conditioning,
                                        visible=visible#10/16追加
                                        )

        elif sampler == "ddim":
            samples = self.ddim_sampling(x_latent, conditioning,
                                         callfunc,#9/22追加
                                         S, unconditional_guidance_scale=unconditional_guidance_scale,
                                         unconditional_conditioning=unconditional_conditioning,
                                         mask = mask,init_latent=x_T,use_original_steps=False,#10/16変更','を追記
                                         visible=visible#10/16追加
                                         )

3.plms_sampling 関数を変更します。(行数576~616行目)
 生成画像の表示機能と GIF アニメーション生成機能を追加します。

    def plms_sampling(self, cond,b, img,
                      callfunc,#9/22追加
                      ddim_use_original_steps=False,
                      callback=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,
                      visible=False#10/16追加
                      ):
        
        device = self.betas.device
        timesteps = self.ddim_timesteps
        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running PLMS Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
        old_eps = []

        plt_show = False#9/22追加
        fig_old = None#9/22追加
        frames=[]#gif生成用

        for i, step in enumerate(iterator):
#以下9/22追加
            if (plt_show==True) & visible:#10/16変更
                img_tmp = callfunc(img[0].cpu().unsqueeze(0))
                img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
                img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
                image = img_tmp.astype(np.uint8)
                fig_new = plt.figure(figsize=(4, 4), dpi=120)
                if fig_old is not None:
                    plt.close(fig_old)
                plt.axis('off')
                plt.text(0,-10,'{}/{}'.format(i,total_steps))
                plt.imshow(image)
                frames.append(Image.fromarray(image))#gif生成用
                plt.pause(1.0)
                fig_new.clear()
                fig_old,plt_show = fig_new,False
#以上9/22追加 
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)
            ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning,
                                      old_eps=old_eps, t_next=ts_next)
            img, pred_x0, e_t = outs
            old_eps.append(e_t)
            if len(old_eps) >= 4:
                old_eps.pop(0)
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

#以下9/22追加
            plt_show=True
#以下gif生成
        if visible:#10/16変更
            img_tmp = callfunc(img[0].cpu().unsqueeze(0))
            img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
            img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
            image = img_tmp.astype(np.uint8)
            frames.append(Image.fromarray(image))

            pre_frames=[]

            if os.path.isfile('plms_output.gif'):
                frames_duration=[]
                gif_frames = Image.open('plms_output.gif')
                for idx in range(gif_frames.n_frames):
                    gif_frames.seek(idx)
                    frames_duration.append(gif_frames.info['duration'])
                    pre_frames.append(gif_frames.copy())
                gif_frames.close()
            else:
                black_frame=np.zeros_like((image))
                pre_frames.append(Image.fromarray(black_frame))
                frames_duration=[2000]
            pre_frames.extend(frames)
            [frames_duration.append(500) for _ in range(len(frames)-1)]
            frames_duration.append(2000)
            frames=pre_frames.copy()    
            frames[0].save('plms_output.gif',
                     save_all=True, append_images=frames[1:], optimize=False, duration=frames_duration, loop=0)
#以上gif生成用
            if plt_show == True:
                plt.close()
                plt_show = False
#以上9/22追加
        return img

4.ddim_sampling 関数を変更します。(行数778~806行目)
 生成画像の表示機能と GIF アニメーション生成機能を追加します。

    def ddim_sampling(self, x_latent, cond,
                      callfunc,#9/22追加
                      t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
                      mask = None,init_latent=None,use_original_steps=False,
                      visible=False#10/16追加
                      ):

        timesteps = self.ddim_timesteps
        timesteps = timesteps[:t_start]
        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
        x_dec = x_latent
        x0 = init_latent
        plt_show = False#10/16追加
        fig_old = None#10/16追加
        frames=[]#gif生成用

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)            

            if mask is not None:
                # x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice))
                x0_noisy = x0
                x_dec = x0_noisy* mask + (1. - mask) * x_dec

            x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                          unconditional_guidance_scale=unconditional_guidance_scale,
                                          unconditional_conditioning=unconditional_conditioning)
#以下10/16追加
            if visible:#10/16変更
                img_tmp = callfunc(x_dec[0].cpu().unsqueeze(0))
                img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
                img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
                image = img_tmp.astype(np.uint8)
                fig_new = plt.figure(figsize=(4, 4), dpi=120)
                if fig_old is not None:
                    plt.close(fig_old)
                plt.axis('off')
                plt.text(0,-10,'{}/{}'.format(i+1,total_steps))
                plt.imshow(image)
                frames.append(Image.fromarray(image))#gif生成用
                plt.pause(1.0)
                fig_new.clear()
                fig_old,plt_show = fig_new,False

            plt_show=True
#以下gif生成
        if visible:#10/16変更
            pre_frames=[]
            if os.path.isfile('ddim_output.gif'):
                frames_duration=[]
                gif_frames = Image.open('ddim_output.gif')
                for idx in range(gif_frames.n_frames):
                    gif_frames.seek(idx)
                    frames_duration.append(gif_frames.info['duration'])
                    pre_frames.append(gif_frames.copy())
                gif_frames.close()
            else:
                black_frame=np.zeros_like((image))
                pre_frames.append(Image.fromarray(black_frame))
                frames_duration=[2000]
            pre_frames.extend(frames)
            [frames_duration.append(500) for _ in range(len(frames)-1)]
            frames_duration.append(2000)
            frames=pre_frames.copy()
            frames[0].save('ddim_output.gif',
                     save_all=True, append_images=frames[1:], optimize=False, duration=frames_duration, loop=0)
#以上gif生成用
        if (plt_show == True) & visible:#10/16変更
            plt.pause(1.0)
            plt.close()
            plt_show = False
#以上10/16追加
        if mask is not None:
            return x0 * mask + (1. - mask) * x_dec

        return x_dec

ddim.py を保存し、変更は完了です。
あとがきですが、修正は Visual Studio Code を使用しました。

「モデルの読み込みと txt2img/img2img の定義」

Jupyter Notebook に戻り、2番目のセルを実行します。
実行結果は以下のようになりました。
model.jpg
使われていないパラメータがある?などのメッセージが出ますが、問題ありませんでした。
最終的に txt2img defined !!  img2img defined !! が表示されて、処理が完了します。

「txt2imgを使ってみる」

次のセルで以下のコードを入力し、実行します。

import matplotlib.pyplot as plt
%matplotlib

prompt="A photo of bright red convertible Porsche driving through the green wood"
image=txt2img(prompts=prompt,H=512,W=512,seed=14,scale=7.5,ddim_steps=25,precision='full',visible=True)

%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)

pil_img = Image.fromarray(image)
pil_img.save('Car.png')

visible=True としましたので、
plms_progress.jpg
のように進んでいくと、別 Window に途中結果がポップアップで表示されます。
Car_Progress.jpg
実行が完了したら、カレントディレクトリ(フォルダ)に plms_output.gif が作成されます。
※ txt2img のデフォルト sampler='plms' なので上のファイル名としました。
  理由は後述しますが、ファイルがあるとマージされる仕様にしていますので、
  実行前に削除もしくはリネームしてください。
リサイズ、間引きしたものが以下の通りです。最初の黒画面と最終画像だけ2秒となっています。
trim.gif Car.png
prompt="A photo of bright red convertible Porsche driving through the green wood"
「緑の中を走り抜けてく真っ赤なポルシェ」(Part2だけに!)
ですが、人が乗っていないと止まって見えます。
次のセルで以下のコードを入力し、実行します。

import matplotlib.pyplot as plt
%matplotlib

prompt="A portrait of Woman,beautiful face,short hair,cute eyes,beautiful composition"
image=txt2img(prompts=prompt,H=512,W=512,seed=0,scale=7.5,ddim_steps=25,precision='full',visible=False)#False:1:58s

%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)

pil_img = Image.fromarray(image)
pil_img.save('Girl.png')

visible=False としましたので、途中の表示はなく、短時間で画像が生成されます。
Girl.jpg
簡単な prompt でこのような画像ができるとは「すごい」の一言です。

「img2imgを使ってみる」

はっきり言って、 img2img の遊び方で悩んで、今までの画像を作ってきました。
上の2枚の画像を Paint などのツールをつかって、下の画像を作りました。
input_girl_car.jpg
ファイル名を input_girl_car.png として、 img2img の init_img として使います。
次のセルで以下のコードを入力し、実行します。

import matplotlib.pyplot as plt
%matplotlib

prompt="A Photo of woman sitting in bright red convertible Porsche through the green wood,\
        beautiful face,short hair,cute eyes,beautiful composition,\
        golden hour,overhead sunlight,dramatic lighting"

init_image=Image.open('input_girl_car.png')

image=img2img(prompts=prompt,init_img=init_image,ddim_steps=25,H=512,W=512,strength=0.75,
              scale=7.5,device='cuda',seed=6,n_iter=4,precision='full',visible=True)

%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)

visible=True としましたので、別 Window に途中結果がポップアップで表示され、
実行が完了したら、カレントディレクトリ(フォルダ)に ddim_output.gif が作成されます。
※ img2img のデフォルト sampler='ddim' なので上のファイル名としました。
  img2img には n_iter があり、sampler 内で途中なのか初めなのか判断できません。
  そのため、ファイルが無いと黒画面を追加、ファイルがあるとマージされる
  仕様にしました。
  次に GIF アニメーションを新規作成する場合は、実行前に削除もしくは
  リネームしてください。
リサイズ、間引きしたものが以下の通りです。 iter=4 としているので、ノイズ⇒画像のように4回 sampling を繰り返しています。最初と最終画像だけ2秒となってわかるようにしました。
trim.gif Final.png
上の例では18×4=72回 sampling が行われます。
※ n_iter=4 でしたが、同じ画像を繰り返しており、できる画像に影響はなさそうです。

次に、以下のコードを入力し、実行してみます。

import matplotlib.pyplot as plt
%matplotlib

prompt="A Photo of woman sitting in bright red convertible Porsche through the green wood,\
        beautiful face,short hair,cute eyes,beautiful composition,\
        golden hour,overhead sunlight,dramatic lighting"

init_image=Image.open('input_girl_car.png')

image=img2img(prompts=prompt,init_img=init_image,ddim_steps=100,H=512,W=512,strength=0.75,
                  scale=7.5,device='cuda',seed=6,n_iter=1,precision='full',visible=True)#6,

%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)

途中経過と最終画像は以下の通りです。
trim.gif Final100.jpg
sampling は75回で、ボンネットのエンブレム、ホイール周りや地面のディテールは細かくなった気がします。顔は苦手なので、崩れ始めています。

ここまで、サラッと書いてきましたが、 prompt/seed で生成される画像が大きく変わるため、試行錯誤(いわゆるガチャ回し)を行っていました。
prompt も driving car/riding car などとすると、立ち上がったり、馬乗り、箱乗り、最悪は貞子ばりにフロントグラスから出てきたりもしました。最終的に sitting in car に落ち着きました。

「txt2imgとimg2imgを連携させる」

次に、txt2img で生成した画像を img2img に渡して画像を生成してみます。
次のセルに以下のコードを入力し、実行してみます。

import matplotlib.pyplot as plt
%matplotlib

prompt="A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors, 8k"

image=txt2img(prompts=prompt,H=512,W=512,seed=25,scale=7.5,ddim_steps=7,precision='full',visible=True)
image=img2img(prompts=prompt,init_img=Image.fromarray(image),ddim_steps=25,H=512,W=512,strength=0.75,
                  scale=7.5,device='cuda',seed=6,precision='full',visible=True)

%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)

txt2img でぼやけた画像を生成、 img2img で詳細な画像を生成しました。
trim.gif FinalX.jpg
「これに何の意味があるの?」と言われてしまうかもしれません。
モデルを一回読み込んでしまえば、以降の実行で待たされる時間が短くなります。
ネット環境も要らない完全ローカル版なので、どんな画像を生成させても問題ありません。

「まとめ」

本題である img2img の関数化はそんなに難しいところはなく、存在するコードから必要のない部分の削除だけです。時間がかかったのは、
・記事にして興味を持っていただける例を探す(ガチャを回す)。
・GIFアニメーション生成の難しさ。
でした。
ガチャ回しでは Google Colaboratory を使ったりもしましたが、立ち上げの際の環境構築の時間、モデル構築の時間、ネット環境必要などデメリットも感じました。ネットでの Stable Diffusion 関連の記事で Google Colaboratry 勢が多いことを考えると、無料のサービスが制限されるのでは?と危惧しています。
GIFアニメーションについては Qiita での記事投稿で必要となりましたが、これだけで記事が書けるくらいの収穫がありました。写経ばかりでプログラミングする楽しみが薄れていましたが、記事の投稿のためとモチベーションが加わり、頑張ってみようと思っています。
長い記事となってしまいました。ここまで読み進めていただきましてありがとうございました。

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?