0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

CandleでStable Diffusion 1.5を使った画像生成

Posted at

前回Candle でローカルLLMを実行しましたが、今回は Stable Diffusion 1.5 を使った画像生成を行います。

モデルの取得

今回も Hugging Face から必要なファイルを手動でダウンロードします。

まずは、https://huggingface.co/openai/clip-vit-base-patch32/tree/main 1から以下のファイルをダウンロードします。

  • tokenizer.json

次に、https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main から以下のファイルをダウンロードします。

  • text_encoder/model.safetensors
  • unet/diffusion_pytorch_model.safetensors
  • vae/diffusion_pytorch_model.safetensors

実装

Stable Diffusion 1.5 で画像生成する流れは概ねこのようになります。2

手順 名称 処理内容
1 Tokenizer Tokenizer でプロンプトやネガティブプロンプトをトークン化
2 CLIP CLIPエンコーダでトークン列から意味情報を抽出
3 UNET 拡散モデルで画像の圧縮表現を生成。現在の画像(圧縮表現)と意味情報から UNET でノイズを算出し、スケジューラでノイズを除去する処理を繰り返す
4 VAE VAEデコーダで画像の圧縮表現から実際の画像を生成

プロンプトとネガティブプロンプトをまとめて処理する方法も考えられますが、ここでは分かり易さのために別々に処理します。

スケジューラによってノイズ除去の結果が変わりますが、Candle における Stable Diffusion 1.5 のデフォルトスケジューラは DDIM(Denoising Diffusion Implicit Model)のようです。3

画像の質を上げるためにガイダンスを使って意味情報を強調するようにしていますが、ガイダンスを適用しない場合は次のようにプロンプトで算出したノイズをそのままノイズ除去に使います。

ガイダンスを適用しない場合のノイズ除去例
// プロンプトの意味情報を使ったノイズ算出
let noize_pred_text = unet.forward(&input, timestep as f64, &prompt_embeds)?;
// ノイズの除去
latents = scheduler.step(&noize_pred_text, timestep, &latents)?;

VAE でデコードした結果は、(rgb x 高さ x 幅) の並びで概ね -1.0 から 1.0 の値になるようなので、これを (高さ x 幅 x rgb) の並びで 0 から 255 の値へ変換し、画像ファイルとして保存しています。

また、全体的にそれなりのメモリを消費するので、スコープを活用して処理毎にリソースを解放するようにしています。4

SD1.5による画像生成コード例
use candle_core::{DType, Device, Module, Tensor};
use candle_transformers::models::stable_diffusion::{
    build_clip_transformer, clip::ClipTextTransformer, StableDiffusionConfig,
};
use serde::{Deserialize, Serialize};
use std::env;
use tokenizers::{Result, Tokenizer};

#[derive(Debug, Clone, Serialize, Deserialize)]
struct Task {
    tokenizer: String,
    clip: String,
    unet: String,
    vae: String,
    prompt: String,
    negative_prompt: Option<String>,
    steps: usize,
    guidance_scale: f64,
    output_file: Option<String>,
    width: Option<usize>,
    height: Option<usize>,
}

fn main() -> Result<()> {
    let task: Task = {
        let file = env::args().nth(1).ok_or("task file")?;
        serde_json::from_str(&std::fs::read_to_string(file)?)?
    };

    let vae_scale = 0.18215;
    let pad_token = "<|endoftext|>";
    let use_flash_attn = false;

    let dtype = DType::F32;

    let device = Device::Cpu;

    let config = StableDiffusionConfig::v1_5(None, task.height, task.width);

    println!("clip");

    let (prompt_embeds, ng_prompt_embeds) = {
        let tokenizer = Tokenizer::from_file(task.tokenizer)?;
        let pad_id = tokenizer.token_to_id(&pad_token).ok_or("not found eos")?;

        let clip = build_clip_transformer(&config.clip, task.clip, &device, dtype)?;
        // プロンプトの意味情報
        let prompt_embeds = text_embeddings(
            &device,
            &clip,
            &tokenizer,
            &task.prompt,
            config.clip.max_position_embeddings,
            pad_id,
        )?;
        // ネガティブプロンプトの意味情報
        let ng_prompt_embeds = text_embeddings(
            &device,
            &clip,
            &tokenizer,
            &task.negative_prompt.unwrap_or("".into()),
            config.clip.max_position_embeddings,
            pad_id,
        )?;

        (prompt_embeds, ng_prompt_embeds)
    };

    println!("unet");
    // 3. 拡散モデルの処理
    let latents = {
        let unet = config.build_unet(task.unet, &device, 4, use_flash_attn, dtype)?;

        let mut scheduler = config.build_scheduler(task.steps)?;

        let timesteps = scheduler.timesteps().to_vec();
        // 初期画像(圧縮表現)の用意
        let init_latents = Tensor::randn(
            0f32,
            1f32,
            (1, 4, config.height / 8, config.width / 8),
            &device,
        )?;

        let mut latents = (init_latents * scheduler.init_noise_sigma())?.to_dtype(dtype)?;

        for (i, &timestep) in timesteps.iter().enumerate() {
            let input = scheduler.scale_model_input(latents.clone(), timestep)?;
            // プロンプトの意味情報を使ったノイズ算出
            let noize_pred_text = unet.forward(&input, timestep as f64, &prompt_embeds)?;
            // ネガティブプロンプトの意味情報を使ったノイズ算出
            let noize_pred_ng = unet.forward(&input, timestep as f64, &ng_prompt_embeds)?;
            // ガイダンスの処理
            let noize_pred =
                (((noize_pred_text - &noize_pred_ng)? * task.guidance_scale)? + noize_pred_ng)?;
            // ノイズの除去
            latents = scheduler.step(&noize_pred, timestep, &latents)?;

            println!("  done: {}/{}, timestep={timestep}", i + 1, task.steps);
        }

        latents
    };

    println!("vae");
    // 4. 画像生成
    let imgbuf = {
        let vae = config.build_vae(task.vae, &device, dtype)?;

        let img = {
            let img = vae.decode(&(latents / vae_scale)?)?;

            let img_p = ((img + 1.)? / 2.)?.to_device(&Device::Cpu)?; // [-1.0, 1.0] => [0.0, 1.0]
            (img_p.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)? // [0.0, 1.0] => [0, 255]
        };

        img.squeeze(0)?
            .permute((1, 2, 0))?
            .flatten_all()?
            .to_vec1::<u8>()?
    };
    // 画像ファイル保存
    image::save_buffer(
        task.output_file.unwrap_or("output.png".into()),
        &imgbuf,
        config.width as u32,
        config.height as u32,
        image::ColorType::Rgb8,
    )?;

    Ok(())
}

fn text_embeddings(
    device: &Device,
    clip: &ClipTextTransformer,
    tokenizer: &Tokenizer,
    prompt: &str,
    tokens_len: usize,
    pad_id: u32,
) -> Result<Tensor> {
    // 1. トークン化
    let mut tokens = tokenizer.encode(prompt, true)?.get_ids().to_vec();

    if tokens.len() > tokens_len {
        println!(
            "[WARN] too long prompt: current={}, max={}",
            tokens.len(),
            tokens_len
        );
        tokens.truncate(tokens_len);
    } else {
        while tokens.len() < tokens_len {
            tokens.push(pad_id);
        }
    }

    let tokens_t = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
    // 2. 意味情報の抽出
    clip.forward(&tokens_t).map_err(|e| e.into())
}
Cargo.tomlの依存定義
[dependencies]
candle-core = "0.8.2"
candle-nn = "0.8.2"
candle-transformers = "0.8.2"
image = "0.25.5"
serde = "1.0.217"
serde_json = "1.0.138"
tokenizers = "0.21.0"

動作確認

次の内容で実行します。

test1.json
{
    "tokenizer": "../model/tokenizer/tokenizer.json",
    "clip": "../model/text_encoder/model.safetensors",
    "unet": "../model/unet/diffusion_pytorch_model.safetensors",
    "vae": "../model/vae/diffusion_pytorch_model.safetensors",
    "guidance_scale": 7.5,
    "prompt": "A rusty robot holding a fire torch in its hand",
    "negative_prompt": "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate",
    "steps": 20
}
実行例
$ cargo run --release test1.json
...省略
clip
unet
  done: 1/20, timestep=951
  done: 2/20, timestep=901
  done: 3/20, timestep=851
  done: 4/20, timestep=801
  done: 5/20, timestep=751
  done: 6/20, timestep=701
  done: 7/20, timestep=651
  done: 8/20, timestep=601
  done: 9/20, timestep=551
  done: 10/20, timestep=501
  done: 11/20, timestep=451
  done: 12/20, timestep=401
  done: 13/20, timestep=351
  done: 14/20, timestep=301
  done: 15/20, timestep=251
  done: 16/20, timestep=201
  done: 17/20, timestep=151
  done: 18/20, timestep=101
  done: 19/20, timestep=51
  done: 20/20, timestep=1
vae

このような画像が生成されました。

output.png

  1. stable-diffusion-v1-5 のモデルには tokenizer.json が含まれていないのでこちらを使います

  2. Stable Diffusion XL は Tokenizer と CLIP エンコーダが 2種類に増えますが、処理の流れは同じでした

  3. Stable Diffusion 1.5 ベースの他のモデルではスケジューラを変更する必要があるかもしれません。例えば dreamshaper-8-lcm は DDPMScheduler の方が合ってました

  4. これにより、メモリ8GBの古いノートPC(Linux)でコンテナ実行しても一応は動作しました

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?