お絵描きAI、Stable Diffusion ですが、
スペックの低いPCでは待ち時間がヒマじゃないですか?
Stable DiffusionをノートPCで持ち歩きたい!
で、完全ローカル化しましたので、
コードをいじって画像ができてくる途中を見えるようにしました。
使用したPCは
OS:Windows10 Home 64bit
CPU:Intel(R) Core(TM) i7-10750H
GPU:NVIDIA GeForce GTX 1650 Ti
RAM:16GB
です。
CPUのみでも動作しますので上の記事を参考にしてください。
※Stable Diffusion 画像を作っている途中を見たい!Part2 を投稿しました。
 txt2imgに加えてimg2imgも扱っています。興味のある方はそちらをご覧ください。
今回はサンプリング途中の画像リストを modelFirstStage の decode_first_stage 関数で処理して生成途中の画像を見えるようにします。
「jupyter notebook 内の変更」
サンプリング中に decode_first_stage 関数が使えるように
jupyter notebook セルのコードを変更します。
def txt2img(prompts="",H=512,W=512,C=4,f=8,dim_steps=50,fixed_code=50,ddim_eta=0.0,n_rows=0,scale=7.5,device='cuda',seed=None,unet_bs=1,precision='full',format_type='png',sampler='plms'):
の中で、model.sample 関数に引数 callfunc を追加して、関数の中から decode_first_stage 関数を使えるようにします。
samples_ddim = model.sample(
                   S=dim_steps,
+                  callfunc = modelFS.decode_first_stage,#←追加
                   conditioning=c,
                   seed=seed,
                   shape=shape,
                   verbose=False,
                   unconditional_guidance_scale=scale,
                   unconditional_conditioning=uc,
                   eta=ddim_eta,
                   x_T=start_code,
                   sampler = sampler,
               )
「ddpm.pyの変更」
model.Sample 側も変更が必要です。
フォルダ構成を確認しておきましょう。。
C:\Users
 +---XXXX
   +---Documents
     +---Source
       +---Python
       +---stable-diffusion-main
         +---optimizesdSD
            +---ddpm.py ←このファイルを変更します。
            +---
            +---
         +---models
           +---ldm
             +---stable-diffusion-v1
               +----model.ckpt
となっています。
ダウンロードしてきた stable-diffusion-main 内 optimizedSD 下の ddpm.py の中を変更します。
テキストエディターなどで ddpm.py を開き編集します。
先頭付近で必要なモジュールを追加します。
...
+ import matplotlib.pyplot as plt
import time, math
from tqdm import trange, tqdm
import torch
...
次に class UNet の 関数 sample に引数 callfunc を追加。(470行付近)
@torch.no_grad()
   def sample(self,
              S,
              conditioning,
+             callfunc,#←追加
              x0=None,
              shape = None,
              seed=1234, 
              callback=None,
              img_callback=None,
              quantize_x0=False,
              eta=0.,
              mask=None,
              sampler = "plms",
              temperature=1.,
              noise_dropout=0.,
              score_corrector=None,
              corrector_kwargs=None,
              verbose=True,
              x_T=None,
              log_every_t=100,
              unconditional_guidance_scale=1.,
              unconditional_conditioning=None,
              ):
さらに(500行付近)
       if sampler == "plms":
           print(f'Data shape for PLMS sampling is {shape}')
           samples = self.plms_sampling(conditioning, batch_size, x_latent,
+                                       callfunc,#←追加
                                        callback=callback,
                                        img_callback=img_callback,
                                        quantize_denoised=quantize_x0,
                                        mask=mask, x0=x0,
                                        ddim_use_original_steps=False,
                                        noise_dropout=noise_dropout,
                                        temperature=temperature,
                                        score_corrector=score_corrector,
                                        corrector_kwargs=corrector_kwargs,
                                        log_every_t=log_every_t,
                                        unconditional_guidance_scale=unconditional_guidance_scale,
                                        unconditioning=unconditional_conditioning,
                                       )
       elif sampler == "ddim":
           samples = self.ddim_sampling(x_latent, conditioning,
+                                       callfunc,#←追加
                                        S, unconditional_guidance_scale=unconditional_guidance_scale,
                                        unconditional_conditioning=unconditional_conditioning,
                                        mask = mask,init_latent=x_T,use_original_steps=False)
plms_sampling の定義を大幅に書き換えます。(550行付近)
   @torch.no_grad()
   def plms_sampling(self, cond,b, img,
+                    callfunc,#←追加
                     ddim_use_original_steps=False,
                     callback=None, quantize_denoised=False,
                     mask=None, x0=None, img_callback=None, log_every_t=100,
                     temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                     unconditional_guidance_scale=1., unconditional_conditioning=None,):
       
       device = self.betas.device
       timesteps = self.ddim_timesteps
       time_range = np.flip(timesteps)
       total_steps = timesteps.shape[0]
       print(f"Running PLMS Sampling with {total_steps} timesteps")
       iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
       old_eps = []
+      plt_show = False#←追加
+      fig_old = None#←追加
       for i, step in enumerate(iterator):
#以下追加
+          if plt_show==True:
+              img_tmp = callfunc(img[0].cpu().unsqueeze(0))
+              img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
+              img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
+              image = img_tmp.astype(np.uint8)
+              fig_new = plt.figure(figsize=(4, 4), dpi=120)
+              if fig_old is not None:
+                  plt.close(fig_old)
+              plt.axis('off')
+              plt.text(0,-10,'{}/{}'.format(i,total_steps))
+              plt.imshow(image)
+              plt.pause(1.0)
+              fig_new.clear()
+              fig_old,plt_show = fig_new,False
#以上追加 
           index = total_steps - i - 1
           ts = torch.full((b,), step, device=device, dtype=torch.long)
           ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
           if mask is not None:
               assert x0 is not None
               img_orig = self.q_sample(x0, ts)  # TODO: deterministic forward pass?
               img = img_orig * mask + (1. - mask) * img
           outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                     quantize_denoised=quantize_denoised, temperature=temperature,
                                     noise_dropout=noise_dropout, score_corrector=score_corrector,
                                     corrector_kwargs=corrector_kwargs,
                                     unconditional_guidance_scale=unconditional_guidance_scale,
                                     unconditional_conditioning=unconditional_conditioning,
                                     old_eps=old_eps, t_next=ts_next)
           img, pred_x0, e_t = outs
           old_eps.append(e_t)
           if len(old_eps) >= 4:
               old_eps.pop(0)
           if callback: callback(i)
           if img_callback: img_callback(pred_x0, i)
#以下追加
+          plt_show=True
+      if plt_show == True:
+          plt.close()
+          plt_show = False
#以上追加
       return img
ddim_sampling は引数だけを追加します。(730行付近)
  @torch.no_grad()
   def ddim_sampling(self, x_latent, cond,
+                    callfunc,#←追加
                     t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
                     mask = None,init_latent=None,use_original_steps=False):
ddpm.py の変更は以上です。
「txt2imgの起動」
jupyter notebook に戻り、kernel を Restart させます。
セルを順次実行し、txt2img が定義されたら以下のようなセルを実行します。
import matplotlib.pyplot as plt
-%matplotlib inline#←削除
+%matplotlib#←追加
prompt="A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors, 8k"
image=txt2img(prompts=prompt,H=512,W=512,seed=25,scale=7.5,dim_steps=50,precision='full')
+%matplotlib inline#←追加
plt.figure(figsize=(8, 8), dpi=120)
plt.axis('off')
plt.imshow(image)
途中結果は以下の通り。
Using matplotlib backend: QtAgg
seeds used =  [25]
Data shape for PLMS sampling is [1, 4, 64, 64]
Running PLMS Sampling with 50 timesteps
PLMS Sampler:   2%|█▍                                                                   | 1/50 [00:10<08:13, 10.07s/it]
のように進んでいくと、別 Window に途中結果がポップアップで表示されます。

「実行結果」
完成画像の細部に影響しそうな dim_steps を変えて途中過程と完成画像を並べてみました。
dim_steps=10
 
    
dim_steps=25
 
    
dim_steps=50
 
    
※アニメーションはアップロードに限りがあり、間引いていますので時間軸は一致していません。
「まとめ」
今回は画像を作っている途中が見えるようにしました。
画像細部に影響しそうな dim_steps を変えて比較してみました。
大きくすれば完成画像が高画質になるようですが、ステップが進む途中で画像が鮮明になるのではなく、最後の数ステップで細部が決まってくるようです。
実行環境は前述のように、GPU もありますが、画像表示は CPU に移して実行が必要なため、
dim_step=50 で14分かかっています。
2022年10月11日以下追記。
記事の中で GIF アニメーションを作っていますが、
備忘録としてコードを記録しておきます。
まず、ライブラリ Pillow が必要なため、 Anaconda Navigater の環境(今は ldm )から
Open Terminal でコマンドプロンプトを起動します。
(ldm) C:\Users\XXXX>conda install pillow
と入力し Pillow をインストールします。
修正するファイルは以下の ddpm.py です。
C:\Users
 +---XXXX
   +---Documents
     +---Source
       +---Python
       +---stable-diffusion-main
         +---optimizesdSD
            +---ddpm.py ←このファイルを変更します。
            +---
            +---
         +---models
           +---ldm
             +---stable-diffusion-v1
               +----model.ckpt
先頭付近で必要なモジュールを追加します。
...
import matplotlib.pyplot as plt#記事で追加済み
+from PIL import Image  #←GIF作成用
import time, math
...
UNet クラスの plms_sampling メソッドのループ処理付近(570行付近)に以下を追加します。
...
       iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
       old_eps = []
       plt_show = False #記事で追加済み
       fig_old = None   #記事で追加済み
+      frames=[]        #←GIF作成用 空のリストを作成します。
       for i, step in enumerate(iterator):
#以下、記事で追加済み
           if plt_show==True:
               img_tmp = callfunc(img[0].cpu().unsqueeze(0))
               img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
               img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
               image = img_tmp.astype(np.uint8)
               fig_new = plt.figure(figsize=(4, 4), dpi=120)
               if fig_old is not None:
                   plt.close(fig_old)
               plt.axis('off')
               plt.text(0,-10,'{}/{}'.format(i,total_steps))
               plt.imshow(image)
+              frames.append(Image.fromarray(image))#←GIF作成用 リストにイメージを追加します。
               plt.pause(1.0)
               fig_new.clear()
               fig_old,plt_show = fig_new,False
#以上、記事で追加済み
           index = total_steps - i - 1
           ts = torch.full((b,), step, device=device, dtype=torch.long)
           ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
           if mask is not None:
               assert x0 is not None
               img_orig = self.q_sample(x0, ts)  # TODO: deterministic forward pass?
               img = img_orig * mask + (1. - mask) * img
           outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                     quantize_denoised=quantize_denoised, temperature=temperature,
                                     noise_dropout=noise_dropout, score_corrector=score_corrector,
                                     corrector_kwargs=corrector_kwargs,
                                     unconditional_guidance_scale=unconditional_guidance_scale,
                                     unconditional_conditioning=unconditional_conditioning,
                                     old_eps=old_eps, t_next=ts_next)
           img, pred_x0, e_t = outs
           old_eps.append(e_t)
           if len(old_eps) >= 4:
               old_eps.pop(0)
           if callback: callback(i)
           if img_callback: img_callback(pred_x0, i)
#以下、記事で追加済み
           plt_show=True
#以下GIF生成用コード
+      img_tmp = callfunc(img[0].cpu().unsqueeze(0))
+      img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
+      img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
+      image = img_tmp.astype(np.uint8)
+      frames.append(Image.fromarray(image))#以上で最終画像をリストに追加します。
+      frames[0].save('output.gif',
+               save_all=True, append_images=frames[1:], optimize=False, duration=500, loop=0)
#以上GIF生成用
       if plt_show == True:
           plt.close()
           plt_show = False
#以上、記事で追加済み
       return img
...
以上で、カレントディレクトリ(フォルダ)に output.gif が作成されます。
大きさは512×512×(dim_steps×0.5秒)となります。
記事ではオンラインで GIF アニメーションを編集できるサイトを使用しました。
