はじめに
This Kanji doesn't exist (一部ありそうですが...)
diffusionモデルを作成して、漢字を生成した記録になります
- Diffusionモデルを実際にPythonで動かして、学習方法を理解することが目的です
- 詳しい理論や数式の導出は説明はしておりません
- 多様でありながら単純であり、かつ簡単に準備できる学習データとして"漢字"を利用しました
- コードはこちらを参考にしており、ネットワーク部分はそのまま利用しています
https://github.com/tcapelle/Diffusion-Models-pytorch
もくじ
- ライブラリのインポート
- 漢字データセットの作成
- 拡散過程
- 逆拡散過程
- 学習コード
- 結果
- 漢字生成
0. ライブラリのインポート
学習で利用するライブラリをインポートします。
学習はGoogle Colaboratoryで実施しました。
2023/11/25の時点では全てデフォルトで入っています。
import os
import math
import random
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
# ハイパーパラメータをdataclassとjsonを使って管理するため
import json
from dataclasses import dataclass, field
# 漢字の画像を生成するため
from fontTools import ttLib
from PIL import Image, ImageFont, ImageDraw
# UNetはこちらを利用しています
# ファイルをダウンロードしてインポートするか、コピペしてください
# https://github.com/tcapelle/Diffusion-Models-pytorch/blob/main/modules.py
from modules import UNet
1. 漢字データセットの作成
-
手順
PIL(Python Image Library)で真っ白なカンバスを準備しておいて、
16進数の常用漢字コード → 作成する漢字1文字を決定 → PILでカンバスに漢字を描画 → 図をnumpy arrayに変換という手順です。本題ではないので、詳細な説明は省略します。 -
フォントファイル
任意の文字サイズにするにはフォントファイルを指定する必要があったので、フォントファイルを渡しています(なくてもできる?)
フォントファイルの利用について
フォントファイルはMacならシステム→ライブラリ→Fontsに、windowsならC:\Windows\Fontsにありますのでそのファイルを指定してください。
ただし、フォントファイルには著作権や利用規約があリますのでご注意ください。
本件は、画像化して再生成しておりますが、画像を単体で表示するだけでフォントとして機能する(意味のある文字列を示す)形式では利用していないので、利用規約に違反していないと考えております。
再生成したフォント画像をフォントとして利用する場合は利用規約違反になりますのでご注意ください。
https://www.screen.co.jp/ga_product/sento/support/licensefaq.html#anchor1-5
# %%==========================================================================
# 漢字の画像ファイルを作成する
# ============================================================================
def create_kanji_images(pix, font_file):
"""漢字から画像データを作る関数
Args:
pix (int): 生成する画像のピクセル数
Returns:
numpy array: 生成した画像セット
"""
image_size = (pix, pix)
font_size = pix - 2 * 2
margin = (2, 2) # 外側に2ピクセルずつ空白をあける
image_font = ImageFont.truetype(font=font_file, size=font_size, index=0)
with ttLib.ttFont.TTFont(font_file, fontNumber=0) as font:
cmap = font.getBestCmap()
# 常用漢字コードは16進数で4E00〜9FFF
arr_list = []
for cid in range(0x4E00, 0x9FFF):
# FONT FILEに入っていない漢字は飛ばす
if cid not in cmap.keys():
continue
# 漢字コードを漢字に変換して、カンバスに描画する
letter = chr(cid)
im = Image.new(mode='1', size=image_size, color=0)
draw = ImageDraw.Draw(im)
draw.text(xy=margin, text=letter, font=image_font, fill=1)
arr_list.append(np.array(im)[np.newaxis, :, :])
return np.concatenate(arr_list, axis=0)
# %%==========================================================================
# Dataset
# ============================================================================
class DataSet(torch.utils.data.Dataset):
def __init__(self, pix, font_file):
img_np = create_kanji_images(pix, font_file)
self.img = torch.from_numpy(img_np).to(dtype=torch.float32)
self.img = self.img.unsqueeze(1)
def __getitem__(self, idx):
return self.img[idx, ...]
def __len__(self):
return self.img.shape[0]
def create_dataset(pix, font_file, exist_load=True):
"""漢字のデータセットを作る.
google colabではすごく時間がかかるので、保存しておき、保存データがある場合はロードする
Args:
pix(int): 画像のピクセル数
font_file(str): 画像のフォントファイル
exist_load(bool): Trueの場合、保存してあるdatasetがあればロードする. default True
Returns:
Dataset: kanji dataset
"""
dataset_path = f"kanji_dataset_{pix}.pt"
if exist_load and os.path.exists(dataset_path):
dataset = torch.load(dataset_path)
print("dataset has been loaded")
else:
dataset = DataSet(pix=pix, font_file=font_file)
torch.save(dataset, dataset_path)
print("dataset has been created")
return dataset
if __name__ == "__main__":
font_file = r"./ヒラギノ角ゴシック W5.ttc"
dataset = create_dataset(32, font_file)
img = dataset[3]
create_dataset関数を呼び出すことで、漢字のnumpy配列を返すdatasetが作成できます。
得られた画像を表示するとこのような画像が得られます。
手元のMacbookで実行すると1,2秒でdatasetが作成できますが、google colabでは数分かかるようです(なぜ?)。
生成したdatasetは保存して、2回目以降はloadして使えるようにしています。
2. 拡散過程
まずは、明瞭な画像に少しずつノイズを加えていく拡散過程のメソッドを作成します。
使う計算式はこちらです。
# %%==========================================================================
# Denoising Diffusion Probabilistic Models
# ============================================================================
class DDPM(nn.Module):
def __init__(self, T, device):
super().__init__()
self.device = device
self.T = T
# β1 and βt はオリジナルの ddpm reportに記載されている値を採用します
self.beta_1 = 1e-4
self.beta_T = 0.02
# β = [β1, β2, β3, ... βT] (length = T)
self.betas = torch.linspace(self.beta_1, self.beta_T, T, device=device)
# α = [α1, α2, α3, ... αT] (length = T)
self.alphas = 1.0 - self.betas
# α bar [α_bar_1, α_bar_2, ... , α_bar_T] (length = T)
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def diffusion_process(self, x0, t=None):
if t is None:
t = torch.randint(low=1, high=self.T, size=(x0.shape[0],), device=self.device)
noise = torch.randn_like(x0, device=self.device)
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar) * x0 + torch.sqrt(1 - alpha_bar) * noise
return xt, t, noise
👉 __ init __()
変換に用いるαやβは何度も使うことになるので、イニシャライズの段階で計算して持っておきます。
- self.betas : 各時刻で付加するノイズの強度$\beta_t$を入れている配列です
$$x_t = \sqrt{1-\beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon$$
diffusionモデルの論文では次の値が使われており、本記事でも同様にこの値を使います。$\beta_1から\beta_T$までの間の値は$torch.linspace$で線形に補間します。
\displaylines{
\beta_1 = 1e-4\\
\beta_T = 0.02
}
- self.alphas : $\alpha_t = 1 - \beta_t$ で計算される$\alpha_t$の配列です
- self.alpha_bars : $\bar\alpha_t$の配列です
$x_{t-1}$の画像から$x_t$の画像は
$$x_t = \sqrt{1-\beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon$$
こちらの式で計算できますが、$x_{t-1}$は$x_{t-2}$から計算できるので、頑張って計算すると、
各時刻tのノイズ入り画像は明瞭な画像$x_0$から
$$x_t = \sqrt{\bar\alpha_t}x_0 + \sqrt{1-\bar\alpha_t}\epsilon$$
で計算できます。ここで、
$$\bar\alpha_t= \prod_{i=1}^t(1-\beta_i) $$
この$\bar\alpha_t$の配列を準備しておくことで、各時刻tの画像$x_t$を簡単に計算できます。
ここで準備した配列のサイズはいずれも$(T, )$となっていますので注意が必要です。
計算に用いるときはreshapeします。
👉 diffusion_process()
明瞭な画像$x_0$から各時刻tのノイズ入り画像$x_t$を生成するメソッドです。
こちらを使って、diffusionモデルの学習データを生成します。
- 時刻tの生成
引数にtをとっていますが、学習の過程ではNoneを受け取りますので、メソッドの中で生成します。
tは[523, 320, 427, 541, 43, ...] のように1~Tまでのランダムな整数値で、配列の大きさは(バッチサイズ, )です。つまり、バッチの各データについてランダムに時間を決めています。
- $x_t$の生成
$x_0$と同じサイズのノイズ$\epsilon$を生成して前述のこちらの式で$x_t$を求めます
$$x_t = \sqrt{\bar\alpha_t}x_0 + \sqrt{1-\bar\alpha_t}\epsilon$$
$\bar\alpha$はテンソル計算ができるようにreshapeしています。ch、H, Wの次元はtorchがブロードキャストしてくれます。
- return
このメソッドの返り値は$x_t$, $t$, $\epsilon$です。
時刻tの画像およびその時刻tがインプットで、その時加えられたノイズを教師データとしてdiffusionモデルを学習します。
試しに動かしてみます
if __name__ == "__main__":
x0 = dataset[3]
# バッチ次元がないので、追加します
x0 = x0.unsqueeze(0)
ddpm = DDPM(1000, "cpu")
xt, t, noise = ddpm.diffusion_process(x0)
print(f"{xt.shape=}", f"\n{t.shape=}", f"\n{noise.shape=}")
# 以下は描画のための関数です
fig, ax = plt.subplots(1,3)
ax[0].imshow(1. - x0[0, 0, :, :], cmap="gray")
ax[1].imshow(1. - xt[0, 0, :, :], cmap="gray")
ax[2].imshow(1. - noise[0, 0, :, :], cmap="gray")
plt.show()
plt.clf()
plt.close()
xt, t, noiseのshapeを確認しておきましょう。
漢字は白黒なので、ch数は1です。
xt.shape=torch.Size([1, 1, 32, 32]) # (batch, ch, H, W)
t.shape=torch.Size([1]) #(batch)
noise.shape=torch.Size([1, 1, 32, 32]) # (batch, ch, H, W)
こちらの様な画像ができます。
左) 明瞭な漢字、 (中)ノイズが加えられた漢字、(右)ノイズ
これで学習データを生成する準備ができました。
3. 逆拡散過程
ノイズを取り除いて画像を生成する逆拡散過程のメソッドを作ります。
こちらは先ほど作ったDDPMクラスの中に作ります。
ノイズを予測するモデルはDDPMクラスの外側で作って、引数に渡すコードとしています。
DDPMクラスに持たせてもいいと思うのですが、何かとモデルにアクセスする機会が多いので、メソッドの引数で渡す形にしました。
def denoising_process(self, model, img, ts):
batch_size = img.shape[0]
model.eval()
with torch.no_grad():
time_step_bar = tqdm(reversed(range(1, ts)), leave=False, position=0)
for t in time_step_bar: # ts, ts-1, .... 3, 2, 1
# 整数値のtをテンソルに変換。テンソルのサイズは(バッチサイズ, )
time_tensor = (torch.ones(batch_size, device=self.device) * t).long()
# 現在の画像からノイズを予測
prediction_noise = model(img, time_tensor)
# 現在の画像からノイズを少し取り除く
img = self._calc_denoising_one_step(img, time_tensor, prediction_noise)
model.train()
# 0~255のデータに変換して返す
img = img.clamp(-1, 1)
img = (img + 1) / 2
img = (img * 255).type(torch.uint8)
return img
def _calc_denoising_one_step(self, img, time_tensor, prediction_noise):
beta = self.betas[time_tensor].reshape(-1, 1, 1, 1)
sqrt_alpha = torch.sqrt(self.alphas[time_tensor].reshape(-1, 1, 1, 1))
alpha_bar = self.alpha_bars[time_tensor].reshape(-1, 1, 1, 1)
sigma_t = torch.sqrt(beta)
noise = torch.randn_like(img, device=self.device) if time_tensor[0].item() > 1 else torch.zeros_like(img, device=self.device)
img = 1 / sqrt_alpha * (img - (beta / (torch.sqrt(1 - alpha_bar))) * prediction_noise) + sigma_t * noise
return img
👉 denoising_process
-
引数
- model : ノイズを予測する学習済みのモデル
- img : ノイズを除去する画像
- ts : 引数imgに渡した画像の時刻
imgとtsは通常、imgが完全なノイズ、tsは完全なノイズ状態となる画像の時刻Tです。
後の漢字生成のところで色々と試すために引数として受け取れる様にしています。
-
繰り返し処理
繰り返し処理で変数tに渡される値は、時刻Tから1までの各時刻の値です。
$x_t$からノイズを取り除いた画像$x_{t-1}$の画像を生成しますので、tはTから0でなく、Tから1までです。 -
デノイズ処理
モデルを使って、$x_t$の画像から、加えられているノイズ$\epsilon$を予測します。
そして、$x_t$と$\epsilon$から、少しだけ明瞭な画像$x_{t-1}$を生成します。
この部分は数式が長いので別メソッドにしてあります。
👉 _calc_denoising_one_step
こちらの式を計算しています。
Diffusionモデルの学習がうまくいかないとしばらく悩んでいましたが、この数式を打ち間違えておりました。
お気をつけください。
x_{t-1}=\frac{1}{\sqrt{1-\beta_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}})\epsilon_\theta + \sigma_tZ
4. 学習コード
ここまでで、
- 明瞭な画像からノイズが付加された不明瞭な画像を生成する拡散過程
- ノイズ予測ができるモデルを渡せば、不明瞭な画像からノイズを取り除く逆拡散過程
のメソッドを作りました。
あとは、ノイズを予測するモデルを作ってあげればdiffusionモデルが出来上がります。
モデルの学習は、回帰モデルの学習とほぼ同じになりますが、いくつか注意点を記しておきます。
-
ハイパーパラメータ
ハイパーパラメータはdataclassで定義して、学習関数ddpm_trainに渡しています。
params = HyperParameters() -
ネットワーク
画像から画像を予測するネットワークになります。
基本はUNetがベースを作ります。
本記事では、ネットワークの検証があまりできていないため、こちらのgithubで公開されているコードを利用しています。
https://github.com/tcapelle/Diffusion-Models-pytorch -
loss関数
loss関数はMSE Lossで良いことが数学的に導出できるようです -
教師データ
datasetは明瞭な画像のみを返しますので、拡散過程で
xt, t, noise = ddpm.diffusion_process(x)で、
ノイズ入り画像、時刻、ノイズを作成して、教師データとします。 -
途中経過の表示
params.img_save_stepsで定義したepoch毎に、画像を生成して保存します。
diffusionモデルによる画像生成はGANやVAEと比べて、生成に非常に時間がかかるためご注意ください。
例えば、T=1000とした場合、UNetを使って1000回ノイズを予測することになります。
# %%==========================================================================
# ddpm training
# ============================================================================
def ddpm_train(params):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device=}")
log_dir = make_save_dir_and_save_params(params)
model_path = os.path.join(log_dir, f"model_weight_on_{device}")
# 必要なモデルなどを生成
dataset = create_dataset(params.pix, params.font_file, exist_load=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, shuffle=True, drop_last=True)
ddpm = DDPM(params.time_steps, device)
model = UNet(params.image_ch, params.image_ch).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=params.lr)
loss_fn = torch.nn.MSELoss()
start_epoch = 1
loss_logger = []
loss_min = 9e+9
# 継続して計算する場合はロードする
if params.load_file and os.path.exists(model_path):
model, optimizer, start_epoch, loss_logger, loss_min = load_checkpoint(params, model, optimizer, model_path, device)
# training
model.train()
epoch_bar = tqdm(range(start_epoch, params.epochs+1))
for epoch in epoch_bar:
epoch_bar.set_description(f"Epoch:{epoch}")
loss_tmp = 0
iter_bar = tqdm(dataloader, leave=False)
for iter, x in enumerate(iter_bar):
x = x.to(device)
# xにノイズを加えて学習データを作成する
xt, t, noise = ddpm.diffusion_process(x)
# モデルによる予測〜誤差逆伝播
out = model(xt, t)
loss = loss_fn(noise, out)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# lossの経過記録
iter_bar.set_postfix({"loss=": f"{loss.item():.2e}"})
loss_tmp += loss.item()
loss_logger.append(loss_tmp / (iter + 1))
epoch_bar.set_postfix({"loss=": f"{loss_logger[-1]:.2e}"})
# 保存処理
# lossの経過グラフを出力
save_loss_logger_and_graph(log_dir, loss_logger)
# lossが最小の場合は重みデータを保存
if loss_min >= loss_logger[-1]:
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_logger,
}, model_path)
loss_min = loss_logger[-1]
# 指定したstepで逆拡散過程による画像生成
if epoch % params.img_save_steps == 0:
x0 = torch.randn([32, params.image_ch, params.pix, params.pix], device=device)
img = ddpm.denoising_process(model, x0, params.time_steps).to("cpu")
save_images_plt(img, log_dir=log_dir, epoch=epoch)
def load_checkpoint(params, model, optimizer, model_path, device):
print(f"load model {model_path}")
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss_logger = checkpoint["loss"]
loss_min = min(loss_logger)
print(start_epoch)
return model, optimizer, start_epoch, loss_logger, loss_min
def make_save_dir_and_save_params(params):
# タスクの保存フォルダ
log_dir = os.path.join(r"./", "log", params.task_name)
os.makedirs(log_dir, exist_ok=True)
# epoch, iter毎のデータは多くなるので別フォルダを作る
log_dir_hist = os.path.join(log_dir, "hist")
os.makedirs(log_dir_hist, exist_ok=True)
# 設定ファイルの保存
with open(os.path.join(log_dir, "parameters.json"), 'w') as f:
json.dump(vars(params), f, indent=4)
return log_dir
def save_loss_logger_and_graph(log_dir, loss_logger):
# loss履歴情報を保管しつつ、グラフにして画像としても書き出す
torch.save(loss_logger, os.path.join(log_dir, "loss_logger.pt"))
fig, ax = plt.subplots(1,1)
epoch = range(len(loss_logger))
ax.plot(epoch, loss_logger, label="train_loss")
ax.set_ylim(0, loss_logger[-1]*5)
ax.legend()
fig.savefig(os.path.join(log_dir, "loss_history.jpg"))
plt.clf()
plt.close()
def save_images_plt(images, log_dir, epoch, s=2):
# 生成した画像を並べた図を作成して保存する
num_img = images.shape[0]
img_arr = 255 - images.detach().numpy()
num_row = int(num_img ** 0.5)
num_col = (num_img - 1) // num_row + 1
fig, ax = plt.subplots(num_row, num_col,
figsize=(num_col*s, num_row*s),
tight_layout=True,
sharex=True, sharey=True )
axs = ax.ravel() if num_img > 1 else [ax]
for i, img in enumerate(img_arr):
if img.shape[0] == 1:
img = img[0, :, :]
axs[i].imshow(img, cmap="gray")
else:
img = np.transpose(img, [1, 2, 0])
axs[i].imshow(img)
if isinstance(epoch, int):
epoch = f"{epoch:06d}"
fig.suptitle(f"epoch={epoch}", size=15)
fig.savefig(os.path.join(log_dir, "hist", f"kanji_imgs_epoch_{epoch}.jpg"))
fig.savefig(os.path.join(log_dir, f"kanji_imgs_epoch_latest.jpg"))
plt.clf()
plt.close()
@dataclass
class HyperParameters:
task_name: str = "kanji_diffusion"
epochs: int = 500
img_save_steps: int = 10
batch_size: int = 128
lr: float = 3e-4
time_steps: int = 1000 # T もう少し小さくても良いはず
load_file: bool = True
pix: int = 32
font_file: str = r"./ヒラギノ角ゴシック W5.ttc"
image_ch: int = 1
5. 結果
では、学習を実行します。
params = HyperParameters()
ddpm_train(params)
loss履歴
下の図は横軸がepoch, 縦軸がlossのグラフです。
徐々にlossが減少し学習が進んでいることがわかります。
epoch10
白のキャンバスに何か書いたような図が生成されています。
しかし、漢字には見えません。
epoch50
だいぶ漢字っぽくなってきましたが、無駄な線が多い印象を受けます。
epoch210
lossの減少がだいぶサチってきたepoch数です。
ほとんどの画像が"漢字っぽい"画像となっています。
6. 漢字生成
上記の漢字は完全なノイズから漢字っぽい画像を生成しました。
続いて、ベースの画像に途中までノイズを入れて、その画像からノイズを取り除いてみたいと思います。
Diffusionモデルを使って新しい画像を生成する場合、通常完全なノイズ画像(下図一番左)から徐々にノイズを除去します。ここで、例えば下の例では「七」がうっすらと見える状態でノイズ負荷を止めて、その画像からノイズを除去して画像を生成します。
「三」 ベース画像
まずは、「三」をベースにしてみます。
ここで表示している9つの図は全て同じ図です。
「三」 t=200
これにノイズを加えたt=200の画像がこちらです。
ランダムにノイズを加えているので、9つの画像は少しずつ異なります。
「三」 t=200からの生成
この画像をスタートとして、ノイズを除去した画像がこちらです。
漢字の「三」の特徴を引き継いだ画像が生成されています。
逆拡散過程も確率的な操作になりますので、ベースにするノイズ入り画像に同じものを用いても異なる結果になります。
もう少し、ノイズを多く加えてから画像を生成してみます。
「三」 t=250からの生成
「三」 t=300からの生成
「三」 t=400からの生成
t=400程度までノイズを加えると、ほとんど「三」の特徴がなくなりました
「口」 t=325からの生成
「国」 t=375からの生成
「宇」 t=350からの生成
「国」 の内側だけ t=300
「国」という漢字の"くにがまえ"を残して、内側だけにノイズを付与してこちらの画像を作ります。
(ノイズを加えると0~1の外側の値が含まれるため、グレースケールで表示すると背景色やくにがまえ部分の色が変わってしまいますが、配列として持っている値は変えていません)
この画像から生成するとくにがまえをもつ漢字が生成されやすくなります。
しかし、約半数が「門」や「冂」に変化している点が面白いです。
「宀」 うかんむり
うかんむりでも同様にやってみます
半数以上は宀ではなくなってしまいましたが、一部宀の漢字を生成できました。
所感
拡散、逆拡散過程のコードがちょっと面倒ですが、回帰と同様の簡単な学習コードでかなり安定して学習することができます。
生成に時間を要するものの、簡単なコードで明瞭な画像を生成でき驚きました。
次はネットワークの理解と条件付き生成に挑戦したいと思います。
参考文献
- シンプルなDiffusionモデルの実装です
最初は自分でコードを書いていたのですが、うまく学習が進まずこちらを参考にしました。
- こちらもシンプルなDiffusionモデルの実装解説です
- SonyさんによるDiffusionモデルの解説動画です
最初は難しくて理解できなかったのですが、何度も見て勉強させて貰いました。
基礎知識があれば、非常に分かりやすい動画だと思います。
- 丁寧に数式からDiffusionモデルを解説されています
- Diffusionモデルの論文です