はじめに
画像認識の新方式として期待されているVision Transformer(ViT)を使って、CIFAR10正解率99%に挑戦する。
公式のページでもCIFAR10の転移学習ができるColabのノートブックが提供されていて、さほど難しいことでもないが、そのまま実施しても面白くないので、ここではTensorFlow/Kerasの自作コードに学習済みの重みをロードして実行する。
環境
TensorFlow 2.4.0
tf.keras 2.3.0
Google Colab TPU
ViTのモデルコード
主に下記のコードを参考にした。
Google公式
Eunkwang Jeon氏によるPytorch移植
Googleの公式コードはフレームワークとしてJAXを使っている。
Pytorch移植版はその忠実な再現で、コードも簡潔でこちらのほうが処理がわかりやすいのでおすすめ。AttentionMapの視覚化もできて素晴らしい。(記事作成後に同じ人がすでにKeras版を作っていると気づいたが、こちらは参照してない。)
今回作成したtf.keras版のコードはこちら。(全文は最後のGoogle Colabのノートブックのリンク参照のこと)
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Dense, Dropout, Add, Concatenate, Reshape, Permute, Activation, Flatten, Conv2D, LayerNormalization
class PatchEmbedding(tf.keras.layers.Layer):
def __init__(self, num_patches, d_model, **kwargs):
self.num_patches = num_patches
self.d_model = d_model
super().__init__(**kwargs)
def build(self, input_shape):
self.pos_emb = self.add_weight("pos_emb",
shape=(1, self.num_patches + 1, self.d_model),
initializer=tf.initializers.RandomNormal(stddev=0.02))
self.class_emb = self.add_weight("class_emb",
shape=(1, 1, self.d_model),
initializer=tf.initializers.Zeros())
super().build(input_shape)
def call(self,x):
batch_size = tf.shape(x)[0]
class_emb = tf.broadcast_to(
self.class_emb, [batch_size, 1, self.d_model]
)
x = tf.concat([class_emb, x], axis=1)
x = x + self.pos_emb
return x
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[1]+1, input_shape[2], input_shape[3])
def vit(
image_size = 224,
patch_size = 16,
channels = 3,
d_model = 768,
num_classes =21843,
mlp_dim = 3072,
num_layers = 12,
num_heads = 12,
dropout = 0.1,
attention_dropout = 0.0):
dense_kwargs = {
'kernel_initializer':'glorot_uniform',
'bias_initializer': tf.keras.initializers.RandomNormal(stddev=1e-6)
}
lnorm_kwargs = {'epsilon':1e-6}
def MultiHead_SelfAttention(inputs, d_model, num_heads, dropout,prefix):
projection_dim = d_model // num_heads
query = Dense(d_model, name = prefix+"_Dense_query", **dense_kwargs)(inputs)
key = Dense(d_model, name = prefix+"_Dense_key" , **dense_kwargs)(inputs)
value = Dense(d_model, name = prefix+"_Dense_value", **dense_kwargs)(inputs)
query = Reshape((-1, num_heads, projection_dim),name = prefix+"_Reshape_query")(query)
key = Reshape((-1, num_heads, projection_dim),name = prefix+"_Reshape_key" )(key)
value = Reshape((-1, num_heads, projection_dim),name = prefix+"_Reshape_value")(value)
query = Permute((2, 1, 3), name = prefix+"_Permute_query")(query)
key = Permute((2, 1, 3), name = prefix+"_Permute_key")(key)
value = Permute((2, 1, 3), name = prefix+"_Permute_value")(value)
score = tf.matmul(query, key, transpose_b=True)
score = score/K.sqrt(K.cast(projection_dim, 'float32'))
attention_probs = Activation("softmax", name = prefix+"_Softmax_projection")(score)
if dropout != 0.0:
attention_probs = Dropout(dropout, name=prefix+"_Dropout_projection")(attention_probs)
attention = tf.matmul(attention_probs, value)
attention = Permute((2, 1, 3), name = prefix+"_Permute_attention")(attention)
attention = Reshape( (-1, d_model), name = prefix+"_Reshape_attention")(attention)
output = Dense(d_model, name = prefix+"_Dense_out", **dense_kwargs)(attention)
if dropout != 0.0:
output = Dropout(dropout, name=prefix+"_Dropout_out")(output)
return output
x = inputs = tf.keras.layers.Input(shape=(image_size, image_size, channels), name="Input")
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
x = Conv2D(d_model,patch_size,patch_size, padding='valid', name='Embedding_Conv2D')(x)
x = Reshape((num_patches,d_model), name='Embedding_Reshape')(x)
x = PatchEmbedding(num_patches,d_model, name="PatchEmbedding")(x)
if dropout != 0.0:
x = Dropout(dropout, name='Embedding_Dropout')(x)
# Encoder
for i in range(num_layers):
prefix = f'Encoder_{i}'
# Attention block.
shortcut = x
x = LayerNormalization(name = prefix+"_Attention_LayerNorm")(x)
x = MultiHead_SelfAttention(x, d_model, num_heads, attention_dropout, prefix=prefix+"_MHSA" )
x = Add(name=prefix+"_Attention_Add")([x,shortcut])
# MLP block.
shortcut = x
x = LayerNormalization(name=prefix+"_MLP_LayerNorm",**lnorm_kwargs)(x)
x = Dense(mlp_dim, name=prefix+"_MLP_Dense_0",**dense_kwargs)(x)
x = Activation(tf.nn.gelu, name=prefix+"_MLP_GeLU")(x)
if dropout != 0.0:
x = Dropout(dropout, name=prefix+"_MLP_Dropout")(x)
x = Dense(d_model, name=prefix+"_MLP_Dense_1", **dense_kwargs)(x)
if dropout != 0.0:
x = Dropout(dropout, name=prefix+"_MLP_Dropout_hidden")(x)
x = Add(name=prefix+"_MLP_Add")([x,shortcut])
x = LayerNormalization(name="Encoder_Norm", **lnorm_kwargs)(x)
x = Dense(num_classes, kernel_initializer=tf.keras.initializers.Zeros() , name="Head_Dense")(x[:, 0])
return tf.keras.Model(inputs,x)
公式が提供する重みのデータをロード可能で、実際に精度が出ることから、コードに大きな瑕疵は無いと思われる。
以下コードを作成した際の個人的な感想など。
- 最初にConv2Dが来る
- 開発者側としては無理に畳み込みを排除する姿勢はない
- クラスや位置情報のEmbeddingは学習可能な重みが必要なので、Layerを自作せざるを得ない模様
- そのうちtf.kerasでも実装済みのLayerとして提供されるかもしれない
- MultiHeadAttentionはtf.kerasで提供されるようなので、この辺ももっと簡単にかけるようになるはず
- Transformerの部分は本当に既存の言語処理用の処理とほぼ同じ
- 参照(言語理解のためのTransformerモデル)
- 最後はDense(pytorchでのLinear)がActivation無しで付け加えられているだけ
ファインチューニング
公式の重みをロードしてファインチューニングを行う。今回はベースモデルの重みを使用。
画像は224x224にリサイズして入力とする。最後のDenseのみクラス数が違うので重みはロードせず0で初期化して、その後は普通の学習と同じようにfitさせれば良い。フリーズ等は特にしなくても良いようだ。
パラメータ数は比較的多いが、計算量自体はさほど大きくないようで、割と早くEpochが終わる。メモリ消費もそれほどでもなく、BatchSizeは512でも可能。
今回の各種設定は下記の通り。
パラメータ | 値 |
---|---|
Model Name | ViT-B_16 |
Batch Size | 200 |
Warmup | 3 epochs(Linear) |
Cooldown | 12 epochs(Cosine) |
Horizontal/Vertical Shift | 0.2 |
Horizontal Flip | True |
Cutout | 0.5 x 1 |
Optimizer | SGD(momentum=0.9) |
Learning Rate | 4e-2 |
Label Smoothing | 0.1 |
15エポック(ColabのTPUで約20分)で大体99%までいける。もう少しエポック数を増やせば確実性が上がる。
以下、実際の学習の結果。
Epoch 1/15
250/250 [==============================] - 149s 321ms/step - loss: 1.6647 - acc: 0.8384 - val_loss: 0.5657 - val_acc: 0.9792
Epoch 2/15
250/250 [==============================] - 72s 284ms/step - loss: 0.6631 - acc: 0.9336 - val_loss: 0.5473 - val_acc: 0.9824
Epoch 3/15
250/250 [==============================] - 71s 283ms/step - loss: 0.6373 - acc: 0.9403 - val_loss: 0.5521 - val_acc: 0.9789
Epoch 4/15
250/250 [==============================] - 71s 282ms/step - loss: 0.6282 - acc: 0.9449 - val_loss: 0.5484 - val_acc: 0.9797
Epoch 5/15
250/250 [==============================] - 71s 283ms/step - loss: 0.6055 - acc: 0.9547 - val_loss: 0.5419 - val_acc: 0.9837
Epoch 6/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5913 - acc: 0.9616 - val_loss: 0.5344 - val_acc: 0.9868
Epoch 7/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5795 - acc: 0.9660 - val_loss: 0.5325 - val_acc: 0.9868
Epoch 8/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5699 - acc: 0.9690 - val_loss: 0.5329 - val_acc: 0.9873
Epoch 9/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5639 - acc: 0.9731 - val_loss: 0.5321 - val_acc: 0.9867
Epoch 10/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5584 - acc: 0.9743 - val_loss: 0.5318 - val_acc: 0.9867
Epoch 11/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5476 - acc: 0.9802 - val_loss: 0.5303 - val_acc: 0.9884
Epoch 12/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5452 - acc: 0.9808 - val_loss: 0.5292 - val_acc: 0.9893
Epoch 13/15
250/250 [==============================] - 72s 286ms/step - loss: 0.5385 - acc: 0.9835 - val_loss: 0.5279 - val_acc: 0.9898
Epoch 14/15
250/250 [==============================] - 71s 283ms/step - loss: 0.5364 - acc: 0.9846 - val_loss: 0.5281 - val_acc: 0.9895
Epoch 15/15
250/250 [==============================] - 71s 283ms/step - loss: 0.5385 - acc: 0.9836 - val_loss: 0.5282 - val_acc: 0.9900
まとめ
以前に別の記事でEfficentNetで99%達成する方法を書いたが、あちらでは2時間以上かかった。工夫すればもう少し短くできるが結局1時間は切れなかった。それに比べれば大幅に短い時間で達成できたので驚いた。
現状のViTの欠点としては、転移学習を使わないと低性能に終わる、ということが当初から指摘されている。Qiitaでもいくつか記事があるが、今回作成したモデルでも最初からCIFAR10を学習させると80%にすら到達できなかった1。この辺りは、T2T-ViTのような改良版が出てきているようなので、今後に期待。
参考
画像認識の大革命。AI界で話題爆発中の「Vision Transformer」を解説!
Vision Transformerのコードの写経
-
その後、Cutmixでデータ拡張子し800エポック程度学習させて88%まで到達できた。 ↩