1. はじめに
GAN(敵対的生成ネットワーク)を用いて、グレースケール画像に自動彩色をしていきたいと思います。
技術的には「pix2pix」と呼ぶらしいです。
このグレースケール画像が
自動で以下のように彩色できました!!
ところどころ、おかしなところもあったり、うまくいかない画像もあったりしますが、結構自然な着色になっています。
ちなみに元画像を一番下の段だけ示すと、
こんな感じ。電車やベッドの色は違ったりしますが、全体的には同じような色合いで塗れている気がします。
2.ざっくりとした今回の学習のイメージ
ざっくりとした今回の学習のイメージは以下のようになります。
GANなのでGenerator, Discriminatorの2つのネットワークを用いてます。
(1)
このようにGeneratorとDiscriminator、 2つのネットワークを交互にだまし合うように学習させます。
3.学習のネットワークについて
今回はpytorch 1.1, torchvision 0.30を用いてます。とりあえず、使用するライブラリをimport
import glob
import os
import pickle
import torch
import torch.nn.functional as F
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np #1.16.4
import matplotlib.pyplot as plt
from PIL import Image
from torch import nn
from skimage import io
環境は
windows10, Anaconda1.9.7,
core-i3 8100, RAN 16.0 GB
GEFORCE GTX 1060
結構、学習時間がかかるのでGPU推奨です。
3-1.Generator
Generatorにはセマンティックセグメンテーションに用いられているU-netを用いています。
Encoder-Decoderのネットワークで、入力画像と同じ形状の出力画像を得ることができます。
入力画像はGray画像,出力画像はカラー画像(Fake画像)です。
このU-netの特徴はCopy and Cropの部分になります。
入力層に近い出力を出力層に近い層にも加えて、元画像の形を崩さないようにする工夫(らしい)です。
このCopy and Cropをpytorchで実現するのは結構簡単で、
・torch.catを用いて、入力を結合する。
・Conv2dやBatchNorm2dの入力のチャンネル数を倍にする。
だけです。初めて見たときは結構感動しました。
ただしtorch.catで結合するtensorの形状を合わせておく必要があります。
このU-netをそのまま用いると、かなり巨大なネットワークになります。
(CNNが18個ぐらいあるように見える)
そのため、ネットワークを小さくして、入出力画像のサイズも3×128×128まで小さくします。
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
self.bn1 = nn.BatchNorm2d(32)
self.av2 = nn.AvgPool2d(kernel_size=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.av3 = nn.AvgPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.av4 = nn.AvgPool2d(kernel_size=2)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.av5 = nn.AvgPool2d(kernel_size=2)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2d(256)
self.un6 = nn.UpsamplingNearest2d(scale_factor=2)
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.bn6 = nn.BatchNorm2d(256)
#conv7にはconv6の出力とconv4の出力を流す, input channelが2倍
self.un7 = nn.UpsamplingNearest2d(scale_factor=2)
self.conv7 = nn.Conv2d(256 * 2, 128, kernel_size=3, stride=1, padding=1)
self.bn7 = nn.BatchNorm2d(128)
#conv8にはconv7の出力とconv3の出力を流す, input channelが2倍
self.un8 = nn.UpsamplingNearest2d(scale_factor=2)
self.conv8 = nn.Conv2d(128 * 2, 64, kernel_size=3, stride=1, padding=1)
self.bn8 = nn.BatchNorm2d(64)
#conv9にはconv8の出力とconv2の出力を流す, input channelが2倍
self.un9 = nn.UpsamplingNearest2d(scale_factor=4)
self.conv9 = nn.Conv2d(64 * 2, 32, kernel_size=3, stride=1, padding=1)
self.bn9 = nn.BatchNorm2d(32)
self.conv10 = nn.Conv2d(32 * 2, 3, kernel_size=5, stride=1, padding=2)
self.tanh = nn.Tanh()
def forward(self, x):
#x1-x4はtorch.catする必要があるので,残しておく
x1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
x2 = F.relu(self.bn2(self.conv2(self.av2(x1))), inplace=True)
x3 = F.relu(self.bn3(self.conv3(self.av3(x2))), inplace=True)
x4 = F.relu(self.bn4(self.conv4(self.av4(x3))), inplace=True)
x = F.relu(self.bn5(self.conv5(self.av5(x4))), inplace=True)
x = F.relu(self.bn6(self.conv6(self.un6(x))), inplace=True)
x = torch.cat([x, x4], dim=1)
x = F.relu(self.bn7(self.conv7(self.un7(x))), inplace=True)
x = torch.cat([x, x3], dim=1)
x = F.relu(self.bn8(self.conv8(self.un8(x))), inplace=True)
x = torch.cat([x, x2], dim=1)
x = F.relu(self.bn9(self.conv9(self.un9(x))), inplace=True)
x = torch.cat([x, x1], dim=1)
x = self.tanh(self.conv10(x))
return x
3-2.Discriminator
Discriminatorは普通の画像識別ネットワークに近い構成です。
ただし、出力は1次元ではなく、n×n個の数字になります。
この分割された領域毎にTrue or Falseを出力します。下の画像の場合は4×4ですね。
後は活性化関数にGANの定番のLeakly Relu、
BatchNorm2dの代わりにInstanceNorm2dを用いています。
InstanceNorm2dとBatchNorm2d,両方試しましたが、実はあまり結果に差を感じられませんでした。
Pix2PixではInstanceNorm2dがよいときいたので、今回はこちらを採用しています。
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)
self.in1 = nn.InstanceNorm2d(16)
self.av2 = nn.AvgPool2d(kernel_size=2)
self.conv2_1 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.in2_1 = nn.InstanceNorm2d(32)
self.conv2_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.in2_2 = nn.InstanceNorm2d(32)
self.av3 = nn.AvgPool2d(kernel_size=2)
self.conv3_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.in3_1 = nn.InstanceNorm2d(64)
self.conv3_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.in3_2 = nn.InstanceNorm2d(64)
self.av4 = nn.AvgPool2d(kernel_size=2)
self.conv4_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.in4_1 = nn.InstanceNorm2d(128)
self.conv4_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.in4_2 = nn.InstanceNorm2d(128)
self.av5 = nn.AvgPool2d(kernel_size=2)
self.conv5_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.in5_1 = nn.InstanceNorm2d(256)
self.conv5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.in5_2 = nn.InstanceNorm2d(256)
self.av6 = nn.AvgPool2d(kernel_size=2)
self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.in6 = nn.InstanceNorm2d(512)
self.conv7 = nn.Conv2d(512, 1, kernel_size=1)
def forward(self, x):
x = F.leaky_relu(self.in1(self.conv1(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in2_1(self.conv2_1(self.av2(x))), 0.2, inplace=True)
x = F.leaky_relu(self.in2_2(self.conv2_2(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in3_1(self.conv3_1(self.av3(x))), 0.2, inplace=True)
x = F.leaky_relu(self.in3_2(self.conv3_2(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in4_1(self.conv4_1(self.av4(x))), 0.2, inplace=True)
x = F.leaky_relu(self.in4_2(self.conv4_2(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in5_1(self.conv5_1(self.av5(x))), 0.2, inplace=True)
x = F.leaky_relu(self.in5_2(self.conv5_2(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in6(self.conv6(self.av6(x))), 0.2, inplace=True)
x = self.conv7(x)
return x
3-3.確認
torch.randnを用いて擬似的な画像を生成し、
Generator, Discriminatorの出力サイズを確認します。
ここでは3×128×128のサイズの画像を2枚生成、Generator、Discriminatorの2つに入力しています。
g, d = Generator(), Discriminator()
#乱数による疑似画像
test_imgs = torch.randn([2, 3, 128, 128])
test_imgs = g(test_imgs)
test_res = d(test_imgs)
print("Generator_output", test_imgs.size())
print("Discriminator_output",test_res.size())
出力は以下のようになりました。
Generator_output torch.Size([2, 3, 128, 128])
Discriminator_output torch.Size([2, 1, 4, 4])
Generatorのアウトプットサイズが、入力と同じです。
Discriminatorのアウトプットサイズが 4×4になっています。
4. データローダーについて
b.の部分のデータ拡張
class DataAugment():
#PIL imageをデータオーギュメンテーション, PILをreturn
def __init__(self, resize):
self.data_transform = transforms.Compose([
transforms.RandomResizedCrop(resize, scale=(0.9, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()])
def __call__(self, img):
return self.data_transform(img)
dのtensorに変換する部分では,データの正規化も同時に行います。
class ImgTransform():
#PILのimageをresize,正規化してtensorをreturn
def __init__(self, resize, mean, std):
self.data_transform = transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
def __call__(self, img):
return self.data_transform(img)
PytorchのDatasetクラスを継承をしたクラスで、a-dまでのフローは getitemの場所に書きます。
getitemの部分に1画像の入力と出力のフローを作ることで、簡単にデータローダーを作成できます。
class MonoColorDataset(data.Dataset):
"""
PytorchのDatasetクラスを継承
"""
def __init__(self, file_list, transform_tensor, augment=None):
self.file_list = file_list
self.augment = augment #PIL to PIL
self.transform_tensor = transform_tensor #PIL to Tensor
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
#index番号のファイルパスを取得
img_path = self.file_list[index]
img = Image.open(img_path)
img = img.convert("RGB")
if self.augment is not None:
img = self.augment(img)
#モノクロ画像用のコピー
img_gray = img.copy()
#カラー画像をモノクロ画像に変換
img_gray = transforms.functional.to_grayscale(img_gray,
num_output_channels=3)
#PILをtensorに変換
img = self.transform_tensor(img)
img_gray = self.transform_tensor(img_gray)
return img, img_gray
augment=Noneとすることで、データ拡張をしない、すなわちテストデータ用のデータセットになります。
データローダを作る関数は以下のようにしました。
def load_train_dataloader(file_path, batch_size):
"""
Input
file_path 取得したい画像のファイルパスのリスト
batch_size データローダのバッチサイズ
return
train_loader, RGB_images and Gray_images
"""
size = 128 #画像の1辺のサイズ
mean = (0.5, 0.5, 0.5) #画像の正規化した際のチャンネル毎の平均値
std = (0.5, 0.5, 0.5) #画像の正規化した際のチャンネル毎の標準偏差
#データセット
train_dataset = MonoColorDataset(file_path_train,
transform=ImgTransform(size, mean, std),
augment=DataAugment(size))
#データローダー
train_dataloader = data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True)
return train_dataloader
5.可視化方法
5.1 可視化する関数
複数の画像をタイル状に並べるには「torchvision.utils.make_grid」を使うと便利です。
tensorでタイル状の画像を生成した後にnumpyに変換して、matplotlibで描画します。
def mat_grid_imgs(imgs, nrow, save_path = None):
"""
pytorchのtensor(imgs)をタイル状に描画する関数
nrowでタイルの1辺の数を決定
"""
imgs = torchvision.utils.make_grid(
imgs[0:(nrow**2), :, :, :], nrow=nrow, padding=5)
imgs = imgs.numpy().transpose([1,2,0])
imgs -= np.min(imgs) #最小値を0
imgs /= np.max(imgs) #最大値を1
plt.imshow(imgs)
plt.xticks([])
plt.yticks([])
plt.show()
if save_path is not None:
io.imsave(save_path, imgs)
テスト画像をロードして、gray画像とfake画像をタイル状に描画する関数です。
def evaluate_test(file_path_test, model_G, device="cuda:0", nrow=4):
"""
test画像をロード,gray画像とfake画像をタイル状に描画
"""
model_G = model_G.to(device)
size = 128
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
test_dataset = MonoColorDataset(file_path_test,
transform=ImgTransform(size, mean, std),
augment=None)
test_dataloader = data.DataLoader(test_dataset,
batch_size=nrow**2,
shuffle=False)
#データローダーごとに画像を描画
for img, img_gray in test_dataloader:
mat_grid_imgs(img_gray, nrow=nrow)
img = img.to(device)
img_gray = img_gray.to(device)
#img_grayからGeneratorを用いて,FakeのRGB画像
img_fake = model_G(img_gray)
img_fake = img_fake.to("cpu")
img_fake = img_fake.detach()
mat_grid_imgs(img_fake, nrow=nrow)
5.2 可視化結果(学習前)
g = Generator()
file_path_test = glob.glob("test/*")
evaluate_test(file_path_test, g)
6.学習データの取得方法
今回はとりあえず、大量の画像データを集めればよいということで、COCO2014, PASCAL Voc2007, Labeled Faces in the Wild etc.をちゃんぽんで入力しています。
これらのデータにはGray画像が結構な割合で含まれています。今回は白黒画像をカラーにしたいのに、お手本となるべき画像がGray画像では示しがつきません(?)。なので、Gray画像は除去したいと思います。
Gray画像の場合,R channelとG channelとB channelの色が等しいはずなので、それを利用して除去したいと思います。
同時に白すぎる画像、暗すぎる画像、あまり色の濃淡がない画像(標準偏差が小さい)画像も抜きました。
from skimage import io, color, transform
def color_mono(image, threshold=150):
#3chnnelの入力画像がカラーか否かを判別
#thresholdを大きく設定すると微妙にカラーが混じっている写真もMonoに設定できる
image_size = image.shape[0] * image.shape[1]
#channelの組み合わせは(0, 1),(0, 2),(1, 2)の3通り,チャネル毎の差分を見る
diff = np.abs(np.sum(image[:,:, 0] - image[:,:, 1])) / image_size
diff += np.abs(np.sum(image[:,:, 0] - image[:,:, 2])) / image_size
diff += np.abs(np.sum(image[:,:, 1] - image[:,:, 2])) / image_size
if diff > threshold:
return "color"
else:
return "mono"
def bright_check(image, ave_thres = 0.15, std_thres = 0.1):
try:
#明るすぎる画像,暗すぎる画像,同じような明るさばかりの画像 False
#白黒に変換
image = color.rgb2gray(image)
if image.shape[0] < 144:
return False
#明るすぎる画像の場合
if np.average(image) > (1.-ave_thres):
return False
#暗すぎる画像の場合
if np.average(image) < ave_thres:
return False
#同じような明るさばかりの場合
if np.std(image) < std_thres:
return False
return True
except:
return False
paths = glob.glob("./test2014/*")
for i, path in enumerate(paths):
image = io.imread(path)
save_name = "./trans\\mscoco_" + str(i) +".png"
x = image.shape[0] #x軸方向のピクセル数
y = image.shape[1] #y軸方向のピクセル数
try:
#xとy軸の内、短い方の1/2
clip_half = min(x, y)/2
#画像の正方形の切り出し
image = image[int(x/2 -clip_half): int(x/2 + clip_half),
int(y/2 -clip_half): int(y/2 + clip_half), :]
if color_mono(image) == "color":
if bright_check(image):
image = transform.resize(image, (144, 144, 3),
anti_aliasing = True)
image = np.uint8(image*255)
io.imsave(save_name, image)
except:
pass
正方形に画像を切り取って全部、一つのフォルダに画像をいれました。
データ拡張できるように128×128でなく,144×144の画像になっています。
これで大体okなのですが、なぜか除去漏れやセピア色の画像などもあったりしたので、それは手動で削除しました。
大体11万枚画像を「trans」のフォルダに突っ込みました。
globを用いて、画像のパスのリストを作成して、ロードします。
7.学習
7.1 学習の関数
学習は大体1 epochが20分ぐらいかかりました。
Generatorの学習,Discriminatorの学習の両方させているので、コードが長くなっています。
注意点はlossを計算するためのラベルで、先ほど4.の確認でDiscriminatorのアウトプットのサイズが
[batch_size, 1, 4, 4]になることを確認しましたので、それに合わせて
true_labelsと false_labelsを生成します。
def train(model_G, model_D, epoch, epoch_plus):
device = "cuda:0"
batch_size = 32
model_G = model_G.to(device)
model_D = model_D.to(device)
params_G = torch.optim.Adam(model_G.parameters(),
lr=0.0002, betas=(0.5, 0.999))
params_D = torch.optim.Adam(model_D.parameters(),
lr=0.0002, betas=(0.5, 0.999))
#lossを計算するためのラベル, Discriminatorのsizeに注意
true_labels = torch.ones(batch_size, 1, 4, 4).to(device) #True
false_labels = torch.zeros(batch_size, 1, 4, 4).to(device) #False
#loss_function
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()
#エラーの推移を記録
log_loss_G_sum, log_loss_G_bce, log_loss_G_mae = list(), list(), list()
log_loss_D = list()
for i in range(epoch):
#temporaryのエラーを記録
loss_G_sum, loss_G_bce, loss_G_mae = list(), list(), list()
loss_D = list()
train_dataloader = load_train_dataloader(file_path_train, batch_size)
for real_color, input_gray in train_dataloader:
batch_len = len(real_color)
real_color = real_color.to(device)
input_gray = input_gray.to(device)
#Generatorの訓練
#偽のカラー画像を生成
fake_color = model_G(input_gray)
#偽画像を一時保存
fake_color_tensor = fake_color.detach()
# 偽画像を本物と騙せるようにロスを計算
LAMBD = 100.0 # BCEとMAEの係数
#fake画像を識別器に入れたときのout, Dは0に近づけようとする.
out = model_D(fake_color)
#Dの出力に対するLoss, Gを本物に近づけたいのでtargetはtrue_labels
loss_G_bce_tmp = bce_loss(out, true_labels[:batch_len])
#Gの出力に対するLoss
loss_G_mae_tmp = LAMBD * mae_loss(fake_color, real_color)
loss_G_sum_tmp = loss_G_bce_tmp + loss_G_mae_tmp
loss_G_bce.append(loss_G_bce_tmp.item())
loss_G_mae.append(loss_G_mae_tmp.item())
loss_G_sum.append(loss_G_sum_tmp.item())
#勾配を計算,Gの重みの更新
params_D.zero_grad()
params_G.zero_grad()
loss_G_sum_tmp.backward()
params_G.step()
#Discriminatorの訓練
real_out = model_D(real_color)
fake_out = model_D(fake_color_tensor)
#損失関数の計算
loss_D_real = bce_loss(real_out, true_labels[:batch_len])
loss_D_fake = bce_loss(fake_out, false_labels[:batch_len])
loss_D_tmp = loss_D_real + loss_D_fake
loss_D.append(loss_D_tmp.item())
#勾配を計算,Dの重みの更新
params_D.zero_grad()
params_G.zero_grad()
loss_D_tmp.backward()
params_D.step()
i = i + epoch_plus
print(i, "loss_G", np.mean(loss_G_sum), "loss_D", np.mean(loss_D))
log_loss_G_sum.append(np.mean(loss_G_sum))
log_loss_G_bce.append(np.mean(loss_G_bce))
log_loss_G_mae.append(np.mean(loss_G_mae))
log_loss_D.append(np.mean(loss_D))
file_path_test = glob.glob("test/*")
evaluate_test(file_path_test, model_G, device)
return model_G, model_D, [log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D]
学習を実行します。
file_path_train = glob.glob("trans/*")
model_G = Generator()
model_D = Discriminator()
model_G, model_D, logs = train(model_G, model_D, 40)
## 7.2 学習結果 学習データのLossはこんな感じです。 ![loss.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/141993/59885ca0-eb16-7c5b-4e6c-c413ed27e49c.png)
あれ?
飛行機画像が全く塗れていない以外、結構良い感じ??
案外、2 epoch終了後の画像が良いような気がしてきた…
他の画像も載せてみます。
11 epoch終了後です。失敗気味の画像を多めに選んでいます。
ひどい画像は本当にひどくて、色がほとんど塗れていないとか、
野球の画像みたいに、境界線無視で塗っていたりします。
草とか木の緑系、空とかの青系は得意な気がします。
これは元のデータセットの偏りや、塗りやすさ(認識しやすさ)に依存してそうです。
8.マトメ、感想
pix2pixを用いてGray画像のカラー化を行いました。
今回は何でもかんでも手当たり次第、画像を入れてみてカラー画像を作るということをしましたが、
さすがにネットワークが浅い分、表現力が低いので、
画像の種類を絞りこんだほうが上手くいくような気がします。
参考文献
正直、こちらのほう自分が書いたものより、分かりやすくまとまっている気もします。
U-Net: Convolutional Networks for Biomedical Image Segmentation
https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/
pix2pixを1から実装して白黒画像をカラー化してみた(PyTorch)
https://blog.shikoan.com/pytorch_pix2pix_colorization/
pix2pixを理解したい
https://qiita.com/mine820/items/36ffc3c0aea0b98027fd
画像
CoCo https://cocodataset.org/#home
Labeled Faces in the Wild http://vis-www.cs.umass.edu/lfw/
The PASCAL Visual Object Classes Homepage http://host.robots.ox.ac.uk/pascal/VOC/