論文の勉強をメモ書きレベルですがのせていきます。あくまでも自分の勉強目的です。
構造部分に注目し、その他の部分は書いていません。ご了承ください。
本当にいい加減・不正確な部分が多数あると思いますのでご理解ください。
今回は、以下の論文のPyramid Netの実装を行います。
タイトル:Deep Pyramidal Residual Networks
ほとんどResNetと変わりません。
PyramidNets
ResNet
ResNetなどでは、特徴マップのサイズが減少するタイミングでチャネル数を増加させていました。
ResNetでは、$n$番目のグループの$k$番目のresidualユニットにおける特徴マップのチャネル数$D_k$は、
D_k=\left\{\begin{array}{ll}
16&n(k)=1\\
16・2^{n(k)-2}&n(k)\geq1
\end{array}\right.
となります。ここで、$n(k)\in{1,2,3,4}$ は、$k$番目のresidualユニットが属するグループのインデックスを表しています。
同じグループに属しているユニットの特徴マップは同じサイズであり、$n$番目のグループは$N_n$個のユニットを含むものとします。
1番目のグループは、RGB画像を複数のチャネルに変換する1つの畳み込み層からなります。
n番目のグループでは、$N_n$個のユニットを通して特徴マップのサイズは半分に、チャネル数は2倍となります。
additive PyramidNet
D_k=\left\{\begin{array}{ll}
16&k=1\\
\lfloor D_{k-1}+\alpha/N\rfloor &2\leq k\leq N+1
\end{array}\right.
ここで、$N$はresidualユニットの総数であり、
$$
N=\sum_{n=2}^4N_n
$$
となります。
チャネル数は$\alpha / N$ずつ増加、各グループのユニット数が同じであれば各グループの最後のユニットでのチャネル数は$16+(n-1)\alpha/3$と計算できます。$\alpha$はどの程度チャネル数を増加させるかを表す因子です。
また、$\lfloor \rfloor$は床関数と呼ばれるもので、その値を超えない整数を表します。
multiplicative PyramidNet
D_k=\left\{\begin{array}{ll}
16&k=1\\
\lfloor D_{k-1}・\alpha^{1/N}\rfloor &2\leq k\leq N+1
\end{array}\right.
additiveのものと比べ、input側ではチャネル数はゆっくり変化するが、outputに近くなるほど急に変化します。
すべてのユニットでチャネル数が増加するので、shortcut connectionではzero-paddingでチャネル数を合わせます。
ここでは、bottleneckで、$N=272(N_n=(272-2)/9=30)$で$\alpha=200$のものを実装します。
inputは1層目で16チャネルに変換され、ユニットを通るごとに$200/90$ずつ増加していきます。
各グループで(66~67)×4チャネルずつ増加し、最終的には(16+200)×4=864チャネルとなります。
さらに、baseのもので$N=110(N_n=(110-2)/6=18)$で$\alpha=270$のものを実装します。
学習
最適化手法はSGD(Nesterov)で学習率は0.1、エポック数が150と225のときに0.1を掛けて小さくしていきます。
weight decayは0.0001、 momentumを0.9に設定し、バッチサイズは128とします。
実装
keras
必要なライブラリのインポートを行います。
import tensorflow.keras as keras
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Conv2D, Activation, MaxPooling2D, Flatten, Dense, Dropout, GlobalAveragePooling2D, BatchNormalization, Add
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import LearningRateScheduler
from keras.datasets import cifar10
import numpy as np
import cv2
bottleneckブロックの実装をします。
class res_bottleneck_block(Model):
"""wide residual block"""
def __init__(self, block_num, layer_num, Nn, N, alpha):
super(res_bottleneck_block, self).__init__(name='block'+'_'+str(block_num)+'_'+str(layer_num))
block_name = '_'+str(block_num)+'_'+str(layer_num)
# shortcutとstrideの設定
self._is_change=False
if (layer_num == 0):
# 最初のresblockは(W、 H)は変更しないのでstrideは1にする
if (block_num==1):
stride = 1
else:
stride = 2
self._is_change=True
else:
stride = 1
self.maxpooling=MaxPooling2D(pool_size=(stride, stride), strides=(stride, stride), padding='valid')
bneck_channels = np.ceil(16+(Nn*(block_num-1)+layer_num+1)*alpha/N)
out_channels = bneck_channels*4
self.bn0 = BatchNormalization(name='bn0'+block_name)
# 1層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をbneck_channelsにします
self.conv1 = Conv2D(bneck_channels, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv1'+block_name)
self.bn1 = BatchNormalization(name='bn1'+block_name)
self.act1 = Activation('relu', name='act1'+block_name)
# 2層目 3×3 畳み込み処理を行います
self.conv2 = Conv2D(bneck_channels, kernel_size=3, strides=stride, padding='same', use_bias=False, name='conv2'+block_name)
self.bn2 = BatchNormalization(name='bn2'+block_name)
self.act2 = Activation('relu', name='act2'+block_name)
# 3層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をout_channelsにします
self.conv3 = Conv2D(out_channels, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv3'+block_name)
self.bn3 = BatchNormalization(name='bn3'+block_name)
self.add = Add(name='add'+block_name)
def call(self, x):
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.act2(out)
out = self.conv3(out)
out = self.bn3(out)
if K.int_shape(x) != K.int_shape(out):
if self._is_change:
x = self.maxpooling(x)
if K.int_shape(x)[3] != K.int_shape(out)[3]:
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, K.int_shape(out)[3] - K.int_shape(x)[3]]])
shortcut = x
else:
shortcut = x
out = self.add([out, shortcut])
return out
base blockの実装もします。
class res_base_block(Model):
"""wide residual block"""
def __init__(self, block_num, layer_num, Nn, N, alpha):
super(res_base_block, self).__init__(name='block'+'_'+str(block_num)+'_'+str(layer_num))
block_name = '_'+str(block_num)+'_'+str(layer_num)
# shortcutとstrideの設定
self._is_change=False
if (layer_num == 0):
# 最初のresblockは(W、 H)は変更しないのでstrideは1にする
if (block_num==1):
stride = 1
else:
stride = 2
self._is_change=True
else:
stride = 1
self.maxpooling=MaxPooling2D(pool_size=(stride, stride), strides=(stride, stride), padding='valid')
out_channels = np.ceil(16+(Nn*(block_num-1)+layer_num+1)*alpha/N)
self.bn0 = BatchNormalization(name='bn0'+block_name)
# 1層目 3×3 畳み込み処理は行わず(線形変換)、チャネル数をbneck_channelsにします
self.conv1 = Conv2D(out_channels, kernel_size=3, strides=stride, padding='same', use_bias=False, name='conv1'+block_name)
self.bn1 = BatchNormalization(name='bn1'+block_name)
self.act1 = Activation('relu', name='act1'+block_name)
# 2層目 3×3 畳み込み処理を行います
self.conv2 = Conv2D(out_channels, kernel_size=3, strides=1, padding='same', use_bias=False, name='conv2'+block_name)
self.bn2 = BatchNormalization(name='bn2'+block_name)
self.add = Add(name='add'+block_name)
def call(self, x):
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
if K.int_shape(x) != K.int_shape(out):
if self._is_change:
x = self.maxpooling(x)
if K.int_shape(x)[3] != K.int_shape(out)[3]:
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, K.int_shape(out)[3] - K.int_shape(x)[3]]])
shortcut = x
else:
shortcut = x
out = self.add([out, shortcut])
return out
PyramidNetの実装をします。
引数でbottlenackとbaseを指定できるようにします。
ブロックの数の計算が変わってきます。
class PyramidNet(Model):
def __init__(self,alpha = 200, N = 272, block=res_base_block):
super(PyramidNet, self).__init__()
self._layers = []
if block == res_bottleneck_block:
Nn = (N-2)//9
N_all = N//3
else:
Nn = (N-2)//6
N_all = N//2
# 入力層
self._layers += [
Conv2D(filters = 16, kernel_size = (3,3), strides = 1, padding = 'same', name='conv_input'),
BatchNormalization(name='bn_input'),
Activation('relu', name='act_input')
]
# Residualブロック
self._layers += [block(block_num=1, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(Nn)]
self._layers += [block(block_num=2, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(Nn)]
self._layers += [block(block_num=3, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(N_all-2*Nn)]
# 出力層
self._layers += [
GlobalAveragePooling2D(name='pool_output'),
Dense(10, activation='softmax', name='output')
]
def call(self, x):
for layer in self._layers:
#print(layer)
x = layer(x)
return x
モデルの確認をします。
model = PyramidNet(alpha = 270, N = 110, block=res_base_block)
#model = PyramidNet(block=res_bottleneck_block)
model.build((None, 32, 32, 3)) # build with input shape.
dummy_input = Input(shape=(32, 32, 3)) # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=model.call(dummy_input), name="pretrained")
model_summary.summary()
ますは、bottlenackバージョンです。
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_9 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv_input (Conv2D) (None, 32, 32, 16) 448
_________________________________________________________________
bn_input (BatchNormalization (None, 32, 32, 16) 64
_________________________________________________________________
act_input (Activation) (None, 32, 32, 16) 0
_________________________________________________________________
block_1_0 (res_bottleneck_bl (None, 32, 32, 76) 5517
_________________________________________________________________
block_1_1 (res_bottleneck_bl (None, 32, 32, 84) 8137
_________________________________________________________________
中略
_________________________________________________________________
block_3_27 (res_bottleneck_b (None, 8, 8, 848) 770800
_________________________________________________________________
block_3_28 (res_bottleneck_b (None, 8, 8, 856) 785348
_________________________________________________________________
block_3_29 (res_bottleneck_b (None, 8, 8, 864) 800032
_________________________________________________________________
pool_output (GlobalAveragePo (None, 864) 0
_________________________________________________________________
output (Dense) (None, 10) 8650
=================================================================
Total params: 26,574,858
Trainable params: 26,364,922
Non-trainable params: 209,936
_________________________________________________________________
次に、baseバージョンです。
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_8 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv_input (Conv2D) (None, 32, 32, 16) 448
_________________________________________________________________
bn_input (BatchNormalization (None, 32, 32, 16) 64
_________________________________________________________________
act_input (Activation) (None, 32, 32, 16) 0
_________________________________________________________________
block_1_0 (res_base_block) (None, 32, 32, 21) 7225
_________________________________________________________________
block_1_1 (res_base_block) (None, 32, 32, 26) 11290
_________________________________________________________________
中略
_________________________________________________________________
block_3_16 (res_base_block) (None, 8, 8, 277) 1371961
_________________________________________________________________
block_3_17 (res_base_block) (None, 8, 8, 282) 1422106
_________________________________________________________________
block_3_18 (res_base_block) (None, 8, 8, 286) 1465448
_________________________________________________________________
pool_output (GlobalAveragePo (None, 286) 0
_________________________________________________________________
output (Dense) (None, 10) 2870
=================================================================
Total params: 29,198,857
Trainable params: 29,148,575
Non-trainable params: 50,282
_________________________________________________________________
どちらも論文のもの(PyramidNet($\alpha=270$ ),PyramidNet(bottleneck $\alpha=200$ ))と比べパラメータ数が大きく差はありません。
学習用の設定までを実装して終わります。
# 学習率を返す関数を用意する
def lr_schedul(epoch):
x = 0.1
if epoch >= 150:
x = 0.1*0.1
if epoch >= 225:
x = 0.1*(0.1**2)
return x
lr_decay = LearningRateScheduler(
lr_schedul,
verbose=1,
)
sgd = SGD(lr=0.1, momentum=0.9, decay=1e-4, nesterov=True)
model.compile(loss=['categorical_crossentropy'], optimizer=sgd, metrics=['accuracy'])
pytorch
次はpytorchです。
ほとんどkerasと変わりません。
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
import pytorch_lightning as pl
from torchmetrics import Accuracy as accuracy
class res_bottleneck_block(nn.Module):
"""residual block"""
def __init__(self, block_num, layer_num, Nn, N, alpha):
super(res_bottleneck_block, self).__init__()
# 1番目のブロック以外はチャンネル数がinputとoutputで変わる(output=4×input)
input_channels = int(np.floor(16+(Nn*(block_num-1)+layer_num)*alpha/N)*4)
bneck_channels = int(np.floor(16+(Nn*(block_num-1)+layer_num+1)*alpha/N))
out_channels = int(bneck_channels*4)
# shortcutとstrideの設定
self._is_change = False
if (layer_num == 0):
input_channels=int(np.floor(16+(Nn*(block_num-2)+Nn)*alpha/N)*4)
# 最初のresblockは(W、 H)は変更しないのでstrideは1にする
if (block_num==1):
stride = 1
input_channels=16
else:
self._is_change = True
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
stride = 2
else:
stride = 1
self.pad_channels = out_channels - input_channels
self.bn0 = nn.BatchNorm2d(input_channels)
# 1層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をbneck_channelsにします
self.conv1 = nn.Conv2d(input_channels, bneck_channels, kernel_size=1)
self.bn1 = nn.BatchNorm2d(bneck_channels)
# 2層目 3×3 畳み込み処理を行います
self.conv2 = nn.Conv2d(bneck_channels, bneck_channels, kernel_size=3, stride=stride, padding=1)
self.bn2 = nn.BatchNorm2d(bneck_channels)
# 3層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をout_channelsにします
self.conv3 = nn.Conv2d(bneck_channels, out_channels, kernel_size=1)
self.bn3 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
shortcut = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
# Projection shortcutの場合
if self._is_change:
shortcut = self.maxpool(shortcut)
p3d = (0, 0, 0, 0, 0, self.pad_channels) # pad by (0, 1), (2, 1), and (3, 3)
shortcut = torch.nn.functional.pad(shortcut, p3d, "constant", 0)
out += shortcut
out = self.relu(out)
return out
class res_base_block(nn.Module):
"""residual block"""
def __init__(self, block_num, layer_num, Nn, N, alpha):
super(res_base_block, self).__init__()
# 1番目のブロック以外はチャンネル数がinputとoutputで変わる(output=4×input)
input_channels = int(np.ceil(16+(Nn*(block_num-1)+layer_num)*alpha/N))
out_channels = int(np.ceil(16+(Nn*(block_num-1)+layer_num+1)*alpha/N))
# shortcutとstrideの設定
self._is_change = False
if (layer_num == 0):
input_channels=int(np.ceil(16+(Nn*(block_num-2)+Nn)*alpha/N))
# 最初のresblockは(W、 H)は変更しないのでstrideは1にする
if (block_num==1):
stride = 1
input_channels=16
else:
self._is_change = True
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
stride = 2
else:
stride = 1
self.pad_channels = out_channels - input_channels
print(block_num,input_channels,out_channels)
self.bn0 = nn.BatchNorm2d(input_channels)
# 1層目 1×1 畳み込み処理は行わず(線形変換)、チャネル数をbneck_channelsにします
self.conv1 = nn.Conv2d(input_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
# 2層目 3×3 畳み込み処理を行います
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
shortcut = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# Projection shortcutの場合
if self._is_change:
shortcut = self.maxpool(shortcut)
p3d = (0, 0, 0, 0, 0, self.pad_channels) # pad by (0, 1), (2, 1), and (3, 3)
shortcut = torch.nn.functional.pad(shortcut, p3d, "constant", 0)
out += shortcut
out = self.relu(out)
return out
class PyramidNet(nn.Module):
def __init__(self,alpha = 200, N = 272, num_classes=10, block=res_base_block):
super(PyramidNet, self).__init__()
if block == res_bottleneck_block:
Nn = (N-2)//9
N_all = N//3
in_features = int((16 + alpha)*4)
else:
Nn = (N-2)//6
N_all = N//2
in_features = 16 + alpha
conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
bn1 = nn.BatchNorm2d(16)
relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(*[conv1, bn1, relu1])
self.conv2_x = nn.Sequential(*[block(block_num=1, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(Nn)])
self.conv3_x = nn.Sequential(*[block(block_num=2, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(Nn)])
self.conv4_x = nn.Sequential(*[block(block_num=3, layer_num=k, Nn=Nn, N=N_all, alpha=alpha) for k in range(N_all-2*Nn)])
pool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Sequential(*[pool])
#int(np.floor(16+(Nn*(block_num-1)+layer_num)*alpha/N))
self.linear = nn.Linear(in_features=in_features, out_features=num_classes)
def forward(self, x):
out = self.conv1(x)
out = self.conv2_x(out)
out = self.conv3_x(out)
out = self.conv4_x(out)
out = self.fc(out)
out = out.view(out.shape[0], -1)
out = self.linear(out)
return out
まず、baseモデルの詳細を見てみます。
summary(PyramidNet(alpha = 270, N = 110, block=res_base_block), (3,32,32))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 16, 32, 32] 448
BatchNorm2d-2 [-1, 16, 32, 32] 32
ReLU-3 [-1, 16, 32, 32] 0
中略
res_base_block-445 [-1, 286, 8, 8] 0
AdaptiveAvgPool2d-446 [-1, 286, 1, 1] 0
Linear-447 [-1, 10] 2,870
================================================================
Total params: 29,165,505
Trainable params: 29,165,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 132.18
Params size (MB): 111.26
Estimated Total Size (MB): 243.45
----------------------------------------------------------------
次に、bottleneckモデルの詳細を見てみます。
summary(PyramidNet(block=res_bottleneck_block), (3,32,32))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 16, 32, 32] 448
BatchNorm2d-2 [-1, 16, 32, 32] 32
ReLU-3 [-1, 16, 32, 32] 0
中略
Conv2d-992 [-1, 864, 8, 8] 187,488
BatchNorm2d-993 [-1, 864, 8, 8] 1,728
ReLU-994 [-1, 864, 8, 8] 0
res_bottleneck_block-995 [-1, 864, 8, 8] 0
AdaptiveAvgPool2d-996 [-1, 864, 1, 1] 0
Linear-997 [-1, 10] 8,650
================================================================
Total params: 26,110,850
Trainable params: 26,110,850
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 555.40
Params size (MB): 99.60
Estimated Total Size (MB): 655.02
----------------------------------------------------------------
学習用のclassを定義して終わります。
class PRNTrainer(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = PyramidNet(block=res_bottleneck_block)
def forward(self, x):
x = self.model(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
#x, y = x.to(device), y.to(device)
y_hat = self.forward(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
return {'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
def validation_step(self, batch, batch_idx):
x, y = batch
#x, y = x.to(device), y.to(device)
y_hat = self.forward(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
return {'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
def test_step(self, batch, batch_nb):
x, y = batch
#x, y = x.to(device), y.to(device)
y_hat = self.forward(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
y_label = torch.argmax(y_hat, dim=1)
acc = accuracy()(y_label, y)
return {'test_loss': loss, 'test_acc': acc}
def training_epoch_end(self, train_step_output):
y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0)
y = torch.cat([val['y'] for val in train_step_outputs], dim=0)
epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0)
preds = torch.argmax(y_hat, dim=1)
acc = accuracy()(preds, y)
self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
self.log('train_acc', acc, prog_bar=True, on_epoch=True)
print('---------- Current Epoch {} ----------'.format(self.current_epoch + 1))
print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loass, acc))
def validation_epoch_end(self, val_step_outputs):
y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)
preds = torch.argmax(y_hat, dim=1)
acc = accuracy()(preds, y)
self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
self.log('val_acc', acc, prog_bar=True, on_epoch=True)
print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
def test_epoch_end(self, test_step_outputs):
y_hat = torch.cat([val['y_hat'] for val in test_step_outputs], dim=0)
y = torch.cat([val['y'] for val in test_step_outputs], dim=0)
epoch_loss = sum([val['batch_loss'] for val in test_step_outputs]) / y_hat.size(0)
preds = torch.argmax(y_hat, dim=1)
acc = accuracy()(preds, y)
self.log('test_loss', epoch_loss, prog_bar=True, on_epoch=True)
self.log('test_acc', acc, prog_bar=True, on_epoch=True)
print('test Loss: {:.4f} test Acc: {:.4f}'.format(epoch_loss, acc))
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 225], gamma=0.1)
return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}
不正確・わかりずらい部分が多々あり申し訳ありません。
自分の勉強のための活動ですのでご了承ください。