0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

LIN(AdaLIN)レイヤーを作ってみた

Last updated at Posted at 2021-08-16

#前置き
最初にAdaIN(Adaptive Instance Normalization)[1]を説明します。これはStyle Transferで利用されるNormalization Layerです。
画像をA -> Bに変換するときに、std(B) * ((A - mean(A) / std(A) ) + mean(B)として、相手に近づけるNormalizationらしいです(あまり読んでいません)。
この後、U-GAT-ITというイメージ変換でAdaLIN(Adaptive Layer-Instance Normalization)という正規化が誕生しました[2](こちらもあまり読んでいません)。正規化はr : 1-r (0 <= r <= 1)でInstance NormalizationとLayer Normalizationの結果を掛けて足す手法です。AdaILNとも呼ばれるようです[3]。
また、Adaという部分は、Normalizationの結果をアフィン変換するときのβとγを事前のFully-Connented層で決めることから名づけられたと思います(不確実)。これを事前ではなく、Normalization Layer自体に学習させるのがILNみたいです。
今回は「メルアイコン変換器」[3]に触発されて自分でILNレイヤーを作ってみました。簡単にAdaILNにも変換できると思います。
ただ、そのままコピペして使ったらエラーが出たのと、Normalizationは事前に用意されているものを使った方が安全そうなのでその部分を書き直しました。
#Pytorchで実装しようとして現れた問題点
・Instance Normalizationはsizeが(1,1)の画像を安易に入れるとエラーが出る
・Layer Normalizationは他のNormalizationと__init__の形が違う
・InstanceNormalization2dは使えるのに自作ILNレイヤーだとnanになることがある
・clampを使うとなぜかnanが現れる
これらを解決して何とかそれっぽいものができました。

#解説と結果
・Instance Normalizationはsizeが(1,1)の画像を安易に入れるとエラーが出る
そもそもサイズが1x1の画像ならInstance Normalizationしなくていいじゃないというお気持ちを持たせた

class InstanceNorm2d_check_size_1(nn.InstanceNorm2d):
  def __init__(self,*args, **kwargs):
    super(InstanceNorm2d_check_size_1,self).__init__(*args, **kwargs)
  def forward(self,x):
    if x.shape[2] > 1:
      x = super(InstanceNorm2d_check_size_1,self).forward(x)
    return x
def instanceNorm2d(*args, **kwargs):
  layer = InstanceNorm2d_check_size_1(*args, **kwargs)
  return layer

・Layer Normalizationは他のNormalizationと__init__の形が違う
実はLayerNormalizationはGroupNormalizationの__init__の引数であるnum_groupsとnum_channelsを同じにしたものなのでこれを利用した。

・InstanceNormalization2dは使えるのに自作ILNレイヤーだとnanになることがある
Normalization系統はepsを基準の1e-5ではなく1e-4にするとエラーがでなくなる

・clampを使うとなぜかnanが現れる
パラメータそのものをclampするとエラーが出る?謎

というわけでLINレイヤーができました。gammaとbetaのパラメータをselfから外してforwardに入れればAdaLINになります。

class ILN(nn.Module):
    def __init__(self, num_features, eps=1e-4,affine=False):
      super(ILN, self).__init__() #epsは1e-4
      self.eps = eps
      self.num_features = num_features
      self.affine = affine
      self.rho = nn.Parameter(torch.Tensor(1, self.num_features, 1, 1))
      self.gamma = nn.Parameter(torch.Tensor(1, self.num_features, 1, 1))
      self.beta = nn.Parameter(torch.Tensor(1, self.num_features, 1, 1))
      self.ones = nn.Parameter(torch.ones(1, self.num_features, 1, 1),requires_grad = False)
      self.rho.data.fill_(0.1) #0にしても別に良さそう
      self.instanceNorm2d = InstanceNorm2d_check_size_1(self.num_features, eps=eps) #そのままだとNaN
      self.layernorm = nn.GroupNorm(num_groups = self.num_features,num_channels = self.num_features,affine=False,eps=eps) #Affineは最後にやるのでここではオフ

      if self.affine:
        self.gamma.data.fill_(1.0)
        self.beta.data.fill_(0.0)
 
    def forward(self, input):
      out = input
      self.rho.data.clamp(0.0,1.0) #そのままではなく.data.clamp()にする
      shape = input.shape
      if (self.training):
        #Normalizationなのでトレーニング時のみ
        out_in = self.instanceNorm2d(input).to(out.device)
        out_ln =  self.layernorm(input).view(shape).to(out.device)
        if not ( torch.isnan(out_in.mean()) ) and( not torch.isnan(out_ln.mean()) ):
          #本当はこのif文は必要なさそうだがデバッグ用、消してもよい
          out = self.rho * out_in + (1.0 -self.rho) * out_ln
          #assert not torch.isnan(out.mean())
          if self.affine:
            out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
          #assert not torch.isnan(out.mean())
        else:
          print("Nan")
      
      return out

せっかく作ったけれど自分のモデルではあまり役立っていない気がする……

#ライセンス
MIT

#引用
[1] Xun Huang, Serge Belongie (2017) Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization arXiv:1703.06868
[2] Junho Kim, Minjae Kim, Hyeonwoo Kang, Kwanghee Lee (2019) U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation arXiv:1907.10830
[3] @zassou65535 メルアイコン変換器を作った話 https://qiita.com/zassou65535/items/4bc42fa36203c13fe2d3

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?