1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

論文の勉強4 PyramidNet

Posted at

論文の勉強をメモ書きレベルですがのせていきます。あくまでも自分の勉強目的です。
構造部分に注目し、その他の部分は書いていません。ご了承ください。
本当にいい加減・不正確な部分が多数あると思いますのでご理解ください。

今回は、以下の論文のPyramid Netの実装を行います。

タイトル:Deep Pyramidal Residual Networks

ほとんどResNetと変わりません。

PyramidNets

ResNet

ResNetなどでは、特徴マップのサイズが減少するタイミングでチャネル数を増加させていました。
ResNetでは、$n$番目のグループの$k$番目のresidualユニットにおける特徴マップのチャネル数$D_k$は、

D_k=\left\{\begin{array}{ll}
16&n(k)=1\\
16・2^{n(k)-2}&n(k)\geq1
\end{array}\right.

となります。ここで、$n(k)\in{1,2,3,4}$ は、$k$番目のresidualユニットが属するグループのインデックスを表しています。
同じグループに属しているユニットの特徴マップは同じサイズであり、$n$番目のグループは$N_n$個のユニットを含むものとします。
1番目のグループは、RGB画像を複数のチャネルに変換する1つの畳み込み層からなります。
n番目のグループでは、$N_n$個のユニットを通して特徴マップのサイズは半分に、チャネル数は2倍となります。

additive PyramidNet

image.png

D_k=\left\{\begin{array}{ll}
16&k=1\\
\lfloor D_{k-1}+\alpha/N\rfloor &2\leq k\leq N+1
\end{array}\right.

ここで、$N$はresidualユニットの総数であり、
$$
N=\sum_{n=2}^4N_n
$$
となります。
チャネル数は$\alpha / N$ずつ増加、各グループのユニット数が同じであれば各グループの最後のユニットでのチャネル数は$16+(n-1)\alpha/3$と計算できます。$\alpha$はどの程度チャネル数を増加させるかを表す因子です。
また、$\lfloor \rfloor$は床関数と呼ばれるもので、その値を超えない整数を表します。

multiplicative PyramidNet

image.png

D_k=\left\{\begin{array}{ll}
16&k=1\\
\lfloor D_{k-1}・\alpha^{1/N}\rfloor &2\leq k\leq N+1
\end{array}\right.

additiveのものと比べ、input側ではチャネル数はゆっくり変化するが、outputに近くなるほど急に変化します。
すべてのユニットでチャネル数が増加するので、shortcut connectionではzero-paddingでチャネル数を合わせます。

image.png

ここでは、bottleneckで、$N=272(N_n=(272-2)/9=30)$で$\alpha=200$のものを実装します。
inputは1層目で16チャネルに変換され、ユニットを通るごとに$200/90$ずつ増加していきます。
各グループで(66~67)×4チャネルずつ増加し、最終的には(16+200)×4=864チャネルとなります。

さらに、baseのもので$N=110(N_n=(110-2)/6=18)$で$\alpha=270$のものを実装します。

image.png

学習

最適化手法はSGD(Nesterov)で学習率は0.1、エポック数が150と225のときに0.1を掛けて小さくしていきます。
weight decayは0.0001、 momentumを0.9に設定し、バッチサイズは128とします。

実装

keras

必要なライブラリのインポートを行います。

import tensorflow.keras as keras
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Conv2D, Activation, MaxPooling2D, Flatten, Dense, Dropout, GlobalAveragePooling2D, BatchNormalization, Add
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import LearningRateScheduler
from keras.datasets import cifar10
import numpy as np
import cv2

bottleneckブロックの実装をします。

class res_bottleneck_block(Model):
    """wide residual block"""
    def __init__(self, block_num, layer_num, Nn, N, alpha):
        super(res_bottleneck_block, self).__init__(name='block'+'_'+str(block_num)+'_'+str(layer_num))        
        block_name = '_'+str(block_num)+'_'+str(layer_num)
        
        # shortcutとstrideの設定
        self._is_change=False
        if (layer_num == 0):
            # 最初のresblockは(W、 H)は変更しないのでstrideは1にする
            if (block_num==1):
                stride = 1
            else:
                stride = 2
                self._is_change=True
        else:
            stride = 1

        self.maxpooling=MaxPooling2D(pool_size=(stride, stride), strides=(stride, stride), padding='valid')

        bneck_channels = np.ceil(16+(Nn*(block_num-1)+layer_num+1)*alpha/N)
        out_channels = bneck_channels*4
        
        self.bn0 = BatchNormalization(name='bn0'+block_name)
        
        # 1層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をbneck_channelsにします
        self.conv1 = Conv2D(bneck_channels, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv1'+block_name)
        self.bn1 = BatchNormalization(name='bn1'+block_name)
        self.act1 = Activation('relu', name='act1'+block_name)
        
        # 2層目 3×3 畳み込み処理を行います
        self.conv2 = Conv2D(bneck_channels, kernel_size=3, strides=stride, padding='same', use_bias=False, name='conv2'+block_name)
        self.bn2 = BatchNormalization(name='bn2'+block_name)
        self.act2 = Activation('relu', name='act2'+block_name)
        
        # 3層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をout_channelsにします
        self.conv3 = Conv2D(out_channels, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv3'+block_name)
        self.bn3 = BatchNormalization(name='bn3'+block_name)
        
        self.add = Add(name='add'+block_name)
        
    def call(self, x):
        out = self.bn0(x)
        
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.act1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.act2(out)
        
        out = self.conv3(out)
        out = self.bn3(out)

        if K.int_shape(x) != K.int_shape(out):
            if self._is_change:
                x = self.maxpooling(x)
            if K.int_shape(x)[3] != K.int_shape(out)[3]:                
                x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, K.int_shape(out)[3] - K.int_shape(x)[3]]])

            shortcut = x
        else:
            shortcut = x
        
        out = self.add([out, shortcut])
        return out

base blockの実装もします。

class res_base_block(Model):
    """wide residual block"""
    def __init__(self, block_num, layer_num, Nn, N, alpha):
        super(res_base_block, self).__init__(name='block'+'_'+str(block_num)+'_'+str(layer_num))        
        block_name = '_'+str(block_num)+'_'+str(layer_num)
        
        # shortcutとstrideの設定
        self._is_change=False
        if (layer_num == 0):
            # 最初のresblockは(W、 H)は変更しないのでstrideは1にする
            if (block_num==1):
                stride = 1
            else:
                stride = 2
                self._is_change=True
        else:
            stride = 1

        self.maxpooling=MaxPooling2D(pool_size=(stride, stride), strides=(stride, stride), padding='valid')

        out_channels = np.ceil(16+(Nn*(block_num-1)+layer_num+1)*alpha/N)
        
        self.bn0 = BatchNormalization(name='bn0'+block_name)
        
        # 1層目 3×3 畳み込み処理は行わず(線形変換)、チャネル数をbneck_channelsにします
        self.conv1 = Conv2D(out_channels, kernel_size=3, strides=stride, padding='same', use_bias=False, name='conv1'+block_name)
        self.bn1 = BatchNormalization(name='bn1'+block_name)
        self.act1 = Activation('relu', name='act1'+block_name)
        
        # 2層目 3×3 畳み込み処理を行います
        self.conv2 = Conv2D(out_channels, kernel_size=3, strides=1, padding='same', use_bias=False, name='conv2'+block_name)
        self.bn2 = BatchNormalization(name='bn2'+block_name)
        
        self.add = Add(name='add'+block_name)
        
    def call(self, x):
        out = self.bn0(x)
        
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.act1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)

        if K.int_shape(x) != K.int_shape(out):
            if self._is_change:
                x = self.maxpooling(x)
            if K.int_shape(x)[3] != K.int_shape(out)[3]:                
                x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, K.int_shape(out)[3] - K.int_shape(x)[3]]])

            shortcut = x
        else:
            shortcut = x
        
        out = self.add([out, shortcut])
        return out

PyramidNetの実装をします。
引数でbottlenackとbaseを指定できるようにします。
ブロックの数の計算が変わってきます。

class PyramidNet(Model):
    def __init__(self,alpha = 200, N = 272, block=res_base_block):
        super(PyramidNet, self).__init__()
        self._layers = []
        
        if block == res_bottleneck_block:
            Nn = (N-2)//9
            N_all = N//3
        else:
            Nn = (N-2)//6
            N_all = N//2
            
        # 入力層
        self._layers += [
            Conv2D(filters = 16, kernel_size = (3,3), strides = 1, padding = 'same', name='conv_input'),
            BatchNormalization(name='bn_input'),
            Activation('relu', name='act_input')
        ]

        # Residualブロック
        self._layers += [block(block_num=1, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(Nn)]
        self._layers += [block(block_num=2, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(Nn)]
        self._layers += [block(block_num=3, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(N_all-2*Nn)]
        
        # 出力層
        self._layers += [
            GlobalAveragePooling2D(name='pool_output'),
            Dense(10, activation='softmax', name='output')
        ]
    
    def call(self, x):
        for layer in self._layers:
            #print(layer)
            x = layer(x)
        return x

モデルの確認をします。

model = PyramidNet(alpha = 270, N = 110, block=res_base_block)
#model = PyramidNet(block=res_bottleneck_block)
model.build((None, 32, 32, 3))  # build with input shape.
dummy_input = Input(shape=(32, 32, 3))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=model.call(dummy_input), name="pretrained")
model_summary.summary()

ますは、bottlenackバージョンです。
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_9 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv_input (Conv2D) (None, 32, 32, 16) 448
_________________________________________________________________
bn_input (BatchNormalization (None, 32, 32, 16) 64
_________________________________________________________________
act_input (Activation) (None, 32, 32, 16) 0
_________________________________________________________________
block_1_0 (res_bottleneck_bl (None, 32, 32, 76) 5517
_________________________________________________________________
block_1_1 (res_bottleneck_bl (None, 32, 32, 84) 8137
_________________________________________________________________
中略
_________________________________________________________________

block_3_27 (res_bottleneck_b (None, 8, 8, 848)         770800    
_________________________________________________________________
block_3_28 (res_bottleneck_b (None, 8, 8, 856)         785348    
_________________________________________________________________
block_3_29 (res_bottleneck_b (None, 8, 8, 864)         800032    
_________________________________________________________________
pool_output (GlobalAveragePo (None, 864)               0         
_________________________________________________________________
output (Dense)               (None, 10)                8650      
=================================================================
Total params: 26,574,858
Trainable params: 26,364,922
Non-trainable params: 209,936
_________________________________________________________________

次に、baseバージョンです。
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_8 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv_input (Conv2D) (None, 32, 32, 16) 448
_________________________________________________________________
bn_input (BatchNormalization (None, 32, 32, 16) 64
_________________________________________________________________
act_input (Activation) (None, 32, 32, 16) 0
_________________________________________________________________
block_1_0 (res_base_block) (None, 32, 32, 21) 7225
_________________________________________________________________
block_1_1 (res_base_block) (None, 32, 32, 26) 11290
_________________________________________________________________
中略
_________________________________________________________________

block_3_16 (res_base_block)  (None, 8, 8, 277)         1371961   
_________________________________________________________________
block_3_17 (res_base_block)  (None, 8, 8, 282)         1422106   
_________________________________________________________________
block_3_18 (res_base_block)  (None, 8, 8, 286)         1465448   
_________________________________________________________________
pool_output (GlobalAveragePo (None, 286)               0         
_________________________________________________________________
output (Dense)               (None, 10)                2870      
=================================================================
Total params: 29,198,857
Trainable params: 29,148,575
Non-trainable params: 50,282
_________________________________________________________________

image.png

どちらも論文のもの(PyramidNet($\alpha=270$ ),PyramidNet(bottleneck $\alpha=200$ ))と比べパラメータ数が大きく差はありません。

学習用の設定までを実装して終わります。

# 学習率を返す関数を用意する
def lr_schedul(epoch):
    x = 0.1
    if epoch >= 150:
        x = 0.1*0.1
    if epoch >= 225:
        x = 0.1*(0.1**2)
    return x


lr_decay = LearningRateScheduler(
    lr_schedul,
    verbose=1,
)

sgd = SGD(lr=0.1, momentum=0.9, decay=1e-4, nesterov=True)
model.compile(loss=['categorical_crossentropy'], optimizer=sgd, metrics=['accuracy'])

pytorch

次はpytorchです。
ほとんどkerasと変わりません。

import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

import pytorch_lightning as pl
from torchmetrics import Accuracy as accuracy
class res_bottleneck_block(nn.Module):
    """residual block"""
    def __init__(self, block_num, layer_num, Nn, N, alpha):
        super(res_bottleneck_block, self).__init__()
        
        # 1番目のブロック以外はチャンネル数がinputとoutputで変わる(output=4×input)
        input_channels = int(np.floor(16+(Nn*(block_num-1)+layer_num)*alpha/N)*4)
        
        bneck_channels = int(np.floor(16+(Nn*(block_num-1)+layer_num+1)*alpha/N))
        out_channels = int(bneck_channels*4)
        
        # shortcutとstrideの設定
        self._is_change = False
        if (layer_num == 0):
            input_channels=int(np.floor(16+(Nn*(block_num-2)+Nn)*alpha/N)*4)
            # 最初のresblockは(W、 H)は変更しないのでstrideは1にする
            if (block_num==1):
                stride = 1
                input_channels=16
            else:
                self._is_change = True
                self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
                stride = 2
        else:
            stride = 1
        
        self.pad_channels = out_channels - input_channels
        
        self.bn0 = nn.BatchNorm2d(input_channels)
        
        # 1層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をbneck_channelsにします
        self.conv1 = nn.Conv2d(input_channels, bneck_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(bneck_channels)
        
        # 2層目 3×3 畳み込み処理を行います
        self.conv2 = nn.Conv2d(bneck_channels, bneck_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(bneck_channels)
        
        # 3層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をout_channelsにします
        self.conv3 = nn.Conv2d(bneck_channels, out_channels, kernel_size=1)
        self.bn3 = nn.BatchNorm2d(out_channels)
                
        self.relu = nn.ReLU(inplace=True)
        

    def forward(self, x):
        shortcut = x
        
        out = self.bn0(x)
        
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # Projection shortcutの場合
        if self._is_change:
            shortcut = self.maxpool(shortcut)

        p3d = (0, 0, 0, 0, 0, self.pad_channels) # pad by (0, 1), (2, 1), and (3, 3)
        shortcut = torch.nn.functional.pad(shortcut, p3d, "constant", 0)
        
        out += shortcut
        out = self.relu(out)

        return out
class res_base_block(nn.Module):
    """residual block"""
    def __init__(self, block_num, layer_num, Nn, N, alpha):
        super(res_base_block, self).__init__()
        
        # 1番目のブロック以外はチャンネル数がinputとoutputで変わる(output=4×input)
        input_channels = int(np.ceil(16+(Nn*(block_num-1)+layer_num)*alpha/N))
        
        out_channels = int(np.ceil(16+(Nn*(block_num-1)+layer_num+1)*alpha/N))
        
        # shortcutとstrideの設定
        self._is_change = False
        if (layer_num == 0):
            input_channels=int(np.ceil(16+(Nn*(block_num-2)+Nn)*alpha/N))
            # 最初のresblockは(W、 H)は変更しないのでstrideは1にする
            if (block_num==1):
                stride = 1
                input_channels=16
            else:
                self._is_change = True
                self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
                stride = 2
        else:
            stride = 1
        
        self.pad_channels = out_channels - input_channels
        print(block_num,input_channels,out_channels)
        
        self.bn0 = nn.BatchNorm2d(input_channels)
        
        # 1層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をbneck_channelsにします
        self.conv1 = nn.Conv2d(input_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # 2層目 3×3 畳み込み処理を行います
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
                
        self.relu = nn.ReLU(inplace=True)
        

    def forward(self, x):
        shortcut = x
        
        out = self.bn0(x)
        
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # Projection shortcutの場合
        if self._is_change:
            shortcut = self.maxpool(shortcut)

        p3d = (0, 0, 0, 0, 0, self.pad_channels) # pad by (0, 1), (2, 1), and (3, 3)
        shortcut = torch.nn.functional.pad(shortcut, p3d, "constant", 0)
        out += shortcut
        out = self.relu(out)

        return out
class PyramidNet(nn.Module):
    def __init__(self,alpha = 200, N = 272, num_classes=10, block=res_base_block):
        super(PyramidNet, self).__init__()
        
        if block == res_bottleneck_block:
            Nn = (N-2)//9
            N_all = N//3
            in_features = int((16 + alpha)*4)
        else:
            Nn = (N-2)//6
            N_all = N//2
            in_features = 16 + alpha
        
        conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        bn1 = nn.BatchNorm2d(16)
        relu1 = nn.ReLU(inplace=True)
        
        self.conv1 = nn.Sequential(*[conv1, bn1, relu1])
        
        self.conv2_x = nn.Sequential(*[block(block_num=1, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(Nn)])
        self.conv3_x = nn.Sequential(*[block(block_num=2, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(Nn)])
        self.conv4_x = nn.Sequential(*[block(block_num=3, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(N_all-2*Nn)])
        
        pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Sequential(*[pool])
        #int(np.floor(16+(Nn*(block_num-1)+layer_num)*alpha/N))
        self.linear = nn.Linear(in_features=in_features, out_features=num_classes)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2_x(out)
        out = self.conv3_x(out)
        out = self.conv4_x(out)
        out = self.fc(out)
        out = out.view(out.shape[0], -1)
        out = self.linear(out)
        
        return out

まず、baseモデルの詳細を見てみます。

summary(PyramidNet(alpha = 270, N = 110, block=res_base_block), (3,32,32))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 32, 32]             448
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0

中略

  res_base_block-445            [-1, 286, 8, 8]               0
AdaptiveAvgPool2d-446            [-1, 286, 1, 1]               0
          Linear-447                   [-1, 10]           2,870
================================================================
Total params: 29,165,505
Trainable params: 29,165,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 132.18
Params size (MB): 111.26
Estimated Total Size (MB): 243.45
----------------------------------------------------------------

次に、bottleneckモデルの詳細を見てみます。

summary(PyramidNet(block=res_bottleneck_block), (3,32,32))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 32, 32]             448
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0

中略


          Conv2d-992            [-1, 864, 8, 8]         187,488
     BatchNorm2d-993            [-1, 864, 8, 8]           1,728
            ReLU-994            [-1, 864, 8, 8]               0
res_bottleneck_block-995            [-1, 864, 8, 8]               0
AdaptiveAvgPool2d-996            [-1, 864, 1, 1]               0
          Linear-997                   [-1, 10]           8,650
================================================================
Total params: 26,110,850
Trainable params: 26,110,850
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 555.40
Params size (MB): 99.60
Estimated Total Size (MB): 655.02
----------------------------------------------------------------

学習用のclassを定義して終わります。

class PRNTrainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = PyramidNet(block=res_bottleneck_block)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch 
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def test_step(self, batch, batch_nb):
        x, y = batch
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        y_label = torch.argmax(y_hat, dim=1)
        acc = accuracy()(y_label, y)
        return {'test_loss': loss, 'test_acc': acc}
    
    def training_epoch_end(self, train_step_output):
        y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in train_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)
        
        print('---------- Current Epoch {} ----------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loass, acc))
    
    def validation_epoch_end(self, val_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)
        
        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
    
    def test_epoch_end(self, test_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in test_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in test_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in test_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('test_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('test_acc', acc, prog_bar=True, on_epoch=True)
        
        print('test Loss: {:.4f} test Acc: {:.4f}'.format(epoch_loss, acc))
        
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 225], gamma=0.1)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}

不正確・わかりずらい部分が多々あり申し訳ありません。
自分の勉強のための活動ですのでご了承ください。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?