Tensorflow2 のチュートリアルの中にあるカスタムレイヤー・モデルの説明にResNetの残差ブロックのモデルを作る例が載っているので,これを参考に, ResNetの改良版(v2) の残差ブロックをkerasのsubclassing API を使って定義してみました.
追記 上記チュートリアルでは残差ブロックはtf.keras.Model
のサブクラスとしてモデルとして定義されていましたが,他の公式ガイドを読むと,それ単体でfit()
などを使わない場合は,tf.keras.layers.Layer
のサブクラスとして作っておけばよいようなので,カスタムモデルとしてではなく,カスタムレイヤーとして残差ブロックを作ります.
ResNet-v2の原論文はこちらです.
[He et al., Identity Mappings in Deep Residual Networks] (https://arxiv.org/pdf/1603.05027.pdf)
v2では,
- ショートカット経路は完全に恒等写像にする(入出力の間でreluを使わない)
- 残差計算のために分岐した後はBN-活性化-convolution-BN-活性化-convolutionの順になる
という改良がされ,特に深い層を重ねた場合に,性能が向上したことが報告されています.
ResNet-v2 残差ブロックのTensorflow2での実装 (Kerasカスタムモデル)
TensorFlowチュートリアル例は, オリジナル(v1)モデルのボトルネック・アーキテクチャに相当します.
v2用にこれを書き直してみます.
class ResnetBlockV2(tf.keras.layers.Layer):
def __init__(self, kernel_size=(3,3), filter_size=16,stride=1, dif_fsize=False):
''' args:
kernel_size: kernel_size. default is (3,3)
filter_size: numbers of output filters.
stride: scalar. If this is not 1, skip connection are replaced to 1x1 convolution.
dif_fsize: True if the numbers of input filter and output filter are different
strideを1以上にする場合は特にこれを指定する必要はない
'''
super(ResnetBlockV2, self).__init__(name='')
if stride==1:
strides=(1,1)
else:
strides = (stride,stride)
self.bn2a = BatchNormalization()
self.conv2a =Conv2D(filter_size, kernel_size, strides=strides, padding='same')
self.bn2b = BatchNormalization()
self.conv2b = Conv2D(filter_size, kernel_size, strides=(1,1), padding='same')
# stride が 1 でない(ダウンサンプリング)か,入出力フィルタ数を変化させる場合(通常この2条件は同時に実施する),
# skip結合を恒等写像から1x1 convolution に切り替える.
self.use_identity_shortcut = (stride==1) and not dif_fsize
if not self.use_identity_shortcut:
self.conv2_sc = tf.keras.layers.Conv2D(filter_size, (1,1), strides=strides, padding='same')
def call(self, input_tensor, training=False):
x = self.bn2a(input_tensor, training=training)
x1 = tf.nn.relu(x) # shortcutがidentityではない場合ここから分岐させる
x = self.conv2a(x1) # こちらは残差ブロック側
x = self.bn2b(x, training=training)
x = tf.nn.relu(x)
x = self.conv2b(x)
if self.use_identity_shortcut:
skip = input_tensor
else:
skip = self.conv2_sc(x1)
x += skip
return x
チュートリアルのコードは,入出力の特徴マップのサイズが同じ場合のみに対応していたので,マップサイズをダウンサンプリングしてフィルタ数を変える場合にも対応できる様,論文を再現するため細かい改変が必要でした.
スキップ結合は,通常恒等写像を使いますが,入出力間でフィルター数または画像サイズが変わるとき(strideが2以上の時)は, 1x1
の畳み込みを行い,strideを変えてダウンサンプリングするようにcall
のところで条件分岐させています.
論文からリンクされているの著者のコードgithubをみながら作りましたが,この場合,スキップ結合もBN-活性化-conv という処理が行われますが,BN-活性化の部分は残差を計算するほうの計算と全く同じなので,入力に対しBN-活性化を一度行なった後,処理を枝分かれさせるという形をとります.
実行例
MNIST用に適当に作って動くことを確認しています.
論文のappendixの注意書きに従い,最初のconvの後と最後の残差ブロックの直後にも活性化を入れておきます.
mnist = tf.keras.datasets.mnist
(x_train, t_train), (x_test, t_test) = mnist.load_data()
x_train = (x_train/255.0).astype(np.float32)
x_test = (x_test/255.0).astype(np.float32)
x_train = x_train.reshape([-1,28,28,1])
x_test = x_test.reshape([-1,28,28,1])
# %%
model = tf.keras.Sequential()
model.add( Conv2D(16,kernel_size=(3,3),strides=(1,1),
padding='same',activation='relu')) # 最初のconvの直後にactivationをいれる.
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=16))
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=16))
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=16))
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=32,stride=2))#14x14
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=32))
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=32))
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=64, stride=2))#7x7
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=64))
model.add( ResnetBlockV2(kernel_size=(3,3),filter_size=64))
model.add(ReLU()) # residual blockの最後にactivationを入れる.(v2の場合)
model.add( GlobalAveragePooling2D() )
model.add(Dense(10))
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam',
loss=loss_fn,
metrics=['accuracy'])
# %%
model.fit(x_train, t_train, epochs=5)
# %%
model.evaluate(x_test, t_test, verbose=2)
v2 ボトルネックアーキテクチャモデル
2層のconvを使う代わりに,(1x1),(3,3),(1x1)の畳み込みフィルタを使い,真ん中でフィルタ数を1/4に減らし最後に元のサイズに戻す,ボトルネックアーキテクチャの実装です.
class ResnetBottleneckBlockV2(tf.keras.layers.Layer):
''' Residual unit of resnet v2 model proposed by
He et al., Identity Mappings in Deep Residual Networks
https://arxiv.org/abs/1603.05027
characterized by full pre-activation and purely identity maping in shortcut connections
'''
def __init__(self, kernel_size=(3,3), filter_size=64, stride=1, w_l2reg=0, dif_fsize=False):
''' args:
kernel_size: kernel_size. default is (3,3)
filt: numbers of output filters.
stride: scalar. If this is not 1, skip connection are replaced to 1x1 convolution.
dif_fsize: True if the numbers of input filter and output filter are different
strideを1以上にする場合は特にこれを指定する必要はない
'''
super(ResnetBottleneckBlockV2, self).__init__(name='')
filter_size_bottle = filter_size//4 # ボトルネックのサイズは1/4
if stride==1:
strides=(1,1)
else:
strides = (stride,stride)
# 最初のフィルタはフィルタ数を落とす.1x1畳み込み
self.bn2a = BatchNormalization()
self.conv2a =Conv2D(filter_size_bottle, (1,1), strides=strides,
kernel_initializer='he_normal', padding='same', kernel_regularizer=l2(w_l2reg))
# 次のフィルタはフィルタ数を落とす.3x3などの畳み込み
self.bn2b = BatchNormalization()
self.conv2b = Conv2D(filter_size_bottle, kernel_size, strides=(1,1),
kernel_initializer='he_normal', padding='same', kernel_regularizer=l2(w_l2reg))
# 3番めの畳み込みはフィルタ数を出力サイズに戻す1x1.
self.bn2c = BatchNormalization()
self.conv2c = Conv2D(filter_size, (1,1), strides=(1,1),
kernel_initializer='he_normal', padding='same', kernel_regularizer=l2(w_l2reg))
# stride が 1 でない(ダウンサンプリング)か,入出力フィルタ数を変化させる場合(通常この2条件は同時に実施する),
# skip結合を恒等写像から1x1 convolution に切り替える.
self.use_identity_shortcut = (stride==1) and not dif_fsize
if not self.use_identity_shortcut:
self.conv2_sc = Conv2D(filter_size, (1,1), strides=strides,
kernel_initializer='he_normal', padding='same', kernel_regularizer=l2(w_l2reg))
def call(self, input_tensor, training=False):
x = self.bn2a(input_tensor, training=training)
x1 = tf.nn.relu(x) # shortcutがidentityではない場合ここから分岐させる
x = self.conv2a(x1) # こちらは残差ブロック側
x = self.bn2b(x, training=training)
x = tf.nn.relu(x)
x = self.conv2b(x)
x = self.bn2c(x, training=training)
x = tf.nn.relu(x)
x = self.conv2c(x)
if self.use_identity_shortcut:
skip = input_tensor
else:
skip = self.conv2_sc(x1)
x += skip
return x