7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

NikuGan ~おいしそうなお肉をたくさん見たい!!

Last updated at Posted at 2020-12-31

はじめに

実家に帰省した大晦日,特にやることがなかったので思いつきでDCGANによるおいしそうなお肉の画像生成をやってみました.
(4時間でやったので,クオリティーはめちゃくちゃ低いです...)

本物のおいしそうなお肉の画像収集していく

おいしそうなお肉の画像データを収集するのに, 次のサイトから集めました.

onikuimages

お肉

前にTwitterでこのサイトが流れてきた時に,使ってみたいとおもったので,今回使いました.
このサイトから60枚画像を取得して,本物の画像として利用しました.

(ちゃんとやるならスクレイピングをして,たくさん画像を集めるべきですが,時間がないので行なっていないです.)

DCGANを実装していく

自分の実装はここに置いてあります.きちんとみたい方は参考にしてください.
Pytorchの公式のDCGANチュートリアルhkthiranoさんの記事を参考に実装を行なっています.

データセットの前処理

前処理では,画像のサイズをすべて64*64になるように切り取りを行います.
公式チュートリアルに書いてあるのにしたがって前処理も一緒に行いました.

image_size = 64
batch_size = 2
workers = 0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 前処理も行う
dataset = datasets.ImageFolder(IMG_DIR,
                                transform=transforms.Compose([
                                    transforms.Resize(image_size),
                                    transforms.CenterCrop(image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

前処理した画像をちょこっと覗いてみると

# dataloaderを見てみる
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis('off')
plt.title('Training image')
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

dataloader

きちんと前処理が行われていい感じです.ただ,解像度を落としたのでこの時点でおいしそうなお肉にはもう見えなくなってますね...

Generator

Generatorは,Discriminatorが判別しにくい偽物の画像を作ることが目的です.
潜在変数を100次元とおいて,そこから64643な画像を生成するようになっています.

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.ConvTranspose2d(
                in_channels=100, 
                out_channels=256, 
                kernel_size=4, 
                stride=1, 
                padding=0, 
                bias=False
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

Discriminator

DiscriminatorはGeneratorが生成した偽画像と本物画像をきちんと判別することが目的です.

64643な画像から本物かどうか(1or0)のスカラー値を出力します.

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).squeeze()

敵対的に学習していく

Generatorは判別されにくいような偽画像を作るように,Discriminatorはきちんと判別できるように学習をしていきます.

次のようなループを何回も行うことで,敵対的に学習していきます.

  1. 潜在特徴zから偽画像を生成,うまく騙せるようにGeneratorを学習 (← Generatorが成長する)
  2. 本物画像と偽物画像を使い,きちんと識別できるようにDiscriminatorを学習 (← Discriminatorが成長する)
# 訓練関数
def train_dcgan(model_G, model_D, params_G, params_D, dataloader):
    log_loss_G = []
    log_loss_D = []
    for real_img, _ in dataloader:
        batch_len = len(real_img)

        # === Generatorを訓練する ===
        # 偽画像を生成

        z = torch.randn(batch_len, nz, 1, 1).to(device)
        fake_img = model_G(z)

        # 偽画像を一時的に保存する
        # 偽画像生成を二回行わないため
        fake_img_tensor = fake_img.detach()

        # 偽画像を本物と騙せるように計算を行う
        out = model_D(fake_img)
        loss_G = loss_f(out, ones[: batch_len])
        log_loss_G.append(loss_G.item())

        # 更新していく
        model_D.zero_grad()
        model_G.zero_grad()
        loss_G.backward()
        params_G.step()

        # == Discriminatorの訓練 ===
        # 本物の画像
        real_img = real_img.to(device)

        #本物画像を計算できるようにロスを求める
        real_out = model_D(real_img)
        loss_D_real = loss_f(real_out, ones[:batch_len])

        #さっき保存した偽画像
        fake_img = fake_img_tensor

        #偽画像を偽であると識別できるようにロスを求める
        fake_out = model_D(fake_img_tensor)
        loss_D_fake = loss_f(fake_out, zeros[:batch_len])

        # 本物,偽物のロスを合計する
        loss_D = loss_D_real + loss_D_fake
        log_loss_D.append(loss_D.item())

        # 更新してく
        model_D.zero_grad()
        model_G.zero_grad()
        loss_D.backward()
        params_D.step()
    
    return mean(log_loss_G), mean(log_loss_D)

実際に訓練して生成していく!

バッチサイズ2, 1000エポックで学習を行いました.
下に,100エポックごとの学習結果をまとめたgifをおきます.
first_gif.gif
どうですか?お肉っぽいの見えませんか??

初めはノイズみたいのから,300~500エポックでは,白い背景に白いお皿の上に乗っているお肉ができていると思います.
しかし,500エポック以降では,黒背景にお肉があるだけに戻ってしまいました...
(500エポックの時が一番本物の画像に似ている?)
image.png

追記(2021/1/1)
バッチサイズ8, 5000エポックで学習を行いました.
前回よりは,お肉がお肉らしくできてるように見えます.しかし,同じような画像ばかり生成されてモード崩壊が起きています.
原因としては潜在特徴のベクトルが100次元で弱いから?
結果として多様性のないお肉はできてませんが,リアルなお肉に近づけた気がします.
8_5000.gif

反省と感想

きれいに生成できていない理由として,画像のクオリティーと枚数が考えられます.せっかくおいしそうな画像を集めたのに,学習を行う関係上,解像度を落としたのがもったいないと思いました.また枚数も1サイトから収集したので圧倒的に足りていないと思います.

思いつきからはじまり,2020年が終わるまでにお肉生成ができてとても良かったです.
たった60枚の画像からお肉っぽいものを生成できるDCGANすごい!!

時間があったら,これ以上のクオリティなおいしそうなお肉を生成していきたいと思います!

7
4
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?