1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

StableDiffusionのpromptの線形補間

Posted at

目的

StableDiffusionに二種類のpromptの線形補間を与えた場合、TextEncoderのどこの層で線形補間を与えると良いかを考えたい。
StableDiffusionのmophingあたりで検索を掛ければprompt補間はよくやられている。

1. TextEncoderの出力

二種類のpromptのTextEncoderの二個の出力の線形補間を求めて、それをUnetに渡す。
もっとも簡単な変換である。

from diffusers import StableDiffusionPipeline
import torch
from diffusers import DDIMScheduler
import os
import numpy as np

def lerp(t, v0, v1):
    return (1.0 - t) * v0 + t * v1

func = lerp
func_name = 'lerp'

model_id = "runwayml/stable-diffusion-v1-5"
ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16,scheduler=ddim).to("cuda")

device='cuda'
seed = 100
generator = torch.Generator(device=device)

## mophing01
output_path = "./mophing01_%s/" % (func_name)
os.makedirs(output_path, exist_ok=True)

for i in range(12):
    generator = generator.manual_seed(seed)
    prompt1 = "a photo of an astronaut riding a horse on mars"
    prompt2 = "a photo of an astronaut riding a tiger on grassland"
    prompt3 = ""
    token_ids1 = pipe.tokenizer(prompt1, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    token_ids2 = pipe.tokenizer(prompt2, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    token_ids3 = pipe.tokenizer(prompt3, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    prompt_embeds1 = pipe.text_encoder(token_ids1.to(device))[0]
    prompt_embeds2 = pipe.text_encoder(token_ids2.to(device))[0]
    prompt_embeds3 = pipe.text_encoder(token_ids3.to(device))[0]

    t = np.linspace(0, 1, 12)[i]
    prompt_embeds = func(t, prompt_embeds1, prompt_embeds2)
    
    image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=prompt_embeds3, num_inference_steps=25, guidance_scale=7.5, generator=generator).images[0]
    image.save(output_path + "sample_%02d_%f.png" % (i, t))

2. TextEncoderの出力(その2)

ここで前節ではTextEncoderによる変換はひとまとめに書かれていてtoken_idsからprompt_embedsまでの途中の記述は一切なかった。class CLIPTextTransformer(nn.Module):これを参考にpipe.text_encoder()の部分をもう少し詳細に書くと以下の様になる。
これは前節のpipe.text_encoder内の動作を詳細に記述しただけで線形補間を行う場所自体は変わっていない。

def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length = 0):
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)
    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

## mophing02
output_path = "./mophing02_%s/" % (func_name)
os.makedirs(output_path, exist_ok=True)

for i in range(12):
    generator = generator.manual_seed(seed)
    prompt1 = "a photo of an astronaut riding a horse on mars"
    prompt2 = "a photo of an astronaut riding a tiger on grassland"
    prompt3 = ""
    token_ids1 = pipe.tokenizer(prompt1, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    token_ids2 = pipe.tokenizer(prompt2, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    token_ids3 = pipe.tokenizer(prompt3, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids

    input_shape = token_ids1.size()
    token_ids1 = token_ids1.view(-1, input_shape[-1])
    token_ids2 = token_ids2.view(-1, input_shape[-1])
    token_ids3 = token_ids3.view(-1, input_shape[-1])
    position_ids = pipe.text_encoder.text_model.embeddings.position_ids

    hidden_states1 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids1.to(device), position_ids=position_ids)
    hidden_states2 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids2.to(device), position_ids=position_ids)
    hidden_states3 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids3.to(device), position_ids=position_ids)

    causal_attention_mask = _make_causal_mask(input_shape, hidden_states1.dtype, device=device)
    encoder_outputs1 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states1, causal_attention_mask=causal_attention_mask)[0]
    encoder_outputs2 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states2, causal_attention_mask=causal_attention_mask)[0]
    encoder_outputs3 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states3, causal_attention_mask=causal_attention_mask)[0]

    prompt_embeds1 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs1)
    prompt_embeds2 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs2)
    prompt_embeds3 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs3)

    t = np.linspace(0, 1, 12)[i]
    prompt_embeds1 = func(t, prompt_embeds1, prompt_embeds2)
    
    image = pipe(prompt_embeds=prompt_embeds1, negative_prompt_embeds=prompt_embeds3, num_inference_steps=25, guidance_scale=7.5, generator=generator).images[0]
    image.save(output_path + "sample_%02d_%f.png" % (i, t))

ここで主たるTextEncoderの中間層はTransformer(text_model.encoder)の前後であるhidde_statesencoder_outputsがある。このそれぞれの場所において線形補間したらどうなるかを試したい。

image.png

3. encoder_outputsで線形補間

## mophing03
output_path = "./mophing03_%s/" % (func_name)
os.makedirs(output_path, exist_ok=True)

for i in range(12):
    generator = generator.manual_seed(seed)
    prompt1 = "a photo of an astronaut riding a horse on mars"
    prompt2 = "a photo of an astronaut riding a tiger on grassland"
    prompt3 = ""
    token_ids1 = pipe.tokenizer(prompt1, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    token_ids2 = pipe.tokenizer(prompt2, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    token_ids3 = pipe.tokenizer(prompt3, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids

    input_shape = token_ids1.size()
    token_ids1 = token_ids1.view(-1, input_shape[-1])
    token_ids2 = token_ids2.view(-1, input_shape[-1])
    token_ids3 = token_ids3.view(-1, input_shape[-1])
    position_ids = pipe.text_encoder.text_model.embeddings.position_ids

    hidden_states1 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids1.to(device), position_ids=position_ids)
    hidden_states2 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids2.to(device), position_ids=position_ids)
    hidden_states3 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids3.to(device), position_ids=position_ids)

    causal_attention_mask = _make_causal_mask(input_shape, hidden_states1.dtype, device=device)
    encoder_outputs1 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states1, causal_attention_mask=causal_attention_mask)[0]
    encoder_outputs2 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states2, causal_attention_mask=causal_attention_mask)[0]
    encoder_outputs3 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states3, causal_attention_mask=causal_attention_mask)[0]

    t = np.linspace(0, 1, 12)[i]
    encoder_outputs1 = func(t, encoder_outputs1, encoder_outputs2)

    prompt_embeds1 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs1)
    prompt_embeds3 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs3)

    image = pipe(prompt_embeds=prompt_embeds1, negative_prompt_embeds=prompt_embeds3, num_inference_steps=25, guidance_scale=7.5, generator=generator).images[0]
    image.save(output_path + "sample_%02d_%f.png" % (i, t))

4. hidde_statesで線形補間

## mophing04
output_path = "./mophing04_%s/" % (func_name)
os.makedirs(output_path, exist_ok=True)

for i in range(12):
    generator = generator.manual_seed(seed)
    prompt1 = "a photo of an astronaut riding a horse on mars"
    prompt2 = "a photo of an astronaut riding a tiger on grassland"
    prompt3 = ""
    token_ids1 = pipe.tokenizer(prompt1, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    token_ids2 = pipe.tokenizer(prompt2, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids
    token_ids3 = pipe.tokenizer(prompt3, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt").input_ids

    input_shape = token_ids1.size()
    token_ids1 = token_ids1.view(-1, input_shape[-1])
    token_ids2 = token_ids2.view(-1, input_shape[-1])
    token_ids3 = token_ids3.view(-1, input_shape[-1])
    position_ids = pipe.text_encoder.text_model.embeddings.position_ids

    hidden_states1 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids1.to(device), position_ids=position_ids)
    hidden_states2 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids2.to(device), position_ids=position_ids)
    hidden_states3 = pipe.text_encoder.text_model.embeddings(input_ids=token_ids3.to(device), position_ids=position_ids)

    t = np.linspace(0, 1, 12)[i]
    hidden_states1 = func(t, hidden_states1, hidden_states2)

    causal_attention_mask = _make_causal_mask(input_shape, hidden_states1.dtype, device=device)
    encoder_outputs1 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states1, causal_attention_mask=causal_attention_mask)[0]
    encoder_outputs3 = pipe.text_encoder.text_model.encoder(inputs_embeds=hidden_states3, causal_attention_mask=causal_attention_mask)[0]

    prompt_embeds1 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs1)
    prompt_embeds3 = pipe.text_encoder.text_model.final_layer_norm(encoder_outputs3)

    image = pipe(prompt_embeds=prompt_embeds1, negative_prompt_embeds=prompt_embeds3, num_inference_steps=25, guidance_scale=7.5, generator=generator).images[0]
    image.save(output_path + "sample_%02d_%f.png" % (i, t))

結果

以下、結果を示す。
線形補間を行う場所を変えるとトラとウマの中間に表示される物体が異なる。
1.2.ではまだら模様のウマ、3.ではトラ柄のウマが表示されやすくなり、4.では謎の生物が生成される。
1.2.の結果はほぼ同じであるから、記述が等価であるのが確認できる。
1.の位置と別の位置で線形補間した3.4.は結果が変わる。
4.が最も急激に変わり、3.が最もなだらかに変化しているように見える。

1. TextEncoderの出力で線形補間

mophing01_lerp.png

2. TextEncoderの出力で線形補間(その2)

mophing02_lerp.png

3. encoder_outputsで線形補間

mophing03_lerp.png

4. hidde_statesで線形補間

mophing04_lerp.png

球面補間(slerp)

実は球面補間の方もやってみたが線形補間を行う層を変えた時ほど変化せず、小さな違いしか見えなかった(色合いが変わるくらい?)なので結果は省略する。
ノイズの初期latentの補間に球面補間を使う実装(これとかこれ)を見るので、prompt補間ではなく初期ノイズの補間には球面補間が良いのかもしれない。

まとめ

今回のseedではTextEncoderのTransformerとfinal_layer_normの間の層で線形補間を行うのが最もよく見えた。

なお、"a photo of an astronaut riding a tiger on grassland"の生成確率はa photo of an astronaut riding a horse on marsよりも低く、宇宙飛行士が描かれないケースがしばしばあるので多少のseedガチャが必要である。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?