2
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

diffusersでトークンの長さが制限される問題を解決する

Posted at

はじめに

diffusersでプロンプトをめっちゃ長く入れたりすると、77トークン以上は切り捨てられる現象にあった。
Token indices sequence length is longer than the specified maximum sequence length for this model (** > **)
みたいなエラーが出てちゃんと読み込んでくれなかったので、その対策を探し出したので備忘録

コード

参考にしたissuesのコードはポジティブプロンプトが長くなった場合のみのコードで、あくまでリファレンスしやすいように書かれてあったので、下記のようにしました。

import torch

def token_auto_concat_embeds(pipe, positive, negative):
    max_length = pipe.tokenizer.model_max_length
    positive_length = pipe.tokenizer(positive, return_tensors="pt").input_ids.shape[-1]
    negative_length = pipe.tokenizer(negative, return_tensors="pt").input_ids.shape[-1]
    
    print(f'Token length is model maximum: {max_length}, positive length: {positive_length}, negative length: {negative_length}.')
    if max_length < positive_length or max_length < negative_length:
        print('Concatenated embedding.')
        if positive_length > negative_length:
            positive_ids = pipe.tokenizer(positive, return_tensors="pt").input_ids.to("cuda")
            negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=positive_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
        else:
            negative_ids = pipe.tokenizer(negative, return_tensors="pt").input_ids.to("cuda")  
            positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=negative_ids.shape[-1],  return_tensors="pt").input_ids.to("cuda")
    else:
        positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=max_length,  return_tensors="pt").input_ids.to("cuda")
        negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
    
    positive_concat_embeds = []
    negative_concat_embeds = []
    for i in range(0, positive_ids.shape[-1], max_length):
        positive_concat_embeds.append(pipe.text_encoder(positive_ids[:, i: i + max_length])[0])
        negative_concat_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
    
    positive_prompt_embeds = torch.cat(positive_concat_embeds, dim=1)
    negative_prompt_embeds = torch.cat(negative_concat_embeds, dim=1)
    return positive_prompt_embeds, negative_prompt_embeds

使用方法

prompt=ではなく、prompt_embeds=になる点に注意

import torch
import datetime
from token_auto_concat_embeds import token_auto_concat_embeds
from diffusers import DiffusionPipeline

# Generate Settings
MODEL_ID=''
POSITIVE_PROMPT = ''
NEGATIVE_PROMPT = ''
HEIGHT = 768
WIDTH = 512
SCALE = 12.0
STEP = 28
SEED = 3788086447
DEVICE = 'cuda'
CACHE_DIR = './cache'

# Create Pipeline
pipe = DiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    torch_dtype=torch.float16,
    cache_dir=CACHE_DIR,
).to(DEVICE)

# Toekn concatenated embedding
positive_embeds, negative_embeds = token_auto_concat_embeds(pipe, POSITIVE_PROMPT, NEGATIVE_PROMPT)

# Generate Image
image = pipe(
    prompt_embeds=positive_embeds,
    height=HEIGHT,
    width=WIDTH,
    num_inference_steps=STEP,
    guidance_scale=SCALE,
    negative_prompt_embeds=negative_embeds,
    generator=torch.Generator(device=DEVICE).manual_seed(SEED)
).images[0]

# Save Image
image.save("images/" + str(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')) + ".png")

参考

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?