前回記事を書いた時点ではExtended Textual Inversionの実装はなかった。
Extended Textual Inversionの実装コードを読んでみたが複雑だったので忘備録としてまとめる。
diffuserのUnet実装
まずExtended Textual Inversionではない標準のUnet実装を調べる。
のpipeline
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
...
def _encode_prompt(...
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
...
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
...
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
のforward関数
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
...
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
...
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
...
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
...
return UNet2DConditionOutput(sample=sample)
上記のencoder_hidden_statesに注目する。
ここでencoder_hidden_states=encoder_hidden_states
はdown、mid、upレイヤーに対して同じである。
jakaline-dev氏のXTI実装
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2]
などのように取り出しており、encoder_hidden_statesは従来のencoder_hidden_statesをまとめた一次元高次のtorch.Tensorである。
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
...
encoder_hidden_states = torch.stack([train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) for s in torch.split(input_ids, 1, dim=1)])
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).
...
def unet_forward_XTI(self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
...
# 3. down
down_block_res_samples = (sample,)
down_i = 0
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
)
down_i += 2
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
# 5. up
up_i = 7
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
upsample_size=upsample_size,
)
up_i += 3
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
...
return UNet2DConditionOutput(sample=sample)
ここでencoder_hidden_statesは高次Tensorから取り出している。
down:
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2]
mid:
encoder_hidden_states=encoder_hidden_states[6]
up:
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3]
prompt++の実装
のoverwrite_callで
# predict the noise residual
noise_pred_uncond = self.unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
noise_pred_text = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
のutils.pyで以下の様である。
つまりencoder_hidden_statesは辞書型データであり、高次のtorch.Tensorではない。encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"]
でencoder_hidden_states[idx : idx + 1, :, :]
の値を取り出している。
class PPPAttenProc:
def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask=None):
is_dict_format = True
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, dict):
this_idx = encoder_hidden_states["this_idx"]
_ehs = encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"]
encoder_hidden_states["this_idx"] += 1
encoder_hidden_states["this_idx"] %= 16
else:
_ehs = encoder_hidden_states
else:
_ehs = None
...
class PPPPromptManager:
def __init__(
self, tokenizer, text_encoder, main_token, preserve_prefix, extend_amount
):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.main_token = main_token
self.preserve_prefix = preserve_prefix
self.extend_amount = extend_amount
def expand_prompt(self, text: str):
pp_extended = pp_extend(
text, self.main_token, self.preserve_prefix, self.extend_amount
)
return pp_extended
def embed_prompt(self, text: str):
texts = self.expand_prompt(text)
ids = self.tokenizer(
texts,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
encoder_hidden_states = self.text_encoder(ids.to(self.text_encoder.device))[0]
_hs = {"this_idx": 0}
for idx in range(16):
_hs[f"CONTEXT_TENSOR_{idx}"] = encoder_hidden_states[idx : idx + 1, :, :]
return _hs
これはtrain_ppp.pyでUnetのattn毎にPPPAttenProc()関数が呼び出されるように追加している。encoder_hidden_statesが辞書型でなければ従来promptのembeddingと同じで、encoder_hidden_statesが辞書型の場合はXTIの入力である。
このときencoder_hidden_states["this_idx"]は0=>1...=>15=>0=>1...とPPPAttenProcが呼び出されるたびに0~15の値をループする。従って各attention層で異なるencoder_hidden_statesを呼び出せる。
unet.set_attn_processor(PPPAttenProc())
まとめ:
XTIはUnetのレイヤー毎に異なるpromptを与える。
jakaline-dev氏のようにencoder_hidden_statesを従来より高次のtorch.Tensorで持つならencoder_hidden_states[6]
のようにXTIの実装に引用するレイヤー順番の数を指定する必要がある。
prompt++はencoder_hidden_statesを辞書型として与えている。これによってUnetに対する変更はXTI用のnew_forwardを設計しなくてもよく、楽に見える。
また、XTIを使用しない(高次tensorでない)通常のpromptのprompt_embedsを与えてもエラーは出ない。negative_promptの場合に高次のtorch.Tensorを用意する必要もない。
追記:
P+の論文にあるようなを「緑のトカゲ」と「赤の立方体」の合成を書いてみる。
p++の実装を参考にunet.set_attn_processor()を使い、以下の様にした時「赤のトカゲ」と「緑の立方体」の生成された。
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0
import torch
class PPAttenProc:
def __init__(self):
self.idx=0
def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask=None):
if encoder_hidden_states is not None:
if encoder_hidden_states.size()[2]==768*16:
_ehs = encoder_hidden_states[:,:,(self.idx)*768:(self.idx+1)*768]
else:
_ehs = encoder_hidden_states
self.idx += 1
self.idx %= 16
else:
_ehs = None
return AttnProcessor2_0()(attn, hidden_states, _ehs, attention_mask)
class PPPromptManager:
def __init__(self, pipe):
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
def embed_prompt(self, prompts_A, prompt_B, mix_range, neg_prompts, reverse=False):
batch_size = len(prompts_A)
texts = []
for idx in range(batch_size):
for jdx in range(16):
if (jdx in mix_range and not reverse) or (jdx not in mix_range and reverse):
texts.append(prompts_A[idx][jdx])
else:
texts.append(prompts_B[idx][jdx])
texts.append(neg_prompts[idx][jdx])
print(texts)
ids = self.tokenizer(texts, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids
x = self.text_encoder(ids.to(self.text_encoder.device))[0]
x = torch.reshape(x, (batch_size, 16, 2, self.tokenizer.model_max_length, -1))
x = torch.permute(x, (2, 0, 3, 1, 4))
x = torch.reshape(x, (2, batch_size, self.tokenizer.model_max_length, -1))
return x[0], x[1]
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(PPAttenProc())
pm = PPPromptManager(pipe)
prompts_A = [["red cube" for i in range(16)]]
prompts_B = [["green lizard" for i in range(16)]]
neg_prompts = [["" for i in range(16)]]
mix_range = range(5,8)
prompt_embeds, negative_prompt_embeds = pm.embed_prompt(prompts_A, prompts_B, mix_range, neg_prompts, reverse=False)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=25).images[0]
image.save("output.png")
prompt_embeds, negative_prompt_embeds = pm.embed_prompt(prompts_A, prompts_B, mix_range, neg_prompts, reverse=True)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=25).images[0]
image.save("output2.png")