前回 は Candle でローカルLLMを実行しましたが、今回は Stable Diffusion 1.5 を使った画像生成を行います。
モデルの取得
今回も Hugging Face から必要なファイルを手動でダウンロードします。
まずは、https://huggingface.co/openai/clip-vit-base-patch32/tree/main 1から以下のファイルをダウンロードします。
- tokenizer.json
次に、https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main から以下のファイルをダウンロードします。
- text_encoder/model.safetensors
- unet/diffusion_pytorch_model.safetensors
- vae/diffusion_pytorch_model.safetensors
実装
Stable Diffusion 1.5 で画像生成する流れは概ねこのようになります。2
手順 | 名称 | 処理内容 |
---|---|---|
1 | Tokenizer | Tokenizer でプロンプトやネガティブプロンプトをトークン化 |
2 | CLIP | CLIPエンコーダでトークン列から意味情報を抽出 |
3 | UNET | 拡散モデルで画像の圧縮表現を生成。現在の画像(圧縮表現)と意味情報から UNET でノイズを算出し、スケジューラでノイズを除去する処理を繰り返す |
4 | VAE | VAEデコーダで画像の圧縮表現から実際の画像を生成 |
プロンプトとネガティブプロンプトをまとめて処理する方法も考えられますが、ここでは分かり易さのために別々に処理します。
スケジューラによってノイズ除去の結果が変わりますが、Candle における Stable Diffusion 1.5 のデフォルトスケジューラは DDIM(Denoising Diffusion Implicit Model)のようです。3
画像の質を上げるためにガイダンスを使って意味情報を強調するようにしていますが、ガイダンスを適用しない場合は次のようにプロンプトで算出したノイズをそのままノイズ除去に使います。
// プロンプトの意味情報を使ったノイズ算出
let noize_pred_text = unet.forward(&input, timestep as f64, &prompt_embeds)?;
// ノイズの除去
latents = scheduler.step(&noize_pred_text, timestep, &latents)?;
VAE でデコードした結果は、(rgb x 高さ x 幅)
の並びで概ね -1.0 から 1.0 の値になるようなので、これを (高さ x 幅 x rgb)
の並びで 0 から 255 の値へ変換し、画像ファイルとして保存しています。
また、全体的にそれなりのメモリを消費するので、スコープを活用して処理毎にリソースを解放するようにしています。4
use candle_core::{DType, Device, Module, Tensor};
use candle_transformers::models::stable_diffusion::{
build_clip_transformer, clip::ClipTextTransformer, StableDiffusionConfig,
};
use serde::{Deserialize, Serialize};
use std::env;
use tokenizers::{Result, Tokenizer};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Task {
tokenizer: String,
clip: String,
unet: String,
vae: String,
prompt: String,
negative_prompt: Option<String>,
steps: usize,
guidance_scale: f64,
output_file: Option<String>,
width: Option<usize>,
height: Option<usize>,
}
fn main() -> Result<()> {
let task: Task = {
let file = env::args().nth(1).ok_or("task file")?;
serde_json::from_str(&std::fs::read_to_string(file)?)?
};
let vae_scale = 0.18215;
let pad_token = "<|endoftext|>";
let use_flash_attn = false;
let dtype = DType::F32;
let device = Device::Cpu;
let config = StableDiffusionConfig::v1_5(None, task.height, task.width);
println!("clip");
let (prompt_embeds, ng_prompt_embeds) = {
let tokenizer = Tokenizer::from_file(task.tokenizer)?;
let pad_id = tokenizer.token_to_id(&pad_token).ok_or("not found eos")?;
let clip = build_clip_transformer(&config.clip, task.clip, &device, dtype)?;
// プロンプトの意味情報
let prompt_embeds = text_embeddings(
&device,
&clip,
&tokenizer,
&task.prompt,
config.clip.max_position_embeddings,
pad_id,
)?;
// ネガティブプロンプトの意味情報
let ng_prompt_embeds = text_embeddings(
&device,
&clip,
&tokenizer,
&task.negative_prompt.unwrap_or("".into()),
config.clip.max_position_embeddings,
pad_id,
)?;
(prompt_embeds, ng_prompt_embeds)
};
println!("unet");
// 3. 拡散モデルの処理
let latents = {
let unet = config.build_unet(task.unet, &device, 4, use_flash_attn, dtype)?;
let mut scheduler = config.build_scheduler(task.steps)?;
let timesteps = scheduler.timesteps().to_vec();
// 初期画像(圧縮表現)の用意
let init_latents = Tensor::randn(
0f32,
1f32,
(1, 4, config.height / 8, config.width / 8),
&device,
)?;
let mut latents = (init_latents * scheduler.init_noise_sigma())?.to_dtype(dtype)?;
for (i, ×tep) in timesteps.iter().enumerate() {
let input = scheduler.scale_model_input(latents.clone(), timestep)?;
// プロンプトの意味情報を使ったノイズ算出
let noize_pred_text = unet.forward(&input, timestep as f64, &prompt_embeds)?;
// ネガティブプロンプトの意味情報を使ったノイズ算出
let noize_pred_ng = unet.forward(&input, timestep as f64, &ng_prompt_embeds)?;
// ガイダンスの処理
let noize_pred =
(((noize_pred_text - &noize_pred_ng)? * task.guidance_scale)? + noize_pred_ng)?;
// ノイズの除去
latents = scheduler.step(&noize_pred, timestep, &latents)?;
println!(" done: {}/{}, timestep={timestep}", i + 1, task.steps);
}
latents
};
println!("vae");
// 4. 画像生成
let imgbuf = {
let vae = config.build_vae(task.vae, &device, dtype)?;
let img = {
let img = vae.decode(&(latents / vae_scale)?)?;
let img_p = ((img + 1.)? / 2.)?.to_device(&Device::Cpu)?; // [-1.0, 1.0] => [0.0, 1.0]
(img_p.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)? // [0.0, 1.0] => [0, 255]
};
img.squeeze(0)?
.permute((1, 2, 0))?
.flatten_all()?
.to_vec1::<u8>()?
};
// 画像ファイル保存
image::save_buffer(
task.output_file.unwrap_or("output.png".into()),
&imgbuf,
config.width as u32,
config.height as u32,
image::ColorType::Rgb8,
)?;
Ok(())
}
fn text_embeddings(
device: &Device,
clip: &ClipTextTransformer,
tokenizer: &Tokenizer,
prompt: &str,
tokens_len: usize,
pad_id: u32,
) -> Result<Tensor> {
// 1. トークン化
let mut tokens = tokenizer.encode(prompt, true)?.get_ids().to_vec();
if tokens.len() > tokens_len {
println!(
"[WARN] too long prompt: current={}, max={}",
tokens.len(),
tokens_len
);
tokens.truncate(tokens_len);
} else {
while tokens.len() < tokens_len {
tokens.push(pad_id);
}
}
let tokens_t = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
// 2. 意味情報の抽出
clip.forward(&tokens_t).map_err(|e| e.into())
}
[dependencies]
candle-core = "0.8.2"
candle-nn = "0.8.2"
candle-transformers = "0.8.2"
image = "0.25.5"
serde = "1.0.217"
serde_json = "1.0.138"
tokenizers = "0.21.0"
動作確認
次の内容で実行します。
{
"tokenizer": "../model/tokenizer/tokenizer.json",
"clip": "../model/text_encoder/model.safetensors",
"unet": "../model/unet/diffusion_pytorch_model.safetensors",
"vae": "../model/vae/diffusion_pytorch_model.safetensors",
"guidance_scale": 7.5,
"prompt": "A rusty robot holding a fire torch in its hand",
"negative_prompt": "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate",
"steps": 20
}
$ cargo run --release test1.json
...省略
clip
unet
done: 1/20, timestep=951
done: 2/20, timestep=901
done: 3/20, timestep=851
done: 4/20, timestep=801
done: 5/20, timestep=751
done: 6/20, timestep=701
done: 7/20, timestep=651
done: 8/20, timestep=601
done: 9/20, timestep=551
done: 10/20, timestep=501
done: 11/20, timestep=451
done: 12/20, timestep=401
done: 13/20, timestep=351
done: 14/20, timestep=301
done: 15/20, timestep=251
done: 16/20, timestep=201
done: 17/20, timestep=151
done: 18/20, timestep=101
done: 19/20, timestep=51
done: 20/20, timestep=1
vae
このような画像が生成されました。