0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

CustomDiffusionの実装の謎

Posted at

CustomDiffusionの実装コードを読んでいて理解しにくかった点に関して記録する。

CustomDiffusionの基本

以下の図がCustomDiffusionを良く示している。
image.png
つまり、TextEncoderにV*という追加token(例えば<cat-toy>とか)を追加するTextual Inversionを行い。拡散モデルのUnetのレイヤーの内、CrossAttentionの$K,V$の全結合重み(Unet全体の5%のパラメーター)のみを学習する。

LoRAとの違いは
①CustomDiffusionは重みをそのまま学習。そのため破滅的忘却を起こすことがある。LoRAは元モデルの重みを維持し、変化差分をLowRankAdaptorで学習する。
②CustomDiffusionはUnetのCrossAttentionの$K,V$の全結合重みのみ、LoRAは$Q,K,V,out$のCrossAttentionの全ての全結合重み。場合によってはUnetのSelf-AttentionやTextEncoderのSelf-Attention、1x1のConv計算(全結合重みと同等)を変化させることもある。

class LoRAAttnProcessor(nn.Module):
    def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
...
        self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
        self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
        self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)

これだけ見るとCustomDiffusionはLoRAよりもずっと簡単でTextual Inversionに毛が生えた程度に思えるが、実際にコードを読んでみるといくつか謎がある。

detach

の以下

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        if crossattn:
            detach = torch.ones_like(key)
            detach[:, :1, :] = detach[:, :1, :]*0.
            key = detach*key + (1-detach)*key.detach()
            value = detach*value + (1-detach)*value.detach()

または、

の以下

            if crossattn:
                modifier = torch.ones_like(k)
                modifier[:, :1, :] = modifier[:, :1, :]*0.
                k = modifier*k + (1-modifier)*k.detach()
                v = modifier*v + (1-modifier)*v.detach()

最初これを読んでこの構造部分が何をどう変化させるのかよく分からなかった。
key、valueの値自体は特に変わってないからである。
これを理解するために以下の様なテストコードを書いた。

import torch

x = torch.rand(5)
W = torch.ones(5, 3, requires_grad=True)
y_pred = x @ W
y_true = torch.rand(3)

loss = torch.nn.MSELoss()(y_pred, y_true)
loss.backward()
print('W.grad =', W.grad)

#x = torch.rand(5)
W2 = torch.ones(5, 3, requires_grad=True)
y_pred = x @ W2
#y_true = torch.rand(3)

detach = torch.ones_like(y_pred)
detach[:1] = detach[:1] * 0.0
y_pred = detach*y_pred + (1-detach)*y_pred.detach()
print(detach)
print(1-detach)

loss = torch.nn.MSELoss()(y_pred, y_true)
loss.backward()
print('W2.grad =', W2.grad)
-----------------------------------------
W.grad = tensor([[0.3849, 0.3652, 0.2658],
        [1.5737, 1.4931, 1.0869],
        [1.0702, 1.0155, 0.7392],
        [0.9520, 0.9033, 0.6576],
        [0.6186, 0.5869, 0.4272]])
tensor([0., 1., 1.])
tensor([1., 0., 0.])
W2.grad = tensor([[0.0000, 0.3652, 0.2658],
        [0.0000, 1.4931, 1.0869],
        [0.0000, 1.0155, 0.7392],
        [0.0000, 0.9033, 0.6576],
        [0.0000, 0.5869, 0.4272]])

detachをy_predに掛けるとWの勾配の一部がゼロになった。ある重みの内、一部だけを更新を無効にするのに必要なのかと思った。
問題はこれがどの文脈(どの重みを固定するために)挿入されているのかは不明である。
例えば78トークン中の先頭トークンの寄与する重みの変化の無効化なのだろうか。
またはtextual inversionの文脈で追加embedding重み以外を固定させるために必要なのだろうか。
image.png

set_attn_processor

最近、見直してみると以前と少し実装が変わっている。以前はnew_forwardを定義してた気がするがset_attn_processorを使用している。

def create_custom_diffusion(unet, freeze_model):
...
    unet.set_attn_processor(CustomDiffusionAttnProcessor())
    return unet

最近、この潮流があるのかDiffuserのLoRA実装もこのset_attn_processorを使っている。

        lora_attn_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]

            lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

        unet.set_attn_processor(lora_attn_procs)

他にprompt++(XTI)とかxformerの高速化とかもset_attn_processorを使用する。
set_attn_processorを使うとinference時にtrain時と同じnew_forwardをいちから作らなくてよい。反面、複数の異なるattn_processorを適用出来たりはしないので複数set_attn_processorを使う場合の実装が気にかかる。例えばLoRAとxformerを両方適用するためにLoRAXFormersAttnProcessorなんていうAttnProcessorを新規に作成している。組み合わせによって無数に新規AttnProcessorを作ることにならないか心配である。

from .attention_processor import (  # noqa: F401
    Attention,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor2_0,
    LoRAAttnProcessor,
    LoRALinearLayer,
    LoRAXFormersAttnProcessor,
    SlicedAttnAddedKVProcessor,
    SlicedAttnProcessor,
    XFormersAttnProcessor,
)

SVD(特異値分解)

compress(圧縮)のコードに以下の様なSVD(特異値分解)を利用したデータサイズ圧縮のコードがある。
しかしながら、特異値分解が何物なのかが分からないからこのコードを見ても意味が理解できない。

def compress(delta_ckpt, ckpt, diffuser=False, compression_ratio=0.6, device='cuda'):
...
    for name in layers:
        if 'to_k' in name or 'to_v' in name:
            W = st[name].to(device)
            Wpretrain = pretrained_st[name].clone().to(device)
            deltaW = W-Wpretrain

            u, s, vt = torch.linalg.svd(deltaW.clone())

            explain = 0
            all_ = (s).sum()
            for i, t in enumerate(s):
                explain += t/(all_)
                if explain > compression_ratio:
                    break

            compressed_st[compressed_key][f'{name}'] = {}
            compressed_st[compressed_key][f'{name}']['u'] = (u[:, :i]@torch.diag(s)[:i, :i]).clone()
            compressed_st[compressed_key][f'{name}']['v'] = vt[:i].clone()
        else:
            compressed_st[compressed_key][f'{name}'] = st[name]

特異値分解とは

例えば$768$次元を$320$次元に変形する全結合重みの行列を考えればそれは$(768,320)$のサイズの行列である。入力次元が$768$で、出力次元が$320$である。
特異値分解は例えば$rank=32$なら$U=(768,32)$、$\Gamma=(32,32)$、$V^{T}=(32,320)$とおける。
ここで$U,V$は直交行列、つまり$UU^{T}=I$、$VV^{T}=I$。また$\Gamma$は対角要素$\sigma_1,...\sigma_{32}$を取る対角行列である。この対角要素の事を一般に特異値と呼ぶ。特異値の二乗を固有値と呼ぶ。
特異値の総和を求めた時、この大きい順に並んでいる特異値の部分総和が特異値の全体総和の60%を超える時の$rank$を探す。そして$u=U *\Gamma,v=V^{T}$という二種の重みを作成する。
0.png

この計算は実の所、以下のリンク先の内容と近い。
リンク先は任意のファインチューニングモデル(と学習元のモデルとの差分)からLoRA重みを生成している。
ただし、リンク先は$rank$の次元の大きさは固定で事前に決定しており、上述したコードでは特異値(大きい順に並んでいる)の部分総和と全体総和の比率から$rank$を決定している。

CustomDiffusionのcompress(圧縮)はFineTuneモデルからLoRAを生成する実装と近い。

Multi-Concept Optimization

LoRAでは二種類のLoRA重みをモデルにマージする方法は比較的簡単である。
例えばモデルの元重み$W_0$、一個目のLoRA重み$U_1,V_1,Scale_1$、二個目のLoRA重み$U_2,V_2,Scale_2$とする時、合成重みは単純な線形和で示される。

W=W_0+Scale_1*U_{1}*V_{1}^{T}+Scale_2*U_{2}*V_{2}^{T}

一方、別々に学習した二種類のCustomDiffusion重みをマージする方法は複雑である。
モデルの元重み$W_0$、一個目のCustomDiffusion重み$W_1$、二個目のCustomDiffusion重み$W_2$とする時

W=W_0+(W_1-W_0)+(W_2-W_0)\\
=W_1+(W_2-W_0)

とは出来ない。なぜならCustomDiffusionの追加TIのプロンプト変化を考慮していないためである。
一般的なプロンプト$C_{reg}$、一個目のCustomDiffusionの追加コンセプト(例えば<cat-toy>)のプロンプト$C_{1}$、二個目のCustomDiffusionの追加コンセプトのプロンプト$C_{2}$とする。ここで$C$はプロンプトをTextEncoderに通した後の次元である。
ここでマージ後の重みを$W$とする時、合成重み$W$は以下の三式を満たす必要がある。

WC_{reg}=W_0C_{reg}\\
WC_{1}=W_1C_{1}\\
WC_{2}=W_2C_{2}

また仮に$W=W_0+ΔW_1+ΔW_2$となる時、

(W_0+ΔW_1)C_{reg}=W_0C_{reg}\\
(W_0+ΔW_2)C_{reg}=W_0C_{reg}\\
(W_0+ΔW_1)C_{1}=W_1C_{1}\\
ΔW_2C_{1}=0\\
(W_0+ΔW_2)C_{2}=W_2C_{2}\\
ΔW_1C_{2}=0

ただし$(W_0+ΔW_1)\neq W_1, (W_0+ΔW_2)\neq W_2$。
これを解くのにラグランジュの未定乗数法(method of Lagrange multipliers)を用いるとあるがこの辺から何を書いてあるのかよく理解できていない。
image.png
image.png

ここでコードの引数は以下の様であると考えられる。
$K=C_{reg},V=W_0C_{reg},K_{target}=[C_1,C_2],V_{target}=[W_1C_1,W_2C_2],W=W_0$
LU分解(lu_factor,lu_solve)やら逆行列計算(torch.linalg.inv)が用いられているため正直、途中から自分にはほぼ意味不明である。

def gdupdateWexact(K, V, Ktarget1, Vtarget1, W, device='cuda'):
    input_ = K
    output = V
    C = input_.T@input_
    d = []
    lu, piv = lu_factor(C.cpu().numpy())
    for i in range(Ktarget1.size(0)):
        sol = lu_solve((lu, piv), Ktarget1[i].reshape(-1, 1).cpu().numpy())
        d.append(torch.from_numpy(sol).to(K.device))

    d = torch.cat(d, 1).T

    e2 = d@Ktarget1.T
    e1 = (Vtarget1.T - W@Ktarget1.T)
    delta = e1@torch.linalg.inv(e2)

    Wnew = W + delta@d
    lambda_split1 = Vtarget1.size(0)

    input_ = torch.cat([Ktarget1.T, K.T], dim=1)
    output = torch.cat([Vtarget1, V], dim=0)

    loss = torch.norm((Wnew@input_).T - output, 2, dim=1)
    print(loss[:lambda_split1].mean().item(), loss[lambda_split1:].mean().item())

    return Wnew

正直あまり分かってないが、このマージ手法ではpromptに従属する成分を上手くマージできるのかもしれない。
ただし、逆説的にいうならこれが上手く行くのはCustomDiffusionがprompt入力を受け取る全結合$K,V$のみしか学習してないから上手く行くのであって、$K,V$以外の他の全結合においても学習するLoRAやDreamBoothでこのマージ手法は意味があるのか不明である。

CustomDiffusionの後続

ELITE

画像を使ってTI重みを作成するのとControlNetのように画像をエンコードした入力を足す。
前者は学習時間を下げる効果がある。これに似たアイデアはInSTE4TSuTIOTIPALAVRAなどに見られる。
image.png
image.png

SVDiff

CutmixでAttentionMaskを使う。
また、畳み込み部分もSVD(特異値分解)でパラメータ削減できる。
image.png
image.png

Cones

image.png
image.png

まとめ:

CustomDiffusionのコードを読んでて特異値分解やLU分解やらが出てくる。それに対する自分の理解をメモした。

参考:

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?