Generative Dog Images
Generative Dog Images
kaggleのGANのコンペに参加したので、自分自身の振り返りも兼ねてまとめてみます。
GANの詳しい説明等は他の人の記事を御覧ください。
GANについて概念から実装まで ~DCGANによるキルミーベイベー生成~
今さら聞けないGAN(1) 基本構造の理解
また、僕自身機械学習初学者でkaggleが強いわけでは無く、残り2週間でのスプリント参加だったのでしっかりしたGANコンペの記事を見たい方は下の方がとても楽しくまとめてくれています。
参加動機
- 画像処理の勉強
- フレームワーク(PyTorch)への慣れの為
コンペ概要
- kaggleにとっておそらく初となるGANコンペ
- 期間2019年6月28日~8月14日
- 20579枚の犬の画像
- 120の犬種と犬のBounding Boxのannotationデータ
- kernel onlyコンペ
- 外部データ使用禁止
- 9時間以下の実行時間等
- 評価指標(MiFID)
- 基本的に画像の質と多様性を評価する。(Inception Score)
- 通常GANで使われる指標のFIDは実画像と生成画像の分布間の距離をFrechet距離にて測る。(低ければ良い)
- FIDでは配布された画像をそのまま提出するととてもいいスコアが出てしまうので、実画像と近すぎる画像の場合にはペナルティが設定される。普通にやればペナルティは無いので、FIDと同じ。
- Publicでは公開されている画像と、inceptionのモデルによってスコアが算出される。
- Privateでは未知のデータと未知のモデルによってスコアが算出される。
FIDの詳しい説明はこちら
Progressive/Big/StyleGANsの概要とGANsの性能評価尺度
カーネルコンペで外部データも使用不可能ということもあり、大きいGANは学習が難しい
時間制限のある中で、いかに効率よく学習させるかがポイント。
実際の画像
大変だった部分
D(Discriminator)とG(Generator)の学習のバランス
GANの学習はDとGがお互いに上手く騙し騙されあっていると、進んでいるとのことですがその具合が難しいです。
よくあるのはDが早くに学習が進んで、Gがいくら頑張っても完全に見破られてしまいます。
Dのlossがほぼ0になって、その後Gにいくら勾配が流れても意味のある画像を出力してくれません。
途中まで頑張っていたGもこうなるとやる気が無くなって、タイルみたいな画像しか出力しなくなります。
↓
↓
※このモデルは最終的には色々な模様のタイルを生成することが出来るようになりました。
パラメータ多すぎ&敏感過ぎ問題
GANはとてもパラメータに敏感で、軽い変更を加えただけで学習のバランスが崩れたり、画像の質に影響を及ぼします。
その癖に調整が必要なパラメータが多いです。
画像タスクに関わるパラメータ全ての他に、Gの初期値(z)の数や、その他GAN特有のテクニック(ノイジーラベル等)等キリがありません。
基本的には画像タスクと共通ですが、DとGをそれぞれ面倒見てあげなきゃならないし、かなり敏感でわがままです。
そして論文等でもなぜ上手くいったかはわからず経験的に上手くいったからこのパラメータという事が多いです。
評価指標(MiFID)が意味わからん問題
これもmodelのlossが下がればスコア(MiFID)が改善していく訳では無いので、どこをどう改善したらスコアにどんな影響があるのかが分かりづらいです。
その癖スコアの算出には5分程度かかり頻繁に算出出来ないので、検証しづらい。
下に24枚ずつGが作った画像がありますが、片方はMiFIDが80前後で、もう片方は50前後です。
私の目で見た限りではどちらも同じ様な画像に見えます。。。
このスコア(FID)は多様性に敏感との議論もあるので、質よりも多様性で差がついているかもしれないです。
前半の一週間でやったこと
ひたすらGANの勉強。
GANの知識はおろか画像分類の知識が皆無だったので、色々な論文や記事を調査。
同時にkaggleのカーネルを動かしたり、調べた手法を付け足したりした。
GANの基本的な知識はなんとなく身についたが、結局何をどうすればスコアが改善するのか見当がつかない。
試した手法
- LSGAN
- RaLSGAN
- Hingeloss
- Conditional GAN
- SAGAN
- SNGAN
- Adabound
- 2次元resize((64,64))
色々試したけど、使いこなせなかった。
犬っぽい画像を作ること自体がかなり難しく、上記のタイルや大量の虫みたいな画像を出力したりで全然うまく進みませんでした。
大量の虫
下の画像でようやく舌を出した可愛いワンちゃんが出来たと喜んでました。
後半はひたすら可視化と考察
- 可視化に重みを置いてカーネルを再作成
スコアとその他数値との関係性をより深く理解するため、ベンチマークのカーネルをなるべく丁寧に作り直しました。
具体的にはMiFID,DとGのloss,Dの出力(D(x), D(G(z)))を推移を記録して可視化
- loss関数の統一
GANには様々なloss関数が提唱されているが、ここを色々動かしてしまうと出力の意味が大きく変わってしまい手法の比較・検証がしづらい為、思い切って一つに絞り込みました。
具体的にはDの出力はsigmoidで、lossはBCEが一番安定的かつ、理解もしやすかったので統一。
- 変更箇所は最小限に
パラメータや手法を正確に理解する為に、変更箇所は1,2箇所に止めてその変更がどこにどの様に効いているのかを観察した。
コンペ終盤は特に焦りから色々な手法を同時に試しがちだが、心を鬼にして**「モデル向上ではなく、考察が目的」**を意識して丁寧に検証していく。
- 自分で理解出来るモデルのみに絞る
色んな手法がありあれもこれもと目移りしてしまうが、結局使いこなせないと調整も出来ないのでここも我慢。
それぞれこんな感じでひたすら可視化と考察を繰り返して、パラメータ・手法の理解を深めました。
2人チームで参加していたので、可視化したものをスラックで共有してそのまま考察してました。
ちなみにkaggleでは「Draft Session」(手元で編集しているデータ)と「Committed Session」(commit終了まで出力は見れない)をそれぞれ4つずつ同時に動かすことが出来るので、
- 長時間の検証(8時間程度)は「Committed Session」
- 短時間の検証(2時間程度)は「Draft Session」
と分けて、それぞれ微妙にパラメータを変えたものを4つずつ同時に回していました。
※最近kaggleの仕様も微妙に変わりこんな感じでGPUの使用状況等確認出来るみたいです。

※更に仕様変更があり、Draft Sessionは1つ、Committed Sessionは2つまでとなったみたいです。。。
今はkaggle側でもkernelの開発notebookの開発に力を入れていて、コロコロ変わるのでまた変更があるかもしれないです。
結果
最終日になってもその時点でのベストカーネル(スコア60)より低い数字が出ずに諦めモード。
今まで考察した内容をフル稼働させて、最後の最後に回したモデルが終了10分前くらいで回し終わり結構犬っぽい画像を出力していたので、ブザービーター的な感じギリギリ逆転劇あるかと思い、チームで祈りながらsubmit。
結果は90と惨敗。
最終的には何もわからないという結果となりました笑
※現時点(2019/08/23)ではまだ最終的な結果(privateスコア)は出ていないのでわかりませんが。
所感
- 「ベストスコアのカーネルを適当に色々イジってればbronzeくらいは取れるかな」とか甘い考えだったけど、全くそんなことはなく小手先の技術は通用しなかった気がする。
- 他人のカーネルを使うときは、とにかくリファクタリングをしっかりすることが重要。
- ベンチマークをいかに早く丁寧に作ることが、その後の試行錯誤やスコアにとって重要。
- NNはただでさえブラックボックス的な部分も多いので、より丁寧な可視化と考察が必要
- ベンチマークを決めた後は、信じてコロコロ変えない。
- 最終的には画像処理も、PyTorchにも慣れることが出来たし、とても楽しかったのでまあ良し
最近はkaggleもカーネルのコンペが増えてきているので、マシンスペックとか関係なしにフェアに挑戦出来て、とても勉強になるのでカーネルコンペは特におすすめです!!
Generatorのコード
class Generator(nn.Module):
def __init__(self, nz, nfeats, nchannels, num_classes=120):
super(Generator, self).__init__()
if use_label:
self.label_emb = nn.Embedding(num_classes, nz)
self.conv1 = spectral_norm(nn.ConvTranspose2d(nz*2, nfeats * 8, 4, 1, 0, bias=False))
self.nz = nz
else:
self.conv1 = spectral_norm(nn.ConvTranspose2d(nz, nfeats * 8, 4, 1, 0, bias=False))
# self.bn1 = nn.BatchNorm2d(nfeats * 8)
# state size. (nfeats*8) x 4 x 4
self.conv2 = spectral_norm(nn.ConvTranspose2d(nfeats * 8, nfeats * 8, 4, 2, 1, bias=False))
#self.bn2 = nn.BatchNorm2d(nfeats * 8)
# state size. (nfeats*8) x 8 x 8
self.conv3 = spectral_norm(nn.ConvTranspose2d(nfeats * 8, nfeats * 4, 4, 2, 1, bias=False))
#self.bn3 = nn.BatchNorm2d(nfeats * 4)
# state size. (nfeats*4) x 16 x 16
self.conv4 = spectral_norm(nn.ConvTranspose2d(nfeats * 4, nfeats * 2, 4, 2, 1, bias=False))
#self.bn4 = nn.BatchNorm2d(nfeats * 2)
# state size. (nfeats * 2) x 32 x 32
self.conv5 = spectral_norm(nn.ConvTranspose2d(nfeats * 2, nfeats, 4, 2, 1, bias=False))
#self.bn5 = nn.BatchNorm2d(nfeats)
# state size. (nfeats) x 64 x 64
if use_attention:
self.self_attn_book = Self_Attn(nfeats)
self.conv6 = spectral_norm(nn.ConvTranspose2d(nfeats, nchannels, 3, 1, 1, bias=False))
# state size. (nchannels) x 64 x 64
self.pixnorm = PixelwiseNorm()
def forward(self, inputs):
if use_label:
z, labels = inputs
enc = self.label_emb(labels).view((-1, self.nz, 1, 1))
enc = F.normalize(enc, p=2, dim=1)
x = torch.cat((z, enc), 1)
x = F.leaky_relu(self.conv1(x))
else:
x = F.leaky_relu(self.conv1(inputs))
x = F.leaky_relu(self.conv2(x))
x = self.pixnorm(x)
x = F.leaky_relu(self.conv3(x))
x = self.pixnorm(x)
x = F.leaky_relu(self.conv4(x))
x = self.pixnorm(x)
x = F.leaky_relu(self.conv5(x))
if use_attention:
x = self.self_attn_book(x)
x = self.pixnorm(x)
x = torch.tanh(self.conv6(x))
return x
Discriminatorのコード
class Discriminator(nn.Module):
def __init__(self, nchannels, nfeats, loss_calc, num_classes=120):
super(Discriminator, self).__init__()
if loss_calc == 'bce':
self.use_sigmoid = True
else:
self.use_sigmoid = False
if use_label:
self.num_classes = num_classes
self.label_emb = nn.Embedding(num_classes, 64*64)
self.conv1 = nn.Conv2d(nchannels + 1, nfeats, 4, 2, 1, bias=False)
else:
# input is (nchannels) x 64 x 64
self.conv1 = nn.Conv2d(nchannels, nfeats, 4, 2, 1, bias=False)
# state size. (nfeats) x 32 x 32
if use_attention:
self.self_attn_book = Self_Attn(nfeats)
self.conv2 = spectral_norm(nn.Conv2d(nfeats, nfeats * 2, 4, 2, 1, bias=False))
self.bn2 = nn.BatchNorm2d(nfeats * 2)
# state size. (nfeats*2) x 16 x 16
self.conv3 = spectral_norm(nn.Conv2d(nfeats * 2, nfeats * 4, 4, 2, 1, bias=False))
self.bn3 = nn.BatchNorm2d(nfeats * 4)
# state size. (nfeats*4) x 8 x 8
self.conv4 = spectral_norm(nn.Conv2d(nfeats * 4, nfeats * 8, 4, 2, 1, bias=False))
self.bn4 = nn.MaxPool2d(2)
# state size. (nfeats*8) x 4 x 4
self.batch_discriminator = MinibatchStdDev()
self.pixnorm = PixelwiseNorm()
self.conv5 = spectral_norm(nn.Conv2d(nfeats * 8 +1, 1, 2, 1, 0, bias=False))
# state size. 1 x 1 x 1
def forward(self, inputs):
if use_label:
imgs, labels = inputs
enc = self.label_emb(labels).view((-1, 1, 64, 64))
enc = F.normalize(enc, p=2, dim=1)
x = torch.cat((imgs, enc), 1) # 4 input feature maps(3rgb + 1label)
else:
x = inputs
x = F.leaky_relu(self.conv1(x), 0.2)
if use_attention:
x = self.self_attn_book(x)
x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
# x = self.pixnorm(x)
x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
# x = self.pixnorm(x)
x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2)
# x = self.pixnorm(x)
x = self.batch_discriminator(x)
x= self.conv5(x)
if self.use_sigmoid:
x = torch.sigmoid(x)
return x.view(-1, 1)
学習
nz = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_calc = 'bce' # 'rals'
lr_g = 0.0002
lr_d = 0.0002
beta1 = 0.5
epochs = 200
use_label =True
use_attention = True
netG = Generator(nz, 32, 3).to(device)
netD = Discriminator(3, 48,loss_calc=loss_calc).to(device)
criterion = nn.BCELoss()
#criterion = nn.MSELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr_d, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr_g, betas=(beta1, 0.999))
lr_schedulerG = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizerG,
T_0=8, eta_min=0.00005)
lr_schedulerD = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizerD,T_mult=2,
T_0=8, eta_min=0.00005)
fixed_noise = torch.randn(25, nz, 1, 1, device=device)
real_label = 1.0
fake_label = 0.0
real_uniform = 0.1
fake_uniform = 0.1
batch_size = train_loader.batch_size
num_train_G = 2
cls_info = True
aspect_info = True
fid_each_epoch = 50
### training here
errg_list = []
errd_list = []
dout_real = []
dout_fake1 = []
dout_fake2 = []
epoch_errg_list = []
epoch_errd_list = []
epoch_dout_real = []
epoch_dout_fake1 = []
epoch_dout_fake2 = []
fid_list = []
step = 0
for epoch in range(epochs):
end = time()
if (end -start) > run_time:
break
for ii, (real_images, classes, aspect_flags, dog_labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_images = real_images.to(device)
batch_size = real_images.size(0)
labels = torch.full((batch_size, 1), real_label) + np.random.uniform(0, real_uniform)
if cls_info:
labels = labels * (classes.reshape(-1,1).float())
if aspect_info:
labels = labels * (aspect_flags.reshape(-1,1).float())
real_labels = labels.to(device)
if use_label:
dog_labels = torch.tensor(dog_labels, device=device)
outputR = netD((real_images,dog_labels))
else:
outputR = netD(real_images)
if loss_calc == 'bce':
errD_real = criterion(outputR, real_labels)
errD_real.backward()
D_x = outputR.mean().item()
# train with fake
noise = torch.randn(batch_size, nz, 1, 1, device=device)
if use_label:
fake = netG((noise, dog_labels))
outputF1 = netD((fake.detach(), dog_labels))
else:
fake = netG(noise)
outputF1 = netD(fake.detach())
if loss_calc == 'bce':
fake_labels = torch.full((batch_size, 1), fake_label) + np.random.uniform(0, fake_uniform)
fake_labels = fake_labels.to(device)
errD_fake = criterion(outputF1, fake_labels)
errD_fake.backward()
errD = errD_real + errD_fake
elif loss_calc == 'rals':
errD = (torch.mean((outputR - torch.mean(outputF1) - real_labels) ** 2) +
torch.mean((outputF1 - torch.mean(outputR) + real_labels) ** 2))/2
errD.backward(retain_graph=True)
D_G_z1 = outputF1.mean().item()
optimizerD.step()
for g_iter in range(num_train_G):
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
noise = torch.randn(batch_size, nz, 1, 1, device=device)
if use_label:
fake = netG((noise, dog_labels))
outputF2 = netD((fake, dog_labels))
else:
fake = netG(noise)
outputF2 = netD(fake)
if loss_calc == 'bce':
errG = criterion(outputF2, real_labels)
errG.backward(retain_graph=True)
elif loss_calc == 'rals':
errG = (torch.mean((outputR - torch.mean(outputF2) + real_labels) ** 2) +
torch.mean((outputF2 - torch.mean(outputR) - real_labels) ** 2)) / 2
errG.backward(retain_graph=True)
D_G_z2 = outputF2.mean().item()
optimizerG.step()
############################
# (3) 記録ゾーン
###########################
dout_fake2.append(D_G_z2)
errg_list.append(errG.item())
errd_list.append(errD.item())
lr_schedulerG.step(epoch)
lr_schedulerD.step(epoch)
dout_real.append(D_x)
dout_fake1.append(D_G_z1)
epoch_errg_list.append(np.mean(errg_list[epoch*len(train_loader)*num_train_G:]))
epoch_errd_list.append(np.mean(errd_list[epoch*len(train_loader)*num_train_G:]))
epoch_dout_real.append(np.mean(dout_real[epoch*len(train_loader):]))
epoch_dout_fake1.append(np.mean(dout_fake1[epoch*len(train_loader):]))
epoch_dout_fake2.append(np.mean(dout_fake2[epoch*len(train_loader)*num_train_G:]))
print('[%d/%d]\n Loss_D: %.4f D(x): %.4f D(G(z))1: %.4f'
% (epoch + 1, epochs,errd_list[-1], D_x, D_G_z1,))
for i in range(num_train_G):
print(' Loss_G: %.4f D(G(z))2: %.4f gtr[%d/%d]'
% (errg_list[-(num_train_G-i)],dout_fake2[-(num_train_G-i)],i+1, num_train_G))
print(' epoch_mean:\n Loss_D: %.4f Loss_G: %.4f \n D(x): %.4f D(G(z))1: %.4f D(G(z))2: %.4f'
% (epoch_errd_list[-1], epoch_errg_list[-1], epoch_dout_real[-1],epoch_dout_fake1[-1],epoch_dout_fake2[-1]))
if epoch < 100 and epoch % 5 == 0: # 最初の方のエポックは頻繁に画像出力する。
show_generated_img()
if epoch % fid_each_epoch == fid_each_epoch-1: # fid_each_epoch回数分、画像出力、loss、d_outプロット、画像を10000枚作成、FID計算まで行う。
show_generated_img()
loss_plot(epoch_errd_list,epoch_errg_list,xlabel='epoch',ylabel=loss_calc+'_loss')
dout_plot(epoch_dout_real, epoch_dout_fake1, xlabel='epoch', ylabel='dout(linear)')
# create_image(netG,nz=nz,threshold=1,fold_name='../output_images'+str(epoch),n_images=10000, im_batch_size=100)
# fid = fid_calc('../output_images'+str(epoch))
# fid_list.append(fid)
ここにコードをまとめていますが、整理していないのでとても分かりづらいです。
追記
- 終了後にハイスコアのカーネルを見て比較してみると、Gのチャンネル数が足りなくて表現力不足感がある。
- 犬種(ConditionalGAN)を追加してもあまりいいスコアが出ない。
- mode collapseが原因でスコアが伸びなかった様にも思える。
追追記
- kaggleの結果が出ましたが、310/927で惨敗。
- サブミットも、パブリックとプライベートの開きが大きくて結局は何だったのかよくわからない。。。
- GANの新しい記事書いたのでそちらも御覧ください。