LoginSignup
14
15

More than 3 years have passed since last update.

Vision TransformerでCIFAR10正解率99%を達成する方法

Last updated at Posted at 2021-02-15

はじめに

画像認識の新方式として期待されているVision Transformer(ViT)を使って、CIFAR10正解率99%に挑戦する。
公式のページでもCIFAR10の転移学習ができるColabのノートブックが提供されていて、さほど難しいことでもないが、そのまま実施しても面白くないので、ここではTensorFlow/Kerasの自作コードに学習済みの重みをロードして実行する。

環境

TensorFlow 2.4.0
tf.keras 2.3.0
Google Colab TPU

ViTのモデルコード

主に下記のコードを参考にした。
Google公式
Eunkwang Jeon氏によるPytorch移植

Googleの公式コードはフレームワークとしてJAXを使っている。
Pytorch移植版はその忠実な再現で、コードも簡潔でこちらのほうが処理がわかりやすいのでおすすめ。AttentionMapの視覚化もできて素晴らしい。(記事作成後に同じ人がすでにKeras版を作っていると気づいたが、こちらは参照してない。)

今回作成したtf.keras版のコードはこちら。(全文は最後のGoogle Colabのノートブックのリンク参照のこと)

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import  Dense, Dropout, Add, Concatenate, Reshape, Permute, Activation, Flatten, Conv2D, LayerNormalization
class PatchEmbedding(tf.keras.layers.Layer):
    def __init__(self, num_patches, d_model, **kwargs):
        self.num_patches = num_patches
        self.d_model = d_model
        super().__init__(**kwargs)
    def build(self, input_shape):
        self.pos_emb = self.add_weight("pos_emb", 
                                       shape=(1, self.num_patches + 1, self.d_model),
                                       initializer=tf.initializers.RandomNormal(stddev=0.02))
        self.class_emb = self.add_weight("class_emb", 
                                         shape=(1, 1, self.d_model),
                                         initializer=tf.initializers.Zeros())
        super().build(input_shape)
    def call(self,x):
        batch_size = tf.shape(x)[0]
        class_emb = tf.broadcast_to(
            self.class_emb, [batch_size, 1, self.d_model]
        )
        x = tf.concat([class_emb, x], axis=1)
        x = x + self.pos_emb
        return x
    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1]+1, input_shape[2], input_shape[3])


def vit(
    image_size = 224,
    patch_size = 16,
    channels = 3,
    d_model = 768,
    num_classes =21843,
    mlp_dim = 3072,
    num_layers = 12,
    num_heads = 12,
    dropout = 0.1,
    attention_dropout = 0.0):

    dense_kwargs = {
        'kernel_initializer':'glorot_uniform',
        'bias_initializer': tf.keras.initializers.RandomNormal(stddev=1e-6)
        }
    lnorm_kwargs = {'epsilon':1e-6}

    def MultiHead_SelfAttention(inputs, d_model, num_heads, dropout,prefix):
        projection_dim = d_model // num_heads

        query = Dense(d_model, name = prefix+"_Dense_query", **dense_kwargs)(inputs)
        key   = Dense(d_model, name = prefix+"_Dense_key"  , **dense_kwargs)(inputs)
        value = Dense(d_model, name = prefix+"_Dense_value", **dense_kwargs)(inputs)
        query = Reshape((-1, num_heads, projection_dim),name = prefix+"_Reshape_query")(query)
        key   = Reshape((-1, num_heads, projection_dim),name = prefix+"_Reshape_key"  )(key)
        value = Reshape((-1, num_heads, projection_dim),name = prefix+"_Reshape_value")(value)

        query = Permute((2, 1, 3), name = prefix+"_Permute_query")(query)
        key   = Permute((2, 1, 3), name = prefix+"_Permute_key")(key)
        value = Permute((2, 1, 3), name = prefix+"_Permute_value")(value)

        score = tf.matmul(query, key, transpose_b=True)
        score = score/K.sqrt(K.cast(projection_dim, 'float32'))
        attention_probs = Activation("softmax", name = prefix+"_Softmax_projection")(score) 
        if dropout != 0.0:
            attention_probs = Dropout(dropout, name=prefix+"_Dropout_projection")(attention_probs)

        attention = tf.matmul(attention_probs, value)
        attention = Permute((2, 1, 3), name = prefix+"_Permute_attention")(attention)
        attention = Reshape( (-1, d_model), name = prefix+"_Reshape_attention")(attention)
        output = Dense(d_model, name = prefix+"_Dense_out", **dense_kwargs)(attention)
        if dropout != 0.0:
            output = Dropout(dropout, name=prefix+"_Dropout_out")(output)

        return output

    x = inputs = tf.keras.layers.Input(shape=(image_size, image_size, channels), name="Input")

    num_patches = (image_size // patch_size) ** 2
    patch_dim = channels * patch_size ** 2
    x = Conv2D(d_model,patch_size,patch_size, padding='valid', name='Embedding_Conv2D')(x)
    x = Reshape((num_patches,d_model), name='Embedding_Reshape')(x)
    x = PatchEmbedding(num_patches,d_model, name="PatchEmbedding")(x)

    if dropout != 0.0:
       x = Dropout(dropout, name='Embedding_Dropout')(x)

    # Encoder
    for i in range(num_layers):
        prefix = f'Encoder_{i}'
        # Attention block.
        shortcut = x
        x = LayerNormalization(name = prefix+"_Attention_LayerNorm")(x)
        x = MultiHead_SelfAttention(x, d_model, num_heads, attention_dropout, prefix=prefix+"_MHSA" )
        x = Add(name=prefix+"_Attention_Add")([x,shortcut])
        # MLP block.
        shortcut = x
        x = LayerNormalization(name=prefix+"_MLP_LayerNorm",**lnorm_kwargs)(x)
        x = Dense(mlp_dim, name=prefix+"_MLP_Dense_0",**dense_kwargs)(x)
        x = Activation(tf.nn.gelu, name=prefix+"_MLP_GeLU")(x)
        if dropout != 0.0:
            x = Dropout(dropout, name=prefix+"_MLP_Dropout")(x)
        x = Dense(d_model, name=prefix+"_MLP_Dense_1", **dense_kwargs)(x)
        if dropout != 0.0:
            x = Dropout(dropout, name=prefix+"_MLP_Dropout_hidden")(x)
        x = Add(name=prefix+"_MLP_Add")([x,shortcut])

    x = LayerNormalization(name="Encoder_Norm", **lnorm_kwargs)(x)
    x = Dense(num_classes, kernel_initializer=tf.keras.initializers.Zeros() , name="Head_Dense")(x[:, 0])
    return tf.keras.Model(inputs,x) 

公式が提供する重みのデータをロード可能で、実際に精度が出ることから、コードに大きな瑕疵は無いと思われる。
以下コードを作成した際の個人的な感想など。

  • 最初にConv2Dが来る
    • 開発者側としては無理に畳み込みを排除する姿勢はない
  • クラスや位置情報のEmbeddingは学習可能な重みが必要なので、Layerを自作せざるを得ない模様
    • そのうちtf.kerasでも実装済みのLayerとして提供されるかもしれない
    • MultiHeadAttentionはtf.kerasで提供されるようなので、この辺ももっと簡単にかけるようになるはず  
  • Transformerの部分は本当に既存の言語処理用の処理とほぼ同じ
  • 最後はDense(pytorchでのLinear)がActivation無しで付け加えられているだけ

ファインチューニング

公式の重みをロードしてファインチューニングを行う。今回はベースモデルの重みを使用。
画像は224x224にリサイズして入力とする。最後のDenseのみクラス数が違うので重みはロードせず0で初期化して、その後は普通の学習と同じようにfitさせれば良い。フリーズ等は特にしなくても良いようだ。
パラメータ数は比較的多いが、計算量自体はさほど大きくないようで、割と早くEpochが終わる。メモリ消費もそれほどでもなく、BatchSizeは512でも可能。
今回の各種設定は下記の通り。

パラメータ
Model Name ViT-B_16
Batch Size 200
Warmup 3 epochs(Linear)
Cooldown 12 epochs(Cosine)
Horizontal/Vertical Shift 0.2
Horizontal Flip True
Cutout 0.5 x 1
Optimizer SGD(momentum=0.9)
Learning Rate 4e-2
Label Smoothing 0.1

15エポック(ColabのTPUで約20分)で大体99%までいける。もう少しエポック数を増やせば確実性が上がる。
以下、実際の学習の結果。

Epoch 1/15
250/250 [==============================] - 149s 321ms/step - loss: 1.6647 - acc: 0.8384 - val_loss: 0.5657 - val_acc: 0.9792
Epoch 2/15
250/250 [==============================] - 72s 284ms/step - loss: 0.6631 - acc: 0.9336 - val_loss: 0.5473 - val_acc: 0.9824
Epoch 3/15
250/250 [==============================] - 71s 283ms/step - loss: 0.6373 - acc: 0.9403 - val_loss: 0.5521 - val_acc: 0.9789
Epoch 4/15
250/250 [==============================] - 71s 282ms/step - loss: 0.6282 - acc: 0.9449 - val_loss: 0.5484 - val_acc: 0.9797
Epoch 5/15
250/250 [==============================] - 71s 283ms/step - loss: 0.6055 - acc: 0.9547 - val_loss: 0.5419 - val_acc: 0.9837
Epoch 6/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5913 - acc: 0.9616 - val_loss: 0.5344 - val_acc: 0.9868
Epoch 7/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5795 - acc: 0.9660 - val_loss: 0.5325 - val_acc: 0.9868
Epoch 8/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5699 - acc: 0.9690 - val_loss: 0.5329 - val_acc: 0.9873
Epoch 9/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5639 - acc: 0.9731 - val_loss: 0.5321 - val_acc: 0.9867
Epoch 10/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5584 - acc: 0.9743 - val_loss: 0.5318 - val_acc: 0.9867
Epoch 11/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5476 - acc: 0.9802 - val_loss: 0.5303 - val_acc: 0.9884
Epoch 12/15
250/250 [==============================] - 71s 282ms/step - loss: 0.5452 - acc: 0.9808 - val_loss: 0.5292 - val_acc: 0.9893
Epoch 13/15
250/250 [==============================] - 72s 286ms/step - loss: 0.5385 - acc: 0.9835 - val_loss: 0.5279 - val_acc: 0.9898
Epoch 14/15
250/250 [==============================] - 71s 283ms/step - loss: 0.5364 - acc: 0.9846 - val_loss: 0.5281 - val_acc: 0.9895
Epoch 15/15
250/250 [==============================] - 71s 283ms/step - loss: 0.5385 - acc: 0.9836 - val_loss: 0.5282 - val_acc: 0.9900

Google Colab Notebookへのリンク

まとめ

以前に別の記事でEfficentNetで99%達成する方法を書いたが、あちらでは2時間以上かかった。工夫すればもう少し短くできるが結局1時間は切れなかった。それに比べれば大幅に短い時間で達成できたので驚いた。
現状のViTの欠点としては、転移学習を使わないと低性能に終わる、ということが当初から指摘されている。Qiitaでもいくつか記事があるが、今回作成したモデルでも最初からCIFAR10を学習させると80%にすら到達できなかった1。この辺りは、T2T-ViTのような改良版が出てきているようなので、今後に期待。

参考

画像認識の大革命。AI界で話題爆発中の「Vision Transformer」を解説!
Vision Transformerのコードの写経


  1. その後、Cutmixでデータ拡張子し800エポック程度学習させて88%まで到達できた。 

14
15
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
15