#Shake-Shake Regularizationとは?
正則化の1つです。
擬似的に学習データを増大させることで、長くゆっくり学習できる利点があります。
データ数が少ないときにdata augmentationとして有効なのかも?
とりあえず今回はCIFAR10で試してみたいと思います。
Shake-Shake Regularizationを簡単に説明すると、上のような図になります。
Resnetにおいてredisual blockを2つ並列に作り、residual blockの出力に対し、以下の操作を加えます。
- 学習時の順伝播では0〜1の乱数αをかける
- 誤差逆伝播させるときは0〜1の乱数βをかける(αとは別に生成する)
- 推論時は乱数でなく、定数0.5をかける
詳しくは他の方の記事を参考にしてみてください。私はこちらの記事を読んで理解できました。
https://qiita.com/masataka46/items/fc7f31073c89b02f8a04
#その他、論文で書いてあった細かい点
論文には様々な工夫がされていました。
- Plainアーキテクチャの流れはReLU→Conv→BN→ReLU→Conv→BN→(乱数αとの)Mul
- 3ステージに分け、それぞれにResidual Blockが4つずつ
- ステージごとに32,64,128チャンネルにする
- ステージ1の前に3×3のConvをかける
- ステージ3の次に8×8のaverage poolingをかける
- 最後はfc層
- 学習率は0.2で、コサインカーブで変動させる
- 1800epoch学習させる
- ダウンサンプリング時のショートカットが特殊
- 画像はランダムフリップと標準化する
- ミニバッチサイズは128
#residual blockの作成
resnetではplainアーキテクチャとbottleneckアーキテクチャの2つがありますが、今回は論文に倣ってPlainアーキテクチャを使いたいと思います。
class ResidualPlainBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, padding=0):
super(ResidualPlainBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv1_2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1_2 = nn.BatchNorm2d(out_channels)
self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2_2 = nn.BatchNorm2d(out_channels)
self.identity = nn.Identity()
if in_channels != out_channels:
self.down_avg1 = nn.AvgPool2d(kernel_size=1, stride=1)
self.down_conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0)
self.down_pad1 = nn.ZeroPad2d((1,0,1,0))
self.down_avg2 = nn.AvgPool2d(kernel_size=1, stride=1)
self.down_conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0)
#down sampling時の処理が特殊
def shortcut(self,x):
x = F.relu(x)
h1 = self.down_avg1(x)
h1 = self.down_conv1(h1)
h2 = self.down_pad1(x[:,:,1:,1:])
h2 = self.down_avg1(h2)
h2 = self.down_conv2(h2)
return torch.cat((h1,h2),axis=1)
def forward(self, x):
if self.training:
#1つ目のResdual Block
out = self.bn1(self.conv1(F.relu(x)))
out = self.bn2(self.conv2(F.relu(out)))
#2つ目のResidual Block
out2 = self.bn1_2(self.conv1_2(F.relu(x)))
out2 = self.bn2_2(self.conv2_2(F.relu(out2)))
if self.in_channels != self.out_channels:
output = self.shortcut(x) + ShakeShake.apply(out,out2)
else:
output = self.identity(x) + ShakeShake.apply(out,out2)
return output
else:
out = self.bn1(self.conv1(F.relu(x)))
out = self.bn2(self.conv2(F.relu(out)))
out2 = self.bn1_2(self.conv1_2(F.relu(x)))
out2 = self.bn2_2(self.conv2_2(F.relu(out2)))
if self.in_channels != self.out_channels:
output = self.shortcut(x) + (out+out2)*0.5
else:
output = self.identity(x) + (out+out2)*0.5
return output
コンストラクタはごちゃごちゃしてますが、forward関数を見ればなんとなく理解できるのではないでしょうか。
forward関数の中身は、
1:受け取ったxを2つのblockに与える
2:out,out2を出力させる
3:out,out2を**ShakeShake.apply()**に処理してもらう
4:ショートカットと3を足し合わせて1つにまとめて出力する
##捕捉1:ShakeShake.apply()
ShakeShakeクラスというクラスを定義して、forwardとbackwardでの処理を定義することができます。
class ShakeShake(torch.autograd.Function):
@staticmethod
def forward(ctx, i1, i2):
alpha = random.random()
result = i1 * alpha + i2 * (1-alpha)
return result
@staticmethod
def backward(ctx, grad_output):
beta = random.random()
return grad_output * beta, grad_output * (1-beta)
forwardでは乱数alphaを生成してout,out2にかけています。
backwardでは新たに乱数betaを生成して、grad_output(誤差逆伝播により伝わってきた値)にかけています。
##捕捉2:学習率をコサインカーブで変動させる
PyTorchでは学習率をスケジューリングすることができます。
以下のように実装します。
learning_rate = 0.02
optimizer = optim.SGD(net.parameters(),lr=learning_rate,momentum=0.9,weight_decay=0.0001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0.001)
for i in range(200):
#for 1epoch
#ここで1epoch分学習...
scheduler.step()
こうすることで、1エポックごとに学習率がコサインカーブに沿って変動していきます。
optimizerを定義したら、CosineAnnealingLRというものを呼び出します。
第一引数にoptimizer,第二引数(T_max)はコサインの半周期までのステップ数(エポック数)、第三引数は学習率の最小値です。
上の場合、学習率は50エポックで0.02から0.001まで下がり、その後50エポックで元に戻り、その後50エポックで下がり・・・という風になります。
#実行結果
**精度は89.43%**でした。
論文では95%越えなので、微妙ですね。
ですが、特に工夫してない通常のResNetは精度が80%程度なので、それよりかは精度が良さそうです。
青いのがtrain_accで、オレンジがtest_accです。
実は論文通りに実装してないところがいくつかあります。
- 1800epochでなく200epochしか学習させてない
- 論文だと学習率の最大値が0.2に設定されているが、それだと誤差がnanになってしまったので0.02に変更した
- 論文の読み間違いもあるかも?
#最後に
Shake-Shake Regularizationは強力な正則化手法として注目されています。
最近ではShake Dropという新たな手法も考案されているっぽいので、そちらも実装してみます。