お絵描きAI、Stable Diffusion ですが、
スペックの低いPCでは待ち時間がヒマじゃないですか?
Stable DiffusionをノートPCで持ち歩きたい!
で、完全ローカル化しましたので、
コードをいじって画像ができてくる途中を見えるようにしました。
使用したPCは
OS:Windows10 Home 64bit
CPU:Intel(R) Core(TM) i7-10750H
GPU:NVIDIA GeForce GTX 1650 Ti
RAM:16GB
です。
CPUのみでも動作しますので上の記事を参考にしてください。
※Stable Diffusion 画像を作っている途中を見たい!Part2 を投稿しました。
txt2imgに加えてimg2imgも扱っています。興味のある方はそちらをご覧ください。
今回はサンプリング途中の画像リストを modelFirstStage の decode_first_stage 関数で処理して生成途中の画像を見えるようにします。
「jupyter notebook 内の変更」
サンプリング中に decode_first_stage 関数が使えるように
jupyter notebook セルのコードを変更します。
def txt2img(prompts="",H=512,W=512,C=4,f=8,dim_steps=50,fixed_code=50,ddim_eta=0.0,n_rows=0,scale=7.5,device='cuda',seed=None,unet_bs=1,precision='full',format_type='png',sampler='plms'):
の中で、model.sample 関数に引数 callfunc を追加して、関数の中から decode_first_stage 関数を使えるようにします。
samples_ddim = model.sample(
S=dim_steps,
+ callfunc = modelFS.decode_first_stage,#←追加
conditioning=c,
seed=seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code,
sampler = sampler,
)
「ddpm.pyの変更」
model.Sample 側も変更が必要です。
フォルダ構成を確認しておきましょう。。
C:\Users
+---XXXX
+---Documents
+---Source
+---Python
+---stable-diffusion-main
+---optimizesdSD
+---ddpm.py ←このファイルを変更します。
+---
+---
+---models
+---ldm
+---stable-diffusion-v1
+----model.ckpt
となっています。
ダウンロードしてきた stable-diffusion-main 内 optimizedSD 下の ddpm.py の中を変更します。
テキストエディターなどで ddpm.py を開き編集します。
先頭付近で必要なモジュールを追加します。
...
+ import matplotlib.pyplot as plt
import time, math
from tqdm import trange, tqdm
import torch
...
次に class UNet の 関数 sample に引数 callfunc を追加。(470行付近)
@torch.no_grad()
def sample(self,
S,
conditioning,
+ callfunc,#←追加
x0=None,
shape = None,
seed=1234,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
sampler = "plms",
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
):
さらに(500行付近)
if sampler == "plms":
print(f'Data shape for PLMS sampling is {shape}')
samples = self.plms_sampling(conditioning, batch_size, x_latent,
+ callfunc,#←追加
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditioning=unconditional_conditioning,
)
elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning,
+ callfunc,#←追加
S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask = mask,init_latent=x_T,use_original_steps=False)
plms_sampling の定義を大幅に書き換えます。(550行付近)
@torch.no_grad()
def plms_sampling(self, cond,b, img,
+ callfunc,#←追加
ddim_use_original_steps=False,
callback=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
device = self.betas.device
timesteps = self.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
+ plt_show = False#←追加
+ fig_old = None#←追加
for i, step in enumerate(iterator):
#以下追加
+ if plt_show==True:
+ img_tmp = callfunc(img[0].cpu().unsqueeze(0))
+ img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
+ img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
+ image = img_tmp.astype(np.uint8)
+ fig_new = plt.figure(figsize=(4, 4), dpi=120)
+ if fig_old is not None:
+ plt.close(fig_old)
+ plt.axis('off')
+ plt.text(0,-10,'{}/{}'.format(i,total_steps))
+ plt.imshow(image)
+ plt.pause(1.0)
+ fig_new.clear()
+ fig_old,plt_show = fig_new,False
#以上追加
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
#以下追加
+ plt_show=True
+ if plt_show == True:
+ plt.close()
+ plt_show = False
#以上追加
return img
ddim_sampling は引数だけを追加します。(730行付近)
@torch.no_grad()
def ddim_sampling(self, x_latent, cond,
+ callfunc,#←追加
t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
mask = None,init_latent=None,use_original_steps=False):
ddpm.py の変更は以上です。
「txt2imgの起動」
jupyter notebook に戻り、kernel を Restart させます。
セルを順次実行し、txt2img が定義されたら以下のようなセルを実行します。
import matplotlib.pyplot as plt
-%matplotlib inline#←削除
+%matplotlib#←追加
prompt="A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors, 8k"
image=txt2img(prompts=prompt,H=512,W=512,seed=25,scale=7.5,dim_steps=50,precision='full')
+%matplotlib inline#←追加
plt.figure(figsize=(8, 8), dpi=120)
plt.axis('off')
plt.imshow(image)
途中結果は以下の通り。
Using matplotlib backend: QtAgg
seeds used = [25]
Data shape for PLMS sampling is [1, 4, 64, 64]
Running PLMS Sampling with 50 timesteps
PLMS Sampler: 2%|█▍ | 1/50 [00:10<08:13, 10.07s/it]
のように進んでいくと、別 Window に途中結果がポップアップで表示されます。
「実行結果」
完成画像の細部に影響しそうな dim_steps を変えて途中過程と完成画像を並べてみました。
dim_steps=10
dim_steps=25
dim_steps=50
※アニメーションはアップロードに限りがあり、間引いていますので時間軸は一致していません。
「まとめ」
今回は画像を作っている途中が見えるようにしました。
画像細部に影響しそうな dim_steps を変えて比較してみました。
大きくすれば完成画像が高画質になるようですが、ステップが進む途中で画像が鮮明になるのではなく、最後の数ステップで細部が決まってくるようです。
実行環境は前述のように、GPU もありますが、画像表示は CPU に移して実行が必要なため、
dim_step=50 で14分かかっています。
2022年10月11日以下追記。
記事の中で GIF アニメーションを作っていますが、
備忘録としてコードを記録しておきます。
まず、ライブラリ Pillow が必要なため、 Anaconda Navigater の環境(今は ldm )から
Open Terminal でコマンドプロンプトを起動します。
(ldm) C:\Users\XXXX>conda install pillow
と入力し Pillow をインストールします。
修正するファイルは以下の ddpm.py です。
C:\Users
+---XXXX
+---Documents
+---Source
+---Python
+---stable-diffusion-main
+---optimizesdSD
+---ddpm.py ←このファイルを変更します。
+---
+---
+---models
+---ldm
+---stable-diffusion-v1
+----model.ckpt
先頭付近で必要なモジュールを追加します。
...
import matplotlib.pyplot as plt#記事で追加済み
+from PIL import Image #←GIF作成用
import time, math
...
UNet クラスの plms_sampling メソッドのループ処理付近(570行付近)に以下を追加します。
...
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
plt_show = False #記事で追加済み
fig_old = None #記事で追加済み
+ frames=[] #←GIF作成用 空のリストを作成します。
for i, step in enumerate(iterator):
#以下、記事で追加済み
if plt_show==True:
img_tmp = callfunc(img[0].cpu().unsqueeze(0))
img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
image = img_tmp.astype(np.uint8)
fig_new = plt.figure(figsize=(4, 4), dpi=120)
if fig_old is not None:
plt.close(fig_old)
plt.axis('off')
plt.text(0,-10,'{}/{}'.format(i,total_steps))
plt.imshow(image)
+ frames.append(Image.fromarray(image))#←GIF作成用 リストにイメージを追加します。
plt.pause(1.0)
fig_new.clear()
fig_old,plt_show = fig_new,False
#以上、記事で追加済み
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
#以下、記事で追加済み
plt_show=True
#以下GIF生成用コード
+ img_tmp = callfunc(img[0].cpu().unsqueeze(0))
+ img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
+ img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
+ image = img_tmp.astype(np.uint8)
+ frames.append(Image.fromarray(image))#以上で最終画像をリストに追加します。
+ frames[0].save('output.gif',
+ save_all=True, append_images=frames[1:], optimize=False, duration=500, loop=0)
#以上GIF生成用
if plt_show == True:
plt.close()
plt_show = False
#以上、記事で追加済み
return img
...
以上で、カレントディレクトリ(フォルダ)に output.gif が作成されます。
大きさは512×512×(dim_steps×0.5秒)となります。
記事ではオンラインで GIF アニメーションを編集できるサイトを使用しました。