0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

HunyuanvideoのL1推移

Posted at

はじめに

動画生成モデルの論文において各推論stepにおけるlatentのL1変化の推移を見ることがしばしば行われる。

例えばMovieGenでは以下のような図が知られ、最初の推論stepの変化寄与が大きく後半の推論stepの寄与は小さい。これからlinear-quadratic t-scheduleという、推論stepの前半半分では線形、推論stepの後半半分では二次関数のtのシフト関数を採用することで推論stepの削減を目指した。前半では小さく動き、後半では大きく動く。
なお、図のlinear-quadraticは前半25stepで1000stepの線形stepの勾配と同じだが、実際には250stepの線形stepが使用されると論文にあり、25stepでtimestepで0.1進むと思われる。このため図が正確ではない。

image.png

一方、HunyuanVideoやWan2.1などではSD3やFLUX.1を真似たflow-shift関数
$t'=\frac{s\cdot t}{1+(s-1)\cdot t}$が使われる。このshift関数の微分式を求めるとt=1(推論stepの最初)では傾きは1/sになり、t=0(推論stepの最後)では傾きはs倍になる。linear-quadraticと推論step最初で慎重に動き、後半では大きく動くという点では共通している。

\frac{dt'}{dt}=\frac{s}{(1+(s-1)\cdot t)^2}=
\left\{\begin{array}{ll} 1/s & (t = 1) \\
s & (t = 0)
\end{array}
\right.

t=1→0に進むので、時系列に沿った表示では左右反転・上下反転が必要
image.png

import matplotlib.pyplot as plt
import numpy as np

s = 7
s2 = 17
a = 1.57
x = np.arange(51)
y1 = np.arange(51)/50
y2 = 1-(s * (1-y1))/(1+(s-1)*(1-y1))
y3 = 1-(s2 * (1-y1))/(1+(s2-1)*(1-y1))
y4 = np.where(x < 25, 0.004 * x, (x-25+a)*(x-25+a)/(25+a)/(25+a)*(1-0.1)+0.1)

plt.plot(x,y1,label='linear-50step')
plt.plot(x,y2,label='HunyuanVideo shift(s=7.0)')
plt.plot(x,y3,label='HunyuanVideo shift(s=17.0)')
plt.plot(x,y4,label='MovieGen linear-quadratic(250step-smooth)')
plt.hlines(0.11, 0, 50, colors='red', linestyle='dashed')
plt.legend()
plt.show()

SD3ではこのshift関数はs=3が用いられる。

image.png

APTでもこのshift関数があり、画像生成ではs=1、動画生成ではs=12が用いられる。
image.png

一方、これらのshift関数による推論stepの効率化とはまた別にTeaCacheがある。
これはstepごとのL1を監視してこれが小さければ次の推論stepをskipしてtimestep分を一気に進む。
しかし、TeaCacheの元論文にあるOpen Sora、Latte、OpenSora-Planは比較的初期の動画生成AIであり、HunyuanVideoのようなflow-shift関数は採用されていないと思われる。このためこれらのモデルでは推論stepの最適化が行われておらず、無駄な推論stepが多いので特にTeaCacheが有用である。
HunyuanVideoのshift関数とTeaCacheの相性についてはTeaCache論文上では明らかではないのだが、実際TeaCacheは盲目的によく使われる。

image.png

なお、推論stepの前半をあまり間引かず、推論stepの後半に従って多く間引いたとき、TeaCacheはshift関数とほぼ同じ動きをすることを以前の記事内で指摘した。

image.png

HunyuanVideoやWanのL1推移は以下の論文にみられる。
いわゆるU字型である。

image.png

また、Attentionの有用性を観察するために、推論step毎にAttentionの大きさをプロットする場合も存在するが、図として似ているだけで今回の対象ではない。
image.png

diffusersのcallback関数

diffusersのcallback関数を使って推論step毎にlatentの差分L1を計算したい。
このL1の値が推論step数、shift関数、動画フレーム長などに依存するのかを調べたい。
ただしtimestepは1000→0の方向(降順)に進むので図では1000-timestepで便宜上示す。

sample00.png

import torch
from diffusers import HunyuanVideoTransformer3DModel, HunyuanVideoPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
import numpy as np
import matplotlib.pyplot as plt

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)

pipe = HunyuanVideoPipeline.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    transformer=transformer,
    torch_dtype=torch.float16,
)

pipe.vae.enable_tiling()
pipe.enable_sequential_cpu_offload()
pipe.enable_model_cpu_offload()

def l1_callback(pipe, step, timestep, callback_kwargs):
    global l1_log
    prev_latent = callback_kwargs['latent_model_input'].detach().clone()
    latent = callback_kwargs['latents'].detach().clone()
    
    l1 = (latent - prev_latent).abs().mean() / (prev_latent).abs().mean()
    print(f"Step {step}, Timestep {timestep}, L1 = {l1.item():.6f}")
    l1_log.append((step, timestep.item(), l1.item()))

    return callback_kwargs

fig, ax = plt.subplots(1,3,figsize=(18,5))
fig.suptitle('L1 step[10,20]')

for step in [10,20]:
    for shift in [7.0]:
        for num_frames in [1]: 
            l1_log = []
            pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
            prompt = "a photo of an astronaut riding a horse on Mars."
            video = pipe(prompt=prompt, height=320, width=512, num_frames=num_frames, num_inference_steps=step,guidance_scale=6.0,  callback_on_step_end=l1_callback).frames[0]
            export_to_video(video, "sample00.mp4", fps=15)
            l1_log = np.array(l1_log)
            print(l1_log, l1_log.shape)

            ax[0].scatter(l1_log[:,0]/float(step),l1_log[:,2],s=20,label='step=%d, shift=%2.1f, frame=%d' % (step, shift, num_frames))
            ax[0].set_xlabel('Step')
            ax[0].set_ylabel('L1')
            ax[0].legend()
            ax[1].scatter(1000-l1_log[:,1],l1_log[:,2],s=20,label='step=%d, shift=%2.1f, frame=%d' % (step, shift, num_frames))
            ax[1].set_xlabel('1000-Timestep')
            ax[1].set_ylabel('L1')
            ax[1].legend()
            ax[2].plot(l1_log[:,0]/float(step),1000-l1_log[:,1],label='step=%d, shift=%2.1f, frame=%d' % (step, shift, num_frames))
            ax[2].set_xlabel('Step')
            ax[2].set_ylabel('1000-Timestep')
            ax[2].legend()
plt.savefig('sample00.png')
plt.show()

また、引数をうまく取れなかったので['latent_model_input','latents']を直接callback関数に渡すようにちょっと元のpipelineコードを触った。

pipeline_hunyuan_video.py
                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    #for k in callback_on_step_end_tensor_inputs:
                    #    callback_kwargs[k] = locals()[k]
                    callback_kwargs['latent_model_input'] = latent_model_input
                    callback_kwargs['latents'] = latents
                    
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    #latents = callback_outputs.pop("latents", latents)
                    #prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

step変化

sample01.png
sample02.png
sample03.png
sample04.png
sample05.png

shift値変化

sample06.png
sample07.png
sample08.png

num_frame変化

sample09.png

…まあ、結論から言うと事前に紹介した論文にあるL1グラフと合ってない。
推論stepの最初でのL1が最も大きくなるはずだがそうなってない。また、L1変化はTimestep変化の大きさに比例しこれは推論step数やshift値に依存する。shift値が高いほど後半のTimestep変化が大きく、step数が大きいほど1step当たりのTimestep変化が小さい。

再考:

そもそも今回計算したL1はTeaCacheのL1計算に準拠しているが、どのような定義なのか厳密には不明である。差分を取らなかったり絶対値で割らない、単純な絶対値平均という意味でのL1なのかもしれない。
image.png

また、後で気づいたがdiffusersのguidance_scaleではcfg_guidanceは何故か働いていない。cfgを明示的に使うには以下のようになる。純粋なCFGの場合はtrue_cfg_scaleを使う必要があり、guidance_scaleを使用する場合はuncond推論が行われないようである。

video = pipe(prompt=prompt, negative_prompt="", height=320, width=512, 
             num_frames=num_frames, num_inference_steps=step, 
             true_cfg_scale=6.0, guidance_scale=1.0, callback_on_step_end=l1_callback).frames[0]

以下はstep=50,shift=1.0,num_frames=1の時のcfgを有効にした場合のstep, timestep, latent_model_input, noise_pred(cond_pred), noise_pred(do_cfg), latentsのabs().mean()変化である。
latentsが次のstepでのlatent_model_inputと等しい。初期stepで変化の大きいのはnoise_pred(cond_pred)かnoise_pred(do_cfg)のように思える。

[[0.0000000e+00 1.0000000e+03 7.9687500e-01 1.0390625e+00 3.4843750e+00 7.8515625e-01]
 [1.0000000e+00 9.8000000e+02 7.8515625e-01 1.5000000e+00 1.4062500e+00 7.6953125e-01]
 [2.0000000e+00 9.6000000e+02 7.6953125e-01 1.4609375e+00 1.4765625e+00 7.5781250e-01]
 [3.0000000e+00 9.4000000e+02 7.5781250e-01 1.4687500e+00 1.6015625e+00 7.4609375e-01]
 [4.0000000e+00 9.2000000e+02 7.4609375e-01 1.4765625e+00 1.6640625e+00 7.3437500e-01]
 [5.0000000e+00 9.0000000e+02 7.3437500e-01 1.4843750e+00 1.6015625e+00 7.2265625e-01]
 [6.0000000e+00 8.8000000e+02 7.2265625e-01 1.4843750e+00 1.6171875e+00 7.1484375e-01]
 [7.0000000e+00 8.6000000e+02 7.1484375e-01 1.4765625e+00 1.6171875e+00 7.0703125e-01]
 [8.0000000e+00 8.4000000e+02 7.0703125e-01 1.4765625e+00 1.6328125e+00 7.0312500e-01]
 [9.0000000e+00 8.2000000e+02 7.0312500e-01 1.4765625e+00 1.6328125e+00 6.9921875e-01]
 [1.0000000e+01 8.0000000e+02 6.9921875e-01 1.4765625e+00 1.6406250e+00 6.9531250e-01]
 [1.1000000e+01 7.8000000e+02 6.9531250e-01 1.4765625e+00 1.6250000e+00 6.9531250e-01]
 [1.2000000e+01 7.6000000e+02 6.9531250e-01 1.4765625e+00 1.6250000e+00 6.9531250e-01]
 [1.3000000e+01 7.4000000e+02 6.9531250e-01 1.4765625e+00 1.6093750e+00 6.9531250e-01]
 [1.4000000e+01 7.2000000e+02 6.9531250e-01 1.4765625e+00 1.6093750e+00 6.9921875e-01]
 [1.5000000e+01 7.0000000e+02 6.9921875e-01 1.4765625e+00 1.6171875e+00 7.0312500e-01]
 [1.6000000e+01 6.8000000e+02 7.0312500e-01 1.4687500e+00 1.6015625e+00 7.1093750e-01]
 [1.7000000e+01 6.6000000e+02 7.1093750e-01 1.4687500e+00 1.6015625e+00 7.1875000e-01]
 [1.8000000e+01 6.4000000e+02 7.1875000e-01 1.4687500e+00 1.6171875e+00 7.2656250e-01]
 [1.9000000e+01 6.2000000e+02 7.2656250e-01 1.4765625e+00 1.6171875e+00 7.3828125e-01]
 [2.0000000e+01 6.0000000e+02 7.3828125e-01 1.4765625e+00 1.6093750e+00 7.5000000e-01]
 [2.1000000e+01 5.8000000e+02 7.5000000e-01 1.4765625e+00 1.6250000e+00 7.6562500e-01]
 [2.2000000e+01 5.6000000e+02 7.6562500e-01 1.4765625e+00 1.6250000e+00 7.7734375e-01]
 [2.3000000e+01 5.4000000e+02 7.7734375e-01 1.4765625e+00 1.6250000e+00 7.9296875e-01]
 [2.4000000e+01 5.2000000e+02 7.9296875e-01 1.4843750e+00 1.6328125e+00 8.0859375e-01]
 [2.5000000e+01 5.0000000e+02 8.0859375e-01 1.4843750e+00 1.6250000e+00 8.2812500e-01]
 [2.6000000e+01 4.8000000e+02 8.2812500e-01 1.4765625e+00 1.6250000e+00 8.4375000e-01]
 [2.7000000e+01 4.6000000e+02 8.4375000e-01 1.4843750e+00 1.6328125e+00 8.6328125e-01]
 [2.8000000e+01 4.4000000e+02 8.6328125e-01 1.4921875e+00 1.6562500e+00 8.8671875e-01]
 [2.9000000e+01 4.2000000e+02 8.8671875e-01 1.4843750e+00 1.6484375e+00 9.0625000e-01]
 [3.0000000e+01 4.0000000e+02 9.0625000e-01 1.4843750e+00 1.6328125e+00 9.2968750e-01]
 [3.1000000e+01 3.8000000e+02 9.2968750e-01 1.4921875e+00 1.6406250e+00 9.4921875e-01]
 [3.2000000e+01 3.6000000e+02 9.4921875e-01 1.4843750e+00 1.6484375e+00 9.7265625e-01]
 [3.3000000e+01 3.4000000e+02 9.7265625e-01 1.4921875e+00 1.6484375e+00 1.0000000e+00]
 [3.4000000e+01 3.2000000e+02 1.0000000e+00 1.4921875e+00 1.6406250e+00 1.0234375e+00]
 [3.5000000e+01 3.0000000e+02 1.0234375e+00 1.4921875e+00 1.6484375e+00 1.0468750e+00]
 [3.6000000e+01 2.8000000e+02 1.0468750e+00 1.4921875e+00 1.6328125e+00 1.0703125e+00]
 [3.7000000e+01 2.6000000e+02 1.0703125e+00 1.4921875e+00 1.6406250e+00 1.1015625e+00]
 [3.8000000e+01 2.4000000e+02 1.1015625e+00 1.4921875e+00 1.6484375e+00 1.1250000e+00]
 [3.9000000e+01 2.2000000e+02 1.1250000e+00 1.4843750e+00 1.6484375e+00 1.1562500e+00]
 [4.0000000e+01 2.0000000e+02 1.1562500e+00 1.4843750e+00 1.6484375e+00 1.1796875e+00]
 [4.1000000e+01 1.8000000e+02 1.1796875e+00 1.4843750e+00 1.6328125e+00 1.2109375e+00]
 [4.2000000e+01 1.6000000e+02 1.2109375e+00 1.4765625e+00 1.6250000e+00 1.2343750e+00]
 [4.3000000e+01 1.4000000e+02 1.2343750e+00 1.4687500e+00 1.6171875e+00 1.2656250e+00]
 [4.4000000e+01 1.2000000e+02 1.2656250e+00 1.4609375e+00 1.6015625e+00 1.2890625e+00]
 [4.5000000e+01 1.0000000e+02 1.2890625e+00 1.4531250e+00 1.5859375e+00 1.3203125e+00]
 [4.6000000e+01 8.0000000e+01 1.3203125e+00 1.4453125e+00 1.5781250e+00 1.3437500e+00]
 [4.7000000e+01 6.0000000e+01 1.3437500e+00 1.4218750e+00 1.5625000e+00 1.3750000e+00]
 [4.8000000e+01 4.0000000e+01 1.3750000e+00 1.3984375e+00 1.5390625e+00 1.4062500e+00]
 [4.9000000e+01 2.0000000e+01 1.4062500e+00 1.3671875e+00 1.5390625e+00 1.4296875e+00]]

ここでL1のcallback関数をnoise_pred(cond_pred)の差分推移を求めた場合、以下。

image.png

l1_log = []
pre_latent = None

def l1_callback(pipe, step, timestep, callback_kwargs):
    global l1_log
    global pre_latent

    latent_input    = callback_kwargs['latent_model_input'].detach().clone()
    noise_pred_init = callback_kwargs['noise_pred_init'].detach().clone()
    noise_pred      = callback_kwargs['noise_pred'].detach().clone()
    latent_last     = callback_kwargs['latents'].detach().clone()

    if pre_latent is None:
        pre_latent = noise_pred_init
    else:
        l1 = (noise_pred_init-pre_latent).abs().mean()/pre_latent.abs().mean()
        pre_latent = noise_pred_init
        print(f"Step {step}, Timestep {timestep}, L1 = {l1.item():.6f}")
        l1_log.append((step, timestep.item(), l1.item()))

    return callback_kwargs

ここでL1のcallback関数をnoise_pred(do_cfg)の差分推移を求めた場合、以下。
image.png

l1_log = []
pre_latent = None

def l1_callback(pipe, step, timestep, callback_kwargs):
    global l1_log
    global pre_latent

    latent_input    = callback_kwargs['latent_model_input'].detach().clone()
    noise_pred_init = callback_kwargs['noise_pred_init'].detach().clone()
    noise_pred      = callback_kwargs['noise_pred'].detach().clone()
    latent_last     = callback_kwargs['latents'].detach().clone()

    if pre_latent is None:
        pre_latent = noise_pred
    else:
        l1 = (noise_pred-pre_latent).abs().mean()/pre_latent.abs().mean()
        pre_latent = noise_pred
        print(f"Step {step}, Timestep {timestep}, L1 = {l1.item():.6f}")
        l1_log.append((step, timestep.item(), l1.item()))

    return callback_kwargs

なお、ここではpipeline_hunyuan_video.pyを以下のように修正している。

pipeline_hunyuan_video.py
                latent_model_input = latents.to(transformer_dtype)
                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                timestep = t.expand(latents.shape[0]).to(latents.dtype)

                noise_pred_init = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=prompt_embeds,
                    encoder_attention_mask=prompt_attention_mask,
                    pooled_projections=pooled_prompt_embeds,
                    guidance=guidance,
                    attention_kwargs=attention_kwargs,
                    return_dict=False,
                )[0]

                if do_true_cfg:
                    neg_noise_pred = self.transformer(
                        hidden_states=latent_model_input,
                        timestep=timestep,
                        encoder_hidden_states=negative_prompt_embeds,
                        encoder_attention_mask=negative_prompt_attention_mask,
                        pooled_projections=negative_pooled_prompt_embeds,
                        guidance=guidance,
                        attention_kwargs=attention_kwargs,
                        return_dict=False,
                    )[0]
                    noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred_init - neg_noise_pred)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    #for k in callback_on_step_end_tensor_inputs:
                    #    callback_kwargs[k] = locals()[k]
                    callback_kwargs['latent_model_input'] = latent_model_input
                    callback_kwargs['noise_pred_init'] = noise_pred_init
                    callback_kwargs['noise_pred'] = noise_pred
                    callback_kwargs['latents'] = latents
                    
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    #latents = callback_outputs.pop("latents", latents)
                    #prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

ここで以下noise_pred_init(cond_pred)の変化推移について調べる。

L1推移

念のためコードを再掲しておく。またpipeline_hunyuan_video.pyの変更も必要。

import torch
from diffusers import HunyuanVideoTransformer3DModel, HunyuanVideoPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
import numpy as np
import matplotlib.pyplot as plt

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)

pipe = HunyuanVideoPipeline.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    transformer=transformer,
    torch_dtype=torch.float16,
)

pipe.vae.enable_tiling()
pipe.enable_sequential_cpu_offload()
pipe.enable_model_cpu_offload()

def l1_callback(pipe, step, timestep, callback_kwargs):
    global l1_log
    global pre_latent

    latent_input    = callback_kwargs['latent_model_input'].detach().clone()
    noise_pred_init = callback_kwargs['noise_pred_init'].detach().clone()
    noise_pred      = callback_kwargs['noise_pred'].detach().clone()
    latent_last     = callback_kwargs['latents'].detach().clone()

    if pre_latent is None:
        pre_latent = noise_pred_init
    else:
        l1 = (noise_pred_init-pre_latent).abs().mean()/pre_latent.abs().mean()
        pre_latent = noise_pred_init
        print(f"Step {step}, Timestep {timestep}, L1 = {l1.item():.6f}")
        l1_log.append((step, timestep.item(), l1.item()))

    return callback_kwargs

fig, ax = plt.subplots(1,3,figsize=(18,5))
fig.suptitle('L1 step[10,20]')

for step in [10,20]:
    for shift in [7.0]:
        for num_frames in [1]: 
            l1_log = []
            pre_latent = None
            pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
            prompt = "a photo of an astronaut riding a horse on Mars."
            video = pipe(prompt=prompt, negative_prompt="", height=320, width=512, num_frames=num_frames, num_inference_steps=step, true_cfg_scale=6.0, guidance_scale=1.0,  callback_on_step_end=l1_callback).frames[0]
            export_to_video(video, "sample00.mp4", fps=15)
            l1_log = np.array(l1_log)
            print(l1_log, l1_log.shape)

            ax[0].scatter(l1_log[:,0]/float(step),l1_log[:,2],s=20,label='step=%d, shift=%2.1f, frame=%d' % (step, shift, num_frames))
            ax[0].set_xlabel('Step')
            ax[0].set_ylabel('L1')
            ax[0].legend()
            ax[1].scatter(1000-l1_log[:,1],l1_log[:,2],s=20,label='step=%d, shift=%2.1f, frame=%d' % (step, shift, num_frames))
            ax[1].set_xlabel('1000-Timestep')
            ax[1].set_ylabel('L1')
            ax[1].legend()
            ax[2].plot(l1_log[:,0]/float(step),1000-l1_log[:,1],label='step=%d, shift=%2.1f, frame=%d' % (step, shift, num_frames))
            ax[2].set_xlabel('Step')
            ax[2].set_ylabel('1000-Timestep')
            ax[2].legend()
plt.savefig('sample000.png')
#plt.show()

step変化

グラフ結果は拡大して見てください。
推論stepが大きいほど最大L1は小さい。
sample001.png
sample002.png
sample003.png
sample004.png
sample005.png

shift値変化

shift値が大きいほど最大L1は小さい。
sample006.png
sample007.png
sample008.png

num_frame変化

フレーム長が長いほど最初のL1が大きい。
sample009.png

解像度

解像度との関連性はよく分からない。
sample0000.png

shift関数考察

t=1、t=0の微分式での値を変えず、tが中間の時の勾配を増やすにはどうすればいいか?

f(t)=\frac{s\cdot t}{1+(s-1)\cdot t}
\frac{df(t)}{dt}=\frac{s}{(1+(s-1)\cdot t)^2}=
\left\{\begin{array}{ll} 1/s & (t = 1) \\
s & (t = 0)
\end{array}
\right.

先に微分式を適当に定義し、これの積分を考えてshift関数を作成したい。
例えば、以下のようなt=1,0で同じ微分値を得るためには

\frac{dg(t)}{dt}=\frac{s^{2}}{(1+(s-1)\cdot t)^{3}}=
\left\{\begin{array}{ll} 1/s & (t = 1) \\
s & (t = 0)
\end{array}
\right.

となる。

g(t)=\frac{s^2\cdot t\cdot (1+(s-1)/2\cdot t)}{(1+(s-1)\cdot t)^2}

しかし、この式は$g(0)=0$だが、$g(1)\neq 1$なのでtimestepのshift関数の定義は満たさない。

g(1)=\frac{s+1}{2}

この係数で割ってしまうとt=1、t=0の微分値がsによって下がってしまうため(まあ勾配が下がるのはいいとしてsの値に加速度的になる)この方針では駄目である。

また別の関数を考えたときローレンツ分布的な微分関数を与えたとき、シフト関数は元の関数の積分から求まるのでローレンツ分布の累積分布関数型の変形等しくなる。

\frac{dh(t)}{dt}=\frac{s}{(1+(s^2-1)\cdot t^2)}=
\left\{\begin{array}{ll} 1/s & (t = 1) \\
s & (t = 0)
\end{array}
\right.
h(t)=\frac{s \cdot \arctan(\sqrt{s^2-1} \cdot t)}{\sqrt{s^2-1}}

任意の分布の累積分布関数をx,yの縮尺を変形、平行移動、反転したものがtimestepのshift関数になるのだろうか?

tanh型  sigmoid型(累積分布関数)  分布  
$sgn(x)$ $\frac{1}{2}(1+sgn(x))$ ディラックのデルタ関数
$tanh(x/2)$ $\frac{1}{2}(1+tanh(x/2))=\frac{1}{1+e^{-x}}=\sigma(x)$ ロジスティック分布
$erf(x/\sqrt{2})$ $\frac{1}{2}(1+erf(x/\sqrt{2}))$ ガウス分布
$\frac{x}{\sqrt{1+x^2}}$ $\frac{1}{2}(1+\frac{x}{\sqrt{1+x^2}})$ $\nu=2$のt分布
$\frac{2}{\pi}gd(\frac{\pi}{2}x)$ $\frac{1}{2}(1+\frac{2}{\pi}gd(\frac{\pi}{2}x))$ 双曲線正割分布
$\frac{2}{\pi}\arctan(x)$ $\frac{1}{2}(1+\frac{2}{\pi}\arctan(x))$ ローレンツ分布
$Hardtanh(x)$ $Hardsigmoid(x)$ 一様分布
$softsign(x)$ $\frac{1}{2}(1+\frac{x}{1+abs(x)})$ $\frac{1}{2(1+abs(x))^2}$

また、推論step終端での勾配の大きさを最大にするようにとるのではなく、少し小さな値に落としたほうが良いのか。

image.png

その他所感:

テスト時間の都合上、解像度やフレームはかなり低めでテストしたため結果のL1推移は参考程度である。また、初期ノイズ、prompt、cfgスケール、SamplingMethod(Euler,dpm++,UniPC)などでも変わる可能性がある。
また与えられた動画解像度とフレーム長と推論step数から最適なflow-shift関数の係数を決定できるだろうか。TeaCacheなどの推論step最適化問題においてshift関数の値の影響が特に議論されてない理由は何故だろうか。

まとめ

HunyuanVideoのL1推移におけるstep数、shift値や動画フレーム数の依存性を調べた。
最初にlatentのL1推移を調べたのだが、このL1変化分はtimestepの変化分に比例する。これは拡散モデルで勾配を求め、その(勾配)×(timestepの変化分)でEuler法で1次近似するからだと思われる。
次に、拡散モデルのcond出力の変化分を調べたら論文にあるU字型のL1推移が見られた。このU字型のL1値はstep数やshift値、動画フレーム数に依存するようである。shift値が大きければstep数が小さくても最大L1は小さくできるように見える。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?