1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchでShake-Shake Regularization(ShakeNet)を実装してみた

Last updated at Posted at 2020-01-26

#Shake-Shake Regularizationとは?

正則化の1つです。
擬似的に学習データを増大させることで、長くゆっくり学習できる利点があります。

データ数が少ないときにdata augmentationとして有効なのかも?
とりあえず今回はCIFAR10で試してみたいと思います。

shake-shake-regularization.jpg

Shake-Shake Regularizationを簡単に説明すると、上のような図になります。
Resnetにおいてredisual blockを2つ並列に作り、residual blockの出力に対し、以下の操作を加えます。

  • 学習時の順伝播では0〜1の乱数αをかける
  • 誤差逆伝播させるときは0〜1の乱数βをかける(αとは別に生成する)
  • 推論時は乱数でなく、定数0.5をかける

詳しくは他の方の記事を参考にしてみてください。私はこちらの記事を読んで理解できました。
https://qiita.com/masataka46/items/fc7f31073c89b02f8a04

#その他、論文で書いてあった細かい点

論文には様々な工夫がされていました。

  • Plainアーキテクチャの流れはReLU→Conv→BN→ReLU→Conv→BN→(乱数αとの)Mul
  • 3ステージに分け、それぞれにResidual Blockが4つずつ
  • ステージごとに32,64,128チャンネルにする
  • ステージ1の前に3×3のConvをかける
  • ステージ3の次に8×8のaverage poolingをかける
  • 最後はfc層
  • 学習率は0.2で、コサインカーブで変動させる
  • 1800epoch学習させる
  • ダウンサンプリング時のショートカットが特殊
  • 画像はランダムフリップと標準化する
  • ミニバッチサイズは128

#residual blockの作成

resnetではplainアーキテクチャとbottleneckアーキテクチャの2つがありますが、今回は論文に倣ってPlainアーキテクチャを使いたいと思います。

test.py
class ResidualPlainBlock(nn.Module):

    def __init__(self, in_channels, out_channels, stride, padding=0):
        super(ResidualPlainBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.conv1 = nn.Conv2d(in_channels,  out_channels, kernel_size=3,  stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels,  kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv1_2 = nn.Conv2d(in_channels,  out_channels, kernel_size=3,  stride=stride, padding=1)
        self.bn1_2 = nn.BatchNorm2d(out_channels)

        self.conv2_2 = nn.Conv2d(out_channels, out_channels,  kernel_size=3, stride=1, padding=1)
        self.bn2_2 = nn.BatchNorm2d(out_channels)

        self.identity = nn.Identity()

        if in_channels != out_channels:
          self.down_avg1 = nn.AvgPool2d(kernel_size=1, stride=1)
          self.down_conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0)
          self.down_pad1 = nn.ZeroPad2d((1,0,1,0))
          self.down_avg2 = nn.AvgPool2d(kernel_size=1, stride=1)
          self.down_conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0)

    #down sampling時の処理が特殊 
    def shortcut(self,x):
      x = F.relu(x)
      h1 = self.down_avg1(x)
      h1 = self.down_conv1(h1)
      h2 = self.down_pad1(x[:,:,1:,1:])
      h2 = self.down_avg1(h2)
      h2 = self.down_conv2(h2)
      return torch.cat((h1,h2),axis=1)


    def forward(self, x):
      if self.training:
        #1つ目のResdual Block
          out = self.bn1(self.conv1(F.relu(x)))
          out = self.bn2(self.conv2(F.relu(out)))
          
        #2つ目のResidual Block
          out2 = self.bn1_2(self.conv1_2(F.relu(x)))
          out2 = self.bn2_2(self.conv2_2(F.relu(out2)))

          if self.in_channels != self.out_channels:
            output = self.shortcut(x) + ShakeShake.apply(out,out2)
          else:
            output = self.identity(x) + ShakeShake.apply(out,out2)
          
          return output
      else:
          out = self.bn1(self.conv1(F.relu(x)))
          out = self.bn2(self.conv2(F.relu(out)))
          
          out2 = self.bn1_2(self.conv1_2(F.relu(x)))
          out2 = self.bn2_2(self.conv2_2(F.relu(out2)))

          if self.in_channels != self.out_channels:
            output = self.shortcut(x) + (out+out2)*0.5
          else:
            output = self.identity(x) + (out+out2)*0.5
          
          return output

コンストラクタはごちゃごちゃしてますが、forward関数を見ればなんとなく理解できるのではないでしょうか。

forward関数の中身は、
1:受け取ったxを2つのblockに与える
2:out,out2を出力させる
3:out,out2を**ShakeShake.apply()**に処理してもらう
4:ショートカットと3を足し合わせて1つにまとめて出力する

##捕捉1:ShakeShake.apply()

ShakeShakeクラスというクラスを定義して、forwardとbackwardでの処理を定義することができます。

test.py
class ShakeShake(torch.autograd.Function):
  @staticmethod
  def forward(ctx, i1, i2):
    alpha = random.random()
    result = i1 * alpha + i2 * (1-alpha)

    return result
  @staticmethod
  def backward(ctx, grad_output):
    beta  = random.random()

    return grad_output * beta, grad_output * (1-beta)

forwardでは乱数alphaを生成してout,out2にかけています。

backwardでは新たに乱数betaを生成して、grad_output(誤差逆伝播により伝わってきた値)にかけています。

##捕捉2:学習率をコサインカーブで変動させる

PyTorchでは学習率をスケジューリングすることができます。

以下のように実装します。

test.py
learning_rate = 0.02
optimizer = optim.SGD(net.parameters(),lr=learning_rate,momentum=0.9,weight_decay=0.0001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0.001)

for i in range(200):
  #for 1epoch
    #ここで1epoch分学習...
  scheduler.step()

こうすることで、1エポックごとに学習率がコサインカーブに沿って変動していきます。

optimizerを定義したら、CosineAnnealingLRというものを呼び出します。

第一引数にoptimizer,第二引数(T_max)はコサインの半周期までのステップ数(エポック数)、第三引数は学習率の最小値です。

上の場合、学習率は50エポックで0.02から0.001まで下がり、その後50エポックで元に戻り、その後50エポックで下がり・・・という風になります。

#実行結果

**精度は89.43%**でした。

論文では95%越えなので、微妙ですね。

ですが、特に工夫してない通常のResNetは精度が80%程度なので、それよりかは精度が良さそうです。

青いのがtrain_accで、オレンジがtest_accです。
17-89.43.png

実は論文通りに実装してないところがいくつかあります。

  • 1800epochでなく200epochしか学習させてない
  • 論文だと学習率の最大値が0.2に設定されているが、それだと誤差がnanになってしまったので0.02に変更した
  • 論文の読み間違いもあるかも?

#最後に

Shake-Shake Regularizationは強力な正則化手法として注目されています。

最近ではShake Dropという新たな手法も考案されているっぽいので、そちらも実装してみます。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?