1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

論文の勉強3 WideResNet

Last updated at Posted at 2021-12-16

論文の勉強をメモ書きレベルですがのせていきます。
構造部分に注目し、その他の部分は書いていません。ご了承ください。

今回は、以下の論文のWide ResNetの実装を行います。

実装はほとんどResNetと変わりません。

wide residual network

恒等変換を使用したresidual blockは、

x_{l+1}=x_l+F(x_l,W_l)

で表されます。
ここで、$x_l,x_{l+1}$はネットワークの$l$番目のブロックの出力で、$F$は残差関数、そして$W_l$はそのネットワークでの重みを表します。
residuak networkはこのブロックの積み重ねとなります。

residual networkには2種類のブロックがあります。

  • basic 2つの3×3の畳み込み層で構成されます
  • bottleneck 3×3の畳み込み層を1×1の畳み込み層ではさみこんだ構成となります

今回は、basicに注目して話を進めていきます。
このbasicnの2つの3×3の畳み込み層をB(3,3)、bottlebeckはB(1,3,1)とここでは表すこととします。
ここで、この論文で出てくるresidual blockのイメージを示します。

image.png

今回はResNetのときより、よりワイドな構造をもったresidual blockを実装していきます。
ResNetでBatchNormalizationとReLUは、「conv-BN-ReLU」という順番で挿入していましたが、「BN-ReLU-conv」という順番に変更します。
畳み込み層をReLUの間にドロップアウト層を入れることで過学習を防いでいます。

構造

本来、さまざまな層の深さやブロック内での畳み込み層の数などを比較していますが、ここでは割愛します。
詳細は論文を参照してください。

今回実装する構造は以下に示すものです。

image.png

ここで、$k$はチャンネル数の大きさを表すもので$k=1$は「薄い」ネットワークであり、$k>1$のときに「ワイドな」ネットワークであるとします。
また、$N$は各ブロックの繰り返し数を決める値で、全体の層の数$n(depth)$から求められます。
ここでは、全体の層の数が$n$、チャンネルの大きさを決める因子が$k$のresidualネットワークをWRN-n-kと表すこととします。
さらに、ブロックの種類がB(3,3)であるときは、WRN-n-k-B(3,3)と表します。

学習

今回は、WRN-28-10モデルを実装していきます。

入力は、32×32のカラー画像とします。
学習には、Nesterovの加速勾配降下法を使用したSGDと、交差エントロピー誤差を使用します。
学習率は0.1で、weight decayを0.0005、momentumを0.9とします。
学習率は、エポック数が60、120、160のときに0.2倍して減衰させていきます。

実装

keras

必要なライブラリのインポートを行います。

import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Conv2D, Activation, MaxPooling2D, Flatten, Dense, Dropout, GlobalAveragePooling2D, BatchNormalization, Add
from keras import backend as K
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import LearningRateScheduler
from keras.datasets import cifar10
import numpy as np
import cv2

wide residual blockの実装をします。B(3,3)タイプのものです。
実装はResNetとほぼ同じです。

class wide_res_block(Model):
    """wide residual block"""
    def __init__(self, out_channels, block_num, layer_num):
        super(wide_res_block, self).__init__(name='block'+'_'+str(block_num)+'_'+str(layer_num))        
        block_name = '_'+str(block_num)+'_'+str(layer_num)
        
        # shortcutとstrideの設定
        if (layer_num == 0):
            # 最初のresblockは(W、 H)は変更しないのでstrideは1にする
            if (block_num==1):
                self._is_change = False
                stride = 1
            else:
                self._is_change = True
                stride = 2
            self.conv_sc = Conv2D(out_channels, kernel_size=1, strides=stride, padding='same', use_bias=False, name='conv_sc'+block_name)
            self.bn_sc = BatchNormalization(name='bn_sc'+block_name)
        else:
            self._is_change = False
            stride = 1

        # 1層目 3×3 畳み込み処理を行います
        self.bn1 = BatchNormalization(name='bn1'+block_name)
        self.act1 = Activation('relu', name='act1'+block_name)
        self.drop1 = Dropout(rate=0.3, name='drop1'+block_name)
        self.conv1 = Conv2D(out_channels, kernel_size=3, strides=stride, padding='same', use_bias=False, name='conv1'+block_name)
        
        # 2層目 3×3 畳み込み処理を行います
        self.bn2 = BatchNormalization(name='bn2'+block_name)
        self.act2 = Activation('relu', name='act2'+block_name)
        self.drop2 = Dropout(rate=0.3, name='drop2'+block_name)
        self.conv2 = Conv2D(out_channels, kernel_size=3, strides=1, padding='same', use_bias=False, name='conv2'+block_name)
        
        self.add = Add(name='add'+block_name)
        
    def call(self, x):
        out = self.bn1(x)
        out = self.act1(out)
        out = self.drop1(out)
        out = self.conv1(out)
        
        out = self.bn2(out)
        out = self.act2(out)
        out = self.drop2(out)
        out = self.conv2(out)
        
        if K.int_shape(x) != K.int_shape(out):
            shortcut = self.conv_sc(x)
            shortcut = self.bn_sc(shortcut)
        else:
            shortcut = x
        
        out = self.add([out, shortcut])
        return out

WideResNetを実装していきます。
depthとkを引数に与えて構築します。

class WideResNet(Model):
    def __init__(self, depth, k, num_classes=10):
        super().__init__()
        self._layers = []
        
        # 各ネットワークでのブロックの数
        N = (depth - 4) // 6
        
        # 入力層
        self._layers += [
            Conv2D(filters = 16, kernel_size = 3, strides = 1, padding = 'same', name='conv_input')
        ]

        # Residualブロック
        self._layers += [wide_res_block(out_channels=16*k, block_num=1, layer_num=i) for i in range(N)]
        self._layers += [wide_res_block(out_channels=32*k, block_num=2, layer_num=i) for i in range(N)]
        self._layers += [wide_res_block(out_channels=64*k, block_num=3, layer_num=i) for i in range(N)]

        # 出力層
        self._layers += [
            BatchNormalization(name='bn_input'),
            Activation('relu', name='act_input'),
            GlobalAveragePooling2D(name='pool_output'),
            Dense(num_classes, activation='softmax', name='output')
        ]
    
    def call(self, x):
        for layer in self._layers:
            x = layer(x)
        return x

ここで、構造の確認をします。
WideResNetでは、入力は32×32×です。

model = WideResNet(depth=28, k=10, num_classes=10)
model.build((None, 32, 32, 3))  # build with input shape.
dummy_input = Input(shape=(32, 32, 3))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=model.call(dummy_input))
model_summary.summary()

image.png

今回の学習ではエポック数に応じて学習率を小さくしていきます。

# 学習率を返す関数を用意する
def lr_schedul(epoch):
    x = 0.1
    if epoch >= 60:
        x = 0.1*0.2
    if epoch >= 120:
        x = 0.1*(0.2**2)
    if epoch >= 160:
        x = 0.1*(0.2**3)
    return x


lr_decay = LearningRateScheduler(
    lr_schedul,
    verbose=1,
)

sgd = SGD(lr=0.1, momentum=0.9, decay=1e-4, nesterov=True)
model.compile(loss=['categorical_crossentropy'], optimizer=sgd, metrics=['accuracy'])

pytorch

次にpytorchで実装を行います。
ほとんどResNetと同じため説明は割愛します。

import torch
import torch.nn as nn
import torch.optim as optim

import pytorch_lightning as pl
from torchmetrics import Accuracy as accuracy
class wide_res_block(nn.Module):
    """wide residual block"""
    def __init__(self, out_channels, block_num, layer_num):
        super(wide_res_block, self).__init__()
        
        # 1番目のブロック以外はチャンネル数がinputとoutputで変わる(output=4×input)
        if (layer_num==0):
            if (block_num==1):
                input_channels = 16
            else:
                input_channels = out_channels//2
        else:
            input_channels = out_channels

        # shortcutとstrideの設定
        if (layer_num == 0):
            self._is_change = True
            # 最初のresblockは(W、 H)は変更しないのでstrideは1にする
            if (block_num==1):
                stride = 1
            else:
                stride = 2
            
            self.conv_sc = nn.Conv2d(input_channels, out_channels, kernel_size=1, stride=stride)
            self.bn_sc = nn.BatchNorm2d(out_channels)
        else:
            self._is_change = False
            stride = 1

        # 1層目 3×3 畳み込み処理を行います
        self.bn1 = nn.BatchNorm2d(input_channels)
        self.drop1 = nn.Dropout(p=0.3)
        self.conv1 = nn.Conv2d(input_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        
        
        # 2層目 3×3 畳み込み処理を行います
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.drop2 = nn.Dropout(p=0.3)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        self.relu = nn.ReLU(inplace=True)
        

    def forward(self, x):
        shortcut = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.drop1(out)
        out = self.conv1(out)
        
        out = self.bn2(out)
        out = self.relu(out)
        out = self.drop2(out)
        out = self.conv2(out)
        
        # Projection shortcutの場合
        if self._is_change:
            shortcut = self.conv_sc(shortcut)
            shortcut = self.bn_sc(shortcut)
        
        out += shortcut
        return out
class WideResNet(nn.Module):
    def __init__(self, depth, k, num_classes=10):
        super(WideResNet, self).__init__()
        
        # 各ネットワークでのブロックの数
        N = (depth - 4) // 6
        
        conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv1 = nn.Sequential(*[conv1])
        
        self.conv2_x = nn.Sequential(*[wide_res_block(out_channels=16*k, block_num=1, layer_num=i) for i in range(N)])
        self.conv3_x = nn.Sequential(*[wide_res_block(out_channels=32*k, block_num=2, layer_num=i) for i in range(N)])
        self.conv4_x = nn.Sequential(*[wide_res_block(out_channels=64*k, block_num=3, layer_num=i) for i in range(N)])
        
        bn = nn.BatchNorm2d(64*k)
        relu = nn.ReLU(inplace=True)
        pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Sequential(*[bn, relu, pool])
        
        self.linear = nn.Linear(in_features=640, out_features=num_classes)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2_x(out)
        out = self.conv3_x(out)
        out = self.conv4_x(out)
        out = self.fc(out)
        out = out.view(out.shape[0], -1)
        out = self.linear(out)
        
        return out
from torchsummary import summary

summary(WideResNet(depth=28, k=10, num_classes=10), (3,32,32))

image.png

エポック数に応じて、学習率を下げるときにはMultiStepLRという関数が用意されています。

class WRNTrainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = WideResNet(depth=28, k=10, num_classes=10)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch 
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def test_step(self, batch, batch_nb):
        x, y = batch
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        y_label = torch.argmax(y_hat, dim=1)
        acc = accuracy()(y_label, y)
        return {'test_loss': loss, 'test_acc': acc}
    
    def training_epoch_end(self, train_step_output):
        y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in train_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)
        
        print('---------- Current Epoch {} ----------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loass, acc))
    
    def validation_epoch_end(self, val_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)
        
        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
    
    # New: テストデータに対するエポックごとの処理
    def test_epoch_end(self, test_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in test_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in test_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in test_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('test_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('test_acc', acc, prog_bar=True, on_epoch=True)
        
        print('test Loss: {:.4f} test Acc: {:.4f}'.format(epoch_loss, acc))
        
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}

これでWideResNetの実装を終わります。
ResNetの派生形をしばらく扱っていきます。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?