お絵描きAI、Stable Diffusion 今回は img2img を扱います。
Stable Diffusion 画像を作っている途中を見たい!では txt2img のみでしたが、
同じモデルで実装できるということで、同時に使えるようにします。
使用したPCは
OS:Windows10 Home 64bit
CPU:Intel(R) Core(TM) i7-10750H
GPU:NVIDIA GeForce GTX 1650 Ti
RAM:16GB
です。
今回の目標は
● StableDiffusionPipelineを使わない。
● Jupyter Notebook で txt2img/img2img をシームレスに実行できるようにする。
● 各関数に visible フラグを追加して途中画像の表示/非表示を選択できるようにする。
● 作っている途中を GIF アニメーションに保存する。
とします。
「環境構築」
基本はStable DiffusionをノートPCで持ち歩きたい!と同じです。
以上の記事の 「Notebookを作成する」 で新たな Notebook を作成することにします。
ただし、 GIF アニメションを作成するため、 pillow ライブラリをインストールする必要があります。
インストールには Anaconda Prompt を使用します。手順は以下の通りです。
Anaconda Navigater を Windows のメニューから起動します。
左メニューから Environments を選択。環境 ldm の ▶ マークから Open Terminal を
クリックして Anaconda Prompt を起動します。
conda install pillow
と入力してインストールします。
「新たな Notebook の作成」
一旦、すべての Window を閉じます。
スタートメニューから Anacnda3(64ビット)→ Jupyter Notebook(ldm) を起動します。
ダークモードを使用しているので見かけが違いますが気にしないでください。
右上の New プルダウンから Python 3(ipykernel) を選択し、空のNotebookを作成します。
名前はUntitledになりますが、保存後に Rename できます。
以降、コードセルにコードを入力し、実行していきます。
「環境の確認」
Notebookの初めのセルで以下のコードを入力し、動作を確認します。
import sys
print("Python = "+sys.version)
import numpy as np
print("Numpy = "+np.__version__)
import torch
print("Pytorch = "+torch.__version__)
print("Pytorch GPU =",torch.cuda.is_available())
if torch.cuda.is_available():
print("Pytorch GPU Name = "+torch.cuda.get_device_name())
!nvcc --version
import matplotlib
print("Matplotlib = "+matplotlib.__version__)
!pwd
結果は
Pytorch Cuda が動作し、 GPU が認識されていることがわかります。
「optimized_txt2img/optimized_img2imgを関数化する」
次のセルに以下を入力(コピペでOKです)。
import argparse, os, re
import torch
import numpy as np
from random import randint
from omegaconf import OmegaConf
from PIL import Image
from tqdm.auto import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
import importlib
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def split_weighted_subprompts(text):
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":")
prompt = text[:idx]
remaining -= idx
text = text[idx+1:]
if " " in text:
idx = text.index(" ")
else:
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
except:
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
weight = 1.0
else:
weight = 1.0
remaining -= idx
text = text[idx+1:]
prompts.append(prompt)
weights.append(weight)
else:
if len(text) > 0:
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
return sd
def load_img(image, h0, w0):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
print(f"New image size ({w}, {h})")
image = image.resize((w, h), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
config = "../stable-diffusion-main/optimizedSD/v1-inference.yaml"
ckpt = "../stable-diffusion-main/models/ldm/stable-diffusion-v1/model.ckpt"
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, value in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)
config = OmegaConf.load(f"{config}")
model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
del sd
def txt2img(prompts="",H=512,W=512,C=4,f=8,ddim_steps=50,fixed_code=50,ddim_eta=0.0,
n_rows=0,scale=7.5,device='cuda',seed=None,unet_bs=1,precision='full',sampler='plms',visible=False):
tic = time.time()
if seed == None:
seed = randint(0, 1000000)
seed_everything(seed)
model.eval()
model.unet_bs = unet_bs
model.cdevice = device
model.turbo = False
modelCS.eval()
modelCS.cond_stage_model.device = device
modelFS.eval()
if device != "cpu" and precision == "autocast":
model.half()
modelCS.half()
start_code = None
if fixed_code:
start_code = torch.randn([1,C,H // f, W // f], device=device)
n_rows = n_rows if n_rows > 0 else 1
if precision == "autocast" and device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
with torch.no_grad():
all_samples = list()
with precision_scope("cuda"):
modelCS.to(device)
uc = None
if scale != 1.0:
uc = modelCS.get_learned_conditioning([""])
subprompts, weights = split_weighted_subprompts(prompts)
if len(subprompts) > 1:
c = torch.zeros_like(uc)
totalWeight = sum(weights)
for i in range(len(subprompts)):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
shape = [1, C, H // f, W // f]
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
samples_ddim = model.sample(
S=ddim_steps,
callfunc = modelFS.decode_first_stage,#9/22追加
conditioning = c,
seed = seed,
shape = shape,
verbose = False,
unconditional_guidance_scale = scale,
unconditional_conditioning = uc,
eta = ddim_eta,
x_T = start_code,
sampler = sampler,
visible = visible
)
modelFS.to(device)
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[0].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
image=x_sample.astype(np.uint8)
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
del samples_ddim
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
toc = time.time()
time_taken = (toc - tic) / 60.0
print(("Samples finished in {0:.2f} minutes " + prompts + "\nSeeds used = {1:}").format(time_taken,seed))
#return c.cpu().numpy(),image#9/29変更
return image
print('txt2img defined !!')
def img2img(prompts="",init_img=None,ddim_steps=50,ddim_eta=0.0,n_iter=1,H=512,W=512,strength=0.75,n_samples=5,
n_rows=0,scale=7.5,device='cuda',seed=None,unet_bs=1,precision='full',sampler='ddim',visible=False):
tic = time.time()
if seed == None:
seed = randint(0, 1000000)
seed_everything(seed)
model.eval()
model.cdevice = device
model.unet_bs = unet_bs
model.turbo = False
modelCS.eval()
modelCS.cond_stage_model.device = device
modelFS.eval()
init_image = load_img(init_img, H, W).to(device)
if device != "cpu" and precision == "autocast":
model.half()
modelCS.half()
modelFS.half()
init_image = init_image.half()
batch_size = 1
n_rows = n_rows if n_rows > 0 else 1
modelFS.to(device)
init_image = repeat(init_image, "1 ... -> b ...", b=1)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
t_enc = int(strength*ddim_steps)
print(f"target t_enc is {t_enc} steps")
if precision == "autocast" and device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
with torch.no_grad():
all_samples = list()
for n in trange(n_iter,desc="Sampling"):
with precision_scope("cuda"):
modelCS.to(device)
uc = None
if scale != 1.0:
uc = modelCS.get_learned_conditioning([""])
subprompts, weights = split_weighted_subprompts(prompts)
if len(subprompts) > 1:
c = torch.zeros_like(uc)
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(len(subprompts)):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
# encode (scaled latent)
z_enc = model.stochastic_encode(
x0 = init_latent,
t = torch.tensor([t_enc]).to(device),
ddim_eta = ddim_eta,
ddim_steps = ddim_steps,
seed = seed
)
# decode it
samples_ddim = model.sample(
S = t_enc,
conditioning = c,
callfunc = modelFS.decode_first_stage,#10/16追加
x0 = z_enc,
unconditional_guidance_scale = scale,
unconditional_conditioning = uc,
sampler = sampler,
visible = visible
)
modelFS.to(device)
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[0].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
image=x_sample.astype(np.uint8)
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
del samples_ddim
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
toc = time.time()
time_taken = (toc - tic) / 60.0
print(("Samples finished in {0:.2f} minutes " + prompts + "\nSeeds used = {1:}").format(time_taken,seed))
return image
print('img2img defined !!')
内容は ライブラリ/モージュールのインポート、共通の関数の定義、モデルの構築、関数 txt2img/img2img の定義です。
txt2img/img2img では途中画像表示に必要な callfunc = modelFS.decode_first_stage と visibleフラグに関する部分が大きな変更点です。
config = "../stable-diffusion-main/optimizedSD/v1-inference.yaml"
ckpt = "../stable-diffusion-main/models/ldm/stable-diffusion-v1/model.ckpt"
は構成したフォルダの配置で変わってきますので、適当に修正してください。
このまま実行するとエラーとなるので、実行はお待ちください!
「ddpm.pyの変更」
フォルダ構成を確認しておきましょう。
C:\Users
+---XXXX
+---Documents
+---Source
+---Python
+---stable-diffusion-main
+---optimizesdSD
+---ddpm.py ←このファイルを変更します。
+---
+---
+---models
+---ldm
+---stable-diffusion-v1
+----model.ckpt
ddpm.py の変更は多いので、以降の行数は手順を実行していく途中の行数を記入します。
※2022/10/23現在 ここ からダウンロードできるファイルを使用しています。
1.冒頭のライブラリ/モジュールの追加をします。(行数9行目)
import matplotlib.pyplot as plt#9/22追加
from PIL import Image#gif生成用
import os#gif生成用
import time, math
from tqdm.auto import trange, tqdm
import torch
from einops import rearrange
#from tqdm import tqdm#9/22削除
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from functools import partial
from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import exists, default, instantiate_from_config
from ldm.modules.diffusionmodules.util import make_beta_schedule
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
#from samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff#9/22削除
2.model.sample 関数への引数を追加をします。(行数470行目)
def sample(self,
S,
conditioning,
callfunc,#9/22追加
x0=None,
shape = None,
seed=1234,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
sampler = "plms",
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
visible=False#10/16追加
):
2.plms_sampling/ddim_sampling 関数への引数を追加をします。(行数515行目)
if sampler == "plms":
print(f'Data shape for PLMS sampling is {shape}')
samples = self.plms_sampling(conditioning, batch_size, x_latent,
callfunc,#9/22追加
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
visible=visible#10/16追加
)
elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning,
callfunc,#9/22追加
S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask = mask,init_latent=x_T,use_original_steps=False,#10/16変更','を追記
visible=visible#10/16追加
)
3.plms_sampling 関数を変更します。(行数576~616行目)
生成画像の表示機能と GIF アニメーション生成機能を追加します。
def plms_sampling(self, cond,b, img,
callfunc,#9/22追加
ddim_use_original_steps=False,
callback=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
visible=False#10/16追加
):
device = self.betas.device
timesteps = self.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
plt_show = False#9/22追加
fig_old = None#9/22追加
frames=[]#gif生成用
for i, step in enumerate(iterator):
#以下9/22追加
if (plt_show==True) & visible:#10/16変更
img_tmp = callfunc(img[0].cpu().unsqueeze(0))
img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
image = img_tmp.astype(np.uint8)
fig_new = plt.figure(figsize=(4, 4), dpi=120)
if fig_old is not None:
plt.close(fig_old)
plt.axis('off')
plt.text(0,-10,'{}/{}'.format(i,total_steps))
plt.imshow(image)
frames.append(Image.fromarray(image))#gif生成用
plt.pause(1.0)
fig_new.clear()
fig_old,plt_show = fig_new,False
#以上9/22追加
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
#以下9/22追加
plt_show=True
#以下gif生成
if visible:#10/16変更
img_tmp = callfunc(img[0].cpu().unsqueeze(0))
img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
image = img_tmp.astype(np.uint8)
frames.append(Image.fromarray(image))
pre_frames=[]
if os.path.isfile('plms_output.gif'):
frames_duration=[]
gif_frames = Image.open('plms_output.gif')
for idx in range(gif_frames.n_frames):
gif_frames.seek(idx)
frames_duration.append(gif_frames.info['duration'])
pre_frames.append(gif_frames.copy())
gif_frames.close()
else:
black_frame=np.zeros_like((image))
pre_frames.append(Image.fromarray(black_frame))
frames_duration=[2000]
pre_frames.extend(frames)
[frames_duration.append(500) for _ in range(len(frames)-1)]
frames_duration.append(2000)
frames=pre_frames.copy()
frames[0].save('plms_output.gif',
save_all=True, append_images=frames[1:], optimize=False, duration=frames_duration, loop=0)
#以上gif生成用
if plt_show == True:
plt.close()
plt_show = False
#以上9/22追加
return img
4.ddim_sampling 関数を変更します。(行数778~806行目)
生成画像の表示機能と GIF アニメーション生成機能を追加します。
def ddim_sampling(self, x_latent, cond,
callfunc,#9/22追加
t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
mask = None,init_latent=None,use_original_steps=False,
visible=False#10/16追加
):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
plt_show = False#10/16追加
fig_old = None#10/16追加
frames=[]#gif生成用
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
if mask is not None:
# x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice))
x0_noisy = x0
x_dec = x0_noisy* mask + (1. - mask) * x_dec
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
#以下10/16追加
if visible:#10/16変更
img_tmp = callfunc(x_dec[0].cpu().unsqueeze(0))
img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
image = img_tmp.astype(np.uint8)
fig_new = plt.figure(figsize=(4, 4), dpi=120)
if fig_old is not None:
plt.close(fig_old)
plt.axis('off')
plt.text(0,-10,'{}/{}'.format(i+1,total_steps))
plt.imshow(image)
frames.append(Image.fromarray(image))#gif生成用
plt.pause(1.0)
fig_new.clear()
fig_old,plt_show = fig_new,False
plt_show=True
#以下gif生成
if visible:#10/16変更
pre_frames=[]
if os.path.isfile('ddim_output.gif'):
frames_duration=[]
gif_frames = Image.open('ddim_output.gif')
for idx in range(gif_frames.n_frames):
gif_frames.seek(idx)
frames_duration.append(gif_frames.info['duration'])
pre_frames.append(gif_frames.copy())
gif_frames.close()
else:
black_frame=np.zeros_like((image))
pre_frames.append(Image.fromarray(black_frame))
frames_duration=[2000]
pre_frames.extend(frames)
[frames_duration.append(500) for _ in range(len(frames)-1)]
frames_duration.append(2000)
frames=pre_frames.copy()
frames[0].save('ddim_output.gif',
save_all=True, append_images=frames[1:], optimize=False, duration=frames_duration, loop=0)
#以上gif生成用
if (plt_show == True) & visible:#10/16変更
plt.pause(1.0)
plt.close()
plt_show = False
#以上10/16追加
if mask is not None:
return x0 * mask + (1. - mask) * x_dec
return x_dec
ddim.py を保存し、変更は完了です。
あとがきですが、修正は Visual Studio Code を使用しました。
「モデルの読み込みと txt2img/img2img の定義」
Jupyter Notebook に戻り、2番目のセルを実行します。
実行結果は以下のようになりました。
使われていないパラメータがある?などのメッセージが出ますが、問題ありませんでした。
最終的に txt2img defined !! img2img defined !! が表示されて、処理が完了します。
「txt2imgを使ってみる」
次のセルで以下のコードを入力し、実行します。
import matplotlib.pyplot as plt
%matplotlib
prompt="A photo of bright red convertible Porsche driving through the green wood"
image=txt2img(prompts=prompt,H=512,W=512,seed=14,scale=7.5,ddim_steps=25,precision='full',visible=True)
%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)
pil_img = Image.fromarray(image)
pil_img.save('Car.png')
visible=True としましたので、
のように進んでいくと、別 Window に途中結果がポップアップで表示されます。
実行が完了したら、カレントディレクトリ(フォルダ)に plms_output.gif が作成されます。
※ txt2img のデフォルト sampler='plms' なので上のファイル名としました。
理由は後述しますが、ファイルがあるとマージされる仕様にしていますので、
実行前に削除もしくはリネームしてください。
リサイズ、間引きしたものが以下の通りです。最初の黒画面と最終画像だけ2秒となっています。
prompt="A photo of bright red convertible Porsche driving through the green wood"
「緑の中を走り抜けてく真っ赤なポルシェ」(Part2だけに!)
ですが、人が乗っていないと止まって見えます。
次のセルで以下のコードを入力し、実行します。
import matplotlib.pyplot as plt
%matplotlib
prompt="A portrait of Woman,beautiful face,short hair,cute eyes,beautiful composition"
image=txt2img(prompts=prompt,H=512,W=512,seed=0,scale=7.5,ddim_steps=25,precision='full',visible=False)#False:1:58s
%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)
pil_img = Image.fromarray(image)
pil_img.save('Girl.png')
visible=False としましたので、途中の表示はなく、短時間で画像が生成されます。
簡単な prompt でこのような画像ができるとは「すごい」の一言です。
「img2imgを使ってみる」
はっきり言って、 img2img の遊び方で悩んで、今までの画像を作ってきました。
上の2枚の画像を Paint などのツールをつかって、下の画像を作りました。
ファイル名を input_girl_car.png として、 img2img の init_img として使います。
次のセルで以下のコードを入力し、実行します。
import matplotlib.pyplot as plt
%matplotlib
prompt="A Photo of woman sitting in bright red convertible Porsche through the green wood,\
beautiful face,short hair,cute eyes,beautiful composition,\
golden hour,overhead sunlight,dramatic lighting"
init_image=Image.open('input_girl_car.png')
image=img2img(prompts=prompt,init_img=init_image,ddim_steps=25,H=512,W=512,strength=0.75,
scale=7.5,device='cuda',seed=6,n_iter=4,precision='full',visible=True)
%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)
visible=True としましたので、別 Window に途中結果がポップアップで表示され、
実行が完了したら、カレントディレクトリ(フォルダ)に ddim_output.gif が作成されます。
※ img2img のデフォルト sampler='ddim' なので上のファイル名としました。
img2img には n_iter があり、sampler 内で途中なのか初めなのか判断できません。
そのため、ファイルが無いと黒画面を追加、ファイルがあるとマージされる
仕様にしました。
次に GIF アニメーションを新規作成する場合は、実行前に削除もしくは
リネームしてください。
リサイズ、間引きしたものが以下の通りです。 iter=4 としているので、ノイズ⇒画像のように4回 sampling を繰り返しています。最初と最終画像だけ2秒となってわかるようにしました。
上の例では18×4=72回 sampling が行われます。
※ n_iter=4 でしたが、同じ画像を繰り返しており、できる画像に影響はなさそうです。
次に、以下のコードを入力し、実行してみます。
import matplotlib.pyplot as plt
%matplotlib
prompt="A Photo of woman sitting in bright red convertible Porsche through the green wood,\
beautiful face,short hair,cute eyes,beautiful composition,\
golden hour,overhead sunlight,dramatic lighting"
init_image=Image.open('input_girl_car.png')
image=img2img(prompts=prompt,init_img=init_image,ddim_steps=100,H=512,W=512,strength=0.75,
scale=7.5,device='cuda',seed=6,n_iter=1,precision='full',visible=True)#6,
%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)
途中経過と最終画像は以下の通りです。
sampling は75回で、ボンネットのエンブレム、ホイール周りや地面のディテールは細かくなった気がします。顔は苦手なので、崩れ始めています。
ここまで、サラッと書いてきましたが、 prompt/seed で生成される画像が大きく変わるため、試行錯誤(いわゆるガチャ回し)を行っていました。
prompt も driving car/riding car などとすると、立ち上がったり、馬乗り、箱乗り、最悪は貞子ばりにフロントグラスから出てきたりもしました。最終的に sitting in car に落ち着きました。
「txt2imgとimg2imgを連携させる」
次に、txt2img で生成した画像を img2img に渡して画像を生成してみます。
次のセルに以下のコードを入力し、実行してみます。
import matplotlib.pyplot as plt
%matplotlib
prompt="A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors, 8k"
image=txt2img(prompts=prompt,H=512,W=512,seed=25,scale=7.5,ddim_steps=7,precision='full',visible=True)
image=img2img(prompts=prompt,init_img=Image.fromarray(image),ddim_steps=25,H=512,W=512,strength=0.75,
scale=7.5,device='cuda',seed=6,precision='full',visible=True)
%matplotlib inline
plt.figure(figsize=(6, 6), dpi=120)
plt.axis('off')
plt.imshow(image)
txt2img でぼやけた画像を生成、 img2img で詳細な画像を生成しました。
「これに何の意味があるの?」と言われてしまうかもしれません。
モデルを一回読み込んでしまえば、以降の実行で待たされる時間が短くなります。
ネット環境も要らない完全ローカル版なので、どんな画像を生成させても問題ありません。
「まとめ」
本題である img2img の関数化はそんなに難しいところはなく、存在するコードから必要のない部分の削除だけです。時間がかかったのは、
・記事にして興味を持っていただける例を探す(ガチャを回す)。
・GIFアニメーション生成の難しさ。
でした。
ガチャ回しでは Google Colaboratory を使ったりもしましたが、立ち上げの際の環境構築の時間、モデル構築の時間、ネット環境必要などデメリットも感じました。ネットでの Stable Diffusion 関連の記事で Google Colaboratry 勢が多いことを考えると、無料のサービスが制限されるのでは?と危惧しています。
GIFアニメーションについては Qiita での記事投稿で必要となりましたが、これだけで記事が書けるくらいの収穫がありました。写経ばかりでプログラミングする楽しみが薄れていましたが、記事の投稿のためとモチベーションが加わり、頑張ってみようと思っています。
長い記事となってしまいました。ここまで読み進めていただきましてありがとうございました。