前回記事でHunyuanVideoのL1推移を調べた。
WAN2.1のT2Vの1.3Bモデルと14BモデルのL1推移を調べてみたい。
T2V-1.3B
step変化
shift値変化
num_frame変化
import torch
from diffusers import WanPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
import numpy as np
import matplotlib.pyplot as plt
model_name = "wan_1.3B"
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
#model_name = "wan_14B"
#model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
def l1_callback(pipe, step, timestep, callback_kwargs):
global l1_log
global pre_latent
latent_input = callback_kwargs['latent_model_input'].detach().clone()
noise_pred_init = callback_kwargs['noise_pred_init'].detach().clone()
noise_pred = callback_kwargs['noise_pred'].detach().clone()
latent_last = callback_kwargs['latents'].detach().clone()
if pre_latent is None:
pre_latent = noise_pred_init
else:
l1 = (noise_pred_init-pre_latent).abs().mean()/pre_latent.abs().mean()
pre_latent = noise_pred_init
print(f"Step {step}, Timestep {timestep}, L1 = {l1.item():.6f}")
l1_log.append((step, timestep.item(), l1.item()))
return callback_kwargs
fig, ax = plt.subplots(1,3,figsize=(18,5))
fig.suptitle('L1 step[10,20]')
for step in [10,20]:
for shift in [7.0]:
for num_frames in [1]:
l1_log = []
pre_latent = None
pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
prompt = "a photo of an astronaut riding a horse on Mars."
video = pipe(prompt=prompt, negative_prompt="", height=320, width=512, num_frames=num_frames, num_inference_steps=step, guidance_scale=5.0, callback_on_step_end=l1_callback).frames[0]
export_to_video(video, "sample_wan00.mp4", fps=15)
l1_log = np.array(l1_log)
print(l1_log, l1_log.shape)
label = 'step=%d, shift=%2.1f, frame=%d' % (step, shift, num_frames)
ax[0].scatter(l1_log[:,0]/float(step),l1_log[:,2],s=20,label=label)
ax[0].set_xlabel('Step')
ax[0].set_ylabel('L1')
ax[0].legend()
ax[1].scatter(1000-l1_log[:,1],l1_log[:,2],s=20,label=label)
ax[1].set_xlabel('1000-Timestep')
ax[1].set_ylabel('L1')
ax[1].legend()
ax[2].plot(l1_log[:,0]/float(step),1000-l1_log[:,1],label=label)
ax[2].set_xlabel('Step')
ax[2].set_ylabel('1000-Timestep')
ax[2].legend()
plt.savefig('%s_000.png' % model_name)
#plt.show()
diffusersのpipelineのcallback周りを以下のように記述。
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred_init = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred_init - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
#for k in callback_on_step_end_tensor_inputs:
# callback_kwargs[k] = locals()[k]
callback_kwargs['latent_model_input'] = latent_model_input
callback_kwargs['noise_pred_init'] = noise_pred_init
callback_kwargs['noise_pred'] = noise_pred
callback_kwargs['latents'] = latents
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
#latents = callback_outputs.pop("latents", latents)
#prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
#negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
TimestepがHunyuanVideoとWAN2.1で違う
左から二番目の図は横軸が1000-Timestepとなっているが、HunyuanVideoではshift=11のとき最後の推論stepは650付近までしかない。つまり最後の1stepでtimestepは350進む。
一方、WAN2.1でのshift=11のとき最後の推論stepが900付近にある。つまり最後の1stepでtimestepは100進む。
HunyuanVideo
WAN2.1
これはdiffusersコード内でtimestepsの定義が違っており、HunyuanVideoでは
import numpy as np
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1]
self.scheduler.set_timesteps(sigmas=sigmas, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
一方、WAN2.1では以下のようになっている。hunyuanvideoではsigmasを渡してるが、wan2.1ではsigmasを渡していない。
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
上記のコードによればunipcにおいてはdiffusersでのWANでの定義のtimestepでいいが、dpm++やEulerの場合はdiffusersでのHunyuanvideoの定義と同じ事がわかる。
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
...
def get_sampling_sigmas(sampling_steps, shift):
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
return sigma
以下、単純なEulerタイプのスケジュラーを使用したいのでpipelineのtimestep定義をHunyuanVideo準拠に書き換えてL1を見る。
T2V-1.3B(timestep修正)
step変化
shift値変化
num_frame変化
T2V-14B(timestep修正)
step変化
shift値変化
num_frame変化
フレーム数が大きいほど推論step最初のL1が大きい。また生成フレーム長が9や17の低フレームにおいて全体的にL1が悪化する。
HunyuanVideoとWAN2.1の比較
HunyuanVideoでは推論stepが小さいとき、推論step初期のL1が大きく、shift値を大きくしていくほうがL1が下がってよくなるように見える。
反面、WANにおいては推論step初期のL1はそれほど大きくなく、また推論step後半のL1はshift値が大きいほどL1が大きいデメリットを感じる。
HunyuanVideo
WAN2.1
Shift関数再考
通常のshift関数はt=1で勾配1/s、t=0で勾配sである。
f(t)=\frac{s\cdot t}{1+(s-1)\cdot t}
\frac{df(t)}{dt}=\frac{s}{(1+(s-1)\cdot t)^2}=
\left\{\begin{array}{ll} 1/s & (t = 1) \\
s & (t = 0)
\end{array}
\right.
Wanではt=0,1で勾配1/s、t=1/2で勾配sになる関数を考えてみる。
これを与える関数は例えば以下の関数である。
\frac{dg(t)}{dt}=\frac{s}{1+(s^2-1)\cdot (2t-1)^2}=
\left\{\begin{array}{ll} 1/s & (t = 0,1) \\
s & (t = 1/2)
\end{array}
\right.
この微分関数(ローレンツ分布)を積分したものを$g(t)$として$g(0)=0$、$g(1)=1$となるように調整する。この積分した関数形はローレンツ分布の累積分布関数となる。
この関数は
g(t)=\frac{1}{2}(1+\frac{\arctan{\sqrt{s^2-1}}(2t-1)}{\arctan{\sqrt{s^2-1}}})
なお$t=1→0$に進むから本来shift関数を左右反転、上下反転する必要があるが、今回の場合、左右と上下を一緒に反転するとグラフは元の形と一致するため、反転を考慮する必要はない。
import matplotlib.pyplot as plt
import numpy as np
s = 7
s2 = 17
a = 1.57
x = np.arange(51)
y1 = np.arange(51)/50
y2 = 1-(s * (1-y1))/(1+(s-1)*(1-y1))
y3 = 1-(s2 * (1-y1))/(1+(s2-1)*(1-y1))
y4 = np.where(x < 25, 0.004 * x, (x-25+a)*(x-25+a)/(25+a)/(25+a)*(1-0.1)+0.1)
s = 3
y5 = 1/2*(1+np.arctan(np.sqrt(s**2-1)*(2*y1-1))/np.arctan(np.sqrt(s**2-1)))
s = 5
y6 = 1/2*(1+np.arctan(np.sqrt(s**2-1)*(2*y1-1))/np.arctan(np.sqrt(s**2-1)))
s = 7
y7 = 1/2*(1+np.arctan(np.sqrt(s**2-1)*(2*y1-1))/np.arctan(np.sqrt(s**2-1)))
plt.plot(x,y1,label='linear-50step')
plt.plot(x,y2,label='HunyuanVideo shift(s=7.0)')
plt.plot(x,y3,label='HunyuanVideo shift(s=17.0)')
plt.plot(x,y4,label='MovieGen linear-quadratic(250step-smooth)')
plt.plot(x,y5,label='New shift(s=3.0)')
plt.plot(x,y6,label='New shift(s=5.0)')
plt.plot(x,y7,label='New shift(s=7.0)')
plt.hlines(0.11, 0, 50, colors='red', linestyle='dashed')
plt.legend()
plt.show()
ここでdiffuserのeulerのschedulerのshift関数を
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
以下のように書き換える。+1e-10
を省略するとself.shift=1.0
のとき分母をゼロで割るエラーが出る。
s = self.shift + 1e-10
sigmas = 1/2*(1+np.arctan(np.sqrt(s**2-1)*(2*sigmas-1))/np.arctan(np.sqrt(s**2-1)))
T2V-14B(new shift)
結論から言うと前半の推論ではtimestep勾配を緩やかにしてL1を小さくできるが、後半の推論のtimestep勾配を緩やかにしてもL1推移は小さくはならなかった。まあ、schedulerを調整すればシグモイド状のshift関数も作れることを確認した。
step変化
shift値変化
num_frame変化
なし
まとめ
WAN2.1におけるL1を調べたが、HunyuanvideoのL1と勝手が違い、大きなshift値の入力が推論stepの効率化に寄与しない可能性はある。
推論step後半のL1が大きいのが気にかかり、推論step後半でtimestep勾配を小さくするシグモイド関数的なshift関数を新たに作ってみたが、後半のL1の大きさは特に改善しなかった。
テストのほとんどはテスト時間の都合で動画1フレームであり、wan2.1 T2V-14Bの9frameや17frameの時の低フレームでの出力動画も初期フレームにおいてすこぶる悪かったため、1frame生成結果もどれだけ信用できるのか不明です。このため生成動画の見た目の評価はほとんど信用できないと感じたためL1推移以外の結果をここでは述べません。