目的
StableDiffusionに二種類のpromptの線形補間を与えた場合、TextEncoderのどこの層で線形補間を与えると良いかを考えたい。
StableDiffusionのmophingあたりで検索を掛ければprompt補間はよくやられている。
1. TextEncoderの出力
二種類のpromptのTextEncoderの二個の出力の線形補間を求めて、それをUnetに渡す。
もっとも簡単な変換である。
from diffusers import StableDiffusionPipeline
import torch
from diffusers import DDIMScheduler
import os
import numpy as np
def lerp(t, v0, v1):
return (1.0 - t) * v0 + t * v1
func = lerp
func_name = 'lerp'
model_id = "runwayml/stable-diffusion-v1-5"
ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16,scheduler=ddim).to("cuda")
device='cuda'
seed = 100
generator = torch.Generator(device=device)
## mophing01
output_path = "./mophing01_%s/" % (func_name)
os.makedirs(output_path, exist_ok=True)
for i in range(12):
generator = generator.manual_seed(seed)
prompt1 = "a photo of an astronaut riding a horse on mars"
prompt2 = "a photo of an astronaut riding a tiger on grassland"
prompt3 = ""
token_ids1 = pipe.tokenizer(prompt1, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
token_ids2 = pipe.tokenizer(prompt2, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
token_ids3 = pipe.tokenizer(prompt3, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
prompt_embeds1 = pipe.text_encoder(token_ids1.to(device))[0]
prompt_embeds2 = pipe.text_encoder(token_ids2.to(device))[0]
prompt_embeds3 = pipe.text_encoder(token_ids3.to(device))[0]
t = np.linspace(0, 1, 12)[i]
prompt_embeds = func(t, prompt_embeds1, prompt_embeds2)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=prompt_embeds3, num_inference_steps=25, guidance_scale=7.5, generator=generator).images[0]
image.save(output_path + "sample_%02d_%f.png" % (i, t))
2. TextEncoderの出力(その2)
ここで前節ではTextEncoderによる変換はひとまとめに書かれていてtoken_idsからprompt_embedsまでの途中の記述は一切なかった。class CLIPTextTransformer(nn.Module):
のこれを参考にpipe.text_encoder()の部分をもう少し詳細に書くと以下の様になる。
これは前節のpipe.text_encoder
内の動作を詳細に記述しただけで線形補間を行う場所自体は変わっていない。
def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length = 0):
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
## mophing02
output_path = "./mophing02_%s/" % (func_name)
os.makedirs(output_path, exist_ok=True)
for i in range(12):
generator = generator.manual_seed(seed)
prompt1 = "a photo of an astronaut riding a horse on mars"
prompt2 = "a photo of an astronaut riding a tiger on grassland"
prompt3 = ""
token_ids1 = pipe.tokenizer(prompt1, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
token_ids2 = pipe.tokenizer(prompt2, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
token_ids3 = pipe.tokenizer(prompt3, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
input_shape = token_ids1.size()
token_ids1 = token_ids1.view(-1, input_shape[-1])
token_ids2 = token_ids2.view(-1, input_shape[-1])
token_ids3 = token_ids3.view(-1, input_shape[-1])
position_ids = pipe.text_encoder.text_model.embeddings.position_ids
hidden_states1 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids1.to(device), position_ids=position_ids)
hidden_states2 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids2.to(device), position_ids=position_ids)
hidden_states3 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids3.to(device), position_ids=position_ids)
causal_attention_mask = _make_causal_mask(input_shape, hidden_states1.dtype, device=device)
encoder_outputs1 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states1, causal_attention_mask=causal_attention_mask)[0]
encoder_outputs2 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states2, causal_attention_mask=causal_attention_mask)[0]
encoder_outputs3 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states3, causal_attention_mask=causal_attention_mask)[0]
prompt_embeds1 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs1)
prompt_embeds2 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs2)
prompt_embeds3 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs3)
t = np.linspace(0, 1, 12)[i]
prompt_embeds1 = func(t, prompt_embeds1, prompt_embeds2)
image = pipe(prompt_embeds=prompt_embeds1, negative_prompt_embeds=prompt_embeds3, num_inference_steps=25, guidance_scale=7.5, generator=generator).images[0]
image.save(output_path + "sample_%02d_%f.png" % (i, t))
ここで主たるTextEncoderの中間層はTransformer(text_model.encoder)の前後であるhidde_statesとencoder_outputsがある。このそれぞれの場所において線形補間したらどうなるかを試したい。
3. encoder_outputsで線形補間
## mophing03
output_path = "./mophing03_%s/" % (func_name)
os.makedirs(output_path, exist_ok=True)
for i in range(12):
generator = generator.manual_seed(seed)
prompt1 = "a photo of an astronaut riding a horse on mars"
prompt2 = "a photo of an astronaut riding a tiger on grassland"
prompt3 = ""
token_ids1 = pipe.tokenizer(prompt1, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
token_ids2 = pipe.tokenizer(prompt2, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
token_ids3 = pipe.tokenizer(prompt3, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
input_shape = token_ids1.size()
token_ids1 = token_ids1.view(-1, input_shape[-1])
token_ids2 = token_ids2.view(-1, input_shape[-1])
token_ids3 = token_ids3.view(-1, input_shape[-1])
position_ids = pipe.text_encoder.text_model.embeddings.position_ids
hidden_states1 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids1.to(device), position_ids=position_ids)
hidden_states2 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids2.to(device), position_ids=position_ids)
hidden_states3 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids3.to(device), position_ids=position_ids)
causal_attention_mask = _make_causal_mask(input_shape, hidden_states1.dtype, device=device)
encoder_outputs1 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states1, causal_attention_mask=causal_attention_mask)[0]
encoder_outputs2 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states2, causal_attention_mask=causal_attention_mask)[0]
encoder_outputs3 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states3, causal_attention_mask=causal_attention_mask)[0]
t = np.linspace(0, 1, 12)[i]
encoder_outputs1 = func(t, encoder_outputs1, encoder_outputs2)
prompt_embeds1 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs1)
prompt_embeds3 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs3)
image = pipe(prompt_embeds=prompt_embeds1, negative_prompt_embeds=prompt_embeds3, num_inference_steps=25, guidance_scale=7.5, generator=generator).images[0]
image.save(output_path + "sample_%02d_%f.png" % (i, t))
4. hidde_statesで線形補間
## mophing04
output_path = "./mophing04_%s/" % (func_name)
os.makedirs(output_path, exist_ok=True)
for i in range(12):
generator = generator.manual_seed(seed)
prompt1 = "a photo of an astronaut riding a horse on mars"
prompt2 = "a photo of an astronaut riding a tiger on grassland"
prompt3 = ""
token_ids1 = pipe.tokenizer(prompt1, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
token_ids2 = pipe.tokenizer(prompt2, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
token_ids3 = pipe.tokenizer(prompt3, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
input_shape = token_ids1.size()
token_ids1 = token_ids1.view(-1, input_shape[-1])
token_ids2 = token_ids2.view(-1, input_shape[-1])
token_ids3 = token_ids3.view(-1, input_shape[-1])
position_ids = pipe.text_encoder.text_model.embeddings.position_ids
hidden_states1 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids1.to(device), position_ids=position_ids)
hidden_states2 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids2.to(device), position_ids=position_ids)
hidden_states3 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids3.to(device), position_ids=position_ids)
t = np.linspace(0, 1, 12)[i]
hidden_states1 = func(t, hidden_states1, hidden_states2)
causal_attention_mask = _make_causal_mask(input_shape, hidden_states1.dtype, device=device)
encoder_outputs1 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states1, causal_attention_mask=causal_attention_mask)[0]
encoder_outputs3 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states3, causal_attention_mask=causal_attention_mask)[0]
prompt_embeds1 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs1)
prompt_embeds3 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs3)
image = pipe(prompt_embeds=prompt_embeds1, negative_prompt_embeds=prompt_embeds3, num_inference_steps=25, guidance_scale=7.5, generator=generator).images[0]
image.save(output_path + "sample_%02d_%f.png" % (i, t))
結果
以下、結果を示す。
線形補間を行う場所を変えるとトラとウマの中間に表示される物体が異なる。
1.2.ではまだら模様のウマ、3.ではトラ柄のウマが表示されやすくなり、4.では謎の生物が生成される。
1.2.の結果はほぼ同じであるから、記述が等価であるのが確認できる。
1.の位置と別の位置で線形補間した3.4.は結果が変わる。
4.が最も急激に変わり、3.が最もなだらかに変化しているように見える。
1. TextEncoderの出力で線形補間
2. TextEncoderの出力で線形補間(その2)
3. encoder_outputsで線形補間
4. hidde_statesで線形補間
球面補間(slerp)
実は球面補間の方もやってみたが線形補間を行う層を変えた時ほど変化せず、小さな違いしか見えなかった(色合いが変わるくらい?)なので結果は省略する。
ノイズの初期latentの補間に球面補間を使う実装(これとかこれ)を見るので、prompt補間ではなく初期ノイズの補間には球面補間が良いのかもしれない。
まとめ
今回のseedではTextEncoderのTransformerとfinal_layer_normの間の層で線形補間を行うのが最もよく見えた。
なお、"a photo of an astronaut riding a tiger on grassland"
の生成確率はa photo of an astronaut riding a horse on mars
よりも低く、宇宙飛行士が描かれないケースがしばしばあるので多少のseedガチャが必要である。