2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ReZero-ViTの実装と評価

Last updated at Posted at 2021-07-13

はじめに

TransformerではLayerNormalizationを使用することが基本となっているが、ReZeroというそれを置き換える構成が提案されている。本記事ではReZeroをVisionTransformer(ViT)に応用してその効果を確認する。

ReZeroとは

下記論文で提案されている。(Submitted on 10 Mar 2020)

"X Is All You Need"というタイトルが付く論文は、有名なAttention Is All You Needが発表されて以降は結構あるようで、これもその仲間。
効果としては学習の収束が速くなることが主張されている。

ReZeroと既存の構成との違いを論文からの切り抜きで示す。

スクリーンショット 2021-07-12 14.57.54.png

ReZeroではNormがなく、Trainableな$\alpha_i$が出てくる。
ReZeroは「Residual with Zero initialization」の略でこの$\alpha_i$をZeroから始めるということに由来する。有名な日本のライトノベル/アニメ作品の「Re:ゼロから始める異世界生活」に関係があるかどうかは、論文には書いてなかった。

それはともかく、NFNetResMLPのようにNormalizationをなくす方向性については現在研究が進んでいるようで、ReZeroはそのうちの一つと位置付けられるだろう。

論文中にはTransformerへのReZeroの組み込み方が書いてあり、公式実装もある。
本記事はこれをViTに応用してみたが、記事タイトルの"ReZero-ViT"というのは筆者の命名で、このように呼ばれるViTは筆者が確認した限りでは発表されていないと思う。

論文にはCNNでの応用として、ResNetでBatchNormalizarionをReZeroに置き換えたモデルでのCIFAR-10の実験結果も掲載されており、ReZeroの方が収束がはやく、正解率も良くなったとされている。

実装

TensorFlow(2.5.0)のtf.kerasで実装する。

ReZeroは単独のレイヤーとして実装すると以下のようになる。

class ReZero(tf.keras.layers.Layer):
    def build(self, input_shape):
        self.rezero_alpha = self.add_weight("rezero_a", 
                                initializer=tf.keras.initializers.Zeros(),
                                trainable=Truedtype=tf.float32)
        super().build(input_shape)
    def call(self, inputs):
        return inputs*self.rezero_alpha

TrainableなWeightを含むLayerとしては、もっともシンプルな形では無いかと思われる。self.rezero_alphaという1つのパラメータしかなく、パラメータ数としても1増えるだけで、計算量もほとんど無視できる程度だろう。

Transformerブロックとしては、下記のような実装になる。

class SkipConnection(tf.keras.layers.Layer):
    def __init__(self, drop_rate=0.0,**kwargs):
        super().__init__(**kwargs)
        self.drop_rate = drop_rate

    def build(self, input_shape):
        super().build(input_shape)

    def get_config(self):
        config = {'drop_rate': self.drop_rate}
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def call(self, inputs):
        shortcut, x = inputs
        if self.drop_rate != 0.0:
            x = layers.Dropout(rate=self.drop_rate, noise_shape=(None,1,1))(x)
        x = layers.Add()([shortcut,x])
        return x

class Transformer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, mlp_dim,dropout=0.0,layer_drop_rate=0.0,l2_reg=0.0,use_rezero=False,**kwargs):
        super().__init__(**kwargs)
        self.dropout = dropout
        self.d_model = d_model
        self.mlp_dim = mlp_dim
        self.num_heads = num_heads
        self.layer_drop_rate = layer_drop_rate
        self.l2_reg = l2_reg
        self.use_rezero = use_rezero

    def get_config(self):
        config = {'d_model': self.d_model, 'num_heads': self.num_heads, 'mlp_dim': self.mlp_dim,
                  'l2_reg': self.l2_reg, 'dropout': self.dropout, 'layer_drop_rate': self.layer_drop_rate,
                  'use_rezero': self.use_rezero}
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

    
    def build(self, input_shape):
        dense_kwargs = {
            'kernel_initializer':'glorot_normal',
            'bias_initializer': tf.keras.initializers.RandomNormal(stddev=1e-2),
            'kernel_regularizer':tf.keras.regularizers.l2(self.l2_reg)
            }
        lnorm_kwargs = {'epsilon':1e-6}
        self.mlp1 = layers.Dense(self.mlp_dim,**dense_kwargs)
        self.mlp2= layers.Dense(self.d_model, **dense_kwargs)
        self.activation = layers.Activation(tf.nn.gelu)
        self.norm_attn = layers.LayerNormalization(**lnorm_kwargs)
        self.norm_mlp = layers.LayerNormalization(**lnorm_kwargs)
        self.self_attention  = layers.MultiHeadAttention(self.num_heads, self.d_model//self.num_heads, **dense_kwargs)
        self.rezero = ReZero()
        super().build(input_shape)

    def call(self, inputs):
        # Attention block
        shortcut = x = inputs
        if not self.use_rezero:
            x = self.norm_attn(x)
            x = self.self_attention(x,x)
        else:
            x = self.self_attention(x,x)
            x = self.rezero(x)
        x = SkipConnection(self.layer_drop_rate)([shortcut,x])

        # MLP block.
        shortcut = x
        if not self.use_rezero:
            x = self.norm_mlp(x)
            x = self.mlp2(self.activation(self.mlp1(x)))
        else:
            x = self.mlp2(self.activation(self.mlp1(x)))
            x = self.rezero(x)
        x = SkipConnection(self.layer_drop_rate)([shortcut,x])

        return x

従来との違いがわかりやすいようにcallメソッド内の処理を記述した。
ブロック内にResidualが2つあるが、論文ではブロック内でReZeroの係数を共用するように書かれているので、この実装でもReZeroのLayerは同じものを使っている。

実験

CIFAR-10で実験を行った。
OptimizerはAdamW+SAMを使用。CNNではSGDを使うのが一般的と思うが、ViT系はAdamWが好まれるようだ。
AdamWはTensorFlow Addonsでも提供されているが、実装に問題があると思われるので、今回はpytorchにおけるAdamW実装に近いものを自作した。SAMも自作
ViTはViT-Tiと呼ばれる、パラメータ数が5M程度のかなり小さいモデルで行った。
その他実験の詳細はGoogleColabのノート参照のこと。

以下結果。通常のViTはLayerNorm-Vitと表記しておく。

モデル 正解率
LayerNorm-ViT-Ti 91.57
ReZero-ViT-Ti 94.92

ReZero.png

大幅に正解率が向上している。差が大きすぎて逆に理解に苦しむが、何度やっても結果がこうなので仕方がない。ViTは学習が遅いので、ReZeroで学習が早くなった結果、モデル本来の性能が出やすくなったのかもしれない。
この辺は実験の条件に左右されるはずで、常にこのような性能向上があるわけではないだろうが、驚きの結果となった。ConvolutionとハイブリッドにしたViTならばこの程度の正解率にするのはそれほど難しくないが、ConvolutionなしのViTとしてはかなり高い方だと思う。CIFAR-10のような小規模のデータセットでの学習をメインにした論文としては「Escaping the Big Data Paradigm with Compact Transformers」があるので興味のある方は確認してもらいたい。

ResNetへの応用

論文ではPreActタイプのResNetV2のBatchNormalizationをReZeroに置き換えても性能向上したとされている。論文には詳しく実装法が書いていないが、類似(ほぼ同じ)技術であるSkipInitの論文には詳細が書かれているので、そちらを参照すると以下のように変更するようだ。

  • SkipConnectionの直前にReZero層を追加
  • BatchNormalization層削除
  • Conv2Dでuse_biasをTrueに変更

筆者も上記のような変更を適用して実験してみたが性能向上は見られなかった(むしろ悪化した)。

SkipInitの論文によれば、Dropoutを追加したり学習率を変えたりしないと、BatchNormalizationに匹敵するような性能は得られないようだ。Residual Blockの後にActivationが来るタイプのResNetではSkipInit(ReZero)はそのまま適用できないも書いてあるようだ。

BatchNormalizationの置き換えは効果が薄いが、LayerNormalizationやGroupNormalizationの置き換えには性能向上の効果があるのかもしれない。

まとめ

ReZeroをViTに組み込み、CIFAR10での学習において大幅に性能向上することを確認した。ただしCNNでの性能向上は未確認。
"ReZero is All You Need"は流石に言い過ぎのような気がするが、Normalization Layer無しのモデルの可能性を感じられた。

2
3
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?