0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

拡散モデル入門④、条件付きU-Net(MNIST実装付き)

Last updated at Posted at 2025-02-24

入門①:DDPMの理論とMNISTの実装
入門②:SDE/ODEの基礎理論(Tensorflow実装付き)
入門③:EDMの解説とMNISTの実装
入門④:ここ

最後にU-Netの詳細と条件の埋め込み方について見ていきます。

参考
https://huggingface.co/blog/annotated-diffusion
https://github.com/eloialonso/diamond/

DDPMベースのU-Net

DDPMの実装ではU-Netに Wide ResNet block の使用、Attention層の追加、バッチ正規化からグループ正規化への変更がされています。
EDMでもU-Netの実装自体は変わりはないようです。
本記事では、細かい違いはありますが内部でEDMが使われているDIAMONDの実装(GitHub)をベースに見ていきます。

全体から各モジュールに掘り下げる形で見ていきます。

U-Net

U-Netは画像を入力として画像を出力するモデルとなります。
解像度を下げるDownsampleと解像度を上げるUpsample、それらを直接つなげるスキップ接続(Skip Connection)があるのが特徴です。

aa-ページ1.drawio.png

画像に情報を付与する場合は、条件付けとして Condition を入力します。
Conditionは全部の層に対して入力されます。

また、ブロックの順序ですがネットの実装を見てみると "Down/Up→ResBlock" と "ResBock→Down/Up" の両方の実装が見られました。
違いが分からなかったのでChatGPTに聞いたところ以下のようです。

順序 特徴 メリット/デメリット
ResBlock → Downsample 特徴を先に抽出してから縮小 計算コストが高いが、高解像度の情報を保持 DDPM,画像生成
Downsample → ResBlock 縮小してから特徴抽出 計算コストが低いが、情報の損失が大きい 分類モデル,小規模な特徴抽出
順序 特徴 メリット デメリット
ResBlock → Upsample 特徴を先に抽出し、拡大する 高解像度の情報を保持しやすい アップサンプリング後の調整が少ない DDPM,画像生成
Upsample → ResBlock 先に拡大し、特徴を再調整 高解像度の特徴を学習できる 拡大時の情報損失が大きくなる可能性 超解像, 高解像度の画像処理

Condition

Condition は各画像の追加情報で、拡散モデルのタイムステップの情報も含まれます。
今回はMNISTのカテゴリ情報も追加してみました。

aa-ページ4.drawio.png

タイムステップはDIAMONDの実装ではフーリエ変換された特徴量を使っていましたが、この記事ではDDPMに倣ってサインコサインエンコーディングにしています。
MNISTのカテゴリ情報はTFのEmbbedingレイヤーでベクトル化しました。

Downsample/Upsample

DIAMONDの実装では、DownsampleはConv2Dのstrides=2でサイズを半分にし、Upsampleは最近傍補間で2倍にしていましたので、それに倣っています。

class Downsample(keras.layers.Layer):
    def build(self, input_shape):
        self.conv = kl.Conv2D(
            filters=input_shape[-1],
            kernel_size=3,
            strides=2,
            padding="same",
            kernel_initializer=keras.initializers.Orthogonal(),
        )

    def call(self, x, training=False):
        return self.conv(x, training=training)

class Upsample(keras.layers.Layer):
    def build(self, input_shape):
        self.conv = kl.Conv2D(
            filters=input_shape[-1],
            kernel_size=3,
            strides=1,
            padding="same",
        )

    def call(self, x, training=False):
        # 最近傍補間で2倍にリサイズ
        input_shape = tf.shape(x)
        x = tf.image.resize(x, size=(input_shape[1] * 2, input_shape[2] * 2), method="nearest")
        return self.conv(x, training=training)

ResBLocks & ResBlock(Residual Block)

ResBLocksですが、複数のResBlockから成ります。

aa-ページ3.drawio.png

Skip Connection は Upsample にのみあり Downsample の出力に対応します。
次にResBlockです。

aa-ページ2.drawio.png

普通のResBlockと違う点は最後にAttentionレイヤー(SelfAttention)が入っている点です。
またConditionですが、DDPMではConv3x3層に直接足していましたが、DIAMONDでは AdaGN としてNorm層に追加しています。

Adaptive Group Normalization(AdaGN)

Normレイヤーですが、一般的には Group Normalization(GN)が使われます。

ss1.png

図はGroup Normalization (論文)より、他の正規化手法とGNの違いを表しており、Nがバッチ、Cがチャンネル、(H,W)が画像を表す軸となります。
簡単にまとめると以下です。

Norm バッチ軸 チャンネル軸
Batch Norm All 1 バッチに対して正規化
Layer Norm 1 All チャンネルに対して正規化
Instance Norm 1 1      画像チャンネル毎に正規化
Group Norm 1 グループで分割 チャンネル軸を分割し、それぞれのグループで正規化

DIAMONDの実装では GN ではなく、Adaptive Group Normalization(AdaGN)が使われていました。
これは GN に条件(例えばスタイル情報など)を適用できるように拡張した手法です。

GN は以下のように正規化されます。

$$
\hat{x}_i = \frac{x_i - \mu_g}{\sigma_g}
$$

$\mu_g$と$\sigma_g$がグループ$g$の平均と標準偏差を表します。
AdaGNではこれに条件 $s$ に対するスケール $\gamma$ とバイアス $\beta$ が追加されます。

$$
\hat{x}_i = \gamma(s)\cdot \frac{x_i - \mu_g}{\sigma_g} + \beta(s)
$$

主にスタイル変換(Style Transfer)や今回みたいな条件付画像生成で使われることが多いようです。

多分これが論文:Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization

SelfAttention

Transformerで有名になったAttentionについて簡単に説明します。
詳細に見たい場合は以下の動画を参考にしてください。
【深層学習】Attention - 全領域に応用され最高精度を叩き出す注意機構の仕組み【ディープラーニングの世界 vol. 24】(youtube; AIcia Solid Project)

※いろいろ調べたところの私の解釈となります。

Attention機構はQuery/Key/Valueの3つの値を入力して、その中で注目度の高い要素を出力する仕組みとなります。
Attention機構の代表的な Source-Target型 Attention の例だと以下のようになります。(Q≠K=Vの形)

aa-ページ5.drawio.png

縦の列ですが、翻訳の例だとトークン列となりそのまま1次元で表現できますが、画像は2次元なのでw*hの1次元に並べ替えて表現させます。
(各ドットに対して別のドットに対する注目度を見るイメージ)

SelfAttentionは Q=K=V と、3つとも同じ値を入力したものになり、その画像内で関係性の高いドットを強調する役割になります。

説明は以上で、残りは実装結果です。

実行結果

各数字を出力させてみました。
ちゃんとそれっぽい数字が出力されていますね。
ただモデルの大きさとlossの下がり具合的にまだ学習が足りていない気がします。

  • 0

0のみhisotryものせておきます。

q_plot_0.png
q_plot_0_history.png

  • 1
    q_plot_1.png

  • 2
    q_plot_2.png

  • 3
    q_plot_3.png

  • 4
    q_plot_4.png

  • 5
    q_plot_5.png

  • 6
    q_plot_6.png

  • 7
    q_plot_7.png

  • 8
    q_plot_8.png

  • 9
    q_plot_9.png

全体コード

version

  • Windows11
  • WSL2: Ubuntu24.04
  • Python3.12.3
  • Tensorflow 2.18.0
  • CUDA 12.5.1
  • cuDNN 9.3.0
  • GeForce RTX 3060 12GB
import math
from functools import partial
from pathlib import Path

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras
from tqdm import tqdm

kl = keras.layers


# define
img_size = 32
img_shape = (img_size, img_size, 1)


def create_dataset():
    (x_train, y_train), (_, _) = keras.datasets.mnist.load_data()
    # x_train, y_train = x_train[y_train == 1], y_train[y_train == 1]
    x_train = tf.image.resize(tf.expand_dims(x_train, -1), [img_size, img_size]).numpy()  # 28 -> 32
    x_train = (x_train / 255.0) * 2 - 1  # [0,255] -> [-1,1]
    return x_train.astype(np.float32), y_train.astype(np.float32)


def decode_image(x):
    img = np.clip(x, -1.0, 1.0)
    # [-1,1] -> [0,255]
    img = (((img + 1) / 2) * 255).astype(np.uint8)
    return img


# ----------------------------------
# model
# ----------------------------------

# alias
Conv2D1x1 = partial(kl.Conv2D, kernel_size=1, strides=1, padding="valid")
Conv2D3x3 = partial(kl.Conv2D, kernel_size=3, strides=1, padding="same")
IdentityLayer = partial(kl.Lambda, function=lambda x: x)  # 何もしないレイヤー


class PositionalEmbedding(keras.layers.Layer):
    def __init__(self, dim: int, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim

    def call(self, time: tf.Tensor) -> tf.Tensor:
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = tf.exp(tf.range(half_dim, dtype=tf.float32) * -embeddings)
        embeddings = tf.expand_dims(time, axis=-1) * tf.expand_dims(embeddings, axis=0)
        embeddings = tf.concat([tf.math.sin(embeddings), tf.math.cos(embeddings)], axis=-1)
        return embeddings

    @staticmethod
    def plot(embedding_dim: int = 500, N: int = 100):  # for debug
        model = PositionalEmbedding(embedding_dim)
        timestep = EDM.create_timesptes(N)
        emb = model(tf.constant(timestep)).numpy()
        plt.pcolormesh(emb.T, cmap="RdBu")
        plt.ylabel("dimension")
        plt.xlabel("time step")
        plt.colorbar()
        plt.show()


class AdaGroupNorm(keras.layers.Layer):
    def __init__(self, group_size: int = 32, eps: float = 1e-5, **kwargs) -> None:
        super().__init__(**kwargs)
        self.group_size = group_size
        self.eps = eps

    def build(self, input_shape):
        in_channels = input_shape[-1]

        # group_sizeは割り切れる場合のみ指定、-1: LayerNorm, 1: InstanceNorm
        groups = self.group_size if in_channels % self.group_size == 0 else -1

        self.norm = kl.GroupNormalization(groups=groups, epsilon=self.eps)
        self.gamma = kl.Dense(in_channels, use_bias=False, kernel_initializer="zeros")
        self.beta = kl.Dense(in_channels, use_bias=False, kernel_initializer="zeros")

    def call(self, x, condition, training=False):
        x = self.norm(x, training=training)
        condition = tf.expand_dims(tf.expand_dims(condition, axis=1), axis=1)  # (b,c)->(b,1,1,c)
        gamma = self.gamma(condition, training=training)
        beta = self.beta(condition, training=training)
        return x * (1 + gamma) + beta


class SelfAttention2D(keras.layers.Layer):
    def __init__(self, head_dim: int = 8, **kwargs) -> None:
        super().__init__(**kwargs)
        self.head_dim = head_dim

    def build(self, input_shape):
        in_channels = input_shape[-1]
        self.n_head = max(1, in_channels // self.head_dim)
        assert in_channels % self.n_head == 0, f"入力チャンネル数はhead数で分割できる数(head={self.n_head}"
        self.norm = kl.GroupNormalization()
        self.qkv_proj = Conv2D1x1(in_channels * 3)
        self.out_proj = Conv2D1x1(in_channels, kernel_initializer="zeros", bias_initializer="zeros")
        self.softmax = kl.Softmax(axis=-1)

    def call(self, x, training=False):
        n, h, w, c = x.shape
        x = self.norm(x, training=training)
        qkv = self.qkv_proj(x)
        # chをヘッド数で分割、hとwはまとめてseq_lenとする
        # [batch, h, w, ch*3] -> [batch, h*w, ch//head, head*3]
        qkv = tf.reshape(qkv, (n, h * w, c // self.n_head, self.n_head * 3))
        # [batch, seq_len, head, d*3] -> [batch, head, seq_len, d*3]
        qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
        # -> [batch, head, seq_len, d] * 3
        q, k, v = tf.split(qkv, num_or_size_splits=3, axis=-1)
        attn = tf.matmul(q, k, transpose_b=True)  # q@k.T
        attn = attn / tf.math.sqrt(tf.cast(k.shape[-1], tf.float32))
        attn = tf.matmul(self.softmax(attn), v)
        # ヘッドを結合し、seq_len->h,wに分割
        # [batch, head, seq_len, d] -> [batch, seq_len, head, d] -> [batch, h, w, head*d]
        y = tf.transpose(attn, perm=[0, 2, 1, 3])
        y = tf.reshape(y, (n, h, w, c))
        return x + self.out_proj(y)


class Downsample(keras.layers.Layer):
    def build(self, input_shape):
        self.conv = kl.Conv2D(
            filters=input_shape[-1],
            kernel_size=3,
            strides=2,
            padding="same",
            kernel_initializer=keras.initializers.Orthogonal(),
        )

    def call(self, x, training=False):
        return self.conv(x, training=training)


class Upsample(keras.layers.Layer):
    def build(self, input_shape):
        self.conv = kl.Conv2D(
            filters=input_shape[-1],
            kernel_size=3,
            strides=1,
            padding="same",
        )

    def call(self, x, training=False):
        # 最近傍補間で2倍にリサイズ
        input_shape = tf.shape(x)
        x = tf.image.resize(x, size=(input_shape[1] * 2, input_shape[2] * 2), method="nearest")
        return self.conv(x, training=training)


class ResBlock(keras.Model):
    def __init__(self, channels: int, use_attention: bool, **kwargs) -> None:
        super().__init__(**kwargs)
        self.channels = channels
        self.use_attention = use_attention

    def build(self, input_shape):
        use_projection = input_shape[-1] != self.channels
        self.proj = Conv2D1x1(self.channels) if use_projection else IdentityLayer()
        self.norm1 = AdaGroupNorm()
        self.act1 = kl.Activation("silu")
        self.conv1 = Conv2D3x3(self.channels)
        self.norm2 = AdaGroupNorm()
        self.act2 = kl.Activation("silu")
        self.conv2 = Conv2D3x3(self.channels)
        self.attn = SelfAttention2D() if self.use_attention else IdentityLayer()

    def call(self, x, condition, training=False):
        r = self.proj(x, training=training)
        x = self.norm1(x, condition, training=training)
        x = self.act1(x, training=training)
        x = self.conv1(x, training=training)
        x = self.norm2(x, condition, training=training)
        x = self.act2(x, training=training)
        x = self.conv2(x, training=training)
        x = x + r
        x = self.attn(x, training=training)
        return x


class ResBlocks(keras.Model):
    def __init__(self, channels_list: list[int], use_attention: bool, **kwargs) -> None:
        super().__init__(**kwargs)
        self.resblocks = [ResBlock(c, use_attention) for c in channels_list]

    def call(self, x, condition=None, shortcut=None, training=False):
        outputs = []
        for i, resblock in enumerate(self.resblocks):
            if shortcut is not None:
                x = tf.concat([x, shortcut[i]], axis=-1)
            x = resblock(x, condition, training=training)
            outputs.append(x)
        return x, outputs


class UNet(keras.Model):
    def __init__(self, img_shape: tuple, category_num: int, **kwargs) -> None:
        super().__init__(**kwargs)

        # condition
        embedding_dim = 128
        self.time_embedding1 = PositionalEmbedding(embedding_dim)
        self.time_embedding2 = kl.Dense(embedding_dim, activation="gelu")
        self.time_embedding3 = kl.Dense(embedding_dim)
        self.category_embedding1 = kl.Embedding(category_num, embedding_dim)
        self.category_embedding2 = kl.Dense(embedding_dim, activation="gelu")
        self.category_embedding3 = kl.Dense(embedding_dim)

        # Note: SelfAttentionは画像サイズが大きい場合に適用するとメモリ不足になった

        # downsample
        self.down_block11 = ResBlocks([32, 32], use_attention=False)
        self.down_block12 = ResBlocks([32, 32], use_attention=False)
        self.down_block13 = ResBlocks([32, 32], use_attention=False)
        self.downsample1 = Downsample()
        self.down_block21 = ResBlocks([32, 32], use_attention=False)
        self.down_block22 = ResBlocks([32, 32], use_attention=False)
        self.down_block23 = ResBlocks([32, 32], use_attention=False)
        self.downsample2 = Downsample()

        # middle
        self.middle_block = ResBlocks([256], use_attention=True)

        # upsample
        self.up_block23 = ResBlocks([32, 32], use_attention=False)
        self.up_block22 = ResBlocks([32, 32], use_attention=False)
        self.up_block21 = ResBlocks([32, 32], use_attention=False)
        self.upsample2 = Upsample()
        self.up_block13 = ResBlocks([32, 32], use_attention=False)
        self.up_block12 = ResBlocks([32, 32], use_attention=False)
        self.up_block11 = ResBlocks([32, 32], use_attention=False)
        self.upsample1 = Upsample()

        # 出力層(正規分布のノイズ画像)
        self.out_layer = kl.Conv2D(1, (1, 1), padding="same")

        # build & init weight
        self(
            [
                np.zeros((1,) + img_shape),
                np.zeros((1,)),
                np.zeros((1,)),
            ]
        )

    @tf.function
    def call(self, inputs, training=False):
        # 入力: (ノイズ画像、時間ステップ、カテゴリ)
        x, t, category = inputs

        # 時間埋め込み
        t_emb = self.time_embedding1(t, training=training)
        t_emb = self.time_embedding2(t_emb, training=training)
        t_emb = self.time_embedding3(t_emb, training=training)

        # カテゴリの埋め込み
        c_emb = self.category_embedding1(category, training=training)
        c_emb = self.category_embedding2(c_emb, training=training)
        c_emb = self.category_embedding3(c_emb, training=training)

        # 条件付けはtimesteps+catetory
        condition = t_emb + c_emb

        # --- U-Net
        x, o11 = self.down_block11(x, condition, training=training)
        x, o12 = self.down_block12(x, condition, training=training)
        x, o13 = self.down_block13(x, condition, training=training)
        x = self.downsample1(x, training=training)  # 28x28 -> 14x14
        x, o21 = self.down_block21(x, condition, training=training)
        x, o22 = self.down_block22(x, condition, training=training)
        x, o23 = self.down_block23(x, condition, training=training)
        x = self.downsample2(x, training=training)  # 14x14 -> 7x7

        x, _ = self.middle_block(x, condition, training=training)

        x = self.upsample2(x, training=training)  # 7x7 -> 14x14
        x, _ = self.up_block23(x, condition, o23[::-1], training=training)
        x, _ = self.up_block22(x, condition, o22[::-1], training=training)
        x, _ = self.up_block21(x, condition, o21[::-1], training=training)
        x = self.upsample1(x, training=training)  # 14x14 -> 28x28
        x, _ = self.up_block13(x, condition, o13[::-1], training=training)
        x, _ = self.up_block12(x, condition, o12[::-1], training=training)
        x, _ = self.up_block11(x, condition, o11[::-1], training=training)

        x = self.out_layer(x, training=training)
        return x


class EDM:
    def __init__(self, img_shape, category_num, lr: float = 0.0001, sigma_data=0.5):
        self.sigma_data = sigma_data
        self.model = UNet(img_shape, category_num)
        self.optimizer = keras.optimizers.Adam(learning_rate=lr)

    def calc_c(self, sigma):
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / tf.sqrt(sigma**2 + self.sigma_data**2)
        c_in = 1 / tf.sqrt(self.sigma_data**2 + sigma**2)
        c_noise = tf.math.log(sigma) / 4
        return c_skip, c_out, c_in, c_noise

    def call(self, x, sigma, category, training=False):
        c_skip, c_out, c_in, c_noise = self.calc_c(sigma)
        c_noise = tf.reshape(c_noise, (-1,))
        fx = self.model([c_in * x, c_noise, category], training=training)
        dx = c_skip * x + c_out * fx
        return dx

    def train(self, img, category, p_mean=-1.2, p_std=1.2):
        # Noise distibution: ln(σ) ~ N(Pmena, Pstd^2)
        r = np.random.randn(img.shape[0], 1, 1, 1).astype(np.float32)
        sigma = np.exp(p_mean + r * p_std)
        c_skip, c_out, c_in, c_noise = self.calc_c(sigma)
        c_noise = tf.reshape(c_noise, (-1,))

        # λ(σ): (σ^2+σd^2) / (σ*σd)^2
        lambda_w = (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)

        # ノイズ
        n = np.random.randn(*img.shape) * sigma

        # tensorflowに取り込める形に変換
        n = tf.convert_to_tensor(n, dtype=tf.float32)
        sigma = tf.convert_to_tensor(sigma.reshape((-1, 1, 1, 1)), dtype=tf.float32)

        with tf.GradientTape() as tape:
            output = self.model([c_in * (img + n), c_noise, category])
            target = (img - c_skip * (img + n)) / c_out
            loss = lambda_w * (c_out**2) * ((output - target) ** 2)
            loss = tf.reduce_mean(loss)
        grad = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grad, self.model.trainable_variables))

        return loss.numpy()

    def generate(
        self,
        size: int,
        N: int,
        category: int,
        sigma_min=0.002,
        sigma_max=80,
        rho=7,
        s_churn=0,
        s_min=0,
        s_max=float("inf"),
        s_noise=1,
    ):
        # Time steps
        timesteps = self.create_timesptes(N, sigma_min, sigma_max, rho)
        tf_category = tf.convert_to_tensor([category] * size, dtype=tf.float32)

        samples_history = []

        # 初期サンプル
        var = timesteps[0]
        x = np.random.randn(size, *img_shape).astype(np.float32) * var

        x = tf.convert_to_tensor(x, dtype=tf.float32)
        for t, next_t in tqdm(
            zip(timesteps[:-1], timesteps[1:]),
            total=N,
            desc="sampling loop",
        ):
            gamma = min(s_churn / N, np.sqrt(2) - 1) if s_min <= t <= s_max else 0
            t_hat = t + gamma * t
            e = np.random.randn(size, *img_shape).astype(np.float32) * s_noise
            x_hat = x + np.sqrt(t_hat**2 - t**2) * e

            # tf
            t_hat = tf.convert_to_tensor([t_hat] * size, dtype=tf.float32)
            t_hat = tf.reshape(t_hat, (-1, 1, 1, 1))
            denoise = self.call(x_hat, t_hat, tf_category)

            d = (x_hat - denoise) / t_hat
            x_next = x_hat + (next_t - t_hat) * d
            if next_t != 0:
                # tf
                t_next = tf.convert_to_tensor([next_t] * size, dtype=tf.float32)
                t_next = tf.reshape(t_next, (-1, 1, 1, 1))
                denoise = self.call(x_next, t_next, tf_category)

                d_dash = (x_next - denoise) / next_t
                x_next = x_hat + (next_t - t_hat) * (d / 2 + d_dash / 2)

            samples_history.append(x_next)
            x = x_next

        return x, samples_history

    @staticmethod
    def create_timesptes(N: int, sigma_min=0.002, sigma_max=80, rho=7):
        timesteps = []
        for i in range(N):
            t = (sigma_max ** (1 / rho) + i / (N - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
            timesteps.append(t)
        timesteps += [0]  # N=0
        return timesteps


# ----------------------------------
# main
# ----------------------------------
def plot_summary():
    # U-Net
    model = UNet(img_shape, category_num=10)
    model.summary()

    # Time steps
    timesptes = EDM.create_timesptes(20)
    plt.figure(figsize=(8, 5))
    plt.plot(range(len(timesptes)), timesptes, marker=".")
    plt.xlabel("i")
    plt.ylabel("Time step")
    plt.title("Visualization of Time steps")
    plt.grid()
    plt.show()

    # 位置エンコーディングの可視化
    PositionalEmbedding.plot()


def train(epochs: int, batch_size: int, lr: float):
    x_img, x_category = create_dataset()

    # モデル
    edm = EDM(img_shape, category_num=10, lr=lr)

    # 学習用にデータをバッチ化
    train_dataset = tf.data.Dataset.from_tensor_slices((x_img, x_category)).shuffle(len(x_img)).batch(batch_size)

    for epoch in range(epochs):
        with tqdm(train_dataset, desc=f"Epoch {epoch + 1}/{epochs}") as pbar:
            for img, category in pbar:
                loss = edm.train(img, category)
                pbar.set_postfix(loss=loss)  # 損失を進捗バーに表示
    edm.model.save_weights(Path(__file__).parent / "edm.weights.h5")


def generate(steps: int, category: int, **gen_kwargs):
    edm = EDM(img_shape, category_num=10)
    edm.model.load_weights(Path(__file__).parent / "edm.weights.h5")

    # 生成
    num_w = 8
    num_h = 2
    num_samples = num_w * num_h
    samples, samples_history = edm.generate(num_samples, steps, category, **gen_kwargs)

    samples = decode_image(samples)
    samples_history = decode_image(samples_history)

    # 結果
    plt.figure(figsize=(10, 4))
    for i in range(num_samples):
        plt.subplot(num_h, num_w, i + 1)
        plt.imshow(samples[i, :, :, 0], cmap="gray")
        plt.axis("off")
    plt.savefig(Path(__file__).parent / f"plot_{category}.png")
    # plt.show()

    # 作成過程
    index = 4
    img_list = np.array(samples_history)[:, index, :, :, 0]
    plt.figure(figsize=(20, 5))
    step_idxs = list(range(0, len(img_list), int(steps / 12)))  # 多いので一定間隔で抜き出し
    step_idxs += [len(img_list) - 1]  # 最後も追加
    for i, idx in enumerate(step_idxs):
        plt.subplot(1, len(step_idxs), i + 1)
        plt.imshow(img_list[idx], cmap="gray")
        plt.xticks([])
        plt.yticks([])
        plt.xlabel(f"step={idx}")
    plt.savefig(Path(__file__).parent / f"plot_{category}_history.png")
    # plt.show()


if __name__ == "__main__":
    # plot_summary()
    train(epochs=10, batch_size=128, lr=0.0002)
    for i in range(10):
        generate(steps=20, category=i)

モデルのsummary

TFは2階層以上のクラスのshapeがデフォルトで表示されません。(昔から解決策を調べてるんですが…、分かる方教えてください)

Model: "u_net"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ positional_embedding                 │ ?                           │               0 │
│ (PositionalEmbedding)                │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense (Dense)                        │ (1, 128)                    │          16,512 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_1 (Dense)                      │ (1, 128)                    │          16,512 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ embedding (Embedding)                │ (1, 128)                    │           1,280 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks (ResBlocks)               │ ?                           │          53,154 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks_1 (ResBlocks)             │ ?                           │          70,016 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ downsample (Downsample)              │ ?                           │           9,248 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks_2 (ResBlocks)             │ ?                           │          70,016 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks_3 (ResBlocks)             │ ?                           │          70,016 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ downsample_1 (Downsample)            │ ?                           │           9,248 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks_4 (ResBlocks)             │ ?                           │         296,384 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks_5 (ResBlocks)             │ ?                           │         164,608 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks_6 (ResBlocks)             │ ?                           │         109,120 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ upsample (Upsample)                  │ ?                           │         147,584 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks_7 (ResBlocks)             │ ?                           │         109,120 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ res_blocks_8 (ResBlocks)             │ ?                           │         109,120 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ upsample_1 (Upsample)                │ ?                           │           9,248 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ conv2d (Conv2D)                      │ (1, 32, 32, 1)              │              33 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 1,261,219 (4.81 MB)
 Trainable params: 1,261,219 (4.81 MB)
 Non-trainable params: 0 (0.00 B)

学習途中のloss

Epoch 1/10: 100%|███████████████████| 469/469 [09:29<00:00,  4.22s/it, loss=0.178]
Epoch 1/10: 100%|███████████████████| 469/469 [09:29<00:00,  1.21s/it, loss=0.178]
Epoch 2/10: 100%|███████████████████| 469/469 [08:44<00:00,  1.10s/it, loss=0.153]
Epoch 2/10: 100%|███████████████████| 469/469 [08:44<00:00,  1.12s/it, loss=0.153]
Epoch 3/10: 100%|███████████████████| 469/469 [08:15<00:00,  1.06s/it, loss=0.147]
Epoch 4/10: 100%|███████████████████| 469/469 [08:02<00:00,  1.03s/it, loss=0.119]
Epoch 4/10: 100%|███████████████████| 469/469 [08:02<00:00,  1.03s/it, loss=0.119]
Epoch 5/10: 100%|███████████████████| 469/469 [07:27<00:00,  1.05it/s, loss=0.128]
Epoch 6/10: 100%|███████████████████| 469/469 [07:15<00:00,  1.08it/s, loss=0.113]
Epoch 7/10: 100%|███████████████████| 469/469 [07:13<00:00,  1.08it/s, loss=0.104]
Epoch 8/10: 100%|███████████████████| 469/469 [07:17<00:00,  1.01it/s, loss=0.113]
Epoch 8/10: 100%|███████████████████| 469/469 [07:17<00:00,  1.07it/s, loss=0.113]
Epoch 9/10: 100%|███████████████████| 469/469 [07:07<00:00,  1.10it/s, loss=0.115]
Epoch 10/10: 100%|██████████████████| 469/469 [07:12<00:00,  1.08it/s, loss=0.12]

おわりに

これで拡散モデルは終わりな気がする(多分)
学習が少ないと前の小さいモデルの方がすぐ学習できて優秀に感じました。
このモデルもパラメータ数をもっと増やして長時間学習させれば精度が上がりそうな感じはしています。
後、学習率を徐々に下げていった方がいいかも。

誰かの参考になれば幸いです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?