2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

SDXLの量子化3

Posted at

前回の記事に引き続いてSDXLのモデル重みをfloat16からInt型に変換してfloat16に戻す。
推論はfloat16のままなので推論速度は変わらない。
今回の記事内ではInt型に変換したデータを書き出して保存する事を目指す。

コード

前提として予めunet/diffusion_pytorch_model.fp16.safetensorsにSDXLのunet成分のstate_dict(モデル重みの辞書型)がsafetensorsで出力されているとする。
safetensorsのデータは純Tensor型のみでlist型などのデータは保持できない。(保存時、全てのデータをTensor型にしないといけない)

変換の精度は
Conv2DレイヤーはNF5(32)型にした。(全体の13%)
linearレイヤーはNF5(32),NF4.58(24),NF4(16),NF3.58(12),NF2.81(7)型にそれぞれした。(全体の86%)

モデルサイズはlinearレイヤーがNF5(32)で1/3、NF3.58(12)で1/4、NF2.81(7)で1/5程度に削減できる。これはあくまで一パラメータ当たりの使用bit数を削減しているだけでSSD-1B(Segmind Stable Diffusion Model) のようなパラメータ数自体を削減している訳ではない。
また、LCM-XLやSDXL Turboのようなinference_stepを削減するタイプでもない。

モデルの更新日時を見てわかるようにモデル読み込み、Int型に変換、保存、保存したInt型の読み込み、float16型に戻す、unet重み置き換え、画像生成に一時間程度かかっており読み込み時間が増えるので正直全く実用的ではない。
また、pipeline読み込み時にunet=Noneが指定できないため結局のところ元のunet重み無しでは動かないという本末転倒な結果でもある。

image.png

from diffusers import DiffusionPipeline
from safetensors.torch import load_file, save_file
import torch
import numpy as np
import bz2

def make_q(split_num=16):
    x = np.linspace(-1, 1, split_num+1)
    Qx = torch.erfinv(torch.tensor(x, dtype=torch.float32) * 0.96)
    Qx = Qx.detach().numpy().copy()

    q = np.zeros(split_num)
    q = (Qx[:-1]+Qx[1:])/2
    q = q/np.max(q)

    q = list(q)
    q.append(1.3)
    return q

def quantization(M, split_num):
    shift_num = 50
    q = make_q(split_num)

    block_size1 = 64
    block_size2 = 64*256
    orig_shape = M.shape
    orig_length = len(M.flatten())
    if orig_length%block_size2==0:
        extend_M = M.flatten()
    else:
        extend_M = torch.cat((M.flatten(), torch.zeros(block_size2 - orig_length%block_size2)))
    extend_length = len(extend_M)

    Mmax1 = []
    for i in range(extend_length//block_size1):
        Mmax1.append(torch.max(torch.abs(extend_M[block_size1*i:block_size1*(i+1)])))
    Mmax1 = torch.stack(Mmax1, dim=0)
    Mmax1 = Mmax1.repeat_interleave(block_size1)

    Mmax2 = []
    for i in range(extend_length//block_size2):
        Mmax2.append(torch.max(torch.abs(extend_M[block_size2*i:block_size2*(i+1)])))
    Mmax2 = torch.stack(Mmax2, dim=0)
    Mmax2 = Mmax2.repeat_interleave(block_size2)
    Mmax2_save = Mmax2[::block_size2].clone()
    
    c1 = torch.round(255.0 * Mmax1 / Mmax2).to(torch.uint8)
    c1_save = c1[::block_size1].clone()
    c1 = c1.to(torch.float32).clamp(min=1e-12)
    
    M2 = extend_M / (Mmax2 * c1 / 255.0)
    for i in range(split_num):
        M2 = torch.where(M2 <= (q[i]+q[i+1])/2, shift_num+i, M2)
    M2 = torch.round(M2).to(torch.uint8)
    M2 = M2[:orig_length]
    M2 = bz2.compress(M2.detach().numpy().copy().tobytes().decode("ascii").encode("utf-8"))
    M2_save = torch.from_numpy(np.array([M2]).view(np.uint8)).clone()

    quant_param = torch.tensor([orig_length, split_num, shift_num, block_size1, block_size2], dtype=torch.int32)
    #print('before_quant:', orig_length*2, ', after_quant:', len(c1_save) + len(Mmax2_save)*4 + len(M2_save), end=', ')
    #print('quant_rate:', (len(c1_save) + len(Mmax2_save)*4 + len(M2_save))/(orig_length*2))
    return  orig_shape, quant_param, c1_save, Mmax2_save, M2_save

def save_quantization_model(model_path, output_path, linear_split_num, conv_split_num):
    state_dict = load_file(model_path, device="cpu")
    new_state_dict = {}
    for name, param in state_dict.items():
        if len(param.shape)==2 and not('emb' in name):
            print(name, param.shape, type(param))
            M = param.clone().to(torch.float32)
            orig_shape, quant_param, c1, c2, M2 = quantization(M, linear_split_num)
            
            new_state_dict[name] = torch.tensor([0]) # dummy_data
            new_state_dict[name+'.quantization.orig_shape'] = torch.tensor(orig_shape)
            new_state_dict[name+'.quantization.param'] = quant_param.clone()
            new_state_dict[name+'.quantization.c1'] = c1.clone()
            new_state_dict[name+'.quantization.c2'] = c2.clone()
            new_state_dict[name+'.quantization.M'] = M2.clone()

        elif len(param.shape)==4:
            print(name, param.shape, type(param))
            M = param.clone().to(torch.float32)
            orig_shape, quant_param, c1, c2, M2 = quantization(M, conv_split_num)

            new_state_dict[name] = torch.tensor([0]) # dummy_data
            new_state_dict[name+'.quantization.orig_shape'] = torch.tensor(orig_shape)
            new_state_dict[name+'.quantization.param'] = quant_param.clone()
            new_state_dict[name+'.quantization.c1'] = c1.clone()
            new_state_dict[name+'.quantization.c2'] = c2.clone()
            new_state_dict[name+'.quantization.M'] = M2.clone()

        else:
            print(name, param.shape, type(param))
            new_state_dict[name] = param.clone()

    save_file(new_state_dict, output_path)


def load_quantization_model(path):
    state_dict = load_file(path, device="cpu")
    for name, param in state_dict.items():
        if not('quantization' in name) and len(param)==1:
            orig_shape = torch.Size(list(state_dict[name+'.quantization.orig_shape']))
            print(name, orig_shape, type(orig_shape))
            orig_length, split_num, shift_num, block_size1, block_size2 = tuple(state_dict[name+'.quantization.param'])
            c1 = state_dict[name+'.quantization.c1']
            c2 = state_dict[name+'.quantization.c2']
            M = state_dict[name+'.quantization.M']

            q = make_q(split_num)

            c1 = c1.repeat_interleave(block_size1).to(torch.float32).clamp(min=1e-12)[:orig_length]
            c2 = c2.repeat_interleave(block_size2)[:orig_length]
            M = M.detach().numpy().copy().tobytes()
            M = bz2.decompress(M)
            M = torch.from_numpy(np.array([M]).view(np.uint8)-int(shift_num)).clone()
            M2 = M.to(torch.float16)
            for i in range(split_num):
                M2 = torch.where(i==M, q[i], M2)
            M2 = M2 * (c2 * c1 / 255.0)
            M2 = M2[:orig_length].reshape(orig_shape).to(torch.float16)

            state_dict[name] = M2.clone()

    state_dict = {name: param for name, param in state_dict.items() if not('.quantization' in name)}
    return state_dict

model_id = './stable-diffusion-xl-base-1.0/'
model_path = './stable-diffusion-xl-base-1.0/unet/diffusion_pytorch_model.fp16.safetensors'
output_img_path = './quant3/'

prompts = ['An ice sculpture is made with the text "Happy Holidays". Christmas decorations in the bacground. Dslr photo.',
           'An origami of a monkey dressed as a monk riding a bike on a mountain.',
           'A storefront with "Text to Image" written on it.',
           'Greek statue of a man tripping over a cat.',
           'A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says Welcome Friends!',
           'A map of the United States made out of sushi. It is on a table next to a glass of red wine.',
           'an armchair in the shape of an avocado',
           'A teddybear on a skateboard in Times Square.']
seed = 42
generator = torch.Generator(device="cuda")

pipe = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True, torch_dtype=torch.float16, variant="fp16").to("cuda")
pipe.enable_model_cpu_offload()

for split_num in [32, 24, 16, 12, 7]:
    output_path = './stable-diffusion-xl-base-1.0/unet/quant_model_%d.safetensors' % (split_num)

    save_quantization_model(model_path, output_path, linear_split_num=split_num, conv_split_num=32)

    state_dict = load_quantization_model(output_path)
    pipe.unet.load_state_dict(state_dict)
    del state_dict

    for i, prompt in enumerate(prompts):
        generator = generator.manual_seed(seed)
        image = pipe(prompt=prompt, generator=generator).images[0]
        image.save(output_img_path + "img_quant_NF4(%d)_%02d.png" % (split_num, i))

結果:

concat08.png

concat09.png

torchのDynamic Quantization

torchの量子化の記事を読んだが自分のやってきたことはDynamic Quantizationという分野の結果に近い。
これは$y=W*x$のとき、モデル重み$W$のみをint化して入力$x$や出力$y$はfloat16のままである。

このDynamic QuantizationをSDXLじゃなくてStableDiffusionで試してみる。
確かにunetの重みをtorch.float16からtorch.qint8型に変換できているのが確認できるが、この時の推論速度はfp16から早くなっていない。しかもintをfloatに変換しているのか推論が始まるまでに時間が掛かる。

pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.enable_model_cpu_offload()

unet_fp16 = pipe.unet
unet_int8 = torch.quantization.quantize_dynamic(unet_fp16, {torch.nn.Linear}, dtype=torch.qint8)
pipe.unet = unet_int8

print('fp16_weight=', unet_fp16.time_embedding.linear_1.weight)
print('int8_weight=', unet_int8.time_embedding.linear_1.weight())
print('pipe_weight=', pipe.unet.time_embedding.linear_1.weight())
-------------------------------------------
fp16_weight= Parameter containing:
tensor([[-0.0025,  0.0008,  0.0041,  ..., -0.0163,  0.0121, -0.0105],
        [ 0.0007, -0.0001,  0.0060,  ..., -0.0087,  0.0062, -0.0029],
        [ 0.0025, -0.0033,  0.0018,  ..., -0.0150, -0.0041,  0.0027],
        ...,
        [-0.0047, -0.0016,  0.0041,  ..., -0.0106,  0.0102,  0.0080],
        [-0.0020, -0.0027, -0.0004,  ...,  0.0086,  0.0011,  0.0030],
        [-0.0026,  0.0061,  0.0078,  ...,  0.0080,  0.0078,  0.0159]],
       dtype=torch.float16, requires_grad=True)
int8_weight= tensor([[-0.0019,  0.0000,  0.0037,  ..., -0.0168,  0.0112, -0.0112],
        [ 0.0000,  0.0000,  0.0056,  ..., -0.0093,  0.0056, -0.0037],
        [ 0.0019, -0.0037,  0.0019,  ..., -0.0149, -0.0037,  0.0019],
        ...,
        [-0.0056, -0.0019,  0.0037,  ..., -0.0112,  0.0093,  0.0075],
        [-0.0019, -0.0019,  0.0000,  ...,  0.0093,  0.0019,  0.0037],
        [-0.0019,  0.0056,  0.0075,  ...,  0.0075,  0.0075,  0.0168]],
       size=(1280, 320), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0018650428391993046,
       zero_point=0)
pipe_weight= tensor([[-0.0019,  0.0000,  0.0037,  ..., -0.0168,  0.0112, -0.0112],
        [ 0.0000,  0.0000,  0.0056,  ..., -0.0093,  0.0056, -0.0037],
        [ 0.0019, -0.0037,  0.0019,  ..., -0.0149, -0.0037,  0.0019],
        ...,
        [-0.0056, -0.0019,  0.0037,  ..., -0.0112,  0.0093,  0.0075],
        [-0.0019, -0.0019,  0.0000,  ...,  0.0093,  0.0019,  0.0037],
        [-0.0019,  0.0056,  0.0075,  ...,  0.0075,  0.0075,  0.0168]],
       size=(1280, 320), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0018650428391993046,
       zero_point=0)

一方、$y=W*x$で$W$だけではなく$x$と$y$もint型に変換するのをStatic Quantizationという。
Unetの入口と出口でfloat16=>int、int=>float16に変換して、unet内部はint型で行うので高速化するのだろうか。拡散モデルであればunetが繰り返し(step数だけ)呼ばれるので途中のunetのfloat型への変換は省略できるかもしれないが、この場合代わりにCFGやsampling method(DPM++とか)が計算できないだろう。

まとめ

SDXLモデル容量を減らすコードを示した。
実用面では読み込み時間、Int型からfloat16型への変換時間が増え、全く実用的ではない。
モデルサイズが削減できるが、追加モデル変換読み込み時間>>Download時間なのでメリットはない。linearがNF2.81(7)でほぼ1/5のモデルサイズでも画像を書き出せるのが意外とはいえ、promptによっては劣化が目立つ。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?