1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

拡散モデル入門③、EDMをMNISTで実装してみた(Tensorflow)

Last updated at Posted at 2025-02-02

入門①:DDPMの理論とMNISTの実装
入門②:SDE/ODEの基礎理論(Tensorflow実装付き)
入門③:ここ
入門④:条件付きU-Net(MNIST実装付き)

EDM(Elucidating the Design Space of Diffusion-Based Generative Models)

DIAMONDの論文でEDMの実装を使っているので元の論文を調べてみました。

論文: https://arxiv.org/abs/2206.00364
Github: https://github.com/NVlabs/edm/tree/main

立ち位置?

拡散モデルの流れは知らないのでこのEDMがどういう立ち位置か簡単に調べてみました。
(そこまでちゃんと調べていないので参考程度に…)

手法 概要 論文
2019,2020 SBM データ分布のスコア関数(確率密度関数の勾配)を推定し、それを利用したデータ生成 1,2
2020 DDPM 生成過程を明示的に定義し、逆過程によりデータ生成 1
2021 Improved DDPM 分散の学習やスケジューリングの改善などDDPMを改良 1
2021 Score-Based SDE 拡散プロセスを確率微分方程式(SDE)として定式化した論文 1
2022 LDM 拡散プロセスを低次元の潜在空間(latent space)で実行することで、計算コストを大幅に削減した手法。(Stable Diffusionで使われてる?) 1
2022 EDM 拡散モデルの要素(拡散プロセスなど)の影響を分類・整理し、各手法における性能への影響を明確化した論文 1
2023 DiTs UNetの代わりにTransformersを使い、大幅に精度向上 1

以下の記事によるとNVIDIAのチームが発表した論文らしく、NeurIPS 2022 で優秀論文賞を受賞しているとの事。
https://developer.nvidia.com/blog/generative-ai-research-spotlight-demystifying-diffusion-based-models/

論文の概要

この論文が主張している内容は以下です。

  1. 拡散モデルの理論を実用的な観点でまとめ、各手法の関連や影響を調査した(これによるシステム設計への貢献)
  2. ルンゲ・クッタ法(Runge–Kutta method)適用によるサンプリングプロセスの大幅な改善
  3. トレーニングを改善するためのベストプラクティスの提示

拡散モデルの共通フレームワーク

以下の表の縦軸をまとめたことが最初の主張になるかと思います。

縦軸が設計時に選択できるコンポーネントで、横軸が主要な手法を並べ、具体的に値や方法をのせています。
これらのコンポーネントは独立して変更可能であるそうです。(ただ性能向上には特定のコンポーネントの組み合わせが重要らしい)

sss1.png

本記事では論文で提案された新しい手法"EDM"に焦点をあてます。
他の手法との比較や特徴などの議論は論文を参照してください。

またサンプリングで使う Time steps を生成するコードを以下に書いておきます。

def create_timesptes(N: int, sigma_min=0.002, sigma_max=80, rho=7):
    timesteps = []
    for i in range(N):
        t = (sigma_max ** (1 / rho) + i / (N - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
        timesteps.append(t)
    timesteps += [0]  # N=0
    return timesteps

以下みたいになります。

Figure_1.png

ノイズ除去モデルの学習

ノイズ除去モデル(Denoiser) $D(x,\sigma)$ はノイズ入り画像 $x$ を入力するとノイズを除去した後の画像を出力するモデルです。

sss2.png

図は論文より、Denoiserの理想的な出力結果で、分散 $\sigma$ が大きいほど元画像の平均値が出力されます。

Denoiserですが、ニューラルネットワークで直接学習することはあまり適切ではないそうです。
これは入力値 $x$ が元画像 $y$ とガウスノイズ $n$ の組み合わせのため、入力値が $\sigma$ の値によって大きく変動するためです。
なので一般的には代わりのネットワーク $F_{\theta}$ を用いて以下のように表現し、$F_{\theta}$ を学習します。

sss3.png

$c_{skip}(\sigma)$ はスキップ接続の調整、$c_{in}(\sigma)$と$c_{out}(\sigma)$は入出力値の調整、$c_{noise}(\sigma)$はノイズレベルの調整をします。
損失は以下です。

sss4.png

学習時のノイズレベル $\sigma$ がサンプリングされる確率を $p_{train}(\sigma)$ とした場合、対応する重みが $\lambda(\sigma)$ となります。

EDM版のコード例は以下です。

sigma_data = 0.5
f_model = ニューラルネットワーク

def calc_c(sigma):
    c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
    c_out = sigma * sigma_data / np.sqrt(sigma**2 + sigma_data**2)
    c_in = 1 / np.sqrt(sigma_data**2 + sigma**2)
    c_noise = np.log(sigma) / 4
    return c_skip, c_out, c_in, c_noise

def model_forward(x, sigma):
    c_skip, c_out, c_in, c_noise = calc_c(sigma)
    dx = c_skip * x + c_out * f_model(c_in * x, c_noise)
    return dx

def model_compute_loss(y, p_mean = -1.2, p_std = 1.2):
    # Noise distibution
    r = np.random.randn(y.shape[0], 1, 1, 1)
    sigma = np.exp(p_mean + r * p_std)
    c_skip, c_out, c_in, c_noise = calc_c(sigma)

    # λ(σ)
    lambda_w = (sigma**2 + sigma_data**2) / ((sigma * sigma_data) ** 2)

    # ノイズ
    n = np.random.randn(*y.shape) * sigma

    # 損失の計算
    network_output = f_model(c_in * (y + n), c_noise)
    training_target = (y - c_skip * (y + n)) / c_out
    loss = lambda_w * (c_out**2) * ((network_output - training_target) ** 2)

    lossでモデルを更新

決定的サンプリング

論文に記載があったアルゴリズム1を実装してみました。

sss_alg1.png

コード例は以下です。(メインじゃない部分は削っています、動くコードは最後に)
ただ上手く生成できず、何か間違っているかもしれません…。
基本こちらは使わないので参考程度に

def generate_deterministic(size: int, N: int, sigma_min=0.002, sigma_max=80, rho=7):
    # Time steps
    timesteps = create_timesptes(N, sigma_min, sigma_max, rho)

    # 初期サンプル
    sigma_0 = timesteps[0]  # Schedule
    s_0 = 1  # Scaling
    var = (sigma_0**2) * (s_0**2)
    x = np.random.normal(0, var, size=(size,) + img_shape)

    for i in range(N - 1):
        # 各変数
        t = timesteps[i]
        t_next = timesteps[i + 1]
        sigma_t = timesteps[i]  # Schedule
        sigma_dt = timesteps[i] - timesteps[i+1]
        s_t = 1  # Scaling
        s_dt = 1 - 1

        denoise = model_forward(x / s_t, sigma_t)

        coeff = (sigma_dt / sigma_t + s_dt / s_t)
        d = coeff * x - (sigma_dt * s_t / sigma_t) * denoise
        x_next = x + (t_next - t) * d

        if t_next != 0:
            sigma_t_next = timesteps[i + 1]
            sigma_dt_next = timesteps[i+1] - timesteps[i + 2]
            s_t_next = 1
            s_dt_next = 1 - 1

            denoise = model_forward(x_next / s_t_next, sigma_t_next)

            coeff = (sigma_dt_next / sigma_t_next + s_dt_next / s_t_next)
            d_dash = coeff * x_next - (sigma_dt_next * s_t_next / sigma_t_next) *  denoise
            x_next = x + (t_next - t) * (d / 2 + d_dash / 2)
        
        x = x_next
    return x

確率的サンプリング

同じく論文に記載があったアルゴリズム2の実装です。
こちらはちゃんと生成できました。

sss_alg2.png

def generate_stochastic(
    size: int,
    N: int,
    sigma_min=0.002,
    sigma_max=80,
    rho=7,
    s_churn=0,
    s_min=0,
    s_max=float("inf"),
    s_noise=1,
):
    # Time steps
    timesteps = create_timesptes(N, sigma_min, sigma_max, rho)

    # 初期サンプル
    var = timesteps[0]
    x = np.random.normal(0, var, size=(size,) + img_shape)

    for i in range(N - 1):
        t = timesteps[i]
        t_next = timesteps[i + 1]

        gamma = min(s_churn / N, np.sqrt(2) - 1) if s_min <= t <= s_max else 0
        t_hat = t + gamma * t
        e = np.random.normal(0, s_noise, size=(size,) + img_shape)
        x_hat = x + np.sqrt(t_hat**2 - t**2) * e

        denoise = model_forward(x_hat, t_hat)

        d = (x_hat - denoise) / t_hat
        x_next = x_hat + (t_next - t_hat) * d
        if t_next != 0:
            denoise = model_forward(x_next, t_next)
            d_dash = (x_next - denoise) / t_next
            x_next = x_hat + (t_next - t_hat) * (d / 2 + d_dash / 2)
        x = x_next

    return x

結果

・学習過程

Epoch 1/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:31<00:00, 14.66it/s, loss=0.258]
Epoch 2/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:27<00:00, 16.97it/s, loss=0.188] 
Epoch 3/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:27<00:00, 17.00it/s, loss=0.197] 
Epoch 4/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:27<00:00, 17.11it/s, loss=0.176] 
Epoch 5/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:27<00:00, 17.00it/s, loss=0.174] 

・生成結果

Figure_2.png

・生成過程

Figure_3.png

ノイズもなく綺麗な画像が生成されていますね。

全体コード

モデルは今まで通りUNetを使い、位置エンコーディングはサインコサインエンコードを使っています。

from pathlib import Path

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras
from tqdm import tqdm

kl = keras.layers

img_shape = (28, 28, 1)


def create_dataset():
    (x_train, y_train), (_, _) = keras.datasets.mnist.load_data()
    # x_train = x_train[y_train == 1]
    x_train = (x_train[..., np.newaxis] / 255.0) * 2 - 1  # [0,255] -> [-1,1]
    x_train = x_train.astype(np.float32)
    return x_train


def decode_image(x):
    img = np.clip(x, -1.0, 1.0)
    # [-1,1] -> [0,255]
    img = (((img + 1) / 2) * 255).astype(np.uint8)
    return img


# ----------------------------------
# model
# ----------------------------------
class PositionalEmbedding(kl.Layer):
    def __init__(self, embedding_dim: int, max_position: int, **kwargs):
        """
        サインコサイン位置埋め込みを計算するKerasレイヤー
        :param embedding_dim: 埋め込み次元数
        :param max_position: 最大の位置数
        """
        super().__init__(**kwargs)

        # 位置エンコーディングの計算
        positions = np.arange(max_position)[:, np.newaxis]  # (max_position, 1)
        dims = np.arange(embedding_dim)[np.newaxis, :]  # (1, embedding_dim)

        # サイン・コサイン関数で位置埋め込みを計算
        angle_rates = 1 / np.power(10000, (2 * (dims // 2)) / embedding_dim)
        angle_rads = positions * angle_rates

        # 偶数インデックスはsin、奇数インデックスはcos
        positional_encoding = np.zeros_like(angle_rads)
        positional_encoding[:, 0::2] = np.sin(angle_rads[:, 0::2])
        positional_encoding[:, 1::2] = np.cos(angle_rads[:, 1::2])

        self.positional_encoding = tf.constant(positional_encoding, dtype=tf.float32)

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        indices = tf.cast(tf.reshape(inputs, [-1]), tf.int32)
        return tf.gather(self.positional_encoding, indices)

    @staticmethod
    def plot(embedding_dim: int = 500, max_position: int = 1000):  # for debug
        pos_emb = PositionalEmbedding(embedding_dim, max_position)
        step = tf.constant(np.arange(0, max_position)[..., np.newaxis])
        emb = pos_emb(step).numpy()
        plt.pcolormesh(emb.T, cmap="RdBu")
        plt.ylabel("dimension")
        plt.xlabel("time step")
        plt.colorbar()
        plt.show()


def build_unet(img_shape):
    x_t = kl.Input(shape=img_shape)
    sigma = keras.Input(shape=(1, 1, 1))

    # --- ノイズ情報の埋め込み
    # (batch, 1) -> (batch, dim) -> (batch, 1, 1, dim) -> (batch, 28, 28, dim)
    t_embedding = PositionalEmbedding(128, 10000)(sigma)
    t_embedding = kl.Dense(128, activation="gelu")(t_embedding)
    t_embedding = kl.Reshape((1, 1, 128))(t_embedding)
    t_embedding = kl.UpSampling2D(img_shape[:2])(t_embedding)

    # 埋め込み情報をチャンネルに追加
    x = kl.Concatenate()([x_t, t_embedding])

    # --- down sampling
    c1 = kl.Conv2D(64, (3, 3), padding="same", activation="relu")(x)
    c1 = kl.Conv2D(64, (3, 3), padding="same", activation="relu")(c1)
    p1 = kl.MaxPooling2D((2, 2))(c1)  # 28x28 -> 14x14
    c2 = kl.Conv2D(128, (3, 3), padding="same", activation="relu")(p1)
    c2 = kl.Conv2D(128, (3, 3), padding="same", activation="relu")(c2)
    p2 = kl.MaxPooling2D((2, 2))(c2)  # 14x14 -> 7x7

    # --- ボトム
    p2 = kl.Conv2D(256, (3, 3), activation="relu", padding="same")(p2)

    # --- up sampling
    u1 = kl.UpSampling2D((2, 2))(p2)  # 7x7 -> 14x14
    u1 = kl.Concatenate()([u1, c2])
    u1 = kl.Conv2D(128, (3, 3), activation="relu", padding="same")(u1)
    u1 = kl.Conv2D(128, (3, 3), activation="relu", padding="same")(u1)
    u2 = kl.UpSampling2D((2, 2))(u1)  # 14x14 -> 28x28
    u2 = kl.Concatenate()([u2, c1])
    u2 = kl.Conv2D(64, (3, 3), activation="relu", padding="same")(u2)
    u2 = kl.Conv2D(64, (3, 3), activation="relu", padding="same")(u2)

    y = kl.Conv2D(1, (1, 1), padding="same")(u2)

    model = keras.Model(inputs=[x_t, sigma], outputs=y, name="u_net")
    return model


class EDM:
    def __init__(self, img_shape, sigma_data=0.5):
        self.sigma_data = sigma_data
        self.model = build_unet(img_shape)
        self.optimizer = keras.optimizers.Adam(learning_rate=0.0005)

    def calc_c(self, sigma):
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / tf.sqrt(sigma**2 + self.sigma_data**2)
        c_in = 1 / tf.sqrt(self.sigma_data**2 + sigma**2)
        c_noise = tf.math.log(sigma) / 4
        return c_skip, c_out, c_in, c_noise

    def call(self, x, sigma, training=False):
        c_skip, c_out, c_in, c_noise = self.calc_c(sigma)
        fx = self.model([c_in * x, c_noise], training=training)
        dx = c_skip * x + c_out * fx
        return dx

    def train(self, y, p_mean=-1.2, p_std=1.2):
        # Noise distibution: ln(σ) ~ N(Pmena, Pstd^2)
        r = np.random.randn(y.shape[0], 1, 1, 1).astype(np.float32)
        sigma = np.exp(p_mean + r * p_std)
        c_skip, c_out, c_in, c_noise = self.calc_c(sigma)

        # λ(σ): (σ^2+σd^2) / (σ*σd)^2
        lambda_w = (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)

        # ノイズ
        n = np.random.randn(*y.shape) * sigma

        # tensorflowに取り込める形に変換
        n = tf.convert_to_tensor(n, dtype=tf.float32)
        sigma = tf.convert_to_tensor(sigma.reshape((-1, 1, 1, 1)), dtype=tf.float32)

        with tf.GradientTape() as tape:
            output = self.model([c_in * (y + n), c_noise])
            target = (y - c_skip * (y + n)) / c_out
            loss = lambda_w * (c_out**2) * ((output - target) ** 2)
            loss = tf.reduce_mean(loss)
        grad = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grad, self.model.trainable_variables))

        return loss.numpy()

    def generate_deterministic(self, size: int, N: int, sigma_min=0.002, sigma_max=80, rho=7):
        # Time steps
        timesteps = self.create_timesptes(N, sigma_min, sigma_max, rho)

        samples_history = []

        # 初期サンプル
        sigma_0 = timesteps[0]  # Schedule
        s_0 = 1  # Scaling
        var = (sigma_0**2) * (s_0**2)
        x = np.random.normal(0, var, size=(size,) + img_shape).astype(np.float32)

        x = tf.convert_to_tensor(x, dtype=tf.float32)
        for i in tqdm(range(N - 1), desc="sampling loop"):
            t = timesteps[i]
            t_next = timesteps[i + 1]

            sigma_t = timesteps[i]  # Schedule: 論文の表より t
            sigma_dt = timesteps[i] - timesteps[i + 1]
            s_t = 1  # Scaling: 論文の表より t
            s_dt = 1 - 1

            # tf
            sigma_t = tf.convert_to_tensor([sigma_t] * size, dtype=tf.float32)
            sigma_t = sigma_t[..., tf.newaxis, tf.newaxis, tf.newaxis]
            denoise = self.call(x / s_t, sigma_t)

            coeff = sigma_dt / sigma_t + s_dt / s_t
            d = coeff * x - (sigma_dt * s_t / sigma_t) * denoise
            x_next = x + (t_next - t) * d

            if t_next != 0:
                sigma_t_next = timesteps[i + 1]
                sigma_dt_next = timesteps[i + 1] - timesteps[i + 2]
                s_t_next = 1
                s_dt_next = 1 - 1

                # tf
                sigma_t_next = tf.convert_to_tensor([sigma_t_next] * size, dtype=tf.float32)
                sigma_t_next = sigma_t_next[..., tf.newaxis, tf.newaxis, tf.newaxis]
                denoise = self.call(x_next / s_t_next, sigma_t_next)

                coeff = sigma_dt_next / sigma_t_next + s_dt_next / s_t_next
                d_dash = coeff * x_next - (sigma_dt_next * s_t_next / sigma_t_next) * denoise
                x_next = x + (t_next - t) * (d / 2 + d_dash / 2)

            samples_history.append(x_next)
            x = x_next

        return x, samples_history

    def generate_stochastic(
        self,
        size: int,
        N: int,
        sigma_min=0.002,
        sigma_max=80,
        rho=7,
        s_churn=0,
        s_min=0,
        s_max=float("inf"),
        s_noise=1,
    ):
        # Time steps
        timesteps = self.create_timesptes(N, sigma_min, sigma_max, rho)

        samples_history = []

        # 初期サンプル
        var = timesteps[0]
        x = np.random.normal(0, var, size=(size,) + img_shape).astype(np.float32)

        x = tf.convert_to_tensor(x, dtype=tf.float32)
        for i in tqdm(range(N - 1), desc="sampling loop"):
            t = timesteps[i]
            t_next = timesteps[i + 1]

            gamma = min(s_churn / N, np.sqrt(2) - 1) if s_min <= t <= s_max else 0
            t_hat = t + gamma * t
            e = np.random.normal(0, s_noise, size=(size,) + img_shape).astype(np.float32)
            x_hat = x + np.sqrt(t_hat**2 - t**2) * e

            # tf
            t_hat = tf.convert_to_tensor([t_hat] * size, dtype=tf.float32)
            t_hat = t_hat[..., tf.newaxis, tf.newaxis, tf.newaxis]
            denoise = self.call(x_hat, t_hat)

            d = (x_hat - denoise) / t_hat
            x_next = x_hat + (t_next - t_hat) * d
            if t_next != 0:
                # tf
                t_next = tf.convert_to_tensor([t_next] * size, dtype=tf.float32)
                t_next = t_next[..., tf.newaxis, tf.newaxis, tf.newaxis]
                denoise = self.call(x_next, t_next)

                d_dash = (x_next - denoise) / t_next
                x_next = x_hat + (t_next - t_hat) * (d / 2 + d_dash / 2)

            samples_history.append(x_next)
            x = x_next

        return x, samples_history

    @staticmethod
    def create_timesptes(N: int, sigma_min=0.002, sigma_max=80, rho=7):
        timesteps = []
        for i in range(N):
            t = (sigma_max ** (1 / rho) + i / (N - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
            timesteps.append(t)
        timesteps += [0]  # N=0
        return timesteps


# ----------------------------------
# main
# ----------------------------------
def sample():
    # --- Time steps
    timesptes = EDM.create_timesptes(20)
    plt.figure(figsize=(8, 5))
    plt.plot(range(len(timesptes)), timesptes, marker=".")
    plt.xlabel("i")
    plt.ylabel("Time step")
    plt.title("Visualization of Time steps")
    plt.grid()
    plt.show()

    # --- 位置エンコーディングの可視化
    PositionalEmbedding.plot()

    # --- Model
    model = build_unet(img_shape)
    model.summary()


def train(epochs: int, batch_size: int = 128):
    x_train = create_dataset()

    # モデル
    edm = EDM(img_shape)

    # 学習用にデータをバッチ化
    train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(len(x_train)).batch(batch_size)

    for epoch in range(epochs):
        with tqdm(train_dataset, desc=f"Epoch {epoch + 1}/{epochs}") as pbar:
            for y in pbar:
                loss = edm.train(y)
                pbar.set_postfix(loss=loss)  # 損失を進捗バーに表示
    edm.model.save_weights(Path(__file__).parent / "edm.weights.h5")


def generate(steps, gen_type="stochastic", **gen_kwargs):
    edm = EDM(img_shape)
    edm.model.load_weights(Path(__file__).parent / "edm.weights.h5")

    # 生成
    num_samples = 16
    if gen_type == "stochastic":
        samples, samples_history = edm.generate_stochastic(num_samples, steps, **gen_kwargs)
    else:
        samples, samples_history = edm.generate_deterministic(num_samples, steps, **gen_kwargs)

    samples = decode_image(samples)
    samples_history = decode_image(samples_history)

    # 結果
    plt.figure(figsize=(10, 10))
    for i in range(num_samples):
        plt.subplot(4, 4, i + 1)
        plt.imshow(samples[i, :, :, 0], cmap="gray")
        plt.axis("off")
    plt.show()

    # 作成過程
    index = 4
    img_list = np.array(samples_history)[:, index, :, :, 0]
    plt.figure(figsize=(20, 5))
    step_idxs = list(range(0, len(img_list), int(steps / 12)))  # 多いので一定間隔で抜き出し
    step_idxs += [len(img_list) - 1]  # 最後も追加
    for i, idx in enumerate(step_idxs):
        plt.subplot(1, len(step_idxs), i + 1)
        plt.imshow(img_list[idx], cmap="gray")
        plt.xticks([])
        plt.yticks([])
        plt.xlabel(f"step={idx}")
    plt.show()


if __name__ == "__main__":
    # sample()
    train(epochs=5)
    generate(steps=50, gen_type="stochastic")
    # generate(steps=50, gen_type="deterministic")
    # generate(steps=256, s_churn=40, s_min=0.05, s_max=50,s_noise=1.003)

最後に

ここにたどり着くまで長かった…。
深堀するとかなり難しい内容でした…。
特に日本語の資料が全然なかったのがきつかった。
誰かの参考になれば幸いです。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?