16
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Stable Diffusion 画像を作っている途中を見たい!

Last updated at Posted at 2022-09-25

お絵描きAI、Stable Diffusion ですが、
スペックの低いPCでは待ち時間がヒマじゃないですか?
Stable DiffusionをノートPCで持ち歩きたい!
で、完全ローカル化しましたので、
コードをいじって画像ができてくる途中を見えるようにしました。

使用したPCは
OS:Windows10 Home 64bit
CPU:Intel(R) Core(TM) i7-10750H
GPU:NVIDIA GeForce GTX 1650 Ti
RAM:16GB
です。
CPUのみでも動作しますので上の記事を参考にしてください。

Stable Diffusion 画像を作っている途中を見たい!Part2 を投稿しました。
 txt2imgに加えてimg2imgも扱っています。興味のある方はそちらをご覧ください。

今回はサンプリング途中の画像リストを modelFirstStage の decode_first_stage 関数で処理して生成途中の画像を見えるようにします。

「jupyter notebook 内の変更」

サンプリング中に decode_first_stage 関数が使えるように
jupyter notebook セルのコードを変更します。

def txt2img(prompts="",H=512,W=512,C=4,f=8,dim_steps=50,fixed_code=50,ddim_eta=0.0,n_rows=0,scale=7.5,device='cuda',seed=None,unet_bs=1,precision='full',format_type='png',sampler='plms'):

の中で、model.sample 関数に引数 callfunc を追加して、関数の中から decode_first_stage 関数を使えるようにします。

samples_ddim = model.sample(
                   S=dim_steps,
+                  callfunc = modelFS.decode_first_stage,#←追加
                   conditioning=c,
                   seed=seed,
                   shape=shape,
                   verbose=False,
                   unconditional_guidance_scale=scale,
                   unconditional_conditioning=uc,
                   eta=ddim_eta,
                   x_T=start_code,
                   sampler = sampler,
               )

「ddpm.pyの変更」

model.Sample 側も変更が必要です。
フォルダ構成を確認しておきましょう。。

C:\Users
 +---XXXX
   +---Documents
     +---Source
       +---Python
       +---stable-diffusion-main
         +---optimizesdSD
            +---ddpm.py このファイルを変更します
            +---
            +---
         +---models
           +---ldm
             +---stable-diffusion-v1
               +----model.ckpt

となっています。
ダウンロードしてきた stable-diffusion-main 内 optimizedSD 下の ddpm.py の中を変更します。
テキストエディターなどで ddpm.py を開き編集します。
先頭付近で必要なモジュールを追加します。

...
+ import matplotlib.pyplot as plt
import time, math
from tqdm import trange, tqdm
import torch
...

次に class UNet の 関数 sample に引数 callfunc を追加。(470行付近)

@torch.no_grad()
   def sample(self,
              S,
              conditioning,
+             callfunc,#←追加
              x0=None,
              shape = None,
              seed=1234, 
              callback=None,
              img_callback=None,
              quantize_x0=False,
              eta=0.,
              mask=None,
              sampler = "plms",
              temperature=1.,
              noise_dropout=0.,
              score_corrector=None,
              corrector_kwargs=None,
              verbose=True,
              x_T=None,
              log_every_t=100,
              unconditional_guidance_scale=1.,
              unconditional_conditioning=None,
              ):

さらに(500行付近)

       if sampler == "plms":
           print(f'Data shape for PLMS sampling is {shape}')
           samples = self.plms_sampling(conditioning, batch_size, x_latent,
+                                       callfunc,#←追加
                                        callback=callback,
                                        img_callback=img_callback,
                                        quantize_denoised=quantize_x0,
                                        mask=mask, x0=x0,
                                        ddim_use_original_steps=False,
                                        noise_dropout=noise_dropout,
                                        temperature=temperature,
                                        score_corrector=score_corrector,
                                        corrector_kwargs=corrector_kwargs,
                                        log_every_t=log_every_t,
                                        unconditional_guidance_scale=unconditional_guidance_scale,
                                        unconditioning=unconditional_conditioning,
                                       )

       elif sampler == "ddim":
           samples = self.ddim_sampling(x_latent, conditioning,
+                                       callfunc,#←追加
                                        S, unconditional_guidance_scale=unconditional_guidance_scale,
                                        unconditional_conditioning=unconditional_conditioning,
                                        mask = mask,init_latent=x_T,use_original_steps=False)

plms_sampling の定義を大幅に書き換えます。(550行付近)

   @torch.no_grad()
   def plms_sampling(self, cond,b, img,
+                    callfunc,#←追加
                     ddim_use_original_steps=False,
                     callback=None, quantize_denoised=False,
                     mask=None, x0=None, img_callback=None, log_every_t=100,
                     temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                     unconditional_guidance_scale=1., unconditional_conditioning=None,):
       
       device = self.betas.device
       timesteps = self.ddim_timesteps
       time_range = np.flip(timesteps)
       total_steps = timesteps.shape[0]
       print(f"Running PLMS Sampling with {total_steps} timesteps")

       iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
       old_eps = []

+      plt_show = False#←追加
+      fig_old = None#←追加

       for i, step in enumerate(iterator):
#以下追加
+          if plt_show==True:
+              img_tmp = callfunc(img[0].cpu().unsqueeze(0))
+              img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
+              img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
+              image = img_tmp.astype(np.uint8)
+              fig_new = plt.figure(figsize=(4, 4), dpi=120)
+              if fig_old is not None:
+                  plt.close(fig_old)
+              plt.axis('off')
+              plt.text(0,-10,'{}/{}'.format(i,total_steps))
+              plt.imshow(image)
+              plt.pause(1.0)
+              fig_new.clear()
+              fig_old,plt_show = fig_new,False
#以上追加 
           index = total_steps - i - 1
           ts = torch.full((b,), step, device=device, dtype=torch.long)
           ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)

           if mask is not None:
               assert x0 is not None
               img_orig = self.q_sample(x0, ts)  # TODO: deterministic forward pass?
               img = img_orig * mask + (1. - mask) * img

           outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                     quantize_denoised=quantize_denoised, temperature=temperature,
                                     noise_dropout=noise_dropout, score_corrector=score_corrector,
                                     corrector_kwargs=corrector_kwargs,
                                     unconditional_guidance_scale=unconditional_guidance_scale,
                                     unconditional_conditioning=unconditional_conditioning,
                                     old_eps=old_eps, t_next=ts_next)
           img, pred_x0, e_t = outs
           old_eps.append(e_t)
           if len(old_eps) >= 4:
               old_eps.pop(0)
           if callback: callback(i)
           if img_callback: img_callback(pred_x0, i)

#以下追加
+          plt_show=True
+      if plt_show == True:
+          plt.close()
+          plt_show = False
#以上追加
       return img

ddim_sampling は引数だけを追加します。(730行付近)

  @torch.no_grad()
   def ddim_sampling(self, x_latent, cond,
+                    callfunc,#←追加
                     t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
                     mask = None,init_latent=None,use_original_steps=False):

ddpm.py の変更は以上です。

「txt2imgの起動」

jupyter notebook に戻り、kernel を Restart させます。
セルを順次実行し、txt2img が定義されたら以下のようなセルを実行します。

import matplotlib.pyplot as plt
-%matplotlib inline#←削除
+%matplotlib#←追加

prompt="A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors, 8k"

image=txt2img(prompts=prompt,H=512,W=512,seed=25,scale=7.5,dim_steps=50,precision='full')

+%matplotlib inline#←追加

plt.figure(figsize=(8, 8), dpi=120)
plt.axis('off')
plt.imshow(image)

途中結果は以下の通り。

Using matplotlib backend: QtAgg
seeds used =  [25]
Data shape for PLMS sampling is [1, 4, 64, 64]
Running PLMS Sampling with 50 timesteps
PLMS Sampler:   2%|█▍                                                                   | 1/50 [00:10<08:13, 10.07s/it]

のように進んでいくと、別 Window に途中結果がポップアップで表示されます。
StableDiffusion.jpg

「実行結果」

完成画像の細部に影響しそうな dim_steps を変えて途中過程と完成画像を並べてみました。

dim_steps=10
Step10.gif    Step10x.png
dim_steps=25
Step25.gif    Step25x.png
dim_steps=50
Step50.gif    Step50x.png
※アニメーションはアップロードに限りがあり、間引いていますので時間軸は一致していません。

「まとめ」

今回は画像を作っている途中が見えるようにしました。
画像細部に影響しそうな dim_steps を変えて比較してみました。
大きくすれば完成画像が高画質になるようですが、ステップが進む途中で画像が鮮明になるのではなく、最後の数ステップで細部が決まってくるようです。
実行環境は前述のように、GPU もありますが、画像表示は CPU に移して実行が必要なため、
dim_step=50 で14分かかっています。

2022年10月11日以下追記。
記事の中で GIF アニメーションを作っていますが、
備忘録としてコードを記録しておきます。
まず、ライブラリ Pillow が必要なため、 Anaconda Navigater の環境(今は ldm )から
Open Terminal でコマンドプロンプトを起動します。

(ldm) C:\Users\XXXX>conda install pillow

と入力し Pillow をインストールします。
修正するファイルは以下の ddpm.py です。

C:\Users
 +---XXXX
   +---Documents
     +---Source
       +---Python
       +---stable-diffusion-main
         +---optimizesdSD
            +---ddpm.py このファイルを変更します
            +---
            +---
         +---models
           +---ldm
             +---stable-diffusion-v1
               +----model.ckpt

先頭付近で必要なモジュールを追加します。

...
import matplotlib.pyplot as plt#記事で追加済み
+from PIL import Image  #←GIF作成用
import time, math
...

UNet クラスの plms_sampling メソッドのループ処理付近(570行付近)に以下を追加します。

...
       iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
       old_eps = []

       plt_show = False #記事で追加済み
       fig_old = None   #記事で追加済み
+      frames=[]        #←GIF作成用 空のリストを作成します。

       for i, step in enumerate(iterator):
#以下、記事で追加済み
           if plt_show==True:
               img_tmp = callfunc(img[0].cpu().unsqueeze(0))
               img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
               img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
               image = img_tmp.astype(np.uint8)
               fig_new = plt.figure(figsize=(4, 4), dpi=120)
               if fig_old is not None:
                   plt.close(fig_old)
               plt.axis('off')
               plt.text(0,-10,'{}/{}'.format(i,total_steps))
               plt.imshow(image)
+              frames.append(Image.fromarray(image))#←GIF作成用 リストにイメージを追加します。
               plt.pause(1.0)
               fig_new.clear()
               fig_old,plt_show = fig_new,False
#以上、記事で追加済み
           index = total_steps - i - 1
           ts = torch.full((b,), step, device=device, dtype=torch.long)
           ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)

           if mask is not None:
               assert x0 is not None
               img_orig = self.q_sample(x0, ts)  # TODO: deterministic forward pass?
               img = img_orig * mask + (1. - mask) * img

           outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                     quantize_denoised=quantize_denoised, temperature=temperature,
                                     noise_dropout=noise_dropout, score_corrector=score_corrector,
                                     corrector_kwargs=corrector_kwargs,
                                     unconditional_guidance_scale=unconditional_guidance_scale,
                                     unconditional_conditioning=unconditional_conditioning,
                                     old_eps=old_eps, t_next=ts_next)
           img, pred_x0, e_t = outs
           old_eps.append(e_t)
           if len(old_eps) >= 4:
               old_eps.pop(0)
           if callback: callback(i)
           if img_callback: img_callback(pred_x0, i)

#以下、記事で追加済み
           plt_show=True
#以下GIF生成用コード
+      img_tmp = callfunc(img[0].cpu().unsqueeze(0))
+      img_tmp = torch.clamp((img_tmp + 1.0) / 2.0, min=0.0, max=1.0)
+      img_tmp = 255.0 * rearrange(img_tmp[0].numpy(), "c h w -> h w c")
+      image = img_tmp.astype(np.uint8)
+      frames.append(Image.fromarray(image))#以上で最終画像をリストに追加します。
+      frames[0].save('output.gif',
+               save_all=True, append_images=frames[1:], optimize=False, duration=500, loop=0)
#以上GIF生成用
       if plt_show == True:
           plt.close()
           plt_show = False
#以上、記事で追加済み
       return img

...

以上で、カレントディレクトリ(フォルダ)に output.gif が作成されます。
大きさは512×512×(dim_steps×0.5秒)となります。
記事ではオンラインで GIF アニメーションを編集できるサイトを使用しました。

16
11
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?