3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

LoRAにおいてlora_upの重みを固定化出来るか?

Last updated at Posted at 2023-06-08

LoRAはStableDiffusionにおいて少ない重みでオブジェクト学習を行うための手法である。

CustomDiffusionに倣いUnetのある$K,V$重みにおいてLoRAモジュールを計算するとする。
この場合、LoRAは例えば$(768,320)$の全結合重みに対して、LoRAのRank=4とすれば並列に$(768,4)$と$(4,320)$の二種類の全結合重みを付け加える。それぞれの全結合重みをlora_down、lora_upという。
lora_downとlora_upの行列を掛けると$(768,4)×(4,320)=(768,320)$で元の全結合重みの変化差分を低次元の二個の行列で学習する事が出来る。
input_dimは$K,V$重みの場合768(TextEncoderの次元)でそれ以外の場合はoutput_dimと等しい。
output_dimはUnetの層によって異なる。通常は$320,640,1280$のいずれかである
image.png

さて、自分が考えたのはlora_upの重みをembedding重みで固定し、lora_downのみの重みで学習できないかという事を考えた。この考察の目的はlora_downの重みが768長だからTextual Inversion重みに変換して保持する事が可能なのではないかという仮定を考えた時、seedを固定したらEmbeddingは毎回同じガウス分布を生成するので、$K,V$重みのみを持つLoRAを任意のTextual Inversion重みに変換できるのではと考えた。これはExtended Textual InversionみたいにAttention毎に異なるTextual Inversion重みを取り出すAttnProcessorを設計せねばならないが、まあそれはさておきこの記事内では単にlora_downの初期重みをゼロ、lora_upにEmbedding重みで固定化した場合、LoRAが学習可能かどうかを確認してみたい。
image.png

LoRAコード

diffuserにおける従来のLoRALinearLayerは以下の通りである。
従来はlora_downが正規分布、lora_upが初期重みをゼロである。
これをlora_downが初期重みをゼロ、lora_upにEmbedding重みで固定化としたい。

class LoRALinearLayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, network_alpha=None):

        self.down = nn.Linear(in_features, rank, bias=False)
        self.up = nn.Linear(rank, out_features, bias=False)

        nn.init.normal_(self.down.weight, std=1 / rank)
        nn.init.zeros_(self.up.weight)

これを変えて以下の様に変更してみる。
embeddings = nn.Embedding(vocab_size, emb_dim)で、lora_upとembeddingは行列の次元が逆なので転置させている。lora_upのrequires_gradは固定化し、lora_downのみゼロから学習させるように変更している。

class LoRALinearLayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, network_alpha=None):

        self.down = nn.Linear(in_features, rank, bias=False)
        self.up = nn.Linear(rank, out_features, bias=False)

        seed = 0
        torch.manual_seed(seed)
        embedding = nn.Embedding(rank, out_features)

        nn.init.zeros_(self.down.weight)
        self.up.weight = nn.Parameter(embedding.weight.T.detach())
        self.up.weight.requires_grad=False

また、LoRACrossAttnProcessorからto_q_loraとto_out_loraを消し、$K,V$のLoRA重みのみにした。

class LoRACrossAttnProcessor(nn.Module):
    def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
...

        #self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
        self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
        #self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)

    def __call__(
        self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
    ):
...

        #query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
        query = attn.to_q(hidden_states)
...
        #hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
        hidden_states = attn.to_out[0](hidden_states)

学習結果

学習データ
1.jpeg

①TI学習のみ

SD1.5モデルを元にnum_vec_per_tokenを変えてTI学習を行った。
diffusers/examples/research_projects/mulit_token_textual_inversion/を少し改造して使用。
具体的はMultiTokenCLIPTokenizerを使わずに学習する。

①-1: num_vec_per_token=1, infer_token=1

prompt="A <cat-toy>"
cat-backpack_cat_1_TI_1_sample_0.png
prompt="A <cat-toy> backpack"
cat-backpack_cat_1_TI_1_sample_0.png

①-2: num_vec_per_token=2, infer_token=2

prompt="A <cat-toy>_0 <cat-toy>_1"
cat-backpack_cat_2_TI_2_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 backpack"
cat-backpack_cat_2_TI_2_sample_0.png

①-3: num_vec_per_token=4, infer_token=4

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 <cat-toy>_3"
cat-backpack_cat_4_TI_4_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 <cat-toy>_3 backpack"
cat-backpack_cat_4_TI_4_sample_0.png

①-4: num_vec_per_token=4, infer_token=3

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2"
cat-backpack_cat_4_TI_3_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 backpack"
cat-backpack_cat_4_TI_3_sample_0.png

①-5: num_vec_per_token=8, infer_token=8

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 ... <cat-toy>_7"
cat-backpack_cat_8_TI_8_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 ... <cat-toy>_7 backpack"
cat-backpack_cat_8_TI_8_sample_0.png

学習時のnum_vec_per_tokenが多いほど、オブジェクト精度のimageアライメントが上昇するが、代わりにprompt変化性であるtextアライメントが低下する。示した例ではリュックサックになる割合が減る。生成時(inference)において生成promptのTI token長さを減らして調整するとtextアライメントを上昇させることは出来る。

②TI学習+従来LoRA

①のTI学習したモデルを元にdiffuserの従来LoRA学習を行う。
diffusers/examples/dreambooth/train_dreambooth_lora.pyを使用。

num_vec_per_token=4の場合の--instance_prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 <cat-toy>_3"としている。
通常DreamBoothやLoRAでは学習対象にsksトークンなどの出現頻度の低いレアトークンを使用するが、今回は①で追加したTIトークンを対象にして学習をする。
multitoken TIの場合は作法として(TIトークン)_数字という形式で複数のTIトークンが追加されている。

②-1: num_vec_per_token=1, infer_token=1

prompt="A <cat-toy>"
cat-backpack_cat_1_TI_1_sample_0.png
prompt="A <cat-toy> backpack"
cat-backpack_cat_1_TI_1_sample_0.png

②-2: num_vec_per_token=2, infer_token=2

prompt="A <cat-toy>_0 <cat-toy>_1"
cat-backpack_cat_2_TI_2_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 backpack"
cat-backpack_cat_2_TI_2_sample_0.png

②-3: num_vec_per_token=4, infer_token=4

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 <cat-toy>_3"
cat-backpack_cat_4_TI_4_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 <cat-toy>_3 backpack"
cat-backpack_cat_4_TI_4_sample_0.png

②-4: num_vec_per_token=4, infer_token=3

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2"
cat-backpack_cat_4_TI_3_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 backpack"
cat-backpack_cat_4_TI_3_sample_0.png

②-5: num_vec_per_token=8, infer_token=8

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 ... <cat-toy>_7"
cat-backpack_cat_8_TI_8_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 ... <cat-toy>_7 backpack"
cat-backpack_cat_8_TI_8_sample_0.png

③TI学習+LoRA_up固定化

cross_attention.pyを改造。LoRA_downの初期値ゼロ、LoRA_upをembedding初期重みで固定化した。
このLoRAサイズは$K,V$のみなので1.6MB。従来の$Q,K,V,out$の時は3.2MB
LoRA_upは生成できるので学習重みは実質これの更に半分だが、今回は自動生成のLoRA_upも含まれている。

③-1: num_vec_per_token=1, infer_token=1

prompt="A <cat-toy>"
cat-backpack_cat_1_TI_1_sample_0.png
prompt="A <cat-toy> backpack"
cat-backpack_cat_1_TI_1_sample_0.png

③-2: num_vec_per_token=2, infer_token=2

prompt="A <cat-toy>_0 <cat-toy>_1"
cat-backpack_cat_2_TI_2_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 backpack"
cat-backpack_cat_2_TI_2_sample_0.png

③-3: num_vec_per_token=4, infer_token=4

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 <cat-toy>_3"
cat-backpack_cat_4_TI_4_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 <cat-toy>_3 backpack"
cat-backpack_cat_4_TI_4_sample_0.png

③-4: num_vec_per_token=4, infer_token=3

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2"

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 backpack"
cat-backpack_cat_4_TI_3_sample_0.png

③-5: num_vec_per_token=8, infer_token=8

prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 ... <cat-toy>_7"
cat-backpack_cat_8_TI_8_sample_0.png
prompt="A <cat-toy>_0 <cat-toy>_1 <cat-toy>_2 ... <cat-toy>_7 backpack"
cat-backpack_cat_8_TI_8_sample_0.png

④TI学習+LoRA_up固定化(Self-Attention除外)

再度確認して分かったのだが今回の③学習ではLoRAはCross-AttentionのみではなくSelf-AttentionにおいてもLoRA重みを形成していた。 Cross-Attentionの場合のLoRA_downは$(768,4)$次元なのでTextEncoderの次元$(768)$と等しいので結合可能だが、Self-Attentionの場合のLoRA_downは$(320,4),(640,4),(1280,4)$のいずれかなのでこれをTextEncoderのEmbeddingには計上できない。
Self-AttentionとCross-Attentionは同数でそれぞれ16層ずつだった。出力次元は320が5層、640が5層、1280が6層である。

Self-Attention除外するにはLoRACrossAttnProcessorを更に以下の様に変更する必要がある。
③の学習LoRA重みからSelf-AttentionのLoRA重みの適用を除外した。
この論文CustomDiffusionによればCross-Attentionの効果が重要であり、Self-AttentionのLoRA重みを除外してもそこまで影響はないと思われる。

class LoRACrossAttnProcessor(nn.Module):
    def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
...
        if cross_attention_dim is not None:
            self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
            self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
...
    def __call__(
        self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
    ):
...
        if encoder_hidden_states is not None:
            encoder_hidden_states = encoder_hidden_states
            key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
        else:
            encoder_hidden_states = hidden_states
            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)

④-1: num_vec_per_token=1, infer_token=1

prompt="A <cat-toy>"

prompt="A <cat-toy> backpack"
cat-backpack_cat_1_TI_1_sample_0.png

④-2: num_vec_per_token=2, infer_token=2

prompt="A <cat-toy>_0 <cat-toy>_1"

prompt="A <cat-toy>_0 <cat-toy>_1 backpack"
cat-backpack_cat_2_TI_2_sample_0.png

と、ここまで書いてQiitaのアップロードできる画像の上限に達してしまったので以下省略。
最初からリサイズしてから画像を張ればよかった…。

これにてLoRA_downはUnetの層によらず全て$(768,4)$次元なので、これをTextEncoderのEmbeddingに計上する事が理論上は出来る。
TIおよびLoRAのファイルサイズは以下。

TI(1token) TI(2token) TI(4token) TI(8token) TI(4*16*[K,V]token)
4KB 7KB 13KB 25KB 400KB?
従来LoRA K,Vのみ(Q,out除外) Self-Attention除外 lora_up固定化
3.2MB 1.6MB 0.8MB 0.4MB?

TI学習の学習率は10.0e-04の1500step
LoRA学習の学習率は1.0e-04の500step

参考:Textアライメント(編集性)

・DreamBoothにおいて学習stepを大きすぎ(過学習)にならないよう調整する。(ただ、必要学習stepは初期class_tokenの生成近さ、学習元モデルの描画可能かどうか、学習画像のばらつきの大きさ(≒学習画像枚数)、Data Augmentation(random_crop、flip等)の手法、学習率の大きさ、TextEncoderの学習の有無、正則学習の有無等に依存するので、実際の学習途中の生成結果を見ずに最適stepは決めにくい)
・上述の例では付与単語"backpack"を複数回繰り返したり、強調したりすればTextアライメントが向上すると思われる。
・逆にnum_vec_per_token=4, infer_token=3のように生成トークンを意図的に削ったり、TIトークンのテキスト強度を弱めればTextアライメントが向上すると思われる。

・LoRAの適用scaleを減らしてもTextアライメントが向上すると思われる。(知見なし)
・学習時に学習promptをランダムでDropする手法等がある。(知見なし)

まとめ

LoRA重みをTextEncoderのembeddingテーブル上のTI重みに計上することは出来ないだろうかという考えから、LoRA_up重みを固定化してLoRA学習を実行してみた。結果としては従来LoRA学習よりも精度は若干下がるようだが、学習が破綻するほどではない。

従って$K,V$重みのみを考えるCrossAttentionのみのLoRAにおいてLoRA重みをTI重み(TextEncoderのEmbedding重み)に結合できる可能性を示した。なお、結合する事自体には特にメリットはない。
image.png

考察:LoRAの挿入位置の移動

ここで本来の全結合重み$W$に対して逆行列$W^{-1}$が存在すると仮定すると、この逆行列を使用してLoRAの位置を任意の位置に移動が可能ではないかと思われる。
ここで$W\cdot W^{-1},W^{-1}\cdot W$は共に単位行列。全結合重み$W$の大きさは$(768,320)$、$W^{-1}$の大きさは$(320,768)$である。
image.png
image.png
image.png

すなわちsoftmaxを考えない単純な任意の$QK^TV$計算を考えた場合、$W_Q,W_K,W_V,W_{out}$の各LoRAは以下の様にまとめられるから、$W_Q,W_{out}$のLoRAは同ランクの$W_K,W_V$のLoRAにまとめられ、4つのLoRAは2倍のランク大きさの$W_K,W_V$のLoRAと等しいのではないかと思った。この場合、CrossAttention構造の前後に挿入するAdapterと等しいように思える。

image.png

また、以下の様であればTextEncoderとUnet間に挿入するAdapterと解釈できる。
image.png

この論文2CustomDiffusionによればCross-Attentionの寄与は以下の通りである。
image.png
image.png

またこの論文では上記考察のようにLoRAを変形できるAdapterの内、
$CA_{C}-CA_{C}$…TextEncoderとUnetの間
$CA_{in}-CA_{in}$…CrossAttentionの入口
$CA_{out}-FFN_{in}$…CrossAttentionの出口
の位置に挿入すると影響はいずれも高いように見える。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?