はじめに
TransformerではLayerNormalizationを使用することが基本となっているが、ReZeroというそれを置き換える構成が提案されている。本記事ではReZeroをVisionTransformer(ViT)に応用してその効果を確認する。
ReZeroとは
下記論文で提案されている。(Submitted on 10 Mar 2020)
"X Is All You Need"というタイトルが付く論文は、有名なAttention Is All You Needが発表されて以降は結構あるようで、これもその仲間。
効果としては学習の収束が速くなることが主張されている。
ReZeroと既存の構成との違いを論文からの切り抜きで示す。
ReZeroではNormがなく、Trainableな$\alpha_i$が出てくる。
ReZeroは「Residual with Zero initialization」の略でこの$\alpha_i$をZeroから始めるということに由来する。有名な日本のライトノベル/アニメ作品の「Re:ゼロから始める異世界生活」に関係があるかどうかは、論文には書いてなかった。
それはともかく、NFNetやResMLPのようにNormalizationをなくす方向性については現在研究が進んでいるようで、ReZeroはそのうちの一つと位置付けられるだろう。
論文中にはTransformerへのReZeroの組み込み方が書いてあり、公式実装もある。
本記事はこれをViTに応用してみたが、記事タイトルの"ReZero-ViT"というのは筆者の命名で、このように呼ばれるViTは筆者が確認した限りでは発表されていないと思う。
論文にはCNNでの応用として、ResNetでBatchNormalizarionをReZeroに置き換えたモデルでのCIFAR-10の実験結果も掲載されており、ReZeroの方が収束がはやく、正解率も良くなったとされている。
実装
TensorFlow(2.5.0)のtf.kerasで実装する。
ReZeroは単独のレイヤーとして実装すると以下のようになる。
class ReZero(tf.keras.layers.Layer):
def build(self, input_shape):
self.rezero_alpha = self.add_weight("rezero_a",
initializer=tf.keras.initializers.Zeros(),
trainable=Truedtype=tf.float32)
super().build(input_shape)
def call(self, inputs):
return inputs*self.rezero_alpha
TrainableなWeightを含むLayerとしては、もっともシンプルな形では無いかと思われる。self.rezero_alphaという1つのパラメータしかなく、パラメータ数としても1増えるだけで、計算量もほとんど無視できる程度だろう。
Transformerブロックとしては、下記のような実装になる。
class SkipConnection(tf.keras.layers.Layer):
def __init__(self, drop_rate=0.0,**kwargs):
super().__init__(**kwargs)
self.drop_rate = drop_rate
def build(self, input_shape):
super().build(input_shape)
def get_config(self):
config = {'drop_rate': self.drop_rate}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
shortcut, x = inputs
if self.drop_rate != 0.0:
x = layers.Dropout(rate=self.drop_rate, noise_shape=(None,1,1))(x)
x = layers.Add()([shortcut,x])
return x
class Transformer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, mlp_dim,dropout=0.0,layer_drop_rate=0.0,l2_reg=0.0,use_rezero=False,**kwargs):
super().__init__(**kwargs)
self.dropout = dropout
self.d_model = d_model
self.mlp_dim = mlp_dim
self.num_heads = num_heads
self.layer_drop_rate = layer_drop_rate
self.l2_reg = l2_reg
self.use_rezero = use_rezero
def get_config(self):
config = {'d_model': self.d_model, 'num_heads': self.num_heads, 'mlp_dim': self.mlp_dim,
'l2_reg': self.l2_reg, 'dropout': self.dropout, 'layer_drop_rate': self.layer_drop_rate,
'use_rezero': self.use_rezero}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
dense_kwargs = {
'kernel_initializer':'glorot_normal',
'bias_initializer': tf.keras.initializers.RandomNormal(stddev=1e-2),
'kernel_regularizer':tf.keras.regularizers.l2(self.l2_reg)
}
lnorm_kwargs = {'epsilon':1e-6}
self.mlp1 = layers.Dense(self.mlp_dim,**dense_kwargs)
self.mlp2= layers.Dense(self.d_model, **dense_kwargs)
self.activation = layers.Activation(tf.nn.gelu)
self.norm_attn = layers.LayerNormalization(**lnorm_kwargs)
self.norm_mlp = layers.LayerNormalization(**lnorm_kwargs)
self.self_attention = layers.MultiHeadAttention(self.num_heads, self.d_model//self.num_heads, **dense_kwargs)
self.rezero = ReZero()
super().build(input_shape)
def call(self, inputs):
# Attention block
shortcut = x = inputs
if not self.use_rezero:
x = self.norm_attn(x)
x = self.self_attention(x,x)
else:
x = self.self_attention(x,x)
x = self.rezero(x)
x = SkipConnection(self.layer_drop_rate)([shortcut,x])
# MLP block.
shortcut = x
if not self.use_rezero:
x = self.norm_mlp(x)
x = self.mlp2(self.activation(self.mlp1(x)))
else:
x = self.mlp2(self.activation(self.mlp1(x)))
x = self.rezero(x)
x = SkipConnection(self.layer_drop_rate)([shortcut,x])
return x
従来との違いがわかりやすいようにcallメソッド内の処理を記述した。
ブロック内にResidualが2つあるが、論文ではブロック内でReZeroの係数を共用するように書かれているので、この実装でもReZeroのLayerは同じものを使っている。
実験
CIFAR-10で実験を行った。
OptimizerはAdamW+SAMを使用。CNNではSGDを使うのが一般的と思うが、ViT系はAdamWが好まれるようだ。
AdamWはTensorFlow Addonsでも提供されているが、実装に問題があると思われるので、今回はpytorchにおけるAdamW実装に近いものを自作した。SAMも自作。
ViTはViT-Tiと呼ばれる、パラメータ数が5M程度のかなり小さいモデルで行った。
その他実験の詳細はGoogleColabのノート参照のこと。
以下結果。通常のViTはLayerNorm-Vitと表記しておく。
モデル | 正解率 |
---|---|
LayerNorm-ViT-Ti | 91.57 |
ReZero-ViT-Ti | 94.92 |
大幅に正解率が向上している。差が大きすぎて逆に理解に苦しむが、何度やっても結果がこうなので仕方がない。ViTは学習が遅いので、ReZeroで学習が早くなった結果、モデル本来の性能が出やすくなったのかもしれない。
この辺は実験の条件に左右されるはずで、常にこのような性能向上があるわけではないだろうが、驚きの結果となった。ConvolutionとハイブリッドにしたViTならばこの程度の正解率にするのはそれほど難しくないが、ConvolutionなしのViTとしてはかなり高い方だと思う。CIFAR-10のような小規模のデータセットでの学習をメインにした論文としては「Escaping the Big Data Paradigm with Compact Transformers」があるので興味のある方は確認してもらいたい。
ResNetへの応用
論文ではPreActタイプのResNetV2のBatchNormalizationをReZeroに置き換えても性能向上したとされている。論文には詳しく実装法が書いていないが、類似(ほぼ同じ)技術であるSkipInitの論文には詳細が書かれているので、そちらを参照すると以下のように変更するようだ。
- SkipConnectionの直前にReZero層を追加
- BatchNormalization層削除
- Conv2Dでuse_biasをTrueに変更
筆者も上記のような変更を適用して実験してみたが性能向上は見られなかった(むしろ悪化した)。
SkipInitの論文によれば、Dropoutを追加したり学習率を変えたりしないと、BatchNormalizationに匹敵するような性能は得られないようだ。Residual Blockの後にActivationが来るタイプのResNetではSkipInit(ReZero)はそのまま適用できないも書いてあるようだ。
BatchNormalizationの置き換えは効果が薄いが、LayerNormalizationやGroupNormalizationの置き換えには性能向上の効果があるのかもしれない。
まとめ
ReZeroをViTに組み込み、CIFAR10での学習において大幅に性能向上することを確認した。ただしCNNでの性能向上は未確認。
"ReZero is All You Need"は流石に言い過ぎのような気がするが、Normalization Layer無しのモデルの可能性を感じられた。