90
48

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

GANで国旗を作る

Last updated at Posted at 2021-07-12

対象読者:将来国を建設したいが、国旗を考えるのは面倒くさいという人

 他に誰かやっていると思ったら、ちょっと検索した範囲では見つからなかったのでやってみました。国旗をGANで作れないかという実験です。GANが何かという説明については省略します。

データセット

 全ての国の国旗を取得したとしても n = 206 なのでデータセットとしては正直心許ないですが、国の数は限られているので仕方ありません。おとなしく全取得します。以下のようなコードで、全国旗のpngファイルを取得します。

import os
import re
from time import sleep

import requests

url = 'https://ja.wikipedia.org/wiki/%E5%9B%BD%E6%97%97%E3%81%AE%E4%B8%80%E8%A6%A7'

get = requests.get(url)
txt = get.text

pattern = 'Flag_of_(.*?).svg"'
result = re.findall(pattern,txt)

save_folder = 'flags'

for r in result:
    r = "Flag_of_" + r
    if os.path.exists(os.path.join('flags',r+'.png')):
        continue
    print(f'searching {r}')
    url = f'https://ja.wikipedia.org/wiki/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB:{r}.svg'
    txt = requests.get(url).text
    pattern = '//upload.wikimedia.org/wikipedia/commons/thumb/(.*?).png'
    result = re.findall(pattern,txt)
    url = f'https://upload.wikimedia.org/wikipedia/commons/thumb/{result[0]}.png'
    print(url)
    path = os.path.join(save_folder,f'{r}.png')
    isTrying = True
    while isTrying:
        response = requests.get(url)
        if response.status_code == 200:
            image = response.content
            isTrying = False
        print('failure')
        sleep(2)
    open(path, 'wb').write(image)

 以下のような結果が得られます。

image.png

 このままでは扱いづらいので正方形かつ2冪になるようにリサイズします。こういった画像一括処理にはOpenCVが便利です。

import os

import cv2
import numpy as np

LD = os.listdir('flags')

for L in LD:
    target = os.path.join('flags',L)
    img = cv2.imread(target)
    sq_img = cv2.resize(img,(128,128))
    sq_path = os.path.join('flags_squared',L)
    cv2.imwrite(sq_path,sq_img)

 以下のような結果が得られます。

image.png

 今更ながら、国旗には色々なアスペクト比があって、こうした処理を行うとそれらの違いを無視してしまうことになりますが、この段階でデータを加工しないと後のニューラルネットワークを構築する段階で死ぬほど面倒くさい事態になるので、やむなくこうしています。

 ちなみに、一番横に短い国旗は、

 ネパール

nepal.png

 一番横に長い国旗は、

 Flag_of_Qatar.png

 カタールらしいです。今回の解析の過程で初めて知りました。

学習

 以下のようなコードでGANを回します。

import os
import pickle

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from matplotlib import pyplot as plt
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm


def make_dataset(root):
    paths = os.listdir(root)
    imgs = []
    for path in paths:
        img_path = os.path.join(root, path)
        img = Image.open(img_path)
        img = torchvision.transforms.functional.to_tensor(img)
        label = 0
        imgs.append((img, label))
    return imgs


z_dim = 20
epoch = 10000
batch_size = 32

pickle.dump(make_dataset('flags_squared'),open('train_data.pickle','wb'))
train_data = pickle.load(open('train_data.pickle', 'rb'))
train_loader = DataLoader(train_data, batch_size, True)


class Generator(nn.Module):

    def __init__(self):
        super().__init__()
        self.tran1 = nn.ConvTranspose2d(z_dim, 256, 4, 4)
        self.tran2 = nn.ConvTranspose2d(256, 256, 4, 4)
        self.tran3 = nn.ConvTranspose2d(256, 3, 8, 8)

    def forward(self, x):
        x = self.tran1(x)
        x = torch.relu(x)
        x = nn.BatchNorm2d(256).to('cuda')(x)
        x = self.tran2(x)
        x = torch.relu(x)
        x = nn.BatchNorm2d(256).to('cuda')(x)
        x = self.tran3(x)
        x = torch.sigmoid(x)
        return x


class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, 4, 1)
        self.conv2 = nn.Conv2d(64, 128, 4, 4, 1)
        self.conv3 = nn.Conv2d(128, 256, 4, 4, 1)
        self.conv4 = nn.Conv2d(256, 1, 4, 4, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.BatchNorm2d(64).to('cuda')(x)
        x = nn.LeakyReLU(0.01, inplace=True)(x)
        x = self.conv2(x)
        x = nn.BatchNorm2d(128).to('cuda')(x)
        x = nn.LeakyReLU(0.01, inplace=True)(x)
        x = self.conv3(x)
        x = nn.BatchNorm2d(256).to('cuda')(x)
        x = nn.LeakyReLU(0.01, inplace=True)(x)
        x = self.conv4(x)
        x = torch.flatten(x, 1)
        x = torch.sigmoid(x)
        return x

G = Generator()
D = Discriminator()
G_opt = torch.optim.Adam(G.parameters(), 0.0005, [0.0, 0.9])
D_opt = torch.optim.Adam(D.parameters(), 0.0005, [0.0, 0.9])

sample = torch.rand((1,z_dim,1,1))
summary(G,input_data=sample)
sample = torch.rand((1,3,128,128))
summary(D,input_data=sample)

G.train()
D.train()
G.to('cuda')
D.to('cuda')

G_losses = []
D_losses = []

criterion = nn.BCELoss()

cnt = 0

for e in tqdm(range(epoch)):
    final_loss = 0
    for data in train_loader:
        x = data[0].to('cuda')
        t = data[1].to('cuda')

        # G_training
        noise = torch.randn((batch_size, z_dim, 1, 1)).to('cuda')
        fake_images = G(noise)
        fake_disc = D(fake_images)
        loss_fake_is_real = criterion(fake_disc, torch.ones_like(fake_disc))
        G_loss = loss_fake_is_real

        G_opt.zero_grad()
        G_loss.backward()
        G_opt.step()

        # D_training
        noise = torch.randn((batch_size, z_dim, 1, 1)).to('cuda')
        fake_images = G(noise)
        real_disc = D(x)
        fake_disc = D(fake_images)
        loss_fake_is_fake = criterion(fake_disc, torch.zeros_like(fake_disc))
        loss_real_is_real = criterion(real_disc, torch.ones_like(real_disc))
        D_loss = loss_fake_is_fake + loss_real_is_real

        D_opt.zero_grad()
        D_loss.backward()
        D_opt.step()

        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())

    noise = torch.randn((25, z_dim, 1, 1)).to('cuda')

    if e%50 != 0:
        continue

    print("")
    print("G_loss", G_losses[-1])
    print("D_Loss", D_losses[-1])
    pred = G(noise)
    pred = pred.to('cpu').detach().numpy()
    fig, ax = plt.subplots(5, 5, figsize=(10, 10))
    ar = ax.ravel()
    for i in range(5):
        for j in range(5):
            img = pred[i*5+j]
            img = img.transpose(1,2,0)
            print(img.shape)
            ar[i*5+j].imshow(img)
    plt.savefig(f'images/image_at_epoch{e}.png')
    torch.save(G.state_dict(),f'models/model_at_epoch{e}.pth')
    plt.close()
    cnt += 1

plt.plot(D_losses)
plt.plot(G_losses)
plt.savefig('loss_graph.png')

※ネットワークの組み方や、ハイパーパラメータの設定は、とりあえず動いて絵が出ればいいやの精神で設定したので、詳しい人から見たらツッコミどころがあるかもしれません。

 一般的に、GANの学習は、epochが進めば進む程よくなるというものではありません。Discriminatorが強すぎればGeneratorは学習正無気力症候群に陥ってしまい、逆にDiscriminatorが甘いとGeneratorがワンパターン戦法に頼り切りになってしまい(これを専門用語でモード崩壊といいます)、学習が進まなくなってしまいます。うまくGeneratorを訓練するためには、アメとムチがそのどちらかに偏らないようにする必要があります。それで、一時的に上手くいったように見えても、時間が経つとやっぱりダメになってしまうことが多いです。なんだか人間みたいだなあ……。

 そのため、定期的に出力結果と重みを保存し続け、結果がいい感じになっている時の重みを生成に使う、という戦法でいきます。

学習過程

 エポック 50 ごとの生成結果は以下のようになりました。

image.png

loss_graph.png

 ロスのグラフはこんな感じです。(諸事情で、横軸はエポック数ではなく、エポック数 x 7 ※になっています) ※ ceil(206 / 32)

 まず、最初はおぼろげながら国旗らしきものを返していきます。

image.png
(epoch = 300)

 しかし、その後 epoch 800 ぐらいで暗黒期に入り、同じような暗いパターンしか返さなくなります。

image.png
(epoch = 1000)

Generatorにしてみれば、みんなが本物と見分けがつかない「必勝パターン」を見つけたので、もうそれでいいじゃん、という感じでしょうか。しかし、そこでDiscriminatorが進化したのか、このワンパターン戦法を喝破するに至ります。

image.png
(epoch = 1600)

 そしてGeneratorは次なる境地に至ります。が、なんか現代芸術みたいな模様が。素朴だけどいい音楽を作っていたバンドが、いつの間にか実験的な音楽しか演奏しない集団になっていたみたいな「顧客が求めているのはそうじゃない」感があります。

 その後も何回か暗黒期と復活を経ますが、どんどん抽象的になっていき、最終的には「必勝パターン」をみつけて、そこから永遠に動かなくなってしまいます。

image.png

 上のグラフで言うと、Generatorロス(オレンジの線)が大きく上に振れたあたりです。モード崩壊に陥った時点でこの学習は失敗です。

生成

 学習過程において、epoch 300 あたりがいい感じだったので、そのモデルを元に適当に国旗を50枚ぐらい生成します。モデルだと正方形の画像が返されるので、そこはリサイズして2:3にしていきます。

import os

import cv2
import numpy as np
import torch
from torch import nn


class Generator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.tran1 = nn.ConvTranspose2d(z_dim, 256, 4, 4)
        self.tran2 = nn.ConvTranspose2d(256, 256, 4, 4)
        self.tran3 = nn.ConvTranspose2d(256, 3, 8, 8)

    def forward(self, x):
        x = self.tran1(x)
        x = torch.relu(x)
        x = nn.BatchNorm2d(256).to('cuda')(x)
        x = self.tran2(x)
        x = torch.relu(x)
        x = nn.BatchNorm2d(256).to('cuda')(x)
        x = self.tran3(x)
        x = torch.sigmoid(x)
        return x

z_dim = 20

G = Generator().to('cuda')
G.load_state_dict(torch.load('models/model_at_epoch300.pth'))

noise = torch.randn((50, z_dim, 1, 1)).to('cuda')

pred = G(noise)
pred = pred.to('cpu').detach().numpy()
for i in range(50):
    img = pred[i]*255
    img = img.transpose(1,2,0)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,((192,128)))
    print(img.shape)
    cv2.imwrite(f'generate/{i}.png',img)

image.png

image.png

 あいにく、鮮明な図形とはいきませんでしたが、それとなくインスピレーションを得ることはできるのではないでしょうか。

 個人的には、

49.png

 これなんかがお気に入りです。大英帝国によるアフリカ再植民地化を思わせるスリリングな感じです。

 しかし、現実の国旗制定においては、許可なく(?)ユニオンジャックやスカンディナヴィア十字をシンボルとして使用することは重大な国際問題を引き起こす可能性がある気がします。十分に注意しましょう。

続編?

 この生成結果を途中で見せたところ、GANを研究している人からこれを使ってみてはどうかという提案がされました。

 今回使ったのはGANの中でもかなり初歩的な手法で、そもそも206のデータセットから鮮明な画像を生成しようというのは結構無理のある話だったりします。しかし、少ないデータセットでもデータ拡張やネットワークを工夫することによって性能をあげようという研究も当然ある訳で、それらを実装すればもう少しマシな画像が生成される可能性があります。

 こういう最新の研究を実装するのは、git clone一発で済む場合もあれば、環境構築に四苦八苦することもあるので、またやる気が出たら続編を出そうと思います。

90
48
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
90
48

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?