10
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchによるAuto Encoder-Decoderの実装

Last updated at Posted at 2020-05-19

概要

深層学習フレームワークPyTorchを用いて,Auto Encoder-Decoderを実装しました!
ネットワークは文献[1]のものを実装しています.高速に高精度なencoderなのでとても使いやすいと感じました.

追記:
2020/09/25 自作損失関数のinit内のsuper()の引数が間違っていたかもしれないので修正しました。
python3系でのこのあたりの挙動がまだ把握しきれていない...
2020/10/12 本記事でポイントとして紹介している箇所を含め,以下の記事にまとめなおしました(分類器として実装).↓
https://qiita.com/shun310/items/3fbac0a6cf87a0d70b78

ネットワーク全景:

encoder_class.py
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from consts import * # H,W,N,qなどの定数は別ファイルから読み込んでいます

# 自作のデータセットを使う場合
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data, label, transform=None):
        self.transform = transform
        self.data = data
        self.data_num = len(data)
        self.label = label

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label =  self.label[idx]
        if self.transform:
            out_data = self.transform(out_data)

        return out_data, out_label

# 自作の損失関数を使う場合
class MyMSELoss(nn.MSELoss):
    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(MyMSELoss, self).__init__(size_average, reduce, reduction)

    def forward(self, input, target):
        return F.mse_loss(input, target, reduction=self.reduction)

# encoderの小ブロック
class encoder_SB(nn.Module):
    def __init__(self):
        super(encoder_SB, self).__init__()
        self.e_layers = nn.ModuleList([nn.Conv2d(128,128,3,padding=1) for i in range(N)])
        self.e_downsampling = nn.Conv2d(128,128,4,padding=1,stride=2)
    def forward(self, x):
        x_input = x.clone()
        for c in self.e_layers:
            x = F.leaky_relu(c(x))
        x = x + x_input
        return self.e_downsampling(x)

# decoderの小ブロック
class decoder_SB(nn.Module):
    def __init__(self):
        super(decoder_SB, self).__init__()
        self.d_layers = nn.ModuleList([nn.Conv2d(128,128,3,padding=1) for i in range(N)])
        self.d_upsampling = nn.ConvTranspose2d(128,128,4,padding=1,stride=2)
    def forward(self, x):
        x_input = x.clone()
        for c in self.d_layers:
            x = F.leaky_relu(c(x))
        x = x + x_input
        return self.d_upsampling(x)

# ネットワーク本体
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # encode
        self.e_input = nn.Conv2d(channel,128,3,padding=1)
        self.encoder_layer = nn.ModuleList([encoder_SB() for i in range(q)])
        self.e_fc1 = nn.Linear(128*int((H/2**q))*int((W/2**q)),1024)
        self.e_fc2 = nn.Linear(1024,len_of_latentVec)
        # decode
        self.d_fc1 = nn.Linear(len_of_latentVec,1024)
        self.d_fc2 = nn.Linear(1024,128*int((H/2**q))*int((W/2**q)))
        self.decoder_layer = nn.ModuleList([decoder_SB() for i in range(q)])
        self.d_output = nn.Conv2d(128,channel,3,padding=1)


    def forward(self, x):
        # encode
        x = F.leaky_relu(self.e_input(x))

        for e in self.encoder_layer: x = e(x)

        x = x.view(-1,128*int((H/2**q))*int((W/2**q))) # batch_size,channel,height,width
        x = F.leaky_relu(self.e_fc1(x))
        x = self.e_fc2(x)
        latentVec = x.clone()

        # decode
        x = F.leaky_relu(self.d_fc1(x))
        x = F.leaky_relu(self.d_fc2(x))
        x = x.view(-1,128,int((H/2**q)),int((W/2**q))) # batch_size,channel,height,width

        for d in self.decoder_layer: x = d(x)

        x = self.d_output(x)

        return x

    def encode(self, x):
        # encode
        x = F.leaky_relu(self.e_input(x))

        for e in self.encoder_layer: x = e(x)

        x = x.view(-1,128*int((H/2**q))*int((W/2**q))) # batch_size,channel,height,width
        x = F.leaky_relu(self.e_fc1(x))
        x = self.e_fc2(x)
        latentVec = x.clone()

        return latentVec

    def decode(self, x):
        # decode
        x = F.leaky_relu(self.d_fc1(x))
        x = F.leaky_relu(self.d_fc2(x))
        x = x.view(-1,128,int((H/2**q)),int((W/2**q))) # batch_size,channel,height,width

        for d in self.decoder_layer: x = d(x)

        x = self.d_output(x)

        return x

ポイントは,

.py
self.e_layers = nn.ModuleList([nn.Conv2d(128,128,3,padding=1) for i in range(N)])

のように,nn.MuduleListに内包表記で生成した層リストを渡すことで,類似構造のネットワークを纏めて管理できること.これは繰り返し構造を持つ大規模なネットワークを表現する際にとても有用です.これが無いと,ネットワークの定義だけでめちゃくちゃ長くなってしまうので...

同じ様に,

.py
self.encoder_layer = nn.ModuleList([encoder_SB() for i in range(q)])

と書くことで,自作の部分ネットワークをnn.MuduleListでおなじように管理することもできます.これがすごく便利!今回は特に引数は取っていませんが,例えば入出力のノード数を引数で与えることで,一行でネットワークを構築する,といった芸当も可能です.

加えて,自作のデータセットと損失関数についても記述してあります.損失関数は通常,L1ノルムやMeanSquareなどがありますが,自身で設定することも可能です.物理現象を再現するなどの用途で深層学習を用いる場合,その制約を反映した項を損失関数に加える,といった工夫を行う場合があります.この機能はこう言った場合に使えると思います.

参考文献

[1] Kim, B., Azevedo, V. C., Thuerey, N., Kim, T., Gross, M., & Solenthaler, B., Deep Fluids: A Generative Network for Parameterized Fluid Simulations. Computer Graphics Forum, 38(2), 2019, 59–70.

10
6
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?