概要
深層学習フレームワークPyTorchを用いて,Auto Encoder-Decoderを実装しました!
ネットワークは文献[1]のものを実装しています.高速に高精度なencoderなのでとても使いやすいと感じました.
追記:
2020/09/25 自作損失関数のinit内のsuper()の引数が間違っていたかもしれないので修正しました。
python3系でのこのあたりの挙動がまだ把握しきれていない...
2020/10/12 本記事でポイントとして紹介している箇所を含め,以下の記事にまとめなおしました(分類器として実装).↓
https://qiita.com/shun310/items/3fbac0a6cf87a0d70b78
ネットワーク全景:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from consts import * # H,W,N,qなどの定数は別ファイルから読み込んでいます
# 自作のデータセットを使う場合
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data, label, transform=None):
self.transform = transform
self.data = data
self.data_num = len(data)
self.label = label
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.data[idx]
out_label = self.label[idx]
if self.transform:
out_data = self.transform(out_data)
return out_data, out_label
# 自作の損失関数を使う場合
class MyMSELoss(nn.MSELoss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(MyMSELoss, self).__init__(size_average, reduce, reduction)
def forward(self, input, target):
return F.mse_loss(input, target, reduction=self.reduction)
# encoderの小ブロック
class encoder_SB(nn.Module):
def __init__(self):
super(encoder_SB, self).__init__()
self.e_layers = nn.ModuleList([nn.Conv2d(128,128,3,padding=1) for i in range(N)])
self.e_downsampling = nn.Conv2d(128,128,4,padding=1,stride=2)
def forward(self, x):
x_input = x.clone()
for c in self.e_layers:
x = F.leaky_relu(c(x))
x = x + x_input
return self.e_downsampling(x)
# decoderの小ブロック
class decoder_SB(nn.Module):
def __init__(self):
super(decoder_SB, self).__init__()
self.d_layers = nn.ModuleList([nn.Conv2d(128,128,3,padding=1) for i in range(N)])
self.d_upsampling = nn.ConvTranspose2d(128,128,4,padding=1,stride=2)
def forward(self, x):
x_input = x.clone()
for c in self.d_layers:
x = F.leaky_relu(c(x))
x = x + x_input
return self.d_upsampling(x)
# ネットワーク本体
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# encode
self.e_input = nn.Conv2d(channel,128,3,padding=1)
self.encoder_layer = nn.ModuleList([encoder_SB() for i in range(q)])
self.e_fc1 = nn.Linear(128*int((H/2**q))*int((W/2**q)),1024)
self.e_fc2 = nn.Linear(1024,len_of_latentVec)
# decode
self.d_fc1 = nn.Linear(len_of_latentVec,1024)
self.d_fc2 = nn.Linear(1024,128*int((H/2**q))*int((W/2**q)))
self.decoder_layer = nn.ModuleList([decoder_SB() for i in range(q)])
self.d_output = nn.Conv2d(128,channel,3,padding=1)
def forward(self, x):
# encode
x = F.leaky_relu(self.e_input(x))
for e in self.encoder_layer: x = e(x)
x = x.view(-1,128*int((H/2**q))*int((W/2**q))) # batch_size,channel,height,width
x = F.leaky_relu(self.e_fc1(x))
x = self.e_fc2(x)
latentVec = x.clone()
# decode
x = F.leaky_relu(self.d_fc1(x))
x = F.leaky_relu(self.d_fc2(x))
x = x.view(-1,128,int((H/2**q)),int((W/2**q))) # batch_size,channel,height,width
for d in self.decoder_layer: x = d(x)
x = self.d_output(x)
return x
def encode(self, x):
# encode
x = F.leaky_relu(self.e_input(x))
for e in self.encoder_layer: x = e(x)
x = x.view(-1,128*int((H/2**q))*int((W/2**q))) # batch_size,channel,height,width
x = F.leaky_relu(self.e_fc1(x))
x = self.e_fc2(x)
latentVec = x.clone()
return latentVec
def decode(self, x):
# decode
x = F.leaky_relu(self.d_fc1(x))
x = F.leaky_relu(self.d_fc2(x))
x = x.view(-1,128,int((H/2**q)),int((W/2**q))) # batch_size,channel,height,width
for d in self.decoder_layer: x = d(x)
x = self.d_output(x)
return x
ポイントは,
self.e_layers = nn.ModuleList([nn.Conv2d(128,128,3,padding=1) for i in range(N)])
のように,nn.MuduleListに内包表記で生成した層リストを渡すことで,類似構造のネットワークを纏めて管理できること.これは繰り返し構造を持つ大規模なネットワークを表現する際にとても有用です.これが無いと,ネットワークの定義だけでめちゃくちゃ長くなってしまうので...
同じ様に,
self.encoder_layer = nn.ModuleList([encoder_SB() for i in range(q)])
と書くことで,自作の部分ネットワークをnn.MuduleListでおなじように管理することもできます.これがすごく便利!今回は特に引数は取っていませんが,例えば入出力のノード数を引数で与えることで,一行でネットワークを構築する,といった芸当も可能です.
加えて,自作のデータセットと損失関数についても記述してあります.損失関数は通常,L1ノルムやMeanSquareなどがありますが,自身で設定することも可能です.物理現象を再現するなどの用途で深層学習を用いる場合,その制約を反映した項を損失関数に加える,といった工夫を行う場合があります.この機能はこう言った場合に使えると思います.
参考文献
[1] Kim, B., Azevedo, V. C., Thuerey, N., Kim, T., Gross, M., & Solenthaler, B., Deep Fluids: A Generative Network for Parameterized Fluid Simulations. Computer Graphics Forum, 38(2), 2019, 59–70.