#【概要】
画像を入力にして画像を出力できるpix2pixに、系列データを取り扱えるRNNの一種であるLSTM(Long Short Term Memory)を組み合わせたアーキテクチャを構築し、非定常数値流体力学(CFD)シミュレーションの結果(具体的には、同一の物理条件に対して高解像度格子で計算した場合と低解像度格子で計算した場合の密度分布)の時間発展を学習させ、低解像度格子の計算結果から高解像度格子の計算結果を高精度で生成できることを示しました。この技術により、高コストな計算の結果を低コストな計算の結果から予測でき、とても多くのケースの計算が必要となるような最適化計算の計算コストを大幅に低減できる可能性を示しました。
#【論文】
2021年9月にFrontiers in Artificial intelligenceから出版されました。
[https://www.frontiersin.org/articles/10.3389/frai.2021.670208/full][1](図が小さく見にくいかも)
[1]:https://www.frontiersin.org/articles/10.3389/frai.2021.670208/full
arXivにも投稿してあります。こちらの方が図が見やすいかもしれません。
[https://arxiv.org/pdf/2109.10679.pdf][2]
[2]:https://arxiv.org/pdf/2109.10679.pdf
#【ソースコード】
Kerasで実装を行いました。Githubで公開しています。
[https://github.com/bright1998/pix2pix_LSTM][3]
[3]:https://github.com/bright1998/pix2pix_LSTM
#【アーキテクチャの実装(models.py)】
系列データを扱っていなかった時はpytorchで実装していましたが([前回の記事を参照][4])、今回はLSTMを使うということもあり、既にLSTMを実装したことがあるKerasを選択しました。LSTMレイヤーをまたいでskip connectionを行うのにとても苦労したのですが、解決方法はtf.keras.Modelを継承してモデルを作成することでした。
[4]:https://qiita.com/bright1998/items/da9e772ac206e0249cba
時系列の画像データ(密度分布図)が入力となりますが、各時間の分布図にエンコーダー(普通のCNN: pix2pixにおけるU-Netの前半部分)を作用させてベクトル化し、その時系列のベクトルデータをLSTMに入力し、各時系列での隠れ層のベクトルをデコーダー(アップサンプリングするネットワーク: pix2pixにおけるU-Netの後半部分)で時系列画像に変換するという流れです(論文の図1)。
Google Colaboratoryで計算を行ったのですが、Instance Normalizationを使うには
!pip install tensorflow-addons
をする必要がありました(2021年6月時点では)。
import tensorflow as tf
import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import TimeDistributed, Conv2D, Dropout, Reshape, UpSampling2D, Concatenate
from tensorflow.keras.layers import LeakyReLU, ReLU
from tensorflow_addons.layers import InstanceNormalization
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.compat.v1.keras.initializers import glorot_normal
####【Generator】
Generatorの前半では8つのブロック(畳み込み、正規化、活性化などを行うブロック)を通して入力画像を512次元のベクトルに変換してますが、時系列データなのでTimeDistributed
を使って時間方向に1つのlayerを繰り返し作用させます。512次元の時系列ベクトルを順次LSTMに入力しますが、1系列毎にデコーダー側に隠れ層のベクトルを出力するため、return_sequences=True
とします。LSTMから出力された512次元のベクトルをまた8つのブロック(アップサンプリング、畳み込み、正規化、活性化などを行うブロック)を通して入力画像と同じピクセルサイズの画像に変換します。後半部分も前半部分と同様にTimeDistributed
を使います。skip-connectionを行っているのは、u1 = Concatenate()([x, d7])
のようにしている部分です。
class generator(tf.keras.Model):
def __init__(self, frames, out_channels, dropout_rate, name=None):
super(generator, self).__init__(name=name)
self.conv1 = TimeDistributed(conv2d(filters=64, f_size=4, normalization=False,
alpha=0.2, dropout_rate=0))
self.conv2 = TimeDistributed(conv2d(filters=128, f_size=4, normalization=True,
alpha=0.2, dropout_rate=0))
self.conv3 = TimeDistributed(conv2d(filters=256, f_size=4, normalization=True,
alpha=0.2, dropout_rate=0))
self.conv4 = TimeDistributed(conv2d(filters=512, f_size=4, normalization=True,
alpha=0.2, dropout_rate=dropout_rate))
self.conv5 = TimeDistributed(conv2d(filters=512, f_size=4, normalization=True,
alpha=0.2, dropout_rate=dropout_rate))
self.conv6 = TimeDistributed(conv2d(filters=512, f_size=4, normalization=True,
alpha=0.2, dropout_rate=dropout_rate))
self.conv7 = TimeDistributed(conv2d(filters=512, f_size=2, normalization=True,
alpha=0.2, dropout_rate=dropout_rate))
self.conv8 = TimeDistributed(conv2d(filters=512, f_size=2, normalization=False,
alpha=0.2, dropout_rate=dropout_rate))
self.resh1 = TimeDistributed(Reshape((512, )))
self.lstm = CuDNNLSTM(512, batch_input_shape=(None, frames, 512),
kernel_initializer=glorot_normal(seed=1),
return_sequences=True, stateful=False)
self.resh2 = TimeDistributed(Reshape((1, 1, 512, )))
self.deco1 = TimeDistributed(deconv2d(filters=512, f_size=2, dropout_rate=dropout_rate))
self.deco2 = TimeDistributed(deconv2d(filters=512, f_size=2, dropout_rate=dropout_rate))
self.deco3 = TimeDistributed(deconv2d(filters=512, f_size=4, dropout_rate=dropout_rate))
self.deco4 = TimeDistributed(deconv2d(filters=512, f_size=4, dropout_rate=dropout_rate))
self.deco5 = TimeDistributed(deconv2d(filters=256, f_size=4, dropout_rate=0))
self.deco6 = TimeDistributed(deconv2d(filters=128, f_size=4, dropout_rate=0))
self.deco7 = TimeDistributed(deconv2d(filters=64, f_size=4, dropout_rate=0))
self.upsa = TimeDistributed(UpSampling2D(size=2))
self.conv9 = TimeDistributed(Conv2D(out_channels, kernel_size=4, strides=1, padding='same', activation='tanh'))
def call(self, x):
d1 = self.conv1(x)
d2 = self.conv2(d1)
d3 = self.conv3(d2)
d4 = self.conv4(d3)
d5 = self.conv5(d4)
d6 = self.conv6(d5)
d7 = self.conv7(d6)
d8 = self.conv8(d7)
y = self.resh1(d8)
z = self.lstm(y)
x = self.resh2(z)
x = self.deco1(x)
u1 = Concatenate()([x, d7])
x = self.deco2(u1)
u2 = Concatenate()([x, d6])
x = self.deco3(u2)
u3 = Concatenate()([x, d5])
x = self.deco4(u3)
u4 = Concatenate()([x, d4])
x = self.deco5(u4)
u5 = Concatenate()([x, d3])
x = self.deco6(u5)
u6 = Concatenate()([x, d2])
x = self.deco7(u6)
u7 = Concatenate()([x, d1])
x = self.upsa(u7)
output_imgs = self.conv9(x)
return output_imgs
自作モデルをtf.keras.Model
を継承して作成した場合に、その中でTimeDistributed
を使うとエラーがでます([こちらの記事][5]を参照)。その対処として、conv2d
やdeconv2d
の中で以下のように関数compute_output_shape
を定義しておきます。
[5]:https://qiita.com/bee2/items/76fddc0bcfecb257d127
def compute_output_shape(self, input_shape):
input_batch = input_shape[0]
input_height = input_shape[1]
input_width = input_shape[2]
output_shape = [input_batch, input_height // 2, input_width // 2, self.filters]
return output_shape
####【Discriminator】
Discriminatorでは4つのブロック(畳み込み、正規化、活性化などを行うブロック)を通した後で、畳み込み(Conv2D
)のみ行います。時系列画像を処理するのでGeneratorと同様にTimeDistributed
を使います。
class discriminator(tf.keras.Model):
def __init__(self, name=None):
super(discriminator, self).__init__(name=name)
self.conv1 = TimeDistributed(conv2d(filters=64, f_size=4, normalization=False,
alpha=0.2, dropout_rate=0))
self.conv2 = TimeDistributed(conv2d(filters=128, f_size=4, normalization=True,
alpha=0.2, dropout_rate=0))
self.conv3 = TimeDistributed(conv2d(filters=256, f_size=4, normalization=True,
alpha=0.2, dropout_rate=0))
self.conv4 = TimeDistributed(conv2d(filters=512, f_size=4, normalization=True,
alpha=0.2, dropout_rate=0))
self.conv5 = TimeDistributed(Conv2D(1, kernel_size=4, strides=1, padding='same'))
def call(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
validity = self.conv5(x)
return validity
#【あとがき】
アイデアを思いついてから2年数か月ほどかけて、論文出版までやりきることができました。全てのきっかけはE資格の受験資格を得るため認定講座を受け、GANsの面白さに魅了されたことです。GANsに限らないですが、今後もAI x シミュレーションで自由研究を続けていけたらと思っています。