#前置き
最初にAdaIN(Adaptive Instance Normalization)[1]を説明します。これはStyle Transferで利用されるNormalization Layerです。
画像をA -> Bに変換するときに、std(B) * ((A - mean(A) / std(A) ) + mean(B)として、相手に近づけるNormalizationらしいです(あまり読んでいません)。
この後、U-GAT-ITというイメージ変換でAdaLIN(Adaptive Layer-Instance Normalization)という正規化が誕生しました[2](こちらもあまり読んでいません)。正規化はr : 1-r (0 <= r <= 1)でInstance NormalizationとLayer Normalizationの結果を掛けて足す手法です。AdaILNとも呼ばれるようです[3]。
また、Adaという部分は、Normalizationの結果をアフィン変換するときのβとγを事前のFully-Connented層で決めることから名づけられたと思います(不確実)。これを事前ではなく、Normalization Layer自体に学習させるのがILNみたいです。
今回は「メルアイコン変換器」[3]に触発されて自分でILNレイヤーを作ってみました。簡単にAdaILNにも変換できると思います。
ただ、そのままコピペして使ったらエラーが出たのと、Normalizationは事前に用意されているものを使った方が安全そうなのでその部分を書き直しました。
#Pytorchで実装しようとして現れた問題点
・Instance Normalizationはsizeが(1,1)の画像を安易に入れるとエラーが出る
・Layer Normalizationは他のNormalizationと__init__の形が違う
・InstanceNormalization2dは使えるのに自作ILNレイヤーだとnanになることがある
・clampを使うとなぜかnanが現れる
これらを解決して何とかそれっぽいものができました。
#解説と結果
・Instance Normalizationはsizeが(1,1)の画像を安易に入れるとエラーが出る
そもそもサイズが1x1の画像ならInstance Normalizationしなくていいじゃないというお気持ちを持たせた
class InstanceNorm2d_check_size_1(nn.InstanceNorm2d):
def __init__(self,*args, **kwargs):
super(InstanceNorm2d_check_size_1,self).__init__(*args, **kwargs)
def forward(self,x):
if x.shape[2] > 1:
x = super(InstanceNorm2d_check_size_1,self).forward(x)
return x
def instanceNorm2d(*args, **kwargs):
layer = InstanceNorm2d_check_size_1(*args, **kwargs)
return layer
・Layer Normalizationは他のNormalizationと__init__の形が違う
実はLayerNormalizationはGroupNormalizationの__init__の引数であるnum_groupsとnum_channelsを同じにしたものなのでこれを利用した。
・InstanceNormalization2dは使えるのに自作ILNレイヤーだとnanになることがある
Normalization系統はepsを基準の1e-5ではなく1e-4にするとエラーがでなくなる
・clampを使うとなぜかnanが現れる
パラメータそのものをclampするとエラーが出る?謎
というわけでLINレイヤーができました。gammaとbetaのパラメータをselfから外してforwardに入れればAdaLINになります。
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-4,affine=False):
super(ILN, self).__init__() #epsは1e-4
self.eps = eps
self.num_features = num_features
self.affine = affine
self.rho = nn.Parameter(torch.Tensor(1, self.num_features, 1, 1))
self.gamma = nn.Parameter(torch.Tensor(1, self.num_features, 1, 1))
self.beta = nn.Parameter(torch.Tensor(1, self.num_features, 1, 1))
self.ones = nn.Parameter(torch.ones(1, self.num_features, 1, 1),requires_grad = False)
self.rho.data.fill_(0.1) #0にしても別に良さそう
self.instanceNorm2d = InstanceNorm2d_check_size_1(self.num_features, eps=eps) #そのままだとNaN
self.layernorm = nn.GroupNorm(num_groups = self.num_features,num_channels = self.num_features,affine=False,eps=eps) #Affineは最後にやるのでここではオフ
if self.affine:
self.gamma.data.fill_(1.0)
self.beta.data.fill_(0.0)
def forward(self, input):
out = input
self.rho.data.clamp(0.0,1.0) #そのままではなく.data.clamp()にする
shape = input.shape
if (self.training):
#Normalizationなのでトレーニング時のみ
out_in = self.instanceNorm2d(input).to(out.device)
out_ln = self.layernorm(input).view(shape).to(out.device)
if not ( torch.isnan(out_in.mean()) ) and( not torch.isnan(out_ln.mean()) ):
#本当はこのif文は必要なさそうだがデバッグ用、消してもよい
out = self.rho * out_in + (1.0 -self.rho) * out_ln
#assert not torch.isnan(out.mean())
if self.affine:
out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
#assert not torch.isnan(out.mean())
else:
print("Nan")
return out
せっかく作ったけれど自分のモデルではあまり役立っていない気がする……
#ライセンス
MIT
#引用
[1] Xun Huang, Serge Belongie (2017) Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization arXiv:1703.06868
[2] Junho Kim, Minjae Kim, Hyeonwoo Kang, Kwanghee Lee (2019) U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation arXiv:1907.10830
[3] @zassou65535 メルアイコン変換器を作った話 https://qiita.com/zassou65535/items/4bc42fa36203c13fe2d3