16
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

NASNetのKerasコードを写経する

Posted at

概要

すごく良いモデルらしいGoogleのNASNetは既にKerasにも実装されています。
にもかかわらずXceptionのように使われている例をあまり見かけません。
(自分が知らないだけかもしれませんが)
その理由は思うにKerasで用意されているモデルの大きさが両極端だからじゃないでしょうか。
NASNetLarge(大きいサイズ)とNASNetMobile(小さいサイズ)がありますが、中間のサイズ(Xceptionぐらいのサイズ)がありません。
ちなみにKerasには用意されてないですがNASNetにもXceptionぐらいのサイズのモデルも当然あります。
nasnet0.png
モデルサイズと精度
nasnet1.png

そこで先人の方が書いたgithubにあるNASNetのKeras実装コードを読み、NASNetを自分なりに写経してみました。実装例としては参考①参考②(主に①)を参考に写経致しました。
論文の内容を読んだわけではないので目的やら経緯に関する理解は浅いです。
レイヤーの名前とかCIFAR-10での対応とか削除しても動作に支障ない部分はできる限り省略しました。

NASNetの基礎構造

NASNetにはNormalCellReductionCellという構造があり、これは通常のCNNモデルでいうところのConv2d層と(2x2)pooling層に相当します(詳しくは後述します)。
ReductionCellを一度掛けるごとに画像サイズは1/2になり、同時にfilter数を2倍にしてよいです。
ImageNetのような入力サイズが大きめのモデルを作成する場合、NormalCellを掛ける前にReductionCellによって画像サイズを小さくしておきます。
ImageNetを入力とするNASNetLargeのモデルではInput=>Conv2d (3,3), stride=2=>ReductionCell×2=>NormalCell×6=>ReductionCell×1=>NormalCell×6=>ReductionCell×1=>NormalCell×6=>GlobalAveragePooling2D=>Denseという順に掛けていきます。
CIFAR-10のように入力サイズが小さい場合のモデルはサイズを縮小するConv2d (3,3), stride=2=>ReductionCell×2の部分を省略する必要があります。
nasnet3.png

nasnet.py
    inputs = Input(input_shape)
    # Conv2d (3,3), stride=2
    x = Stem(inputs, stem_filters)

    # ReductionCell 1time
    cur, prev = x, None
    prev, cur = cur, prev
    cur = ReductionCell(cur, prev, filters//2)

    # (ReductionCell 1time+NormalCell 6time)*3
    for j in range(3):
        # ReductionCell 1time
        filters *= 2 
        prev, cur = cur, prev
        if j == 0:
            cur = ReductionCell(cur, prev, filters//2)
        else:
            cur = ReductionCell(cur, prev, filters)

        # NormalCell 6time
        for i in range(num_cell_repeats):
            prev, cur = cur, prev
            cur = NormalCell(cur, prev, filters)

    x = Activation('relu')(cur)
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

NormalCellとReductionCellの構造

さてNormalCellReductionCellの実態は畳み込みの大きさを変えたいろいろな畳み込みやpoolingを掛けたのを足し合わせようというアイデアですが、現在の層(cur)だけではなく、もう一つ前の層(prev)から畳み込みやpoolingを掛けた結果も足し合わせます。
NormalCellではSeparableConv2DやAveragePooling2Dを使って足し合わせConcatenate()でチャンネル方向に結合します。ここでConcatenate()で結合後のチャンネル数はfilters6になっています。
また、この時strides=1のAveragePooling2Dは画像サイズを小さくするわけではなく、単なる平滑化フィルタの畳み込み演算として処理されます(よね?)。
ReductionCellではSeparableConv2D、AveragePooling2DやMaxPooling2Dのstrides=2の設定を用いて画像サイズを1/2にして足し合わせConcatenate()でチャンネル方向に結合します。ここでConcatenate()で結合後のチャンネル数はfilters
4になっています。
下の図だと結合数が5や3に見えるかもしれませんが現在の層(cur)以外のidentityの結合は別途Concatenateに飛ぶようです。従って実際の結合数は6や4になります。
また、Separable関数において実はSeparableConv2Dは二回繰り返して行われます。
nasnet4.png

nasnet.py
def NormalCell(prev, cur, filters):

    cur = SqueezeChannels(cur, filters)
    prev = Fit(prev, filters, cur)

    add_0 = Add()([Separable(filters, 5, strides=1)(cur),
                   Separable(filters, 3, strides=1)(prev)])
    add_1 = Add()([Separable(filters, 5, strides=1)(prev),
                   Separable(filters, 3, strides=1)(prev)])
    add_2 = Add()([AveragePooling2D(pool_size=3, strides=1, padding='same')(cur), prev])
    add_3 = Add()([AveragePooling2D(pool_size=3, strides=1, padding='same')(prev),
                   AveragePooling2D(pool_size=3, strides=1, padding='same')(prev)])
    add_4 = Add()([Separable(filters, 3, strides=1)(cur), cur])

    return Concatenate()([prev, add_0, add_1, add_2, add_3, add_4])

def ReductionCell(prev, cur, filters):

    prev = Fit(prev, filters, cur)
    cur = SqueezeChannels(cur, filters)

    add_0 = Add()([Separable(filters, 5, strides=2)(cur),
                   Separable(filters, 7, strides=2)(prev)])
    add_1 = Add()([MaxPooling2D(3, strides=2, padding='same')(cur),
                   Separable(filters, 7, strides=2)(prev)])
    add_2 = Add()([AveragePooling2D(3, strides=2, padding='same')(cur),
                   Separable(filters, 5, strides=2)(prev)])
    add_3 = Add()([AveragePooling2D(3, strides=1, padding='same')(add_0), add_1])
    add_4 = Add()([Separable(filters, 3, strides=1)(add_0),
                   MaxPooling2D(3, strides=2, padding='same')(cur)])

    return Concatenate()([add_1, add_2, add_3, add_4])

(※NormalCellのadd_3の足し合わせ、同じ(3x3)のAveragePoolingで意味なくないですか?)

SqueezeChannelsとFit

NormalCellReductionCellのコード中にあるSqueezeChannelsとFitに関して簡単に触れておきます。
SqueezeChannelsは要するに(1x1)畳み込みでNormalCellのConcatenate()の結合によってfilters*6の大きさになったチャンネル数を元のfiltersのチャンネル数まで圧縮する処理になります。
Fitは一つ前の層(prev)に対して処理され、現在の層(cur)と一つ前の層(prev)がReductionCellのせいで縦横のサイズが異なる場合、一つ前の層(prev)をPoolingを使って縦横のサイズを1/2にして現在の層(cur)と同じ大きさにしてから(1x1)畳み込みを掛けチャンネル数を圧縮します。なお、処理中にほかにZeroPadding2DやCropping2Dが使われますが(画像を左上に1ずらす?)何のメリットがあるかは分かりません。一つ前の層(prev)のチャンネル数が半分だから(1x1)畳み込みに計算余力が生じているのではと想像します。
一つ前の層(prev)と現在の層(cur)と同じ大きさの場合は一つ前の層(prev)にもSqueezeChannelsが掛けられます。

model.summary()の確認

model.summary()でモデルのパラメータ数を確認してみるとNASNetLargeとNASNetMobileのパラメータ数に完全に一致しています。モデルの畳み込みは取り合えず写経した内容で間違いはないのかなと思います。

# stem_filters=96, num_cell_repeats=6, penultimate_filters=4032 の場合
==============================================================================================
Total params: 88,949,818
Trainable params: 88,753,150
Non-trainable params: 196,668
______________________________________________________________________________________________
...
# stem_filters=32, num_cell_repeats=4, penultimate_filters=1056 の場合
==============================================================================================
Total params: 5,326,716
Trainable params: 5,289,978
Non-trainable params: 36,738
______________________________________________________________________________________________

nasnet5.png

XceptionぐらいのサイズのNASNet

論文見るとNASNet-A(7 @ 1920)のモデルがXceptionと同じくらいのパラメータ数です。
NormalCellの繰り返し数:num_cell_repeats=7、 最終的な出力チャンネル数:penultimate_filters=1920でNASNetのモデルを作るとよいのかなと思います。試してませんが。
nasnet2.png

全コード

nasnet.py
# original:https://github.com/johannesu/NASNet-keras/blob/master/nasnet.py

from keras.layers.convolutional import Conv2D, MaxPooling2D, AveragePooling2D, SeparableConv2D, ZeroPadding2D, Cropping2D
from keras.layers import Input, Concatenate, Add, BatchNormalization, Activation, GlobalAveragePooling2D, Dense
from keras.models import Model

class Separable:
    def __init__(self, filters, kernel_size, strides=1):
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides

    def __call__(self, x):
        x = Activation('relu')(x)
        x = SeparableConv2D(self.filters,
                            kernel_size=self.kernel_size,
                            kernel_initializer='he_normal',
                            strides=self.strides,
                            padding='same',
                            use_bias=False)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = SeparableConv2D(self.filters,
                            kernel_size=self.kernel_size,
                            kernel_initializer='he_normal',
                            strides=1,
                            padding='same',
                            use_bias=False)(x)
        x = BatchNormalization()(x)

        return x

def SqueezeChannels(x, filters):

    x = Activation('relu')(x)
    x = Conv2D(filters, 1, kernel_initializer='he_normal', use_bias=False)(x)
    x = BatchNormalization()(x)

    return x

def Fit(x, filters, target):
    if x is None:
        return target
    if int(x.shape[2]) != int(target.shape[2]):
        x = Activation('relu')(x)

        p1 = AveragePooling2D(pool_size=1, strides=2, padding='valid')(x)
        p1 = Conv2D(filters // 2, 1, kernel_initializer='he_normal', padding='same', use_bias=False)(p1)

        p2 = ZeroPadding2D(padding=((0, 1), (0, 1)))(x)
        p2 = Cropping2D(cropping=((1, 0), (1, 0)))(p2)
        p2 = AveragePooling2D(pool_size=1, strides=2, padding='valid')(p2)

        p2 = Conv2D(filters // 2, 1, kernel_initializer='he_normal', padding='same', use_bias=False)(p2)

        x = Concatenate()([p1, p2])
        x = BatchNormalization()(x)

        return x
    else:
        return SqueezeChannels(x, filters)

def NormalCell(prev, cur, filters):

    cur = SqueezeChannels(cur, filters)
    prev = Fit(prev, filters, cur)

    add_0 = Add()([Separable(filters, 5, strides=1)(cur),
                   Separable(filters, 3, strides=1)(prev)])
    add_1 = Add()([Separable(filters, 5, strides=1)(prev),
                   Separable(filters, 3, strides=1)(prev)])
    add_2 = Add()([AveragePooling2D(pool_size=3, strides=1, padding='same')(cur), prev])
    add_3 = Add()([AveragePooling2D(pool_size=3, strides=1, padding='same')(prev),
                   AveragePooling2D(pool_size=3, strides=1, padding='same')(prev)])
    add_4 = Add()([Separable(filters, 3, strides=1)(cur), cur])

    return Concatenate()([prev, add_0, add_1, add_2, add_3, add_4])

def ReductionCell(prev, cur, filters):

    prev = Fit(prev, filters, cur)
    cur = SqueezeChannels(cur, filters)

    add_0 = Add()([Separable(filters, 5, strides=2)(cur),
                   Separable(filters, 7, strides=2)(prev)])
    add_1 = Add()([MaxPooling2D(3, strides=2, padding='same')(cur),
                   Separable(filters, 7, strides=2)(prev)])
    add_2 = Add()([AveragePooling2D(3, strides=2, padding='same')(cur),
                   Separable(filters, 5, strides=2)(prev)])
    add_3 = Add()([AveragePooling2D(3, strides=1, padding='same')(add_0), add_1])
    add_4 = Add()([Separable(filters, 3, strides=1)(add_0),
                   MaxPooling2D(3, strides=2, padding='same')(cur)])

    return Concatenate()([add_1, add_2, add_3, add_4])

def Stem(x, stem_filters):
    x = Conv2D(stem_filters, 3, strides=2,
               kernel_initializer='he_normal', padding='valid', use_bias=False)(x)
    x = BatchNormalization()(x)
    return x

def NASNetA(input_shape=None, 
            stem_filters=96,
            num_cell_repeats=6,
            penultimate_filters=4032,
            num_classes=1000,
            num_reduction_cells=3):

    filters = int(penultimate_filters / ((2 ** num_reduction_cells) * 6))
    
    inputs = Input(input_shape)
    # Conv2d (3,3), stride=2
    x = Stem(inputs, stem_filters)

    # ReductionCell 1time
    cur, prev = x, None
    prev, cur = cur, prev
    cur = ReductionCell(cur, prev, filters//2)

    # (ReductionCell 1time+NormalCell 6time)*3
    for j in range(3):
        # ReductionCell 1time
        filters *= 2 
        prev, cur = cur, prev
        if j == 0:
            cur = ReductionCell(cur, prev, filters//2)
        else:
            cur = ReductionCell(cur, prev, filters)

        # NormalCell 6time
        for i in range(num_cell_repeats):
            prev, cur = cur, prev
            cur = NormalCell(cur, prev, filters)

    x = Activation('relu')(cur)
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs)
# NASNetLarge
model = NASNetA(input_shape=(331, 331, 3), stem_filters=96, num_cell_repeats=6, penultimate_filters=4032)
model.summary()

# NASNetMobile
# model = NASNetA(input_shape=(331, 331, 3), stem_filters=32, num_cell_repeats=4, penultimate_filters=1056)
# model.summary()
16
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?