0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Extended Textual Inversionの実装比較

Last updated at Posted at 2023-03-29

前回記事を書いた時点ではExtended Textual Inversionの実装はなかった。
Extended Textual Inversionの実装コードを読んでみたが複雑だったので忘備録としてまとめる。

diffuserのUnet実装

まずExtended Textual Inversionではない標準のUnet実装を調べる。

のpipeline

        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )
...
    def _encode_prompt(...
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        return prompt_embeds
...

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
...

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample

のforward関数

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:

...
        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples
...
        # 4. mid
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
            )
...
        # 5. up
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    upsample_size=upsample_size,
                    attention_mask=attention_mask,
                )
            else:
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
...

        return UNet2DConditionOutput(sample=sample)

上記のencoder_hidden_statesに注目する。
ここでencoder_hidden_states=encoder_hidden_statesはdown、mid、upレイヤーに対して同じである。

jakaline-dev氏のXTI実装

encoder_hidden_states=encoder_hidden_states[down_i:down_i+2]などのように取り出しており、encoder_hidden_statesは従来のencoder_hidden_statesをまとめた一次元高次のtorch.Tensorである。

  diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
  diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
  diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
...

                encoder_hidden_states = torch.stack([train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) for s in torch.split(input_ids, 1, dim=1)])

                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).
...
def unet_forward_XTI(self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
...

        # 3. down
        down_block_res_samples = (sample,)
        down_i = 0
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
                )
                down_i += 2
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])

        # 5. up
        up_i = 7
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
                    upsample_size=upsample_size,
                )
                up_i += 3
            else:
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
...
        return UNet2DConditionOutput(sample=sample)

ここでencoder_hidden_statesは高次Tensorから取り出している。
down:
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2]
mid:
encoder_hidden_states=encoder_hidden_states[6]
up:
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3]

prompt++の実装

のoverwrite_callで

                # predict the noise residual
                noise_pred_uncond = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=negative_prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample

                noise_pred_text = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample

のutils.pyで以下の様である。
つまりencoder_hidden_statesは辞書型データであり、高次のtorch.Tensorではない。encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"]encoder_hidden_states[idx : idx + 1, :, :]の値を取り出している。

class PPPAttenProc:
    def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask=None):

        is_dict_format = True
        if encoder_hidden_states is not None:
            if isinstance(encoder_hidden_states, dict):
                this_idx = encoder_hidden_states["this_idx"]

                _ehs = encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"]
                encoder_hidden_states["this_idx"] += 1
                encoder_hidden_states["this_idx"] %= 16
            else:
                _ehs = encoder_hidden_states
        else:
            _ehs = None
...

class PPPPromptManager:
    def __init__(
        self, tokenizer, text_encoder, main_token, preserve_prefix, extend_amount
    ):
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.main_token = main_token
        self.preserve_prefix = preserve_prefix
        self.extend_amount = extend_amount

    def expand_prompt(self, text: str):

        pp_extended = pp_extend(
            text, self.main_token, self.preserve_prefix, self.extend_amount
        )

        return pp_extended

    def embed_prompt(self, text: str):
        texts = self.expand_prompt(text)
        ids = self.tokenizer(
            texts,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        encoder_hidden_states = self.text_encoder(ids.to(self.text_encoder.device))[0]
        _hs = {"this_idx": 0}
        for idx in range(16):
            _hs[f"CONTEXT_TENSOR_{idx}"] = encoder_hidden_states[idx : idx + 1, :, :]

        return _hs

これはtrain_ppp.pyでUnetのattn毎にPPPAttenProc()関数が呼び出されるように追加している。encoder_hidden_statesが辞書型でなければ従来promptのembeddingと同じで、encoder_hidden_statesが辞書型の場合はXTIの入力である。
このときencoder_hidden_states["this_idx"]は0=>1...=>15=>0=>1...とPPPAttenProcが呼び出されるたびに0~15の値をループする。従って各attention層で異なるencoder_hidden_statesを呼び出せる。

    unet.set_attn_processor(PPPAttenProc())

まとめ:

XTIはUnetのレイヤー毎に異なるpromptを与える。
jakaline-dev氏のようにencoder_hidden_statesを従来より高次のtorch.Tensorで持つならencoder_hidden_states[6]のようにXTIの実装に引用するレイヤー順番の数を指定する必要がある。

prompt++はencoder_hidden_statesを辞書型として与えている。これによってUnetに対する変更はXTI用のnew_forwardを設計しなくてもよく、楽に見える。
また、XTIを使用しない(高次tensorでない)通常のpromptのprompt_embedsを与えてもエラーは出ない。negative_promptの場合に高次のtorch.Tensorを用意する必要もない。

追記:

P+の論文にあるようなを「緑のトカゲ」と「赤の立方体」の合成を書いてみる。
p++の実装を参考にunet.set_attn_processor()を使い、以下の様にした時「赤のトカゲ」と「緑の立方体」の生成された。

from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0
import torch

class PPAttenProc:
    def __init__(self):
        self.idx=0
    def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask=None):

        if encoder_hidden_states is not None:
            if encoder_hidden_states.size()[2]==768*16:
                _ehs = encoder_hidden_states[:,:,(self.idx)*768:(self.idx+1)*768]
            else:
                _ehs = encoder_hidden_states
            self.idx += 1
            self.idx %= 16
        else:
            _ehs = None

        return AttnProcessor2_0()(attn, hidden_states, _ehs, attention_mask)

class PPPromptManager:
    def __init__(self, pipe):
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder

    def embed_prompt(self, prompts_A, prompt_B, mix_range, neg_prompts, reverse=False):
        batch_size = len(prompts_A)
        texts = []
        for idx in range(batch_size):
            for jdx in range(16):
                if (jdx in mix_range and not reverse) or (jdx not in mix_range and reverse):
                    texts.append(prompts_A[idx][jdx])
                else:
                    texts.append(prompts_B[idx][jdx])
                texts.append(neg_prompts[idx][jdx])
        print(texts)
        ids = self.tokenizer(texts, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids
        x = self.text_encoder(ids.to(self.text_encoder.device))[0]

        x = torch.reshape(x, (batch_size, 16, 2, self.tokenizer.model_max_length, -1))
        x = torch.permute(x, (2, 0, 3, 1, 4))
        x = torch.reshape(x, (2, batch_size, self.tokenizer.model_max_length, -1))

        return x[0], x[1]

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(PPAttenProc())

pm = PPPromptManager(pipe)
prompts_A = [["red cube" for i in range(16)]]
prompts_B = [["green lizard" for i in range(16)]]
neg_prompts = [["" for i in range(16)]]
mix_range = range(5,8)

prompt_embeds, negative_prompt_embeds = pm.embed_prompt(prompts_A, prompts_B, mix_range, neg_prompts, reverse=False)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=25).images[0]
image.save("output.png")

prompt_embeds, negative_prompt_embeds = pm.embed_prompt(prompts_A, prompts_B, mix_range, neg_prompts, reverse=True)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=25).images[0]
image.save("output2.png")

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?