0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【2024年版】拡散モデルの仕組みを理解してDALL-E 2とStable Diffusionを実装してみた

Posted at

はじめに

2022年後半のAI画像生成ブームから約2年。当初は「すげー」で終わっていた人も多いと思いますが、実際に中身を理解して実装したことありますか?

この記事では、DALL-E 2とStable Diffusionのコア技術である拡散モデルを実装レベルで解説し、両者のアーキテクチャの違いを比較します。「なんとなく理解した気になってる」エンジニアから一歩抜け出したい方向けです。

拡散モデルとは:逆問題を解く天才的発想

従来のGANとの決定的な違い

# GANの場合(Generator)
def generate_image(noise):
    return generator(noise)  # ノイズから直接画像を生成

# 拡散モデルの場合
def generate_image(noise, prompt, steps=50):
    x = noise
    for t in reversed(range(steps)):
        predicted_noise = model(x, t, prompt)
        x = denoise_step(x, predicted_noise, t)  # 段階的にノイズ除去
    return x

拡散モデルの革新性は「逆問題として画像生成を定式化した」ことにあります。

Forward Process(順拡散過程)

数式で表すと:

q(x_t | x_{t-1}) = N(x_t; √(1-β_t) x_{t-1}, β_t I)

実装では:

def add_noise(x0, t, noise_schedule):
    """画像にノイズを段階的に追加"""
    alpha_t = noise_schedule.alpha_cumprod[t]
    noise = torch.randn_like(x0)
    return torch.sqrt(alpha_t) * x0 + torch.sqrt(1 - alpha_t) * noise, noise

Reverse Process(逆拡散過程)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # U-NetアーキテクチャでノイズPredictorを実装
        
    def forward(self, x_t, t, condition=None):
        # タイムステップとCondition(テキスト)を埋め込み
        t_emb = self.time_embedding(t)
        if condition is not None:
            c_emb = self.text_encoder(condition)
            t_emb = t_emb + c_emb
        
        # U-Netでノイズを予測
        predicted_noise = self.unet(x_t, t_emb)
        return predicted_noise

Latent Diffusion:計算量をハックする技術

Stable Diffusionの最大の革新は「Latent Space」での処理です。

VAE(Variational Autoencoder)による次元圧縮

class VAE(nn.Module):
    def encode(self, x):
        # 512x512 → 64x64に圧縮(8倍圧縮)
        return self.encoder(x)
    
    def decode(self, z):
        # 64x64 → 512x512に復元
        return self.decoder(z)

# 計算量比較
# ピクセル空間: 512×512×3 = 786,432次元
# Latent空間: 64×64×4 = 16,384次元(約1/48)

この圧縮により、メモリ使用量と計算時間が劇的に改善されます。

DALL-E 2 vs Stable Diffusion:アーキテクチャ比較

DALL-E 2のアーキテクチャ

class DALLE2Pipeline:
    def __init__(self):
        self.clip_text_encoder = CLIPTextEncoder()
        self.prior = Prior()  # テキスト→画像埋め込み変換
        self.decoder = Decoder()  # CLIP画像埋め込み→画像
        
    def generate(self, text):
        text_emb = self.clip_text_encoder(text)
        image_emb = self.prior(text_emb)  # 2段階目
        image = self.decoder(image_emb)   # 3段階目
        return image

特徴:

  • 2段階生成(Prior + Decoder)
  • CLIP埋め込み空間を中間表現として利用
  • 高品質だが計算コスト大

Stable Diffusionのアーキテクチャ

class StableDiffusionPipeline:
    def __init__(self):
        self.text_encoder = CLIPTextEncoder()
        self.unet = UNet()  # ノイズ予測器
        self.vae = VAE()    # エンコーダー・デコーダー
        
    def generate(self, text):
        text_emb = self.text_encoder(text)
        latent = torch.randn(1, 4, 64, 64)  # ランダムノイズ
        
        for t in self.scheduler.timesteps:
            noise_pred = self.unet(latent, t, text_emb)
            latent = self.scheduler.step(noise_pred, t, latent)
            
        image = self.vae.decode(latent)
        return image

特徴:

  • 1段階生成(Latent Diffusion)
  • VAEによる効率化
  • オープンソース

実装で学ぶ:ミニ拡散モデル

import torch
import torch.nn as nn
from torchvision import transforms

class SimpleDiffusion:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps
        # ノイズスケジュール(線形)
        self.betas = torch.linspace(0.0001, 0.02, timesteps)
        self.alphas = 1. - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        
    def add_noise(self, x0, t):
        """Forward process"""
        noise = torch.randn_like(x0)
        alpha_t = self.alpha_cumprod[t]
        noisy_image = torch.sqrt(alpha_t) * x0 + torch.sqrt(1 - alpha_t) * noise
        return noisy_image, noise
    
    def denoise_step(self, x_t, predicted_noise, t):
        """Reverse process(1ステップ)"""
        alpha_t = self.alpha_cumprod[t]
        alpha_t_prev = self.alpha_cumprod[t-1] if t > 0 else torch.tensor(1.0)
        
        # DDPM sampling
        x_prev = (x_t - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
        if t > 0:
            noise = torch.randn_like(x_t)
            x_prev = x_prev + torch.sqrt(1 - alpha_t_prev) * noise
        
        return x_prev

# 訓練ループ
def train_step(model, x0, diffusion):
    t = torch.randint(0, diffusion.timesteps, (x0.shape[0],))
    x_t, noise = diffusion.add_noise(x0, t)
    
    predicted_noise = model(x_t, t)
    loss = nn.MSELoss()(predicted_noise, noise)
    return loss

パフォーマンス比較

指標 DALL-E 2 Stable Diffusion
生成時間 ~30秒 ~5秒
VRAM使用量 10GB+ 4GB
画質 高品質 高品質(調整可能)
カスタマイズ性 極高
商用利用 制限あり オープン

実践:Stable Diffusionのファインチューニング

from diffusers import StableDiffusionPipeline
import torch

# ベースモデル読み込み
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True
)

# LoRAでファインチューニング
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,  # rank
    lora_alpha=32,
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    lora_dropout=0.1,
)

# U-NetにLoRA適用
pipe.unet = get_peft_model(pipe.unet, lora_config)

オープンソースの威力:コミュニティ拡張

Stable Diffusionがオープンソース化された結果:

主要な拡張機能

  • ControlNet: 構図制御
  • InPainting: 部分修正
  • LoRA: 軽量ファインチューニング
  • Textual Inversion: 概念学習

数字で見るインパクト

# Hugging Face上のStable Diffusion関連モデル数
$ curl -s "https://huggingface.co/api/models?filter=stable-diffusion" | jq '. | length'
15000+  # 2024年時点

# GitHub上のStars(AUTOMATIC1111/stable-diffusion-webui)
156,000+ stars

まとめ:エンジニアとして押さえるべきポイント

  1. 拡散モデルは逆問題として画像生成を定式化した革新的手法
  2. Latent Spaceでの処理が実用化の鍵
  3. DALL-E 2とStable Diffusionは異なる哲学(クローズド vs オープン)
  4. オープンソース化がイノベーションを爆発的に加速

次のステップ


参考リンク:

この記事が拡散モデルの理解に役立てば幸いです!質問があればコメントください 🚀

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?