2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

拡散モデル入門①、DDPMの理論とMNISTの実装付き(Tensorflow)

Last updated at Posted at 2025-01-13

いろいろと調べたので備忘録です。

入門①:ここ
入門②:SDE/ODEの基礎理論(Tensorflow実装付き)
入門③:未定

拡散モデルとは

拡散モデル(Diffusion Model)は画像生成AIであるStableDiffusionで有名になった手法です。
StableDiffusionは従来よりもとても高品質な画像を生成する事で話題になりました。

有名になったのが画像生成なので画像生成モデルと思われがちですが、画像以外の生成にも使えます。
分類としては生成モデルの1種類で、他の有名な生成モデルとしては、VAE(AutoEncoder)、GAN、少し知名度が下がりますがFlow-based models等があります。

この記事ではDDPMベースの拡散モデルを解説し、実際にMNISTを学習させるまでを見ていきたいと思います。

参考
(論文) Denoising Diffusion Probabilistic Models
(huggingface) The Annotated Diffusion Model
 →ここがこの記事のメインソース
(amazon) 拡散モデル データ生成技術の数理
 →体系的に学びたい方におすすめの本
(Qiita) 拡散モデルの基礎と研究事例: Imagen
 →数式から見たい方の記事
(Watching the AI) DiffusionモデルをPyTorchで実装する① ~ Diffusionモデル実装編
 →Torchの実装と解説ですが、より論文に近い内容です。
(Lil'Log) What are Diffusion Models?

拡散モデルの技術内容

拡散モデルはデータにノイズを加える「順方向プロセス」と、逆にノイズを除去して元のデータを生成する「逆方向プロセス」から成り、ノイズの除去方法を学習する事でノイズから元のデータを生成する手法です。

イメージは以下です。

aa.drawio.png

順方向プロセス

拡散過程とも呼ばれるフェーズで、オリジナルデータから徐々にノイズを追加し、完全なノイズを生成するフェーズです。
これを漸化式で表すと以下です。

x_t = \sqrt{1-\beta_t} \cdot x_{t-1} + \sqrt{\beta_t} \cdot \epsilon_t

時間ステップ $t \in { 0, 1, \dots, T }$ において、オリジナルデータ $x_0$ からそれぞれのステップでノイズ $\epsilon \sim \mathcal{N}(0,I)$ を一定割合で加えたデータが $x_t$ となります。
$\mathcal{N}(0,I)$ は標準正規分布($I$が単位行列を表し1と同じ意味)を表し、ノイズは標準正規分布に従う乱数となります。
また、$\beta_t$がある時刻$t$で追加されるノイズの強度を表し、$0<\beta_1<\beta_2<...<\beta_T<1$ と時刻に従って強くなる定数です。

これを条件付確率分布で表すと以下です。

q(x_t \mid x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I)

$\mathcal{N}(\mu,\sigma^2)$は平均$\mu$、分散$\sigma^2$の正規分布で、元のデータを平均、分散をノイズとしてノイズの入ったデータ$x$を表現しています。
この式はある任意の時刻でも解析的に解くことができ、以下に変形できます。

\begin{align} 
q(x_t \mid x_{0}) &= \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}} x_{0}, (1-\bar{\alpha_t}) I) \\
\alpha_t &:= 1 - \beta_t \\
\bar{\alpha_t} &:= \prod_{s=1}^{t} \alpha_s \\
\end{align} 

この式の証明は参考書籍、拡散モデル データ生成技術の数理の59p「任意時刻の拡散条件付確率の証明」などを見てください。

この任意の時刻のデータを得ることができる式ですが、学習に関係しており、深層学習のような大きなモデルに対して効率的に学習することができます。(各サンプルデータを独立して学習できるので誤差逆伝播と相性がいい)

最後にプログラムで書きやすいように変形しておきます。

q(x_0, t) = \sqrt{\bar{\alpha_t}} \cdot x_{0} + \sqrt{1-\bar{\alpha_t}} \cdot \epsilon_t

・$\beta$のルートに関して
これは正規分布の分散のスケーリングを調整するためについています。
正規分布の分散は$\sigma^2$でスケールパラメータの2乗に比例します。
正規分布の値を適切にスケールするために係数に平方根をつけて調整しています。
またわざわざスケールする理由ですが、拡散モデルでは一定間隔でノイズを加えることが重要なのでスケールが必要となります。

逆方向プロセス

逆拡散過程とも呼ばれ、拡散モデルのメイン部分です。
ここではノイズ化されたデータ$x_T$から元のデータ$x_0$を復元します。

やりたい事は順方向の逆で、$x_t$から$x_{t-1}$を予測すること $p(x_{t-1}|x_t)$ です。
ただこれを計算するにはデータが取りうる全ての分布を知る必要があり難しいです。
ですのでパラメータ$\theta$を用いて $p_{\theta}(x_{t-1}|x_t)$ を近似します。(いつもの深層学習パターンですね)

逆拡散過程を表す条件付確率分布を正規分布と仮定すると以下です。

p_{\theta}(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_{\theta}(x_t, t), \Sigma_{\theta}(x_t, t))

漸化式で書くと以下です。

\begin{align} 
x_{t-1} &= \mu_{\theta}(x_t, t) + \Sigma_{\theta}(x_t, t) z_t \\
z_t &\sim \mathcal{N}(0,I)\\
\end{align} 

正規分布で求める必要のあるパラメータは平均$\mu_{\theta}$と分散$\Sigma_{\theta}$です。
ただ分散に関しては単純化のために固定や定数として扱われることが多いそうです。
DDPMの論文では、$\Sigma_{\theta}(x_t, t)=\sigma^2=\beta_t$ と $\sigma^2=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\bar{\alpha}_t}}\beta_t$ の両方を試したところ結果は変わらなったそうです。

・分散に関して
参考のHuggingFaceにて以下の言及がありました。

This was then later improved in the Improved diffusion models paper, where a neural network also learns the variance of this backwards process, besides the mean.

DDPMの論文では分散は固定でしたが、のちの論文で分散も学習すると精度が向上した、みたいな内容があるようです。

損失関数

学習ですが、順方向プロセス$q$と逆方向プロセス$p_{\theta}$の関係はVAE(変分オートエンコーダー)として見ることができます。
ですのでVAEと同じく各ステップの損失$L_t$はELBOを最大化して求めます。(ELBOについては以前書いたの記事をどうぞ)
また各ステップの損失$L_t$は足しても同じなので$L=L_0+L_1+...+L_T$として学習できます。
最終的には$q$と$p_{\theta}$の正規分布のKLダイバージェンスの最小化(=平均二乗誤差の最小化)として扱うことができます。

ノイズ予測モデルと平均の再パラメータ化

ここがちょっとトリッキーです。
平均の学習ですが、そのまま学習するのではなく、新しくノイズ予測モデル $\epsilon_{\theta}(x_t,t)$ を導入し、これを用いて間接的に学習させます。(平均の再パラメータ化)
ノイズ予測モデル$\epsilon_{\theta}$は$x_t$に含まれるノイズ成分$\epsilon$を直接予測するモデルです。

これを用いると平均は以下になります。(途中計算などは省略しています)

\mu_{\theta}(x_t,t) = \frac{1}{\sqrt{\alpha_t}} \bigg(x_t - \frac{\beta_t}{\sqrt{\bar{\alpha_t}}} \epsilon_{\theta}(x_t,t) \bigg) \\

最終的な損失関数$L_{\theta}$は以下です。

\begin{align} 
L_{\theta} &= \mathbb{E}_q [\parallel \epsilon - \epsilon_{\theta}(\sqrt{\bar{\alpha_t}}x_0 + \sqrt{1-\bar{\alpha_t}} \epsilon,t) \parallel ^2 ] \\
&= \mathbb{E}_q [\parallel \epsilon - \epsilon_{\theta}(x_t,t) \parallel ^2] \\
\end{align} 

ある時刻$t$で発生した純粋なノイズ$\epsilon$とニューラルネットワークから予測されたノイズ$\epsilon_{\theta}(x_t,t)$との差で学習されるシンプルな形になりました。
($x_t$はある時刻$t$でのノイズ入りデータ)

・ノイズ予測モデル?
huggingfaceから解釈した内容ですが、拡散モデルの特性の一つにデータの直接予測だけではなく、ネットワークを介してノイズも予測ができるというものがあります。
これを利用してノイズ同士を直接学習すると、計算がすごく簡略化できると言った内容でした。

推論(サンプリング)

最後にデータの生成手順です。
ノイズ予測モデルが入るので順方向プロセスとは異なります。

  1. 初期化
    ノイズ $x_T \sim \mathcal{N}(0,I)$を生成します。(完全なランダムノイズ)
  2. 逆拡散ループ
    時間 $t=T,T-1,...,1$ に対して以下を実行
    • ノイズ $z \sim \mathcal{N}(0,I)$ をサンプリング(t>1のみ、t=1の場合はz=0)
    • 以下の式でデノイズ
      $x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \Big( x_t- \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha_t}}} \epsilon_{\theta}(x_t,t) \Big) + \Sigma_t z$
  3. 出力
    $x_0$が生成されたデータです。

実装時に考える事

ここからはデータを画像に限定して話します。

ノイズ予測モデル(ニューラルネットワーク)

具体的なニューラルネットワークの話です。
ニューラルネットワークとしては、画像$x_t$とタイムステップ$t$を入力として予想されるノイズを出力するモデルとなります。
(ここでノイズのサイズは入力した画像と同じ大きさ)

入力と出力が画像になるのでモデルの選択肢としては、基本はAutoEncoderベースのアーキテクチャになります。
一般的にはDDPMの論文で採用されていたU-Netが多いようです。
ただDDPMではそのまま採用しているわけではなく、畳み込み層を重み標準化バージョンに置き換えたResNetやAttention層を追加していたりするそうです。(詳細は参考文献のhuggingfaceを見てください)

また、他にもTransformerベースにして精度を上げた Diffusion Transformer(DiT)などU-Net以外のモデルや組み合わせもあるようです。

この記事では単純なU-Netで実装しています。

タイムステップ情報の埋め込み(Position embeddings)

ニューラルネットワークのもう一つの入力であるタイムステップ情報をどう埋め込むかという話になります。

DDPMではTransformerで使われている位置エンコーディング(Positional Encoding)をそのまま使っているようです。
Transformerで使われているサイン・コサインエンコード(Sinusoidal Encoding)は以下の情報を埋め込む方法です。

$PE(t,2i) = \sin(t/10000^{2i/d})$
$PE(t,2i+1) = \cos(t/10000^{2i/d})$

※$i$が次元のインデックス、$d$が埋め込みの次元数

ノイズスケジュールの定義

順方向プロセスで実行するノイズのスケーリング方法($\beta$の決定方法)ですが必ずしも線形である必要はないようです。
線形以外にもコサインでスケジュールした方が性能が良くなったとの論文もあるそうです。
(詳細は参考のHuggingfaceを見てください)

この記事ではHuggingfaceの実装にならって $\beta_1=0.0001$ から $\beta_T=0.02$ の線形増加とします。

MNISTによる実装(Tensorflow)

ポイントのみを書いて、最後にコード全体を載せています。

Version

  • python 3.12.7
  • tensorflow 2.18.0

データの準備

画像データは正規分布ノイズに合わせるために[0,255]→[-1,1]に正規化します。

img_size = 28
img_shape = (img_size, img_size, 1)

# 使うデータはx_trainのみです
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
x_train = (x_train / 255.0) * 2 - 1  # [0,255] -> [-1,1]
x_train = x_train.astype(np.float32)

各定数の定義

事前に計算できる値($\beta$や$\alpha$等)を実装します。

timesteps = 500  # T

# define beta schedule
betas = np.linspace(0.0001, 0.02, timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)  # 累積積

ノイズ関数の定義

順方向プロセスは以下でした。

q(x_0, t) = \sqrt{\bar{\alpha_t}} \cdot x_{0} + \sqrt{1-\bar{\alpha_t}} \cdot \epsilon_t

関数で実装します。
バッチ処理を想定しているので引数のデータのshapeの最初はbatchサイズとなります。
また、学習で使うので、生成に使ったノイズもreturnします。

def add_noise(x_start, t):
    assert x_start.shape[0] == t.shape[0]  # バッチサイズ

    # ^at を作成
    a_cumprod = alphas_cumprod[t]
    a_cumprod = a_cumprod[..., np.newaxis, np.newaxis, np.newaxis]

    # q(x0, t) = sqrt(^at) * x0 + sqrt(1-^at)* e
    noise = np.random.normal(0, 1, size=x_start.shape).astype(np.float32)
    noised_x = np.sqrt(a_cumprod) * x_start + np.sqrt(1 - a_cumprod) * noise
    return noised_x, noise

# --- ノイズ画像をサンプルで描画
def plot_sample(img, noised_img):
    # 正規分布が無限の値をとるので[-1,1]にclip
    noised_img = np.clip(noised_img, -1.0, 1.0)
    # [-1,1] -> [0,255]
    img = (((img + 1) / 2) * 255).astype(np.uint8)
    noised_img = (((noised_img + 1) / 2) * 255).astype(np.uint8)

    fig, axes = plt.subplots(2, len(img), figsize=(12, 5))
    for i in range(len(img)):
        axes[0, i].imshow(img[i], cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(noised_img[i], cmap="gray")
        axes[1, i].axis("off")
    plt.tight_layout()
    plt.show()

time_list = np.array([0, 50, 100, 300])
img = x_train[: len(time_list)]
noised_img, noise = add_noise(img, time_list)
plot_sample(img, noised_img)

Figure_1.png

上が元画像で下がノイズを追加した後の画像になります。

U-Net

U-NetはCNNアーキテクチャの1つで、シンプルながら効果的な結果を出すモデルとして有名です。

unet_architecture.jpg

ここの実装はメインではないので実装コードのみとなり、またそのままだと時間がかかるので少し小さくしています。
またタイムステップ情報レイヤー(PositionalEncoding)については後述してあります。

def build_unet(img_shape: tuple[int, int, int], timesteps: int) -> keras.Model:
    # 入力はノイズ入り画像とタイムステップ
    noised_img = kl.Input(shape=img_shape)
    t = keras.Input(shape=(1,))

    # --- タイムステップ情報を追加
    # (1) -> (dim) -> (1, 1, dim) -> (28, 28, dim)
    t_embedding = PositionalEncoding(timesteps, 128)(t)
    t_embedding = kl.Dense(128, activation="gelu")(t_embedding)
    t_embedding = kl.Reshape((1, 1, 128))(t_embedding)
    t_embedding = kl.UpSampling2D(img_shape[:2])(t_embedding)

    # タイムステップ情報をチャンネルに追加
    x = kl.Concatenate()([noised_img, t_embedding])

    # --- down sampling
    c1 = kl.Conv2D(64, (3, 3), padding="same", activation="relu")(x)
    c1 = kl.Conv2D(64, (3, 3), padding="same", activation="relu")(c1)
    p1 = kl.MaxPooling2D((2, 2))(c1)  # 28x28 -> 14x14
    c2 = kl.Conv2D(128, (3, 3), padding="same", activation="relu")(p1)
    c2 = kl.Conv2D(128, (3, 3), padding="same", activation="relu")(c2)
    p2 = kl.MaxPooling2D((2, 2))(c2)  # 14x14 -> 7x7

    # --- bottom
    p2 = kl.Conv2D(256, (3, 3), activation="relu", padding="same")(p2)

    # --- up sampling
    u1 = kl.UpSampling2D((2, 2))(p2)  # 7x7 -> 14x14
    u1 = kl.Concatenate()([u1, c2])
    u1 = kl.Conv2D(128, (3, 3), activation="relu", padding="same")(u1)
    u1 = kl.Conv2D(128, (3, 3), activation="relu", padding="same")(u1)
    u2 = kl.UpSampling2D((2, 2))(u1)  # 14x14 -> 28x28
    u2 = kl.Concatenate()([u2, c1])
    u2 = kl.Conv2D(64, (3, 3), activation="relu", padding="same")(u2)
    u2 = kl.Conv2D(64, (3, 3), activation="relu", padding="same")(u2)

    # 出力は正規分布のノイズ画像
    y = kl.Conv2D(1, (1, 1), padding="same")(u2)

    model = keras.Model(inputs=[noised_img, t], outputs=y, name="u_net")
    return model

model = build_unet(img_shape, timesteps)
model.summary()
summaryの結果
Model: "u_net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input_2 (InputLayer)           [(None, 1)]          0           []

 positional_encoding (Positiona  (None, 128)         0           ['input_2[0][0]']
 lEncoding)

 dense (Dense)                  (None, 128)          16512       ['positional_encoding[0][0]']

 reshape (Reshape)              (None, 1, 1, 128)    0           ['dense[0][0]']

 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []

 up_sampling2d (UpSampling2D)   (None, 28, 28, 128)  0           ['reshape[0][0]']

 concatenate (Concatenate)      (None, 28, 28, 129)  0           ['input_1[0][0]',
                                                                  'up_sampling2d[0][0]']

 conv2d (Conv2D)                (None, 28, 28, 64)   74368       ['concatenate[0][0]']

 conv2d_1 (Conv2D)              (None, 28, 28, 64)   36928       ['conv2d[0][0]']

 max_pooling2d (MaxPooling2D)   (None, 14, 14, 64)   0           ['conv2d_1[0][0]']

 conv2d_2 (Conv2D)              (None, 14, 14, 128)  73856       ['max_pooling2d[0][0]']

 conv2d_3 (Conv2D)              (None, 14, 14, 128)  147584      ['conv2d_2[0][0]']

 max_pooling2d_1 (MaxPooling2D)  (None, 7, 7, 128)   0           ['conv2d_3[0][0]']

 conv2d_4 (Conv2D)              (None, 7, 7, 256)    295168      ['max_pooling2d_1[0][0]']

 up_sampling2d_1 (UpSampling2D)  (None, 14, 14, 256)  0          ['conv2d_4[0][0]']

 concatenate_1 (Concatenate)    (None, 14, 14, 384)  0           ['up_sampling2d_1[0][0]',
                                                                  'conv2d_3[0][0]']

 conv2d_5 (Conv2D)              (None, 14, 14, 128)  442496      ['concatenate_1[0][0]']

 conv2d_6 (Conv2D)              (None, 14, 14, 128)  147584      ['conv2d_5[0][0]']

 up_sampling2d_2 (UpSampling2D)  (None, 28, 28, 128)  0          ['conv2d_6[0][0]']

 concatenate_2 (Concatenate)    (None, 28, 28, 192)  0           ['up_sampling2d_2[0][0]',
                                                                  'conv2d_1[0][0]']

 conv2d_7 (Conv2D)              (None, 28, 28, 64)   110656      ['concatenate_2[0][0]']

 conv2d_8 (Conv2D)              (None, 28, 28, 64)   36928       ['conv2d_7[0][0]']

 conv2d_9 (Conv2D)              (None, 28, 28, 1)    65          ['conv2d_8[0][0]']

==================================================================================================
Total params: 1,382,145
Trainable params: 1,382,145
Non-trainable params: 0

タイムステップ情報レイヤー

class PositionalEncoding(kl.Layer):
    def __init__(self, max_position: int, embedding_dim: int, **kwargs):
        """
        サインコサイン位置埋め込みを計算するKerasレイヤー
        :param max_position: 最大の位置数
        :param embedding_dim: 埋め込み次元数
        """
        super().__init__(**kwargs)

        # 位置エンコーディングの計算
        positions = np.arange(max_position)[:, np.newaxis]  # (max_position, 1)
        dims = np.arange(embedding_dim)[np.newaxis, :]  # (1, embedding_dim)

        # サイン・コサイン関数で位置埋め込みを計算
        angle_rates = 1 / np.power(10000, (2 * (dims // 2)) / embedding_dim)
        angle_rads = positions * angle_rates

        # 偶数インデックスはsin、奇数インデックスはcos
        positional_encoding = np.zeros_like(angle_rads)
        positional_encoding[:, 0::2] = np.sin(angle_rads[:, 0::2])
        positional_encoding[:, 1::2] = np.cos(angle_rads[:, 1::2])

        self.positional_encoding = tf.constant(positional_encoding, dtype=tf.float32)

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        indices = tf.cast(tf.reshape(inputs, [-1]), tf.int32)
        return tf.gather(self.positional_encoding, indices)

    @staticmethod
    def plot(max_position: int = 1000, embedding_dim: int = 500):  # for debug
        pos_emb = PositionalEncoding(max_position, embedding_dim)
        step = tf.constant(np.arange(0, max_position)[..., np.newaxis])
        emb = pos_emb(step).numpy()
        plt.pcolormesh(emb.T, cmap="RdBu")
        plt.ylabel("dimension")
        plt.xlabel("time step")
        plt.colorbar()
        plt.show()

# 可視化
PositionalEncoding.plot()

Figure_2.png

今回sin,cosを交互に追加する形で実装しましたがsinとcosを積み重ねる実装でも問題ないようです。
https://stackoverflow.com/questions/75995195/is-tensorflow-positional-encoding-wrong

タイムステップ情報の埋め込み

埋め込む方法ですが、あまり明示されている内容が見つからず…、以下の種類がありそうです。

  • 追加方法
    • 画像情報に直接加算(huggingface実装)
    • チャンネル情報として追加
  • 追加場所
    • 一番最初に追加
    • 中間層含めすべてに追加(huggingface実装)
    • U-Netのボトム層に追加(ChatGPT実装)

この記事では一番最初にチャンネル情報として追加しています。

学習

学習データを元にepoch毎にノイズ画像を生成しdatasetとします。

epochs = 10
batch_size = 128
model.compile(optimizer="adam", loss="mse")
for _ in range(epochs):
    # タイムステップtをランダムに生成させる
    t = np.random.randint(0, timesteps, len(x_train))
    # tを元にノイズ入り画像を作成
    noisy_images, noise = add_noise(x_train, t)

    # 学習: 入力はノイズ画像とタイムステップ、教師データはノイズでMSEで学習させる
    model.fit([noisy_images, t], noise, batch_size=batch_size, epochs=1)
出力
469/469 [==============================] - 27s 51ms/step - loss: 0.0634
469/469 [==============================] - 22s 48ms/step - loss: 0.0223
469/469 [==============================] - 23s 48ms/step - loss: 0.0204
469/469 [==============================] - 23s 48ms/step - loss: 0.0195
469/469 [==============================] - 22s 48ms/step - loss: 0.0189
469/469 [==============================] - 23s 48ms/step - loss: 0.0183
469/469 [==============================] - 22s 48ms/step - loss: 0.0181
469/469 [==============================] - 23s 48ms/step - loss: 0.0177
469/469 [==============================] - 22s 48ms/step - loss: 0.0175
469/469 [==============================] - 23s 48ms/step - loss: 0.0173

推論

ランダムなノイズ $x_T$ から以下で画像を生成していきます。

\begin{align} 
x_{t-1} &= \frac{1}{\sqrt{\alpha_t}} \Big( x_t- \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha_t}}} \epsilon_{\theta}(x_t,t) \Big) + \Sigma_t z \\
\Sigma_t &= \sigma^2=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t\\
\end{align} 

ただ生成後の画像ですが、正規分布のノイズがはいるので[-1,1]の範囲を超えます。
この処理をどうするかが検索しても分からず…、とりあえずクリップする形で実装しています。


def sample_images(img_shape, model: keras.Model, timesteps: int, num_samples: int) -> np.ndarray:
    samples_history = []

    samples = keras.random.normal((num_samples,) + img_shape, dtype=np.float32)
    for t in reversed(range(timesteps)):
        batch_t = tf.constant([[t] for _ in range(num_samples)])
        noise_pred = model([samples, batch_t])

        # x_t-1 = 1/sqrt(at) * (xt - (1-at)/sqrt(1-^at) e(xt,t)) [t=0以外] + v*z
        samples = (1 / np.sqrt(alphas[t])) * (samples - noise_pred * betas[t] / np.sqrt(1 - alphas_cumprod[t]))
        if t > 0:
            noise = keras.random.normal(samples.shape, dtype=np.float32)
            v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
            samples += np.sqrt(v) * noise

        samples_history.append(samples.numpy())

    samples = tf.clip_by_value(samples, -1.0, 1.0)
    samples = samples.numpy()

    # [-1,1] -> [0,255]
    samples = (((samples + 1) / 2) * 255).astype(np.uint8)
    return samples, samples_history

num_samples = 16
generated_images, images_history = sample_images(img_shape, model, timesteps, num_samples)

# 結果
plt.figure(figsize=(10, 10))
for i in range(num_samples):
    plt.subplot(4, 4, i + 1)
    plt.imshow(generated_images[i, :, :, 0], cmap="gray")
    plt.axis("off")
plt.show()

# 作成過程
index = 4
img_list = np.array(images_history)[:, index, :, :, 0]
plt.figure(figsize=(20, 5))
step_idxs = list(range(0, len(img_list), int(timesteps / 12)))  # 多いので一定間隔で抜き出し
step_idxs += [len(img_list) - 1]  # 最後も追加
for i, idx in enumerate(step_idxs):
    plt.subplot(1, len(step_idxs), i + 1)
    plt.imshow(img_list[idx], cmap="gray")
    plt.xticks([])
    plt.yticks([])
    plt.xlabel(f"step={idx}")
plt.show()

・結果
Figure_3-2.png

・生成過程
Figure_4-2.png

ノイズもなく綺麗に生成されていますね。
読めない画像は元々のデータセットも結構読めない文字が含まれていたりするのでそれに引っ張られていたりするんでしょうか。

コード全体

コード全体
from pathlib import Path

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras
from tqdm import tqdm

kl = keras.layers


# --- 値の定義
img_size = 28
img_shape = (img_size, img_size, 1)

timesteps = 500  # T

# define beta schedule
betas = np.linspace(0.0001, 0.02, timesteps, dtype=np.float32)

# define alphas
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)  # 累積積


def create_dataset():
    (x_train, _), (_, _) = keras.datasets.mnist.load_data()
    x_train = (x_train[..., np.newaxis] / 255.0) * 2 - 1  # [0,255] -> [-1,1]
    x_train = x_train.astype(np.float32)
    return x_train


def add_noise(x_start, t):
    assert x_start.shape[0] == t.shape[0]  # バッチサイズ

    # ^at を作成
    a_cumprod = alphas_cumprod[t]
    a_cumprod = a_cumprod[..., np.newaxis, np.newaxis, np.newaxis]

    # q(x0, t) = sqrt(^at) * x0 + sqrt(1-^at)* e
    noise = np.random.normal(0, 1, size=x_start.shape).astype(np.float32)
    noised_x = np.sqrt(a_cumprod) * x_start + np.sqrt(1 - a_cumprod) * noise
    return noised_x, noise


class PositionalEncoding(kl.Layer):
    def __init__(self, max_position: int, embedding_dim: int, **kwargs):
        """
        サインコサイン位置埋め込みを計算するKerasレイヤー
        :param max_position: 最大の位置数
        :param embedding_dim: 埋め込み次元数
        """
        super().__init__(**kwargs)

        # 位置エンコーディングの計算
        positions = np.arange(max_position)[:, np.newaxis]  # (max_position, 1)
        dims = np.arange(embedding_dim)[np.newaxis, :]  # (1, embedding_dim)

        # サイン・コサイン関数で位置埋め込みを計算
        angle_rates = 1 / np.power(10000, (2 * (dims // 2)) / embedding_dim)
        angle_rads = positions * angle_rates

        # 偶数インデックスはsin、奇数インデックスはcos
        positional_encoding = np.zeros_like(angle_rads)
        positional_encoding[:, 0::2] = np.sin(angle_rads[:, 0::2])
        positional_encoding[:, 1::2] = np.cos(angle_rads[:, 1::2])

        self.positional_encoding = tf.constant(positional_encoding, dtype=tf.float32)

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        indices = tf.cast(tf.reshape(inputs, [-1]), tf.int32)
        return tf.gather(self.positional_encoding, indices)

    @staticmethod
    def plot(max_position: int = 1000, embedding_dim: int = 500):  # for debug
        pos_emb = PositionalEncoding(max_position, embedding_dim)
        step = tf.constant(np.arange(0, max_position)[..., np.newaxis])
        emb = pos_emb(step).numpy()
        plt.pcolormesh(emb.T, cmap="RdBu")
        plt.ylabel("dimension")
        plt.xlabel("time step")
        plt.colorbar()
        plt.show()


def build_unet(img_shape: tuple[int, int, int], timesteps: int) -> keras.Model:
    # 入力はノイズ入り画像とタイムステップ
    noised_img = kl.Input(shape=img_shape)
    t = keras.Input(shape=(1,))

    # --- タイムステップ情報を追加
    # (1) -> (dim) -> (1, 1, dim) -> (28, 28, dim)
    t_embedding = PositionalEncoding(timesteps, 128)(t)
    t_embedding = kl.Dense(128, activation="gelu")(t_embedding)
    t_embedding = kl.Reshape((1, 1, 128))(t_embedding)
    t_embedding = kl.UpSampling2D(img_shape[:2])(t_embedding)

    # タイムステップ情報をチャンネルに追加
    x = kl.Concatenate()([noised_img, t_embedding])

    # --- down sampling
    c1 = kl.Conv2D(64, (3, 3), padding="same", activation="relu")(x)
    c1 = kl.Conv2D(64, (3, 3), padding="same", activation="relu")(c1)
    p1 = kl.MaxPooling2D((2, 2))(c1)  # 28x28 -> 14x14
    c2 = kl.Conv2D(128, (3, 3), padding="same", activation="relu")(p1)
    c2 = kl.Conv2D(128, (3, 3), padding="same", activation="relu")(c2)
    p2 = kl.MaxPooling2D((2, 2))(c2)  # 14x14 -> 7x7

    # --- bottom
    p2 = kl.Conv2D(256, (3, 3), activation="relu", padding="same")(p2)

    # --- up sampling
    u1 = kl.UpSampling2D((2, 2))(p2)  # 7x7 -> 14x14
    u1 = kl.Concatenate()([u1, c2])
    u1 = kl.Conv2D(128, (3, 3), activation="relu", padding="same")(u1)
    u1 = kl.Conv2D(128, (3, 3), activation="relu", padding="same")(u1)
    u2 = kl.UpSampling2D((2, 2))(u1)  # 14x14 -> 28x28
    u2 = kl.Concatenate()([u2, c1])
    u2 = kl.Conv2D(64, (3, 3), activation="relu", padding="same")(u2)
    u2 = kl.Conv2D(64, (3, 3), activation="relu", padding="same")(u2)

    # 出力は正規分布のノイズ画像
    y = kl.Conv2D(1, (1, 1), padding="same")(u2)

    model = keras.Model(inputs=[noised_img, t], outputs=y, name="u_net")
    return model


def sample():
    x_train = create_dataset()

    # --- ノイズ画像をサンプルで描画
    def plot_sample(img, noised_img):
        # 正規分布が無限の値をとるので[-1,1]にclip
        noised_img = np.clip(noised_img, -1.0, 1.0)
        # [-1,1] -> [0,255]
        img = (((img + 1) / 2) * 255).astype(np.uint8)
        noised_img = (((noised_img + 1) / 2) * 255).astype(np.uint8)

        fig, axes = plt.subplots(2, len(img), figsize=(12, 5))
        for i in range(len(img)):
            axes[0, i].imshow(img[i], cmap="gray")
            axes[0, i].axis("off")
            axes[1, i].imshow(noised_img[i], cmap="gray")
            axes[1, i].axis("off")
        plt.tight_layout()
        plt.show()

    time_list = np.array([0, 50, 100, 300])
    img = x_train[: len(time_list)]
    noised_img, noise = add_noise(img, time_list)
    plot_sample(img, noised_img)

    # --- Model
    model = build_unet(img_shape, timesteps)
    model.summary()

    # --- 位置エンコーディングの可視化
    PositionalEncoding.plot()


def train():
    x_train = create_dataset()
    model = build_unet(img_shape, timesteps)

    # --- train
    epochs = 10
    batch_size = 128
    model.compile(optimizer="adam", loss="huber")  # huberの方が学習が早い
    # model.compile(optimizer="adam", loss="mse")

    for _ in range(epochs):
        # タイムステップtをランダムに生成させる
        t = np.random.randint(0, timesteps, len(x_train))
        # tを元にノイズ入り画像を作成
        noisy_images, noise = add_noise(x_train, t)

        # 学習: 入力はノイズ画像とタイムステップ、教師データはノイズでMSEで学習させる
        model.fit([noisy_images, t], noise, batch_size=batch_size, epochs=1)
    model.save_weights(Path(__file__).parent / "diff.weights.h5")


def generate():
    model = build_unet(img_shape, timesteps)
    model.load_weights(Path(__file__).parent / "diff.weights.h5")

    def sample_images(img_shape, model: keras.Model, timesteps: int, num_samples: int) -> np.ndarray:
        samples_history = []

        samples = tf.random.normal((num_samples,) + img_shape, dtype=np.float32)
        for t in tqdm(
            reversed(range(timesteps)),
            desc="sampling loop time step",
            total=timesteps,
        ):
            batch_t = tf.constant([[t] for _ in range(num_samples)])
            noise_pred = model([samples, batch_t])

            # x_t-1 = 1/sqrt(at) * (xt - (1-at)/sqrt(1-^at) e(xt,t)) [t=0以外] + v*z
            samples = (1 / np.sqrt(alphas[t])) * (samples - noise_pred * betas[t] / np.sqrt(1 - alphas_cumprod[t]))
            if t > 0:
                noise = tf.random.normal(samples.shape, dtype=np.float32)
                v = (1 - alphas_cumprod[t - 1]) / (1 - alphas_cumprod[t]) * betas[t]
                samples += np.sqrt(v) * noise

            samples_history.append(samples.numpy())

        samples = tf.clip_by_value(samples, -1.0, 1.0)
        samples = samples.numpy()

        # [-1,1] -> [0,255]
        samples = (((samples + 1) / 2) * 255).astype(np.uint8)
        return samples, samples_history

    num_samples = 16
    generated_images, images_history = sample_images(img_shape, model, timesteps, num_samples)

    # 結果
    plt.figure(figsize=(10, 10))
    for i in range(num_samples):
        plt.subplot(4, 4, i + 1)
        plt.imshow(generated_images[i, :, :, 0], cmap="gray")
        plt.axis("off")
    plt.show()

    # 作成過程
    index = 4
    img_list = np.array(images_history)[:, index, :, :, 0]
    plt.figure(figsize=(20, 5))
    step_idxs = list(range(0, len(img_list), int(timesteps / 12)))  # 多いので一定間隔で抜き出し
    step_idxs += [len(img_list) - 1]  # 最後も追加
    for i, idx in enumerate(step_idxs):
        plt.subplot(1, len(step_idxs), i + 1)
        plt.imshow(img_list[idx], cmap="gray")
        plt.xticks([])
        plt.yticks([])
        plt.xlabel(f"step={idx}")
    plt.show()


if __name__ == "__main__":
    sample()
    train()
    generate()

Tips

SBMとDDPM

拡散モデルの代表的なアルゴリズムとしてはSBM(Score-based Generative Models)(論文リンク)DDPM(Denoising Diffusion Probabilistic Models)(論文リンク)があるようです。
この記事はDDPMがベースになっています。

それぞれの違いは以下となります。(ソースはChatGPT)

項目 SBM DDPM
基礎理論 確率微分方程式(SDE) マルコフ連鎖
学習対象 スコア関数$\nabla_x \log{p(x)}$ ノイズ除去プロセス
生成プロセス 逆拡散プロセス(連続) 逐次的なステップ
損失関数 スコアマッチング 変分下限(ELBO)
特徴 時間連続的な表現、理論的柔軟性 高品質な生成、逐次的生成

終わりに

これを応用した手法がいろいろあるので勉強してみました。
概要だけ見ると簡単ですが実際に実装してみると細かい部分で分からない箇所が色々出てきますね。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?