概要
すごく良いモデルらしいGoogleのNASNetは既にKerasにも実装されています。
にもかかわらずXceptionのように使われている例をあまり見かけません。
(自分が知らないだけかもしれませんが)
その理由は思うにKerasで用意されているモデルの大きさが両極端だからじゃないでしょうか。
NASNetLarge(大きいサイズ)とNASNetMobile(小さいサイズ)がありますが、中間のサイズ(Xceptionぐらいのサイズ)がありません。
ちなみにKerasには用意されてないですがNASNetにもXceptionぐらいのサイズのモデルも当然あります。
モデルサイズと精度
そこで先人の方が書いたgithubにあるNASNetのKeras実装コードを読み、NASNetを自分なりに写経してみました。実装例としては参考①と参考②(主に①)を参考に写経致しました。
論文の内容を読んだわけではないので目的やら経緯に関する理解は浅いです。
レイヤーの名前とかCIFAR-10での対応とか削除しても動作に支障ない部分はできる限り省略しました。
NASNetの基礎構造
NASNetにはNormalCellとReductionCellという構造があり、これは通常のCNNモデルでいうところのConv2d層と(2x2)pooling層に相当します(詳しくは後述します)。
ReductionCellを一度掛けるごとに画像サイズは1/2になり、同時にfilter数を2倍にしてよいです。
ImageNetのような入力サイズが大きめのモデルを作成する場合、NormalCellを掛ける前にReductionCellによって画像サイズを小さくしておきます。
ImageNetを入力とするNASNetLargeのモデルではInput=>Conv2d (3,3), stride=2=>ReductionCell×2=>NormalCell×6=>ReductionCell×1=>NormalCell×6=>ReductionCell×1=>NormalCell×6=>GlobalAveragePooling2D=>Denseという順に掛けていきます。
CIFAR-10のように入力サイズが小さい場合のモデルはサイズを縮小するConv2d (3,3), stride=2=>ReductionCell×2の部分を省略する必要があります。
inputs = Input(input_shape)
# Conv2d (3,3), stride=2
x = Stem(inputs, stem_filters)
# ReductionCell 1time
cur, prev = x, None
prev, cur = cur, prev
cur = ReductionCell(cur, prev, filters//2)
# (ReductionCell 1time+NormalCell 6time)*3
for j in range(3):
# ReductionCell 1time
filters *= 2
prev, cur = cur, prev
if j == 0:
cur = ReductionCell(cur, prev, filters//2)
else:
cur = ReductionCell(cur, prev, filters)
# NormalCell 6time
for i in range(num_cell_repeats):
prev, cur = cur, prev
cur = NormalCell(cur, prev, filters)
x = Activation('relu')(cur)
x = GlobalAveragePooling2D()(x)
outputs = Dense(num_classes, activation='softmax')(x)
NormalCellとReductionCellの構造
さてNormalCellとReductionCellの実態は畳み込みの大きさを変えたいろいろな畳み込みやpoolingを掛けたのを足し合わせようというアイデアですが、現在の層(cur)だけではなく、もう一つ前の層(prev)から畳み込みやpoolingを掛けた結果も足し合わせます。
NormalCellではSeparableConv2DやAveragePooling2Dを使って足し合わせConcatenate()でチャンネル方向に結合します。ここでConcatenate()で結合後のチャンネル数はfilters6になっています。
また、この時strides=1のAveragePooling2Dは画像サイズを小さくするわけではなく、単なる平滑化フィルタの畳み込み演算として処理されます(よね?)。
ReductionCellではSeparableConv2D、AveragePooling2DやMaxPooling2Dのstrides=2の設定を用いて画像サイズを1/2にして足し合わせConcatenate()でチャンネル方向に結合します。ここでConcatenate()で結合後のチャンネル数はfilters4になっています。
下の図だと結合数が5や3に見えるかもしれませんが現在の層(cur)以外のidentityの結合は別途Concatenateに飛ぶようです。従って実際の結合数は6や4になります。
また、Separable関数において実はSeparableConv2Dは二回繰り返して行われます。
def NormalCell(prev, cur, filters):
cur = SqueezeChannels(cur, filters)
prev = Fit(prev, filters, cur)
add_0 = Add()([Separable(filters, 5, strides=1)(cur),
Separable(filters, 3, strides=1)(prev)])
add_1 = Add()([Separable(filters, 5, strides=1)(prev),
Separable(filters, 3, strides=1)(prev)])
add_2 = Add()([AveragePooling2D(pool_size=3, strides=1, padding='same')(cur), prev])
add_3 = Add()([AveragePooling2D(pool_size=3, strides=1, padding='same')(prev),
AveragePooling2D(pool_size=3, strides=1, padding='same')(prev)])
add_4 = Add()([Separable(filters, 3, strides=1)(cur), cur])
return Concatenate()([prev, add_0, add_1, add_2, add_3, add_4])
def ReductionCell(prev, cur, filters):
prev = Fit(prev, filters, cur)
cur = SqueezeChannels(cur, filters)
add_0 = Add()([Separable(filters, 5, strides=2)(cur),
Separable(filters, 7, strides=2)(prev)])
add_1 = Add()([MaxPooling2D(3, strides=2, padding='same')(cur),
Separable(filters, 7, strides=2)(prev)])
add_2 = Add()([AveragePooling2D(3, strides=2, padding='same')(cur),
Separable(filters, 5, strides=2)(prev)])
add_3 = Add()([AveragePooling2D(3, strides=1, padding='same')(add_0), add_1])
add_4 = Add()([Separable(filters, 3, strides=1)(add_0),
MaxPooling2D(3, strides=2, padding='same')(cur)])
return Concatenate()([add_1, add_2, add_3, add_4])
(※NormalCellのadd_3の足し合わせ、同じ(3x3)のAveragePoolingで意味なくないですか?)
SqueezeChannelsとFit
NormalCellとReductionCellのコード中にあるSqueezeChannelsとFitに関して簡単に触れておきます。
SqueezeChannelsは要するに(1x1)畳み込みでNormalCellのConcatenate()の結合によってfilters*6の大きさになったチャンネル数を元のfiltersのチャンネル数まで圧縮する処理になります。
Fitは一つ前の層(prev)に対して処理され、現在の層(cur)と一つ前の層(prev)がReductionCellのせいで縦横のサイズが異なる場合、一つ前の層(prev)をPoolingを使って縦横のサイズを1/2にして現在の層(cur)と同じ大きさにしてから(1x1)畳み込みを掛けチャンネル数を圧縮します。なお、処理中にほかにZeroPadding2DやCropping2Dが使われますが(画像を左上に1ずらす?)何のメリットがあるかは分かりません。一つ前の層(prev)のチャンネル数が半分だから(1x1)畳み込みに計算余力が生じているのではと想像します。
一つ前の層(prev)と現在の層(cur)と同じ大きさの場合は一つ前の層(prev)にもSqueezeChannelsが掛けられます。
model.summary()の確認
model.summary()でモデルのパラメータ数を確認してみるとNASNetLargeとNASNetMobileのパラメータ数に完全に一致しています。モデルの畳み込みは取り合えず写経した内容で間違いはないのかなと思います。
# stem_filters=96, num_cell_repeats=6, penultimate_filters=4032 の場合
==============================================================================================
Total params: 88,949,818
Trainable params: 88,753,150
Non-trainable params: 196,668
______________________________________________________________________________________________
...
# stem_filters=32, num_cell_repeats=4, penultimate_filters=1056 の場合
==============================================================================================
Total params: 5,326,716
Trainable params: 5,289,978
Non-trainable params: 36,738
______________________________________________________________________________________________
XceptionぐらいのサイズのNASNet
論文見るとNASNet-A(7 @ 1920)のモデルがXceptionと同じくらいのパラメータ数です。
NormalCellの繰り返し数:num_cell_repeats=7、 最終的な出力チャンネル数:penultimate_filters=1920でNASNetのモデルを作るとよいのかなと思います。試してませんが。
全コード
# original:https://github.com/johannesu/NASNet-keras/blob/master/nasnet.py
from keras.layers.convolutional import Conv2D, MaxPooling2D, AveragePooling2D, SeparableConv2D, ZeroPadding2D, Cropping2D
from keras.layers import Input, Concatenate, Add, BatchNormalization, Activation, GlobalAveragePooling2D, Dense
from keras.models import Model
class Separable:
def __init__(self, filters, kernel_size, strides=1):
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
def __call__(self, x):
x = Activation('relu')(x)
x = SeparableConv2D(self.filters,
kernel_size=self.kernel_size,
kernel_initializer='he_normal',
strides=self.strides,
padding='same',
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(self.filters,
kernel_size=self.kernel_size,
kernel_initializer='he_normal',
strides=1,
padding='same',
use_bias=False)(x)
x = BatchNormalization()(x)
return x
def SqueezeChannels(x, filters):
x = Activation('relu')(x)
x = Conv2D(filters, 1, kernel_initializer='he_normal', use_bias=False)(x)
x = BatchNormalization()(x)
return x
def Fit(x, filters, target):
if x is None:
return target
if int(x.shape[2]) != int(target.shape[2]):
x = Activation('relu')(x)
p1 = AveragePooling2D(pool_size=1, strides=2, padding='valid')(x)
p1 = Conv2D(filters // 2, 1, kernel_initializer='he_normal', padding='same', use_bias=False)(p1)
p2 = ZeroPadding2D(padding=((0, 1), (0, 1)))(x)
p2 = Cropping2D(cropping=((1, 0), (1, 0)))(p2)
p2 = AveragePooling2D(pool_size=1, strides=2, padding='valid')(p2)
p2 = Conv2D(filters // 2, 1, kernel_initializer='he_normal', padding='same', use_bias=False)(p2)
x = Concatenate()([p1, p2])
x = BatchNormalization()(x)
return x
else:
return SqueezeChannels(x, filters)
def NormalCell(prev, cur, filters):
cur = SqueezeChannels(cur, filters)
prev = Fit(prev, filters, cur)
add_0 = Add()([Separable(filters, 5, strides=1)(cur),
Separable(filters, 3, strides=1)(prev)])
add_1 = Add()([Separable(filters, 5, strides=1)(prev),
Separable(filters, 3, strides=1)(prev)])
add_2 = Add()([AveragePooling2D(pool_size=3, strides=1, padding='same')(cur), prev])
add_3 = Add()([AveragePooling2D(pool_size=3, strides=1, padding='same')(prev),
AveragePooling2D(pool_size=3, strides=1, padding='same')(prev)])
add_4 = Add()([Separable(filters, 3, strides=1)(cur), cur])
return Concatenate()([prev, add_0, add_1, add_2, add_3, add_4])
def ReductionCell(prev, cur, filters):
prev = Fit(prev, filters, cur)
cur = SqueezeChannels(cur, filters)
add_0 = Add()([Separable(filters, 5, strides=2)(cur),
Separable(filters, 7, strides=2)(prev)])
add_1 = Add()([MaxPooling2D(3, strides=2, padding='same')(cur),
Separable(filters, 7, strides=2)(prev)])
add_2 = Add()([AveragePooling2D(3, strides=2, padding='same')(cur),
Separable(filters, 5, strides=2)(prev)])
add_3 = Add()([AveragePooling2D(3, strides=1, padding='same')(add_0), add_1])
add_4 = Add()([Separable(filters, 3, strides=1)(add_0),
MaxPooling2D(3, strides=2, padding='same')(cur)])
return Concatenate()([add_1, add_2, add_3, add_4])
def Stem(x, stem_filters):
x = Conv2D(stem_filters, 3, strides=2,
kernel_initializer='he_normal', padding='valid', use_bias=False)(x)
x = BatchNormalization()(x)
return x
def NASNetA(input_shape=None,
stem_filters=96,
num_cell_repeats=6,
penultimate_filters=4032,
num_classes=1000,
num_reduction_cells=3):
filters = int(penultimate_filters / ((2 ** num_reduction_cells) * 6))
inputs = Input(input_shape)
# Conv2d (3,3), stride=2
x = Stem(inputs, stem_filters)
# ReductionCell 1time
cur, prev = x, None
prev, cur = cur, prev
cur = ReductionCell(cur, prev, filters//2)
# (ReductionCell 1time+NormalCell 6time)*3
for j in range(3):
# ReductionCell 1time
filters *= 2
prev, cur = cur, prev
if j == 0:
cur = ReductionCell(cur, prev, filters//2)
else:
cur = ReductionCell(cur, prev, filters)
# NormalCell 6time
for i in range(num_cell_repeats):
prev, cur = cur, prev
cur = NormalCell(cur, prev, filters)
x = Activation('relu')(cur)
x = GlobalAveragePooling2D()(x)
outputs = Dense(num_classes, activation='softmax')(x)
return Model(inputs, outputs)
# NASNetLarge
model = NASNetA(input_shape=(331, 331, 3), stem_filters=96, num_cell_repeats=6, penultimate_filters=4032)
model.summary()
# NASNetMobile
# model = NASNetA(input_shape=(331, 331, 3), stem_filters=32, num_cell_repeats=4, penultimate_filters=1056)
# model.summary()