入門①:DDPMの理論とMNISTの実装
入門②:SDE/ODEの基礎理論(Tensorflow実装付き)
入門③:ここ
入門④:条件付きU-Net(MNIST実装付き)
EDM(Elucidating the Design Space of Diffusion-Based Generative Models)
DIAMONDの論文でEDMの実装を使っているので元の論文を調べてみました。
論文: https://arxiv.org/abs/2206.00364
Github: https://github.com/NVlabs/edm/tree/main
立ち位置?
拡散モデルの流れは知らないのでこのEDMがどういう立ち位置か簡単に調べてみました。
(そこまでちゃんと調べていないので参考程度に…)
年 | 手法 | 概要 | 論文 |
---|---|---|---|
2019,2020 | SBM | データ分布のスコア関数(確率密度関数の勾配)を推定し、それを利用したデータ生成 | 1,2 |
2020 | DDPM | 生成過程を明示的に定義し、逆過程によりデータ生成 | 1 |
2021 | Improved DDPM | 分散の学習やスケジューリングの改善などDDPMを改良 | 1 |
2021 | Score-Based SDE | 拡散プロセスを確率微分方程式(SDE)として定式化した論文 | 1 |
2022 | LDM | 拡散プロセスを低次元の潜在空間(latent space)で実行することで、計算コストを大幅に削減した手法。(Stable Diffusionで使われてる?) | 1 |
2022 | EDM | 拡散モデルの要素(拡散プロセスなど)の影響を分類・整理し、各手法における性能への影響を明確化した論文 | 1 |
2023 | DiTs | UNetの代わりにTransformersを使い、大幅に精度向上 | 1 |
以下の記事によるとNVIDIAのチームが発表した論文らしく、NeurIPS 2022 で優秀論文賞を受賞しているとの事。
https://developer.nvidia.com/blog/generative-ai-research-spotlight-demystifying-diffusion-based-models/
論文の概要
この論文が主張している内容は以下です。
- 拡散モデルの理論を実用的な観点でまとめ、各手法の関連や影響を調査した(これによるシステム設計への貢献)
- ルンゲ・クッタ法(Runge–Kutta method)適用によるサンプリングプロセスの大幅な改善
- トレーニングを改善するためのベストプラクティスの提示
拡散モデルの共通フレームワーク
以下の表の縦軸をまとめたことが最初の主張になるかと思います。
縦軸が設計時に選択できるコンポーネントで、横軸が主要な手法を並べ、具体的に値や方法をのせています。
これらのコンポーネントは独立して変更可能であるそうです。(ただ性能向上には特定のコンポーネントの組み合わせが重要らしい)
本記事では論文で提案された新しい手法"EDM"に焦点をあてます。
他の手法との比較や特徴などの議論は論文を参照してください。
またサンプリングで使う Time steps を生成するコードを以下に書いておきます。
def create_timesptes(N: int, sigma_min=0.002, sigma_max=80, rho=7):
timesteps = []
for i in range(N):
t = (sigma_max ** (1 / rho) + i / (N - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
timesteps.append(t)
timesteps += [0] # N=0
return timesteps
以下みたいになります。
ノイズ除去モデルの学習
ノイズ除去モデル(Denoiser) $D(x,\sigma)$ はノイズ入り画像 $x$ を入力するとノイズを除去した後の画像を出力するモデルです。
図は論文より、Denoiserの理想的な出力結果で、分散 $\sigma$ が大きいほど元画像の平均値が出力されます。
Denoiserですが、ニューラルネットワークで直接学習することはあまり適切ではないそうです。
これは入力値 $x$ が元画像 $y$ とガウスノイズ $n$ の組み合わせのため、入力値が $\sigma$ の値によって大きく変動するためです。
なので一般的には代わりのネットワーク $F_{\theta}$ を用いて以下のように表現し、$F_{\theta}$ を学習します。
$c_{skip}(\sigma)$ はスキップ接続の調整、$c_{in}(\sigma)$と$c_{out}(\sigma)$は入出力値の調整、$c_{noise}(\sigma)$はノイズレベルの調整をします。
損失は以下です。
学習時のノイズレベル $\sigma$ がサンプリングされる確率を $p_{train}(\sigma)$ とした場合、対応する重みが $\lambda(\sigma)$ となります。
EDM版のコード例は以下です。
sigma_data = 0.5
f_model = ニューラルネットワーク
def calc_c(sigma):
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
c_out = sigma * sigma_data / np.sqrt(sigma**2 + sigma_data**2)
c_in = 1 / np.sqrt(sigma_data**2 + sigma**2)
c_noise = np.log(sigma) / 4
return c_skip, c_out, c_in, c_noise
def model_forward(x, sigma):
c_skip, c_out, c_in, c_noise = calc_c(sigma)
dx = c_skip * x + c_out * f_model(c_in * x, c_noise)
return dx
def model_compute_loss(y, p_mean = -1.2, p_std = 1.2):
# Noise distibution
r = np.random.randn(y.shape[0], 1, 1, 1)
sigma = np.exp(p_mean + r * p_std)
c_skip, c_out, c_in, c_noise = calc_c(sigma)
# λ(σ)
lambda_w = (sigma**2 + sigma_data**2) / ((sigma * sigma_data) ** 2)
# ノイズ
n = np.random.randn(*y.shape) * sigma
# 損失の計算
network_output = f_model(c_in * (y + n), c_noise)
training_target = (y - c_skip * (y + n)) / c_out
loss = lambda_w * (c_out**2) * ((network_output - training_target) ** 2)
lossでモデルを更新
決定的サンプリング
論文に記載があったアルゴリズム1を実装してみました。
コード例は以下です。(メインじゃない部分は削っています、動くコードは最後に)
ただ上手く生成できず、何か間違っているかもしれません…。
基本こちらは使わないので参考程度に
def generate_deterministic(size: int, N: int, sigma_min=0.002, sigma_max=80, rho=7):
# Time steps
timesteps = create_timesptes(N, sigma_min, sigma_max, rho)
# 初期サンプル
sigma_0 = timesteps[0] # Schedule
s_0 = 1 # Scaling
var = (sigma_0**2) * (s_0**2)
x = np.random.normal(0, var, size=(size,) + img_shape)
for i in range(N - 1):
# 各変数
t = timesteps[i]
t_next = timesteps[i + 1]
sigma_t = timesteps[i] # Schedule
sigma_dt = timesteps[i] - timesteps[i+1]
s_t = 1 # Scaling
s_dt = 1 - 1
denoise = model_forward(x / s_t, sigma_t)
coeff = (sigma_dt / sigma_t + s_dt / s_t)
d = coeff * x - (sigma_dt * s_t / sigma_t) * denoise
x_next = x + (t_next - t) * d
if t_next != 0:
sigma_t_next = timesteps[i + 1]
sigma_dt_next = timesteps[i+1] - timesteps[i + 2]
s_t_next = 1
s_dt_next = 1 - 1
denoise = model_forward(x_next / s_t_next, sigma_t_next)
coeff = (sigma_dt_next / sigma_t_next + s_dt_next / s_t_next)
d_dash = coeff * x_next - (sigma_dt_next * s_t_next / sigma_t_next) * denoise
x_next = x + (t_next - t) * (d / 2 + d_dash / 2)
x = x_next
return x
確率的サンプリング
同じく論文に記載があったアルゴリズム2の実装です。
こちらはちゃんと生成できました。
def generate_stochastic(
size: int,
N: int,
sigma_min=0.002,
sigma_max=80,
rho=7,
s_churn=0,
s_min=0,
s_max=float("inf"),
s_noise=1,
):
# Time steps
timesteps = create_timesptes(N, sigma_min, sigma_max, rho)
# 初期サンプル
var = timesteps[0]
x = np.random.normal(0, var, size=(size,) + img_shape)
for i in range(N - 1):
t = timesteps[i]
t_next = timesteps[i + 1]
gamma = min(s_churn / N, np.sqrt(2) - 1) if s_min <= t <= s_max else 0
t_hat = t + gamma * t
e = np.random.normal(0, s_noise, size=(size,) + img_shape)
x_hat = x + np.sqrt(t_hat**2 - t**2) * e
denoise = model_forward(x_hat, t_hat)
d = (x_hat - denoise) / t_hat
x_next = x_hat + (t_next - t_hat) * d
if t_next != 0:
denoise = model_forward(x_next, t_next)
d_dash = (x_next - denoise) / t_next
x_next = x_hat + (t_next - t_hat) * (d / 2 + d_dash / 2)
x = x_next
return x
結果
・学習過程
Epoch 1/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:31<00:00, 14.66it/s, loss=0.258]
Epoch 2/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:27<00:00, 16.97it/s, loss=0.188]
Epoch 3/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:27<00:00, 17.00it/s, loss=0.197]
Epoch 4/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:27<00:00, 17.11it/s, loss=0.176]
Epoch 5/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:27<00:00, 17.00it/s, loss=0.174]
・生成結果
・生成過程
ノイズもなく綺麗な画像が生成されていますね。
全体コード
モデルは今まで通りUNetを使い、位置エンコーディングはサインコサインエンコードを使っています。
from pathlib import Path
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras
from tqdm import tqdm
kl = keras.layers
img_shape = (28, 28, 1)
def create_dataset():
(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()
# x_train = x_train[y_train == 1]
x_train = (x_train[..., np.newaxis] / 255.0) * 2 - 1 # [0,255] -> [-1,1]
x_train = x_train.astype(np.float32)
return x_train
def decode_image(x):
img = np.clip(x, -1.0, 1.0)
# [-1,1] -> [0,255]
img = (((img + 1) / 2) * 255).astype(np.uint8)
return img
# ----------------------------------
# model
# ----------------------------------
class PositionalEmbedding(kl.Layer):
def __init__(self, embedding_dim: int, max_position: int, **kwargs):
"""
サインコサイン位置埋め込みを計算するKerasレイヤー
:param embedding_dim: 埋め込み次元数
:param max_position: 最大の位置数
"""
super().__init__(**kwargs)
# 位置エンコーディングの計算
positions = np.arange(max_position)[:, np.newaxis] # (max_position, 1)
dims = np.arange(embedding_dim)[np.newaxis, :] # (1, embedding_dim)
# サイン・コサイン関数で位置埋め込みを計算
angle_rates = 1 / np.power(10000, (2 * (dims // 2)) / embedding_dim)
angle_rads = positions * angle_rates
# 偶数インデックスはsin、奇数インデックスはcos
positional_encoding = np.zeros_like(angle_rads)
positional_encoding[:, 0::2] = np.sin(angle_rads[:, 0::2])
positional_encoding[:, 1::2] = np.cos(angle_rads[:, 1::2])
self.positional_encoding = tf.constant(positional_encoding, dtype=tf.float32)
def call(self, inputs: tf.Tensor) -> tf.Tensor:
indices = tf.cast(tf.reshape(inputs, [-1]), tf.int32)
return tf.gather(self.positional_encoding, indices)
@staticmethod
def plot(embedding_dim: int = 500, max_position: int = 1000): # for debug
pos_emb = PositionalEmbedding(embedding_dim, max_position)
step = tf.constant(np.arange(0, max_position)[..., np.newaxis])
emb = pos_emb(step).numpy()
plt.pcolormesh(emb.T, cmap="RdBu")
plt.ylabel("dimension")
plt.xlabel("time step")
plt.colorbar()
plt.show()
def build_unet(img_shape):
x_t = kl.Input(shape=img_shape)
sigma = keras.Input(shape=(1, 1, 1))
# --- ノイズ情報の埋め込み
# (batch, 1) -> (batch, dim) -> (batch, 1, 1, dim) -> (batch, 28, 28, dim)
t_embedding = PositionalEmbedding(128, 10000)(sigma)
t_embedding = kl.Dense(128, activation="gelu")(t_embedding)
t_embedding = kl.Reshape((1, 1, 128))(t_embedding)
t_embedding = kl.UpSampling2D(img_shape[:2])(t_embedding)
# 埋め込み情報をチャンネルに追加
x = kl.Concatenate()([x_t, t_embedding])
# --- down sampling
c1 = kl.Conv2D(64, (3, 3), padding="same", activation="relu")(x)
c1 = kl.Conv2D(64, (3, 3), padding="same", activation="relu")(c1)
p1 = kl.MaxPooling2D((2, 2))(c1) # 28x28 -> 14x14
c2 = kl.Conv2D(128, (3, 3), padding="same", activation="relu")(p1)
c2 = kl.Conv2D(128, (3, 3), padding="same", activation="relu")(c2)
p2 = kl.MaxPooling2D((2, 2))(c2) # 14x14 -> 7x7
# --- ボトム
p2 = kl.Conv2D(256, (3, 3), activation="relu", padding="same")(p2)
# --- up sampling
u1 = kl.UpSampling2D((2, 2))(p2) # 7x7 -> 14x14
u1 = kl.Concatenate()([u1, c2])
u1 = kl.Conv2D(128, (3, 3), activation="relu", padding="same")(u1)
u1 = kl.Conv2D(128, (3, 3), activation="relu", padding="same")(u1)
u2 = kl.UpSampling2D((2, 2))(u1) # 14x14 -> 28x28
u2 = kl.Concatenate()([u2, c1])
u2 = kl.Conv2D(64, (3, 3), activation="relu", padding="same")(u2)
u2 = kl.Conv2D(64, (3, 3), activation="relu", padding="same")(u2)
y = kl.Conv2D(1, (1, 1), padding="same")(u2)
model = keras.Model(inputs=[x_t, sigma], outputs=y, name="u_net")
return model
class EDM:
def __init__(self, img_shape, sigma_data=0.5):
self.sigma_data = sigma_data
self.model = build_unet(img_shape)
self.optimizer = keras.optimizers.Adam(learning_rate=0.0005)
def calc_c(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / tf.sqrt(sigma**2 + self.sigma_data**2)
c_in = 1 / tf.sqrt(self.sigma_data**2 + sigma**2)
c_noise = tf.math.log(sigma) / 4
return c_skip, c_out, c_in, c_noise
def call(self, x, sigma, training=False):
c_skip, c_out, c_in, c_noise = self.calc_c(sigma)
fx = self.model([c_in * x, c_noise], training=training)
dx = c_skip * x + c_out * fx
return dx
def train(self, y, p_mean=-1.2, p_std=1.2):
# Noise distibution: ln(σ) ~ N(Pmena, Pstd^2)
r = np.random.randn(y.shape[0], 1, 1, 1).astype(np.float32)
sigma = np.exp(p_mean + r * p_std)
c_skip, c_out, c_in, c_noise = self.calc_c(sigma)
# λ(σ): (σ^2+σd^2) / (σ*σd)^2
lambda_w = (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)
# ノイズ
n = np.random.randn(*y.shape) * sigma
# tensorflowに取り込める形に変換
n = tf.convert_to_tensor(n, dtype=tf.float32)
sigma = tf.convert_to_tensor(sigma.reshape((-1, 1, 1, 1)), dtype=tf.float32)
with tf.GradientTape() as tape:
output = self.model([c_in * (y + n), c_noise])
target = (y - c_skip * (y + n)) / c_out
loss = lambda_w * (c_out**2) * ((output - target) ** 2)
loss = tf.reduce_mean(loss)
grad = tape.gradient(loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(grad, self.model.trainable_variables))
return loss.numpy()
def generate_deterministic(self, size: int, N: int, sigma_min=0.002, sigma_max=80, rho=7):
# Time steps
timesteps = self.create_timesptes(N, sigma_min, sigma_max, rho)
samples_history = []
# 初期サンプル
sigma_0 = timesteps[0] # Schedule
s_0 = 1 # Scaling
var = (sigma_0**2) * (s_0**2)
x = np.random.normal(0, var, size=(size,) + img_shape).astype(np.float32)
x = tf.convert_to_tensor(x, dtype=tf.float32)
for i in tqdm(range(N - 1), desc="sampling loop"):
t = timesteps[i]
t_next = timesteps[i + 1]
sigma_t = timesteps[i] # Schedule: 論文の表より t
sigma_dt = timesteps[i] - timesteps[i + 1]
s_t = 1 # Scaling: 論文の表より t
s_dt = 1 - 1
# tf
sigma_t = tf.convert_to_tensor([sigma_t] * size, dtype=tf.float32)
sigma_t = sigma_t[..., tf.newaxis, tf.newaxis, tf.newaxis]
denoise = self.call(x / s_t, sigma_t)
coeff = sigma_dt / sigma_t + s_dt / s_t
d = coeff * x - (sigma_dt * s_t / sigma_t) * denoise
x_next = x + (t_next - t) * d
if t_next != 0:
sigma_t_next = timesteps[i + 1]
sigma_dt_next = timesteps[i + 1] - timesteps[i + 2]
s_t_next = 1
s_dt_next = 1 - 1
# tf
sigma_t_next = tf.convert_to_tensor([sigma_t_next] * size, dtype=tf.float32)
sigma_t_next = sigma_t_next[..., tf.newaxis, tf.newaxis, tf.newaxis]
denoise = self.call(x_next / s_t_next, sigma_t_next)
coeff = sigma_dt_next / sigma_t_next + s_dt_next / s_t_next
d_dash = coeff * x_next - (sigma_dt_next * s_t_next / sigma_t_next) * denoise
x_next = x + (t_next - t) * (d / 2 + d_dash / 2)
samples_history.append(x_next)
x = x_next
return x, samples_history
def generate_stochastic(
self,
size: int,
N: int,
sigma_min=0.002,
sigma_max=80,
rho=7,
s_churn=0,
s_min=0,
s_max=float("inf"),
s_noise=1,
):
# Time steps
timesteps = self.create_timesptes(N, sigma_min, sigma_max, rho)
samples_history = []
# 初期サンプル
var = timesteps[0]
x = np.random.normal(0, var, size=(size,) + img_shape).astype(np.float32)
x = tf.convert_to_tensor(x, dtype=tf.float32)
for i in tqdm(range(N - 1), desc="sampling loop"):
t = timesteps[i]
t_next = timesteps[i + 1]
gamma = min(s_churn / N, np.sqrt(2) - 1) if s_min <= t <= s_max else 0
t_hat = t + gamma * t
e = np.random.normal(0, s_noise, size=(size,) + img_shape).astype(np.float32)
x_hat = x + np.sqrt(t_hat**2 - t**2) * e
# tf
t_hat = tf.convert_to_tensor([t_hat] * size, dtype=tf.float32)
t_hat = t_hat[..., tf.newaxis, tf.newaxis, tf.newaxis]
denoise = self.call(x_hat, t_hat)
d = (x_hat - denoise) / t_hat
x_next = x_hat + (t_next - t_hat) * d
if t_next != 0:
# tf
t_next = tf.convert_to_tensor([t_next] * size, dtype=tf.float32)
t_next = t_next[..., tf.newaxis, tf.newaxis, tf.newaxis]
denoise = self.call(x_next, t_next)
d_dash = (x_next - denoise) / t_next
x_next = x_hat + (t_next - t_hat) * (d / 2 + d_dash / 2)
samples_history.append(x_next)
x = x_next
return x, samples_history
@staticmethod
def create_timesptes(N: int, sigma_min=0.002, sigma_max=80, rho=7):
timesteps = []
for i in range(N):
t = (sigma_max ** (1 / rho) + i / (N - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
timesteps.append(t)
timesteps += [0] # N=0
return timesteps
# ----------------------------------
# main
# ----------------------------------
def sample():
# --- Time steps
timesptes = EDM.create_timesptes(20)
plt.figure(figsize=(8, 5))
plt.plot(range(len(timesptes)), timesptes, marker=".")
plt.xlabel("i")
plt.ylabel("Time step")
plt.title("Visualization of Time steps")
plt.grid()
plt.show()
# --- 位置エンコーディングの可視化
PositionalEmbedding.plot()
# --- Model
model = build_unet(img_shape)
model.summary()
def train(epochs: int, batch_size: int = 128):
x_train = create_dataset()
# モデル
edm = EDM(img_shape)
# 学習用にデータをバッチ化
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(len(x_train)).batch(batch_size)
for epoch in range(epochs):
with tqdm(train_dataset, desc=f"Epoch {epoch + 1}/{epochs}") as pbar:
for y in pbar:
loss = edm.train(y)
pbar.set_postfix(loss=loss) # 損失を進捗バーに表示
edm.model.save_weights(Path(__file__).parent / "edm.weights.h5")
def generate(steps, gen_type="stochastic", **gen_kwargs):
edm = EDM(img_shape)
edm.model.load_weights(Path(__file__).parent / "edm.weights.h5")
# 生成
num_samples = 16
if gen_type == "stochastic":
samples, samples_history = edm.generate_stochastic(num_samples, steps, **gen_kwargs)
else:
samples, samples_history = edm.generate_deterministic(num_samples, steps, **gen_kwargs)
samples = decode_image(samples)
samples_history = decode_image(samples_history)
# 結果
plt.figure(figsize=(10, 10))
for i in range(num_samples):
plt.subplot(4, 4, i + 1)
plt.imshow(samples[i, :, :, 0], cmap="gray")
plt.axis("off")
plt.show()
# 作成過程
index = 4
img_list = np.array(samples_history)[:, index, :, :, 0]
plt.figure(figsize=(20, 5))
step_idxs = list(range(0, len(img_list), int(steps / 12))) # 多いので一定間隔で抜き出し
step_idxs += [len(img_list) - 1] # 最後も追加
for i, idx in enumerate(step_idxs):
plt.subplot(1, len(step_idxs), i + 1)
plt.imshow(img_list[idx], cmap="gray")
plt.xticks([])
plt.yticks([])
plt.xlabel(f"step={idx}")
plt.show()
if __name__ == "__main__":
# sample()
train(epochs=5)
generate(steps=50, gen_type="stochastic")
# generate(steps=50, gen_type="deterministic")
# generate(steps=256, s_churn=40, s_min=0.05, s_max=50,s_noise=1.003)
最後に
ここにたどり着くまで長かった…。
深堀するとかなり難しい内容でした…。
特に日本語の資料が全然なかったのがきつかった。
誰かの参考になれば幸いです。