0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Timestep Samplingについて

Last updated at Posted at 2025-08-29

musubi-tunerなどで使われるTimestepサンプリング分布を調べた。

一般にはDiTで推論(Inference)やデノイズの事をサンプリングと呼ぶが、この事をこの記事では指していない。この記事ではLoRAの学習時において、事前に決まった分布から学習Timestepを決定するのを指して便宜上Timestep Samplingとしている。
なお推論過程では1.0(ノイズ)→0.0(動画、画像)である。

Uniform Sampling

timestep[0.0,1.0]間の単純な一様分布である。
ただWan2.2において一点注意しないといけないのはhigh/lowノイズモデルの複合学習と個別学習によって学習強度が違う点である。

複合学習においてはtimestep_boundary=0.875なので学習stepの1/8しかhigh noiseモデルで学習されない。

image.png

個別学習においては学習stepが1:1の場合はtimestep密度は1:7となる。

image.png

個人的には、Wan2.2ではUniform Samplingでは複合学習ではhigh noiseモデルの学習割合が少なすぎ、個別学習においては学習stepが1:1だとhigh noiseモデルの学習割合が多すぎると思われる。

Sigmoid Sampling

ソースコード見るとshift=1.0のShift Samplingと等しい。

Shift Sampling

ガウス分布にシグモイドを通してshift関数をかける。
推論時のTimestep密度と等しいのでこれに合わせるか、または元のモデルの学習のshift値に近いと良いと思われる。また、高いshift値で学習させた場合、動きが悪くなることが報告され、Wan2.1ではshift=3.0程度が推奨値である。一方、Wan2.2では現在(2025/08)は推奨値不明である。
(Wan2.2ではモデルが分かれているため学習stepが仮にhighモデルに寄りすぎたとしても、学習後にhigh/lowノイズモデルで適用lora強度を異なる値で調整するのが可能なので重要視されてないのだろうか?)

import matplotlib.pyplot as plt
import torch

for i, shift in enumerate([1.0, 3.0, 7.0]):
    x = torch.randn(40000)
    t = torch.sigmoid(x)
    t = (t * shift) / (1 + (shift - 1) * t)
    t = t.to('cpu').detach().numpy()
    plt.hist(t, bins=100, density=True, alpha=0.4, label="shift_sampling s=%1.f" % (shift))

plt.legend()
plt.show()

image.png

Wan2.2におけるshift値とhigh/low モデル割合

複合学習にてtimestep_boundary=0.875におけるhigh/lowモデルの学習stepの選択割合を示したい。
Uniform Sampligの割合(12.5%)に近いのはshift=2.0である。
high/lowの学習stepが1:1の個別学習に近いのはshift=7.0である。
shift=3.0ではhighモデルの選択割合は20%くらいである。これは個別学習でhighモデルの学習stepがlowモデルの学習stepの1/4程度の場合に近い。
逆にshift=1.0においてはhigh noiseモデルを元に学習stepのたった2.6%しか学習されない。
またshift=12.0においては学習stepの7割がhigh noiseモデルに偏る。

import torch
for shift in [1.0, 2.0, 3.0, 4.0, 5.0, 7.0, 9.0, 12.0]:
    boundary = 0.875

    logits_norm = torch.randn(1000000)
    t = torch.sigmoid(logits_norm)
    t = (t * shift) / (1 + (shift - 1) * t)
    x1 = torch.sum(torch.where(t > boundary, 1.0, 0.0)).item()/1000000 * 100
    x2 = torch.sum(torch.where(t < boundary, 1.0, 0.0)).item()/1000000 * 100
    print('shift=%1.1f, high=%2.1f, low=%2.1f' % (shift,x1,x2))
-----------------------------------
shift=1.0, high=2.6, low=97.4
shift=2.0, high=10.5, low=89.5
shift=3.0, high=19.8, low=80.2
shift=4.0, high=28.8, low=71.2
shift=5.0, high=36.8, low=63.2
shift=7.0, high=50.0, low=50.0
shift=9.0, high=59.9, low=40.1
shift=12.0, high=70.5, low=29.5  

Logit-norm Sampling:

一般に二個のパラメータm,s(またはμ,σ)で示される。
正規分布の確率密度関数はnorm.pdf(x, loc=0, scale=1)で示される。
image.png

from scipy.special import logit
from scipy.stats import norm

for m, s in [(0.0,1.0), (-0.5, 0.6), (0.5, 0.6), (3.0,0.5), (3.0,1.5)]:
    t = np.linspace(10e-10, 1-10e-10, 1000)
    y = norm.pdf(logit(t), loc=m, scale=s)/(t*(1-t))
    plt.plot(t, y, label="logit-norm (m,s)=(%1.1f,%1.1f)" % (m,s))

plt.legend()
plt.show()

image.png

StableDiffusion3:
image.png

Log-SNR Sampling:

image.png

Logit-normサンプリングと同一と思っていたのだが僅かに異なる。
musubi-tunerの実装をさらえば以下の通りである。

eps = 1e-7

for mean, std in [(0.0,2.0), (-1.0, 1.2), (1.0, 1.2), (-6.0,1.0), (-6.0,3.0)]:

    t_uniform = torch.rand(40000)
    t_uniform = torch.clamp(t_uniform, eps, 1.0 - eps)
    term = 2.0 * t_uniform - 1.0
    logsnr = mean + std * np.sqrt(2.0) * torch.erfinv(term)
    t = torch.sigmoid(-logsnr / 2)
    plt.hist(t, bins=100, density=True, alpha=0.4, label="log_snr (mean,std)=(%1.1f,%1.1f)" % (mean,std))

plt.legend()
plt.show()

image.png

Logit-normとlog_snrを比較するとmの符号が反転しているのと、係数がそれぞれ2倍になっている違いがある。

image.png

追記:一様分布のerf_invは正規分布と同じ

[0,1]の一様分布を[-1,1]の一様分布に変換しこれにerf_inv関数をかけたとき、これは単純に正規分布randnから開始するのに等しい。厳密にはepsの分正規分布乱数の値に上限をつけることができるくらい。

eps = 1e-7
t_uniform = torch.rand(40000)
t_uniform = torch.clamp(t_uniform, eps, 1.0 - eps)
term = 2.0 * t_uniform - 1.0
x1 = np.sqrt(2.0) * torch.erfinv(term)

x2 = torch.randn(40000)

plt.hist(x1, bins=100, density=True, alpha=0.4, label="x1(erf_inv(uniform))_sampling")
plt.hist(x2, bins=100, density=True, alpha=0.4, label="x2(randn)_sampling")

plt.legend()
plt.show()

image.png

sigmoid関数 vs erf関数

sigmoid関数ではなくガウス分布の累積分布関数に由来するerf関数を使うべきでは?という思いつき。ただ、理屈的にどちらが正しいのかはよく分からない。
logit関数の逆関数がsigmoid関数であり、erf関数の逆関数がerf_inv関数である。
sigmoid関数はロジスティック分布の累積分布関数で、erf関数は正規分布(ガウス分布)の累積分布関数である。
なお$\frac{\sqrt{3}}{\pi}$がどこから出たのかについてはロジスティック分布の分散の大きさである。

x = torch.linspace(-10.0, 10, 1000)
t = torch.sigmoid(x)
t2 = (torch.erf(x*np.sqrt(1.5)/np.pi)+1)/2
plt.plot(x, t, label="sigmoid")
plt.plot(x, t2, label="erf")

plt.legend()
plt.show()

image.png

image.png

sigmoid関数の代わりにerf関数を用いれば若干の違いがみられる。ただ、logit-normの分布と等しいのはerfinv関数にsigmoid関数をかけた従来の方である。

t = torch.sigmoid(x)
t2 = (torch.erf(x*np.sqrt(1.5)/np.pi)+1)/2
eps = 1e-7

for mean, std in [(0.0,1.0), (-0.5, 0.6), (0.5, 0.6), (3.0,0.5), (3.0,1.5)]:

    t_uniform = torch.rand(40000)
    t_uniform = torch.clamp(t_uniform, eps, 1.0 - eps)
    term = 2.0 * t_uniform - 1.0
    logsnr = -2*mean + 2*std * np.sqrt(2.0) * torch.erfinv(term)
    t = torch.sigmoid(-logsnr / 2) # sigmoid
    plt.hist(t, bins=100, density=True, alpha=0.4, label="log_snr(sigmoid) (mean,std)=(%1.1f,%1.1f)" % (-mean*2,std*2))

    t_uniform = torch.rand(40000)
    t_uniform = torch.clamp(t_uniform, eps, 1.0 - eps)
    term = 2.0 * t_uniform - 1.0
    logsnr = -2*mean + 2*std * np.sqrt(2.0) * torch.erfinv(term)
    t = (torch.erf(-logsnr / 2 *np.sqrt(1.5)/np.pi)+1)/2 # erf
    plt.hist(t, bins=100, density=True, alpha=0.4, label="log_snr(erf) (mean,std)=(%1.1f,%1.1f)" % (-mean*2,std*2))

    t = np.linspace(10e-10, 1-10e-10, 1000)
    y = norm.pdf(logit(t), loc=mean, scale=std)/(t*(1-t))
    plt.plot(t, y, label="logit-norm (m,s)=(%1.1f,%1.1f)" % (mean,std))

plt.legend()
plt.show()

image.png

Shift Sampling vs Log-SNR Sampling

実はShift SamplingとLog-SNR Samplingは簡単な関係で示せ、mean = -2 * np.log(shift), std=2.0ならばこの分布は一致する。反面std=3.0とする分布はshiftサンプリングでは再現出来ない。
StableDiffusion3の論文には以下のような導出が見える。
image.png

image.png

eps = 1e-7

for shift in [1.0, 3.0, 7.0]:
    x = torch.randn(40000)
    t = torch.sigmoid(x)
    t = (t * shift) / (1 + (shift - 1) * t)
    t = t.to('cpu').detach().numpy()
    plt.hist(t, bins=100, density=True, alpha=0.4, label="shift_sampling s=%1.f" % (shift))
    
    mean = -2 * np.log(shift)
    std = 2.0
    t_uniform = torch.rand(40000)
    t_uniform = torch.clamp(t_uniform, eps, 1.0 - eps)
    term = 2.0 * t_uniform - 1.0
    logsnr = mean + std * np.sqrt(2.0) * torch.erfinv(term)
    t = torch.sigmoid(-logsnr / 2)
    plt.hist(t, bins=100, density=True, alpha=0.4, label="log_snr (mean,std)=(%1.1f,%1.1f)" % (mean,std))

plt.legend()
plt.show()

image.png

Wan2.2におけるLog-SNRとhigh/low モデル割合

shift samplingではLog-SNRのstd=2.0の分布しかできないが、Log-SNRではstd=3.0の分布も作成できる。std=3.0ではstd=2.0よりもhigh/lowの割合が均等な方向に移動する。
また、mean=-6.0の分布はshift=20に相当する大きさである。

eps = 1e-7

for std in [2.0, 3.0]:
    for shift in [1.0, 2.0, 3.0, 4.0, 5.0, 7.0, 9.0, 12.0, 20.0]:
        boundary = 0.875
        mean = -2.0 * np.log(shift)
    
        t_uniform = torch.rand(1000000)
        t_uniform = torch.clamp(t_uniform, eps, 1.0 - eps)
        term = 2.0 * t_uniform - 1.0
        logsnr = mean + std * np.sqrt(2.0) * torch.erfinv(term)
        t = torch.sigmoid(-logsnr / 2)
        x1 = torch.sum(torch.where(t > boundary, 1.0, 0.0)).item()/1000000 * 100
        x2 = torch.sum(torch.where(t < boundary, 1.0, 0.0)).item()/1000000 * 100
        print('shift=%1.1f, mean=%1.2f, std=%1.1f  high=%2.1f, low=%2.1f' % (shift,mean,std,x1,x2))
-----------------------------
shift=1.0, mean=-0.00, std=2.0  high=2.6, low=97.4
shift=2.0, mean=-1.39, std=2.0  high=10.5, low=89.5
shift=3.0, mean=-2.20, std=2.0  high=19.8, low=80.2
shift=4.0, mean=-2.77, std=2.0  high=28.9, low=71.1
shift=5.0, mean=-3.22, std=2.0  high=36.9, low=63.1
shift=7.0, mean=-3.89, std=2.0  high=50.0, low=50.0
shift=9.0, mean=-4.39, std=2.0  high=59.9, low=40.1
shift=12.0, mean=-4.97, std=2.0  high=70.5, low=29.5
shift=20.0, mean=-5.99, std=2.0  high=85.3, low=14.7
shift=1.0, mean=-0.00, std=3.0  high=9.8, low=90.2
shift=2.0, mean=-1.39, std=3.0  high=20.3, low=79.7
shift=3.0, mean=-2.20, std=3.0  high=28.6, low=71.4
shift=4.0, mean=-2.77, std=3.0  high=35.5, low=64.5
shift=5.0, mean=-3.22, std=3.0  high=41.1, low=58.9
shift=7.0, mean=-3.89, std=3.0  high=50.1, low=49.9
shift=9.0, mean=-4.39, std=3.0  high=56.5, low=43.5
shift=12.0, mean=-4.97, std=3.0  high=64.0, low=36.0
shift=20.0, mean=-5.99, std=3.0  high=75.8, low=24.2

もしかしたらargs.sigmoid_scaleを調整すれば、shift関数からでもstd=3.0の分布を作れるのかもしれないが…。
追記: sigmoid_scale=1.5ならstd=3.0の分布となる。

Mode Sampling

SD3の論文によれば以下のようにある。
uが[0.0,1.0]の一様分布であればモード関数の密度関数はfの逆関数を微分したものである。

image.png

import matplotlib.pyplot as plt
import numpy as np

u = np.linspace(0.0, 1.0, 1000)
for s in [-0.54, 0, 0.81, 1.29]:
    f_mode = 1 - u - s * (np.cos((np.pi/2*u))**2 - 1 + u)

    density = np.zeros(1000)
    for i in range(0,999):
        density[i] = np.abs((u[i+1] - u[i]) / (f_mode[i+1] - f_mode[i]))
    density[999] = density[998]

    plt.plot(f_mode,density,label='mode_density, s=%1.2f' % s)
plt.legend()
plt.show()

image.png
この結果はSD3論文の図と合う。

image.png
追記: 上記は関数の微小変化の微分を計算しているが、一応こうもかける。

    u = np.linspace(0.0, 1.0, 1000)
    f_mode = 1 - u - s * (np.cos((np.pi/2*u))**2 - 1 + u)
    df_du  = -(1 + s) + (s * np.pi / 2.0) * np.sin(np.pi * u)
    density = 1.0 / np.abs(df_du)
    plt.plot(f_mode,density,label='mode_density, s=%1.2f' % s)

また、Waverの論文でもこのmode関数に触れられている。

image.png

確かにshift関数を適用すればそんな歪んだ感じになる。

u = np.linspace(0.0, 1.0, 1000)
for s in [1.29]:
    f_mode = 1 - u - s * (np.cos((np.pi/2*u))**2 - 1 + u)

    density = np.zeros(1000)
    for i in range(0,999):
        density[i] = np.abs((u[i+1] - u[i]) / (f_mode[i+1] - f_mode[i]))
    density[999] = density[998]
    plt.plot(f_mode,density,label='mode_density, s=%1.2f' % s)

    t = f_mode
    shift = 3
    f_mode = (t * shift) / (1 + (shift - 1) * t)

    density = np.zeros(1000)
    for i in range(0,999):
        density[i] = np.abs((u[i+1] - u[i]) / (f_mode[i+1] - f_mode[i]))
    density[999] = density[998]
    plt.plot(f_mode,density,label='mode_density, s=%1.2f shift=3' % s)

plt.legend()
plt.show()

image.png

flux_shift, qwen_shift

ここからはmusubi-tunerの実装しか見ていない。
この二種類のshiftは以下のような関数があり、中間のToken数の場合を一次fitしている。
1tokenを16x16と考える。(VAEで(1/8,1/8)になり、patchfyで4個のlatentで1tokenとなる)

flux_shiftなら、Image Token数が256tokenの場合、mu=0.5で、4096tokenの場合、mu=1.15である。解像度に直すと256x256の場合mu=0.5でshift=1.65、1024x1024の場合mu=1.15でshift=3.16である。
qwen_shiftなら、Image Token数が256tokenの場合、mu=0.5で、8192tokenの場合、mu=0.9である。解像度に直すと256x256の場合mu=0.5でshift=1.65、1024x2048の場合mu=0.9でshift=2.46である。

def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b
...
h, w = latents.shape[-2:]
# we are pre-packed so must adjust for packed size
if args.timestep_sampling == "flux_shift":
    mu = train_utils.get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
elif args.timestep_sampling == "qwen_shift":
    mu = train_utils.get_lin_function(x1=256, y1=0.5, x2=8192, y2=0.9)((h // 2) * (w // 2))
# def time_shift(mu: float, sigma: float, t: torch.Tensor):
#     return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) # sigma=1.0
shift = math.exp(mu)

以下のように書いて解像度256~1024の画像に対するshift値を求めるとfluxでは2~3のshift値になる。
musubi-tunerでは静止画学習においてbucket学習するので異なるtoken数で学習しない(と思う)ため、解像度に依存する動的shiftを求めるのにうまみがないように感じる。
動画学習なら意味があるのかと思ったがtoken数の計算にフレーム次元の依存性はない。

def get_lin_function(x1 = 256, y1 = 0.5, x2 = 4096, y2 = 1.15, x = 256):
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return m * x + b

h = np.linspace(256, 1024, 1000)
token_num = (h//16) * (h//16)

flux_mu = get_lin_function(x1=256, y1=0.5, x2=4096, y2=1.15, x=token_num)
qwen_mu = get_lin_function(x1=256, y1=0.5, x2=8192, y2=0.90, x=token_num)
flux_shift = np.exp(flux_mu)
qwen_shift = np.exp(qwen_mu)
plt.plot(h,flux_shift,label='flux_shift')
plt.plot(h,qwen_shift,label='qwen_shift')

plt.legend()
plt.show()

image.png

また、Waverの論文でもshift値が解像度によって変わっているのが見える。
image.png

qinglong

以下のような三つ山のサンプリングを取るらしい。
image.png

これは、スタイルの学習、モデルの安定性、ディテールの再現性のバランスを取るために、3つの異なるサンプラーを組み合わせたハイブリッドサンプリング手法です。Style-Friendly SNR Samplerにインスパイアされた実験的な機能です。PR #407 で sdbds (Qing Long) 氏により提案されました。
各学習ステップにおいて、バッチ内の各サンプルに対して、あらかじめ定義された比率に基づき以下のいずれかのサンプラーが選択されます。
flux_shift または qwen_shift (80%): 高解像度モデル向けの標準的なサンプラー。全体的な安定性を重視します。
logsnr (7.5%): Style-Friendlyサンプラー。スタイルの学習を重視します。
logsnr2 (12.5%): 低ノイズ領域(高いlog-SNR値)に焦点を当てたサンプラー。細部のディテール学習を向上させることを目的とします。

何故このようにするかはよく分からないのだが、一点思うことは推論時にeulerサンプルだとtimestepの0-200あたりを通る点が存在する。(dpm++だとこの辺は通らない)
従来shiftだと0-200あたりの学習密度は非常に低い。
また、900-1000にある山は左右でバランスをとる目的で存在するのだろうか。

t = np.linspace(0.0, 1000.0, 1000)
t2 = np.linspace(0.0, 1.0, 10)

plt.scatter(t[:-1]/1000, 1000-t[:-1],label='linear, shift=1.0, step=1000')

scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
scheduler.set_timesteps(10)
timesteps = scheduler.timesteps.tolist()
print("euler,shift=7.0,step=10,timesteps:", timesteps)
print()
plt.scatter(t2, np.array(timesteps),label='euler, shift=7.0, step=10')

scheduler = FlowDPMSolverMultistepScheduler(shift=7.0)
scheduler.set_timesteps(10)
timesteps = scheduler.timesteps.tolist()
print("dpm++,shift=7.0,step=10,timesteps:", timesteps)
print()
plt.scatter(t2, np.array(timesteps),label='dpm++, shift=7.0, step=10')

plt.legend()
plt.show()
----------------------------------
euler,shift=7.0,step=10,timesteps: [1000.0, 982.5910034179688, 961.120849609375, 933.98095703125, 898.5838012695312, 850.4849243164062, 781.3526000976562, 673.5293579101562, 481.9136657714844, 46.75572204589844]

dpm++,shift=7.0,step=10,timesteps: [999, 984, 965, 942, 913, 874, 823, 749, 636, 437]

image.png

まとめ

musubi-tunerにおけるTimestepサンプリングの違いを調べた。
shift以外の違いは良く分かってなかったが、Log-SNRが係数の大きさが違うlogit-normと同じで裾の広いshift関数を記述できることが分かった。

参考:

musubi-tuner:

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?