0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

WAN2.1のL1推移

Posted at

前回記事でHunyuanVideoのL1推移を調べた。

WAN2.1のT2Vの1.3Bモデルと14BモデルのL1推移を調べてみたい。

T2V-1.3B

step変化

wan_1.3B_01.png
wan_1.3B_02.png
wan_1.3B_03.png
wan_1.3B_04.png
wan_1.3B_05.png

shift値変化

wan_1.3B_06.png
wan_1.3B_07.png
wan_1.3B_08.png

num_frame変化

wan_1.3B_09.png

import torch
from diffusers import WanPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
import numpy as np
import matplotlib.pyplot as plt

model_name = "wan_1.3B"
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
#model_name = "wan_14B"
#model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

def l1_callback(pipe, step, timestep, callback_kwargs):
    global l1_log
    global pre_latent

    latent_input    = callback_kwargs['latent_model_input'].detach().clone()
    noise_pred_init = callback_kwargs['noise_pred_init'].detach().clone()
    noise_pred      = callback_kwargs['noise_pred'].detach().clone()
    latent_last     = callback_kwargs['latents'].detach().clone()

    if pre_latent is None:
        pre_latent = noise_pred_init
    else:
        l1 = (noise_pred_init-pre_latent).abs().mean()/pre_latent.abs().mean()
        pre_latent = noise_pred_init
        print(f"Step {step}, Timestep {timestep}, L1 = {l1.item():.6f}")
        l1_log.append((step, timestep.item(), l1.item()))

    return callback_kwargs

fig, ax = plt.subplots(1,3,figsize=(18,5))
fig.suptitle('L1 step[10,20]')

for step in [10,20]:
    for shift in [7.0]:
        for num_frames in [1]:

            l1_log = []
            pre_latent = None
            pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
            prompt = "a photo of an astronaut riding a horse on Mars."
            video = pipe(prompt=prompt, negative_prompt="", height=320, width=512, num_frames=num_frames, num_inference_steps=step, guidance_scale=5.0,  callback_on_step_end=l1_callback).frames[0]
            
            export_to_video(video, "sample_wan00.mp4", fps=15)
            l1_log = np.array(l1_log)
            print(l1_log, l1_log.shape)
            label = 'step=%d, shift=%2.1f, frame=%d' % (step, shift, num_frames)
            ax[0].scatter(l1_log[:,0]/float(step),l1_log[:,2],s=20,label=label)
            ax[0].set_xlabel('Step')
            ax[0].set_ylabel('L1')
            ax[0].legend()
            ax[1].scatter(1000-l1_log[:,1],l1_log[:,2],s=20,label=label)
            ax[1].set_xlabel('1000-Timestep')
            ax[1].set_ylabel('L1')
            ax[1].legend()
            ax[2].plot(l1_log[:,0]/float(step),1000-l1_log[:,1],label=label)
            ax[2].set_xlabel('Step')
            ax[2].set_ylabel('1000-Timestep')
            ax[2].legend()
plt.savefig('%s_000.png' % model_name)
#plt.show()

diffusersのpipelineのcallback周りを以下のように記述。

pipeline_wan.py
                self._current_timestep = t
                latent_model_input = latents.to(transformer_dtype)
                timestep = t.expand(latents.shape[0])

                noise_pred_init = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=prompt_embeds,
                    attention_kwargs=attention_kwargs,
                    return_dict=False,
                )[0]

                if self.do_classifier_free_guidance:
                    noise_uncond = self.transformer(
                        hidden_states=latent_model_input,
                        timestep=timestep,
                        encoder_hidden_states=negative_prompt_embeds,
                        attention_kwargs=attention_kwargs,
                        return_dict=False,
                    )[0]
                    noise_pred = noise_uncond + guidance_scale * (noise_pred_init - noise_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

                if callback_on_step_end is not None:
                    callback_kwargs = {}

                    #for k in callback_on_step_end_tensor_inputs:
                    #    callback_kwargs[k] = locals()[k]
                    callback_kwargs['latent_model_input'] = latent_model_input
                    callback_kwargs['noise_pred_init'] = noise_pred_init
                    callback_kwargs['noise_pred'] = noise_pred
                    callback_kwargs['latents'] = latents

                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    #latents = callback_outputs.pop("latents", latents)
                    #prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    #negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

TimestepがHunyuanVideoとWAN2.1で違う

左から二番目の図は横軸が1000-Timestepとなっているが、HunyuanVideoではshift=11のとき最後の推論stepは650付近までしかない。つまり最後の1stepでtimestepは350進む。
一方、WAN2.1でのshift=11のとき最後の推論stepが900付近にある。つまり最後の1stepでtimestepは100進む。

HunyuanVideo

sample007.png

WAN2.1

wan_1.3B_07.png

これはdiffusersコード内でtimestepsの定義が違っており、HunyuanVideoでは

pipeline_hunyuan_video.py
        import numpy as np
        sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1]
        self.scheduler.set_timesteps(sigmas=sigmas, device=device)
        timesteps = self.scheduler.timesteps
        num_inference_steps = len(timesteps)

一方、WAN2.1では以下のようになっている。hunyuanvideoではsigmasを渡してるが、wan2.1ではsigmasを渡していない。

pipeline_wan.py
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

上記のコードによればunipcにおいてはdiffusersでのWANでの定義のtimestepでいいが、dpm++やEulerの場合はdiffusersでのHunyuanvideoの定義と同じ事がわかる。

            if sample_solver == 'unipc':
                sample_scheduler = FlowUniPCMultistepScheduler(
                    num_train_timesteps=self.num_train_timesteps,
                    shift=1,
                    use_dynamic_shifting=False)
                sample_scheduler.set_timesteps(
                    sampling_steps, device=self.device, shift=shift)
                timesteps = sample_scheduler.timesteps
            elif sample_solver == 'dpm++':
                sample_scheduler = FlowDPMSolverMultistepScheduler(
                    num_train_timesteps=self.num_train_timesteps,
                    shift=1,
                    use_dynamic_shifting=False)
                sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
                timesteps, _ = retrieve_timesteps(
                    sample_scheduler,
                    device=self.device,
                    sigmas=sampling_sigmas)
            else:
                raise NotImplementedError("Unsupported solver.")
...
def get_sampling_sigmas(sampling_steps, shift):
    sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
    sigma = (shift * sigma / (1 + (shift - 1) * sigma))

    return sigma

以下、単純なEulerタイプのスケジュラーを使用したいのでpipelineのtimestep定義をHunyuanVideo準拠に書き換えてL1を見る。

T2V-1.3B(timestep修正)

step変化

wan_1.3B_001.png
wan_1.3B_002.png
wan_1.3B_003.png
wan_1.3B_004.png
wan_1.3B_005.png

shift値変化

wan_1.3B_006.png
wan_1.3B_007.png
wan_1.3B_008.png

num_frame変化

wan_1.3B_009.png

T2V-14B(timestep修正)

step変化

wan_14B_001.png
wan_14B_002.png
wan_14B_003.png
wan_14B_004.png
wan_14B_005.png

shift値変化

wan_14B_006.png
wan_14B_007.png
wan_14B_008.png

num_frame変化

フレーム数が大きいほど推論step最初のL1が大きい。また生成フレーム長が9や17の低フレームにおいて全体的にL1が悪化する。
wan_14B_009.png

HunyuanVideoとWAN2.1の比較

HunyuanVideoでは推論stepが小さいとき、推論step初期のL1が大きく、shift値を大きくしていくほうがL1が下がってよくなるように見える。
反面、WANにおいては推論step初期のL1はそれほど大きくなく、また推論step後半のL1はshift値が大きいほどL1が大きいデメリットを感じる。

HunyuanVideo

sample007.png

WAN2.1

wan_14B_007.png

Shift関数再考

通常のshift関数はt=1で勾配1/s、t=0で勾配sである。

f(t)=\frac{s\cdot t}{1+(s-1)\cdot t}
\frac{df(t)}{dt}=\frac{s}{(1+(s-1)\cdot t)^2}=
\left\{\begin{array}{ll} 1/s & (t = 1) \\
s & (t = 0)
\end{array}
\right.

Wanではt=0,1で勾配1/s、t=1/2で勾配sになる関数を考えてみる。
これを与える関数は例えば以下の関数である。

\frac{dg(t)}{dt}=\frac{s}{1+(s^2-1)\cdot (2t-1)^2}=
\left\{\begin{array}{ll} 1/s & (t = 0,1) \\
s & (t = 1/2)
\end{array}
\right.

image.png(chatgptにグラフを頼んだ)

この微分関数(ローレンツ分布)を積分したものを$g(t)$として$g(0)=0$、$g(1)=1$となるように調整する。この積分した関数形はローレンツ分布の累積分布関数となる。

この関数は

g(t)=\frac{1}{2}(1+\frac{\arctan{\sqrt{s^2-1}}(2t-1)}{\arctan{\sqrt{s^2-1}}})

なお$t=1→0$に進むから本来shift関数を左右反転、上下反転する必要があるが、今回の場合、左右と上下を一緒に反転するとグラフは元の形と一致するため、反転を考慮する必要はない。

image.png

import matplotlib.pyplot as plt
import numpy as np

s = 7
s2 = 17
a = 1.57
x = np.arange(51)
y1 = np.arange(51)/50
y2 = 1-(s * (1-y1))/(1+(s-1)*(1-y1))
y3 = 1-(s2 * (1-y1))/(1+(s2-1)*(1-y1))
y4 = np.where(x < 25, 0.004 * x, (x-25+a)*(x-25+a)/(25+a)/(25+a)*(1-0.1)+0.1)
s = 3
y5 = 1/2*(1+np.arctan(np.sqrt(s**2-1)*(2*y1-1))/np.arctan(np.sqrt(s**2-1)))
s = 5
y6 = 1/2*(1+np.arctan(np.sqrt(s**2-1)*(2*y1-1))/np.arctan(np.sqrt(s**2-1)))
s = 7
y7 = 1/2*(1+np.arctan(np.sqrt(s**2-1)*(2*y1-1))/np.arctan(np.sqrt(s**2-1)))

plt.plot(x,y1,label='linear-50step')
plt.plot(x,y2,label='HunyuanVideo shift(s=7.0)')
plt.plot(x,y3,label='HunyuanVideo shift(s=17.0)')
plt.plot(x,y4,label='MovieGen linear-quadratic(250step-smooth)')
plt.plot(x,y5,label='New shift(s=3.0)')
plt.plot(x,y6,label='New shift(s=5.0)')
plt.plot(x,y7,label='New shift(s=7.0)')
plt.hlines(0.11, 0, 50, colors='red', linestyle='dashed')
plt.legend()
plt.show()

ここでdiffuserのeulerのschedulerのshift関数を

scheduling_flow_match_euler_discrete.py
            sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)

以下のように書き換える。+1e-10を省略するとself.shift=1.0のとき分母をゼロで割るエラーが出る。

scheduling_flow_match_euler_discrete.py
            s = self.shift + 1e-10
            sigmas = 1/2*(1+np.arctan(np.sqrt(s**2-1)*(2*sigmas-1))/np.arctan(np.sqrt(s**2-1)))

T2V-14B(new shift)

結論から言うと前半の推論ではtimestep勾配を緩やかにしてL1を小さくできるが、後半の推論のtimestep勾配を緩やかにしてもL1推移は小さくはならなかった。まあ、schedulerを調整すればシグモイド状のshift関数も作れることを確認した。

step変化

wan_14B_001_shift.png
wan_14B_002_shift.png
wan_14B_003_shift.png
wan_14B_004_shift.png
wan_14B_005_shift.png

shift値変化

wan_14B_006_shift.png
wan_14B_007_shift.png
wan_14B_008_shift.png

num_frame変化

なし

まとめ

WAN2.1におけるL1を調べたが、HunyuanvideoのL1と勝手が違い、大きなshift値の入力が推論stepの効率化に寄与しない可能性はある。
推論step後半のL1が大きいのが気にかかり、推論step後半でtimestep勾配を小さくするシグモイド関数的なshift関数を新たに作ってみたが、後半のL1の大きさは特に改善しなかった。

テストのほとんどはテスト時間の都合で動画1フレームであり、wan2.1 T2V-14Bの9frameや17frameの時の低フレームでの出力動画も初期フレームにおいてすこぶる悪かったため、1frame生成結果もどれだけ信用できるのか不明です。このため生成動画の見た目の評価はほとんど信用できないと感じたためL1推移以外の結果をここでは述べません。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?