4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

SDXLの量子化4

Posted at

はじめに

BitNetの論文を読んで1bit量子化、1.58bit量子化をSDXLで試してみる。
1bit量子化はモデル重みを1bitに量子化するという内容だが、どうして精度が維持できてるのか最初聞いたとき疑問だった。

前回の記事と同様に$W$のfloat16重みを1bit量子化したあと再度float16に戻すので、この実装では推論は高速化しない。 また、自分がやったのは単なる重み変換だけで追加学習はしていない。

なお、論文を読んだ自分の解釈で実装しているので、この解釈が合っているかどうかに関しては保証しません。

前回記事:

図で説明する量子化

雑にだが1bit量子化が前回の量子化とどう違うかを図で説明する。
1番目の図で一般的な8bit量子化は重みの最大値を取ってこれを等間隔$-127~127$に分割する。最も初歩的な量子化だが基本的に外れ値に非常に弱い。
2番目の図でBlockwiseでは本来のWより小さなBlock内の最大値を取って外れ値を取り除こうとする。これは量子化時にgroup_sizeなりblock_sizeを定義する。
3番目の図でNF4では範囲を非等間隔に分割することによって各bitを効率的に使う。この分割間隔は正規分布の累積分布関数(シグモイド的)を縦軸に等間隔に分割したときの$x$座標に等しい。これは要するに逆誤差関数(erfinv)である。

さて、今までの量子化は$W$の絶対値最大値($absmax(W)$)を重要視していた。
1bit量子化は$W$の絶対値平均値($mean(abs(W))$)を係数として用い、それ以上の値を$1$で$clip$する。また、$W-\alpha$を量子化する。

image.png

\alpha=\frac{1}{nm}\sum_{ij}W_{ij}
\beta=\frac{1}{nm}\sum_{ij}|W_{ij}-\alpha|
\tilde W = sign(W-\alpha)=2.0*(clip(round(\frac{W-\alpha}{\beta}+0.5),0.0,1.0)-0.5)\cdots(when bit1)
\tilde W = clip(round(\frac{W-\alpha}{\beta}),-1.0,1.0)= clip(round(\frac{W-\alpha}{\beta}+1.0),0.0,2.0)-1.0\cdots(when bit1.58)
y=Wx=\tilde W x × \beta

因みに、$W$が正規分布($randn$)で$mean(abs(W))=0.798$になった。
これは$\frac{2\sigma}{\sqrt{2\pi}}=0.798\sigma$である事を推定させる。
結局、この絶対値の平均の大きさは分散に比例するならL1距離じゃなくて分散を採用した方がいいのではと思った。
一方、$absmax$と違って$W$の外れ値をほとんど無視できる点、分散に比例するので$n,m$の大きさに依存しない点はいいと思う。

1bit量子化のテストコード

論文に書かれている$x,W$の量子化した結果の量子化前の元のベクトルとのcos類似度を確認したコードを示す。
$x$は768次元で$W$は(360,768)次元、出力$y$は360次元とする。

cos類似度0~2の結果より従来の$W$の量子化で量子化bitの小さいほど性能が劣化している。ただし、量子化によってどれだけ性能劣化するかは$W$の平均値に依存する。$W$の平均値がゼロならcos類似度の量子化の劣化は比較的小さい。
一方、bit1量子化の系統はclip関数を使って、絶対値平均値以上の値は$1,-1$にclipする。
bit1.58量子化は$-1,0,1$の三値に量子化する。
bit1量子化は$max(abs(W))$の代わりに$mean(abs(W))$の係数による変化が大きい。

この類推を使ってbit1量子化の手法を使って任意分割数$s$の関数を立てる。
分割数$s=2$の場合$[-0.5,0.5]$ではなく$sign$関数によって$[-1,1]$としているので分割数2の場合のみ係数が2倍違うが、それ以外では任意分割数に対応できる。

\tilde W = clip(round(\frac{W-\alpha}{\beta}+\frac{s-1.0}{2}),0.0,(s-1.0))-\frac{s-1.0}{2}

さらに、論文は$x,W$をともに量子化してるが、$W$の量子化のみでも結果は変わらなかった。また$x$の正規化は重要なようでx = LN(x)を抜いた時、量子化後のcos類似度は再現しなかった。

まとめると$W$の平均がゼロでない場合、演算の性能低下が顕著でbit1量子化は$W$から平均値$\alpha$を引いている点と係数に絶対値平均を用いていることで大きい値をclipしている点で従来の量子化と異なる。
$W-\alpha$を量子化している点に関しては、仮に$x$を正規化してれば$x_{mean}=0$かつ$W_{mean} x=(\sum_{i,j} W)x=(\sum_{i}{W})x_{mean}=0$の為、$(W-W_{mean})(x-x_{mean})=(W-W_{mean})x=Wx$となって量子化の逆変換に係数$\alpha$は無視できる(?)。

image.png

import numpy as np

x = np.random.randn(768) * 3.2 + 1.5
W = np.random.randn(360, 768) * 0.3 - 7.2

epsilon = 10e-9
Qb = 127.0

def LN(x):
    xmean = np.mean(x)
    xstd  = np.std(x)
    zscore = (x-xmean)/(xstd+epsilon)
    return zscore

def Quant(x, Qb=127.0): # {-127, -126, ... , 126, 127}
    gamma = np.max(np.abs(x)) / Qb
    tilde_x = np.clip(np.round(x / gamma), -Qb, Qb)
    return gamma, tilde_x

def Quant2(W): # {-1, 1}
    alpha = np.mean(W)
    beta = np.mean(np.abs(W-alpha))
    tilde_W = np.sign(W - alpha)
    return beta, tilde_W

def Quant3(W): # {-1, 0, 1}
    alpha = np.mean(W)
    beta = np.mean(np.abs(W-alpha))
    tilde_W = np.clip(np.round((W-alpha)/(beta+epsilon)), -1.0, 1.0)
    return beta, tilde_W

def Quant4(W, split_num=4.0): # {-1.5, -0.5, 0.5, 1.5}, {-2, -1, 0, 1, 2}
    alpha = np.mean(W)
    beta = np.mean(np.abs(W-alpha))
    tilde_W = np.clip(np.round((split_num-1.0)/2.0 + (W-alpha)/(beta+epsilon)), 0.0, (split_num-1.0))
    tilde_W -= (split_num-1.0)/2.0
    if split_num==2.0:
        beta *= 2.0
    return beta, tilde_W

x = LN(x)
y = W @ x
print(x.shape, W.shape, y.shape)

gamma, tilde_x = Quant(x)
beta0, tilde_W0 = Quant(W, Qb=127.0)
beta1, tilde_W1 = Quant(W, Qb=31.0)
beta2, tilde_W2 = Quant(W, Qb=15.0)
beta3, tilde_W3 = Quant2(W)
beta4, tilde_W4 = Quant3(W)
beta5, tilde_W5 = Quant4(W, split_num=2.0)
beta6, tilde_W6 = Quant4(W, split_num=3.0)
beta7, tilde_W7 = Quant4(W, split_num=4.0)
beta8, tilde_W8 = Quant4(W, split_num=8.0)

print(tilde_x[:10])
print(tilde_W0[:10,0])
print(tilde_W1[:10,0])
print(tilde_W2[:10,0])
print(tilde_W3[:10,0])
print(tilde_W4[:10,0])
print(tilde_W5[:10,0])
print(tilde_W6[:10,0])
print(tilde_W7[:10,0])
print(tilde_W8[:10,0])


tilde_y0 = (tilde_W0 * beta0) @ x 
tilde_y1 = (tilde_W1 * beta1) @ x 
tilde_y2 = (tilde_W2 * beta2) @ x 
tilde_y3 = tilde_W3 @ tilde_x * beta3 * gamma
tilde_y4 = tilde_W4 @ tilde_x * beta4 * gamma
tilde_y5 = tilde_W5 @ tilde_x * beta5 * gamma
tilde_y6 = tilde_W6 @ tilde_x * beta6 * gamma
tilde_y7 = tilde_W7 @ tilde_x * beta7 * gamma
tilde_y8 = tilde_W8 @ tilde_x * beta8 * gamma
tilde_y9 = (tilde_W8 * beta8) @ x

'''
print('orignal_y=', y[:10])
print('quant_y0=', tilde_y0[:10])
print('quant_y1=', tilde_y1[:10])
print('quant_y2=', tilde_y2[:10])
print('quant_y3=', tilde_y3[:10])
print('quant_y4=', tilde_y4[:10])
print('quant_y5=', tilde_y5[:10])
print('quant_y6=', tilde_y6[:10])
print('quant_y7=', tilde_y7[:10])
'''

print('cos_sim0=', np.sum(y*tilde_y0)/np.sqrt(np.sum(y*y)*np.sum(tilde_y0*tilde_y0)), ', bit8_W_only')
print('cos_sim1=', np.sum(y*tilde_y1)/np.sqrt(np.sum(y*y)*np.sum(tilde_y1*tilde_y1)), ', bit6_W_only')
print('cos_sim2=', np.sum(y*tilde_y2)/np.sqrt(np.sum(y*y)*np.sum(tilde_y2*tilde_y2)), ', bit4_W_only')
print('cos_sim3=', np.sum(y*tilde_y3)/np.sqrt(np.sum(y*y)*np.sum(tilde_y3*tilde_y3)), ', bit1')
print('cos_sim4=', np.sum(y*tilde_y4)/np.sqrt(np.sum(y*y)*np.sum(tilde_y4*tilde_y4)), ', bit1.58')
print('cos_sim5=', np.sum(y*tilde_y5)/np.sqrt(np.sum(y*y)*np.sum(tilde_y5*tilde_y5)), ', bit1')
print('cos_sim6=', np.sum(y*tilde_y6)/np.sqrt(np.sum(y*y)*np.sum(tilde_y6*tilde_y6)), ', bit1.58')
print('cos_sim7=', np.sum(y*tilde_y7)/np.sqrt(np.sum(y*y)*np.sum(tilde_y7*tilde_y7)), ', bit2')
print('cos_sim8=', np.sum(y*tilde_y8)/np.sqrt(np.sum(y*y)*np.sum(tilde_y8*tilde_y8)), ', bit3')
print('cos_sim9=', np.sum(y*tilde_y9)/np.sqrt(np.sum(y*y)*np.sum(tilde_y9*tilde_y9)), ', bit3_W_only')

SDXLの1bit量子化の実行

以下のコードで、embedprojto_kto_vto_qto_outffconv2Dto_のレイヤーをそれぞれ、1bit量子化、1.58bit量子化、2bit量子化、3bit量子化の結果を示す。
前回の結論としてConv2Dレイヤーはlinearレイヤーよりも量子化に弱いという特徴があったが、今回の1bit量子化ではto_kto_vto_qto_outffは1bit量子化がまだ可能で、projは1bit量子化が難しかった。様々なレイヤーで量子化可能bitが異なるというのは下記のMixed-Bitで述べられている。

前の議論と違って量子化の対象が$W-\alpha$ではないが、これは大部分で入れても入れなくても$\alpha$が小さいので実は変わらなかったうえprojに関しては却って悪化したので除いた。

LLMではTransformerを用いるのでどの全結合層も1bit量子化に問題ないがUnetであるSDXLにおいては少なくとも1bit量子化可能な層は限られる。また、馬は出力されると言ってもモデル全体の1割くらいを1bit量子化しても出力の劣化は無視できなく(拡大して見て貰えば)、複数のレイヤーを同時に1bit量子化すると結局破綻してしまうので、1bit量子化をノーリスクで選べるわけではない。

concat08.png

image.png

from diffusers import DiffusionPipeline
import torch
import numpy as np

model_id = './stable-diffusion-xl-base-1.0/'
output_path = './quant_1bit/'

prompt = "a photo of an astronaut riding a horse on mars"
seed = 42
generator = torch.Generator(device="cuda")
generator = generator.manual_seed(seed)

epsilon = 10e-9
split_num = 8.0
case = 6

pipe = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True, torch_dtype=torch.float16, variant="fp16").to("cuda")
pipe.enable_model_cpu_offload()

total_param = 0
quant_param = 0
for name, param in pipe.unet.named_parameters():
    total_param += param.data.numel()
    if case==0:
        condition = (len(param.size())==2 and 'emb' in name) # 0.010
    if case==1:
        condition = (len(param.size())==2 and not('emb' in name) and ('proj' in name)) # 0.328
    if case==2:
        condition = (len(param.size())==2 and ('to_k' in name)) # 0.106
    if case==3:
        condition = (len(param.size())==2 and ('to_v' in name)) # 0.106
    if case==4:
        condition = (len(param.size())==2 and ('to_q' in name)) # 0.080
    if case==5:
        condition = (len(param.size())==2 and ('to_out' in name)) # 0.080
    if case==6:
        condition = (len(param.size())==2 and ('ff' in name) and not('proj' in name)) # 0.160
    if case==7:
        condition = len(param.size())==4 # Conv2D 0.130
    if case==8:
        condition = (len(param.size())==2 and ('to_' in name)) # 0.372

    if condition:
        print(name, param.size(), param.data.dtype)
        quant_param += param.data.numel()
        orig_data_type = param.data.dtype

        M = param.data.clone().to(torch.float32)
        alpha = torch.mean(M)
        beta = torch.mean(torch.abs(M-alpha))
        if split_num==2.0:
            beta *= 2.0
        M = (torch.clip(torch.round(M / (beta+epsilon) + (split_num-1.0)/2.0), 0.0, split_num-1.0)).to(torch.uint8)
        M2 = (M.to(torch.float32)-(split_num-1.0)/2.0) * beta
        param.data = M2.to(orig_data_type)
        del M, M2

torch.cuda.empty_cache()
print('split_num=', split_num, ', case=', case)
print('quant_param/total_param=', quant_param/total_param)
generator = generator.manual_seed(seed)
image = pipe(prompt=prompt, generator=generator).images[0]
image.save(output_path + "img_1bit(%d)_case%d.png" % (int(split_num), case))

BitNet+NF

上記の拡張はbit1とbit1.58のBitNetの類推から等間隔に分割数を増やした場合を勝手に追加しているが、NF4のように非等間隔に分割数を増やしてみる。
具体的には$\beta$から$\sigma$を求めて(正規分布では$\beta=\frac{2\sigma}{\sqrt{2\pi}}$として)、$-3\sigma~3\sigma$をNFで非等間隔に分割する。
色々試したのだが、少なくともprojconvの$\alpha\not=0.0$のレイヤーではこの量子化手法は前回のBlockwise-NFを特に上回ることは無かった。同じ分割数での比較対象では劣化しかない。
to_kto_vto_qto_outffの量子化に関してはBitNet+NFでも問題なかったが、単に量子化に対する余地の存在だけかもしれない。Blockwise-NFの代わりに採用するメリットは各ブロックごとの絶対値最大値を求める計算が軽くなるが、絶対値最大値付近の正確性は落ちてしまうだろう。
これを使ったときto_kto_vto_qto_outffに対しては分割数12(bit3.58)=>分割数8(bit3)に一部の層で下げられ、SDXLのモデル全体の0.532に対して0.58bit下げられ、全体で0.31bitほど下げられうる。
concat07.png

def make_q(split_num=16):
    x = np.linspace(-1, 1, split_num+1)
    Qx = []
    for i in list(x):
        Qx.append(torch.erfinv(torch.tensor(i/1.0*0.96, dtype=torch.float32)))
    Qx = torch.stack(Qx)
    Qx = Qx.to('cpu').detach().numpy().copy()

    q = np.zeros(split_num)
    for i in range(split_num):
        q[i] = (Qx[i]+Qx[i+1])/2
    q = q/np.max(q)

    q = list(q)
    q.append(1.5)
    return q

def Int8_quant(M):
    c = 127.0 / torch.max(torch.abs(M))
    M2 = (torch.round(M * c)).to(torch.int8)
    M2 = M2.to(torch.float16) / c
    return M2

def Blockwise_NF_quant(M, split_num):
    q = make_q(split_num)

    block_size = 64
    orig_shape = M.shape
    orig_length = len(M.flatten())
    if orig_length%block_size==0:
        extend_M = M.flatten()
    else:
        extend_M = torch.cat((M.flatten(), torch.zeros(block_size - orig_length%block_size)))
    extend_length = len(extend_M)

    Mmax = []
    for i in range(extend_length//block_size):
        Mmax.append(torch.max(torch.abs(extend_M[block_size*i:block_size*(i+1)])))
    Mmax = torch.stack(Mmax, dim=0)
    Mmax = Mmax.repeat_interleave(block_size)

    M2 = extend_M / Mmax
    for i in range(split_num):
        M2 = torch.where(M2 <= (q[i]+q[i+1])/2, 50+i, M2)
    M2 = torch.round(M2)

    M2 = M2.to(torch.int8)
    M3 = M2.to(torch.float16)
    for i in range(split_num):
        M3 = torch.where(i+50==M2, q[i], M3)
    M3 = M3 * Mmax
    M3 = M3[:orig_length].reshape(orig_shape).to(torch.float16)
    return M3

def Bitlinear_NF_quant(M, split_num=12, sigma_num=3.0):
    q = make_q(split_num)
    alpha = torch.mean(M)
    beta = torch.mean(torch.abs(M-alpha))
    sigma = beta * float(np.sqrt(2.0 * 3.14159265)/2.0 * sigma_num)
     
    M = torch.clip(M/sigma, -1.0, 1.0)
    for i in range(split_num):
        M = torch.where(M <= (q[i]+q[i+1])/2, 50+i, M)
    M = M.to(torch.int8)
    M2 = M.to(torch.float16)
    for i in range(split_num):
        M2 = torch.where(i+50==M, q[i], M2)
    M2 = M2 * sigma
    return M2
...
for name, param in pipe.unet.named_parameters():
    orig_data_type = param.data.dtype
    M = param.data.clone().to(torch.float32)
    if len(param.size())==2 and 'emb' in name: # emb
        M2 = Int8_quant(M)
    elif len(param.size())==2 and not('emb' in name) and ('proj' in name): # proj
        M2 = Blockwise_NF_quant(M, 12)
    elif len(param.size())==2 and ('to_k' in name): # to_k
        M2 = Bitlinear_NF_quant(M, split_num=8, sigma_num=3.0)
    elif len(param.size())==2 and ('to_v' in name): # to_v
        M2 = Bitlinear_NF_quant(M, split_num=8, sigma_num=3.0)
    elif len(param.size())==2 and ('to_q' in name): # to_q
        M2 = Bitlinear_NF_quant(M, split_num=8, sigma_num=3.0)
    elif len(param.size())==2 and ('to_out' in name): # to_out
        M2 = Bitlinear_NF_quant(M, split_num=8, sigma_num=3.0)
    elif len(param.size())==2 and ('ff' in name) and not('proj' in name): # ff
        M2 = Bitlinear_NF_quant(M, split_num=8, sigma_num=3.0)
    elif len(param.size())==4: # conv
        M2 = Blockwise_NF_quant(M, 32)
    else:
        M2 = Int8_quant(M)
    param.data = M2.to(orig_data_type)

まとめ:

1bit量子化をSDXLにて確認したが、劣化は大きかった。1bit量子化で精度を出すためには後述する追加学習が必要なのだろうか。
勝手に分割数を増やした場合の拡張を与え、to_kto_vto_qto_outffにおいてはこれでも問題ないのを確認した。しかし、projconvに対しては従来(Blockwise-NF)より良くなかった。

その他:

xの量子化について

論文中の$x$の量子化の式は下記だがclipではなくround(四捨五入関数)だと思う。このままだと$\tilde{x}$は範囲が[-127.0,127.0]なだけで整数値にはならないし、当然小数も取りうるので量子化に相当しないと思う。
image.png

追加学習(quantization-aware training:量子化適応学習)

それとも以下のgithub実装のようにLLAMA-16bit=>モデルのレイヤー置換=>BitLLAMA Mixed 16bit=>BitLLAMAモデル追加学習=>bit1型変換とするのか。
SDXLでも単なる1bit重み変換だけでなく、量子化済み重みで画像が生成できるような追加学習(float16)が必要なのだろうか。

現実をマインクラフトに変換する例でたとえるなら全ての物質(例えば目覚まし時計)をAmazonのスカスカな段ボールに入れ(レイヤー置換、現実の目覚まし時計もブロック化した目覚まし時計も両方持つ)、私たちはこの段ボールの扱いを段ボールの中身と同じようにふるまう(追加学習)、その上で本物の目覚まし時計をブロックに変換(bit1変換)しても目覚まし時計の扱いに性能低下はない。
追加学習をやらず最初に目覚まし時計をブロックに変換した場合、現実の目覚まし時計との差が大きく現実世界にはなじまない(性能低下が酷い)。しかし、目覚まし時計がブロック状で生産される国で生まれた人間にはそういうものとして特に違和感を感じないだろう。

Straight-through estimatorについて

重み$W$に$sign$なりで量子化すると同時に勾配も量子化してしまう。detach()等を使えば勾配を量子化せずに量子化重みが使えるのだろう。これはforward時はW3_quant、backward時はW3を使うと解釈できる。

import torch

x = torch.randn(128)
y_true = torch.randn(32)

torch.manual_seed(0)
W1 = torch.randn((32,128), requires_grad=True)
y_pred1 = W1 @ x
loss = torch.nn.MSELoss()(y_pred1, y_true)
loss.backward()

torch.manual_seed(0)
W2 = torch.randn((32,128), requires_grad=True)
alpha = torch.mean(W2)
beta = torch.mean(torch.abs(W2-alpha))
W2_quant = torch.sign(W2-alpha) * beta
y_pred2 = W2_quant @ x
loss2 = torch.nn.MSELoss()(y_pred2, y_true)
loss2.backward()

torch.manual_seed(0)
W3 = torch.randn((32,128), requires_grad=True)
alpha = torch.mean(W3)
beta = torch.mean(torch.abs(W3-alpha))
W3_quant = (torch.sign(W3-alpha) * beta - W3).detach() + W3
y_pred3 = W3_quant @ x
loss3 = torch.nn.MSELoss()(y_pred3, y_true)
loss3.backward()


print('y_pred1 =', y_pred1[:10])
print('y_pred2 =', y_pred2[:10])
print('y_pred3 =', y_pred3[:10])
print()
print('W1.grad =', W1.grad)
print('W2.grad =', W2.grad)
print('W3.grad =', W3.grad)
-------------------------------------------------------
y_pred1 = tensor([15.6954, -1.6001,  8.4295,  8.3844,  8.5929, 18.6297,  0.7706, -6.3625,
        11.4120, 10.2367], grad_fn=<SliceBackward0>)
y_pred2 = tensor([19.3574, -2.4435,  1.9252,  1.7094, 14.4517,  9.3322, -4.5798,  3.9876,
         5.4392,  7.9817], grad_fn=<SliceBackward0>)
y_pred3 = tensor([19.3574, -2.4435,  1.9252,  1.7094, 14.4517,  9.3322, -4.5798,  3.9876,
         5.4392,  7.9817], grad_fn=<SliceBackward0>)

W1.grad = tensor([[ 0.2808, -2.1831, -0.5253,  ..., -0.9004, -0.5513,  0.4446],
        [-0.0155,  0.1204,  0.0290,  ...,  0.0497,  0.0304, -0.0245],
        [ 0.1574, -1.2242, -0.2945,  ..., -0.5049, -0.3091,  0.2493],
        ...,
        [-0.0554,  0.4312,  0.1037,  ...,  0.1778,  0.1089, -0.0878],
        [ 0.0607, -0.4716, -0.1135,  ..., -0.1945, -0.1191,  0.0960],
        [-0.0613,  0.4765,  0.1146,  ...,  0.1965,  0.1203, -0.0970]])
W2.grad = tensor([[-0.0514, -0.0514, -0.0514,  ...,  0.0525,  0.0525,  0.0525],
        [-0.0514, -0.0514, -0.0514,  ...,  0.0525,  0.0525,  0.0525],
        [-0.0514, -0.0514, -0.0514,  ...,  0.0525, -0.0514, -0.0514],
        ...,
        [ 0.0525,  0.0525,  0.0525,  ..., -0.0514,  0.0525,  0.0525],
        [-0.0514, -0.0514, -0.0514,  ..., -0.0514, -0.0514,  0.0525],
        [ 0.0525, -0.0514, -0.0514,  ..., -0.0514, -0.0514,  0.0525]])
W3.grad = tensor([[ 0.3416, -2.6560, -0.6390,  ..., -1.0954, -0.6707,  0.5409],
        [-0.0295,  0.2293,  0.0552,  ...,  0.0946,  0.0579, -0.0467],
        [ 0.0494, -0.3844, -0.0925,  ..., -0.1585, -0.0971,  0.0783],
        ...,
        [-0.0507,  0.3944,  0.0949,  ...,  0.1627,  0.0996, -0.0803],
        [ 0.0316, -0.2457, -0.0591,  ..., -0.1013, -0.0621,  0.0500],
        [-0.1310,  1.0187,  0.2451,  ...,  0.4201,  0.2572, -0.2075]])

Group Quantizationについて

groupの値が小さければ余計に保持するパラメータ量に対して議論してない。例えば$\beta_{type}=float32$で$group=64$なら$32bit/64=0.5bit$分グループ化した$\beta$を保持するのに費やすはずである。
$randn$の乱数group個を1000回発生させた時、平均と偏差は以下である。
$max(abs(W))$は平均を下げられて量子化が効率よくなる。一方、$mean(abs(W))$は偏差が小さくなるだけなのでgroup化の恩恵は小さいように思う。$量子化bit*(1-\frac{mean_{group}}{mean_N})+量子化bit*\frac{std}{mean}-使用bit$は$mean(abs(W))$ではほぼ誤差みたいなものである。
1bit量子化においては採用して改善するメリットが0.005bitもない。

image.png

4
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?