7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Vision Transformer(ViT)を自作してcifar10を解いてみた

Posted at

はじめに

こんにちにゃんです。
水色桜(みずいろさくら)です。
今回はVision Transformerを自作してcifar10を解いてみようと思います。
解説にあたっては極力数式を用いずに解説するつもりです。
もし記事中で間違い・不明点などあればコメントまたはTwitterまでお寄せいただけると嬉しいです。
まず今回作成したモデルの精度を見てみましょう。

Epoch 1/25
loss: 1.8223 - accuracy: 0.3932 - val_loss: 1.5787 - val_accuracy: 0.4319 - lr: 0.0010
Epoch 2/25
loss: 1.4452 - accuracy: 0.4893 - val_loss: 1.3778 - val_accuracy: 0.5180 - lr: 5.0000e-04
Epoch 3/25
loss: 1.3357 - accuracy: 0.5285 - val_loss: 1.3702 - val_accuracy: 0.5183 - lr: 5.0000e-04
Epoch 4/25
loss: 1.2748 - accuracy: 0.5517 - val_loss: 1.2847 - val_accuracy: 0.5486 - lr: 5.0000e-04
Epoch 5/25
loss: 1.2258 - accuracy: 0.5702 - val_loss: 1.2863 - val_accuracy: 0.5469 - lr: 5.0000e-04
Epoch 6/25
loss: 1.1839 - accuracy: 0.5831 - val_loss: 1.2715 - val_accuracy: 0.5566 - lr: 5.0000e-04
Epoch 7/25
loss: 1.0910 - accuracy: 0.6197 - val_loss: 1.1777 - val_accuracy: 0.5874 - lr: 2.5000e-04
Epoch 8/25
loss: 1.0564 - accuracy: 0.6314 - val_loss: 1.1737 - val_accuracy: 0.5845 - lr: 2.5000e-04
Epoch 9/25
loss: 1.0312 - accuracy: 0.6383 - val_loss: 1.1962 - val_accuracy: 0.5852 - lr: 2.5000e-04
Epoch 10/25
loss: 1.0119 - accuracy: 0.6471 - val_loss: 1.1988 - val_accuracy: 0.5813 - lr: 2.5000e-04
Epoch 11/25
loss: 0.9926 - accuracy: 0.6527 - val_loss: 1.1468 - val_accuracy: 0.5983 - lr: 2.5000e-04
Epoch 12/25
loss: 0.9290 - accuracy: 0.6785 - val_loss: 1.1308 - val_accuracy: 0.6060 - lr: 1.2500e-04
Epoch 13/25
loss: 0.9078 - accuracy: 0.6846 - val_loss: 1.1492 - val_accuracy: 0.6051 - lr: 1.2500e-04
Epoch 14/25
loss: 0.8959 - accuracy: 0.6886 - val_loss: 1.1521 - val_accuracy: 0.6069 - lr: 1.2500e-04
Epoch 15/25
loss: 0.8833 - accuracy: 0.6929 - val_loss: 1.1577 - val_accuracy: 0.6059 - lr: 1.2500e-04
Epoch 16/25
loss: 0.8725 - accuracy: 0.6972 - val_loss: 1.1945 - val_accuracy: 0.5940 - lr: 1.2500e-04
Epoch 17/25
loss: 0.8349 - accuracy: 0.7096 - val_loss: 1.1708 - val_accuracy: 0.6039 - lr: 6.2500e-05
Epoch 18/25
loss: 0.8244 - accuracy: 0.7120 - val_loss: 1.1720 - val_accuracy: 0.6055 - lr: 6.2500e-05
Epoch 19/25
loss: 0.8158 - accuracy: 0.7163 - val_loss: 1.1699 - val_accuracy: 0.6067 - lr: 6.2500e-05
Epoch 20/25
loss: 0.8078 - accuracy: 0.7202 - val_loss: 1.1993 - val_accuracy: 0.6008 - lr: 6.2500e-05
Epoch 21/25
loss: 0.8013 - accuracy: 0.7221 - val_loss: 1.1870 - val_accuracy: 0.6079 - lr: 6.2500e-05
Epoch 22/25
loss: 0.7777 - accuracy: 0.7306 - val_loss: 1.1791 - val_accuracy: 0.6100 - lr: 3.1250e-05
Epoch 23/25
loss: 0.7743 - accuracy: 0.7311 - val_loss: 1.1938 - val_accuracy: 0.6062 - lr: 3.1250e-05
Epoch 24/25
loss: 0.7676 - accuracy: 0.7342 - val_loss: 1.1927 - val_accuracy: 0.6088 - lr: 3.1250e-05
Epoch 25/25
loss: 0.7659 - accuracy: 0.7327 - val_loss: 1.1937 - val_accuracy: 0.6077 - lr: 3.1250e-05

25回の学習で正解率60%となっています。通常Vision transformer(ViT)はImagenetなどで事前学習を行うため、事前学習なしではこのくらいの精度しか出ません。
現在(2023年7月21日)、SoTAを達成しているモデルはTransformerとCNN(Convoluional Neural Network)の組み合わせで出来ています。
Vision Transformer(ViT)を理解することはこれらのモデルを理解する下地になってくれると考えます。
では早速Vision Transformer(ViT)について解説していきます。

Vision Transformer(ViT)とは

2020年にGoogleから発表されたモデル。Vision Transformersの特徴は以下の4つです。

  1. SoTA(State of The Art)を上回る精度を従来の約1/15の計算量で達成したこと。
  2. 畳み込みを用いずにTransformerのみを利用していること。
  3. 画像パッチ(画像を分割したピースのようなもの)を単語のように扱うこと。
  4. アーキテクチャ(モデルの構造)はTransformerのエンコーダ部分であること。

Vision Transformer(ViT)では入力画像をパッチに分割し、Flatten(複数次元を持つ要素を一次元の要素に変換する処理)することで、一つ一つのパッチを単語のように扱います。このパッチに位置エンコーディング(パッチの位置情報)を付加したものが入力になります。

Vision Transformer(ViT)のアーキテクチャ

image.png

Vision Transformer(ViT)の論文より引用

Vision Transformer(ViT)のアーキテクチャは上図のようになっています。まずパッチをFlattenし、線型射影(Dense:通常のニューラルネットワークのように全結合層をかませる)します。これをTransformerのエンコーダ部分に入力し、最後にMLP Headと呼ばれるモデルに入力します。なおここでMLPとMLP Headが登場しますが、二つは似て非なるものなので注意して下さい。

Transformer Encoder

Transformerのエンコーダ部分はLayerNormalization(上図のNormに当たります)、MultiHeadAttention、MLPという3つから構成されます。LayerNormalizationは1つのサンプルにおける各レイヤーの隠れ層の値の平均・分散で正規化します。詳しい解説はこちらの記事をご覧ください。MLPは簡単に言えば全結合層を2つ繋げたものです。特に難しい点はないので割愛します。

Attentionとは

Attentionは「深層学習モデルに入力されたデータのどの部分に注目するのかを学習し、利用する仕組み」のことです。Seq2Seqなどの従来手法では入力全体を最終的に1つの固定長ベクトルに詰め込んで表現するため、入力が長くなると内容を伝えるのが難しくなるという問題がありました。それに対して、Attentionはデコーダーにおいて、入力系列の情報を直接参照できるようにすることで、入力が長くなっても適切に内容を伝えることができるようにしました。入力全体の内容に加えて、パッチを1つ1つ出力する際に毎回、対応する入力系列のパッチを逐次的に考慮しながら変換します。また、注目度も含めて深層学習の誤差逆伝播によって学習できます。

Attentionの利点(こちらを参考にしました)

  1. 高い性能が期待できる(現在の世界最高精度クラスのモデルの多くはAttentionを用いている)
  2. 高速に学習できる(RNN は時刻tの計算が終わるまで時刻t+1の計算をできず、GPU をフルに使えません。Transformer は推論時の Decoder を除いて、すべての時刻の計算を同時に行えるため GPU をフルに使いやすいです。)
  3. 構造が単純

Attentionの構造(こちらを参考にしました)

image.png
Attentionの基本はqueryとmemory(key, value) です。Attentionとはqueryを用いてmemoryから必要な情報を選択的に抽出する仕組みです。memoryから情報を抽出する際、queryはkeyによって取得するmemoryを決定し、対応するvalueを取得します。

image.png

例えば、食べ物というQueryに対して、「『いちご』が80%、『が』が5%、『好き』が15%くらい」という風に、Keyはどこにどれくらい注目するのかを決定します。
具体的な計算はQueryとKeyの行列積をとります。
行列積をとった後、softmax関数にかけることでAttention_weightが得られます。
Attention_weightはvalueから情報を取得する際に、どこにどれだけ注目するのかを示しています。
Attention_weightとvalueの行列積を計算することで、Inputのどこにどれだけ注目すればいいかという情報を持ったoutputを得ることができます。

MLP Headとは

LayerNormalizationと全結合層(Dense)を繋いだものです。こちらも難しい点はないので割愛します。

Vision Transformer(ViT)の実装

ここからは実際に実装を行なっていきます。まず今回実装したコード全体を示します。お忙しい方はコピペして使ってみて下さい。

VisionTransformer.py
from tensorflow.keras.callbacks import LearningRateScheduler
# 学習率のスケジュールを定義する
def step_decay(epoch):
    x = 0.001
    if epoch > 0:
        x /= 2
    if epoch > 5:
        x /= 2
    if epoch > 10:
        x /= 2
    if epoch > 15:
        x /= 2
    if epoch > 20:
        x /= 2
    return x
lr_decay = LearningRateScheduler(step_decay)

from tensorflow.keras.datasets import cifar10
from tensorflow.keras import utils
import tensorflow as tf
# cifar10をロードし、ラベルをone-hotベクトル化する
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.reshape(50000, 32, 32, 3).astype('float32')
X_test = X_test.reshape(10000, 32, 32, 3).astype('float32')
X_train /= 255
X_test /= 255

y_train = utils.to_categorical(y_train, 10)
y_test = utils.to_categorical(y_test, 10)

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Flatten,  MultiHeadAttention, LayerNormalization, BatchNormalization, Embedding,  Dropout
from tensorflow.keras.activations import gelu
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 各種変数を定義する
batch_size = 10  # バッチサイズ
epochs = 25  # エポック数
num_heads = 4  # Attention Headの数
layer_number = 4  # Transformer Encoderのレイヤー数
layer_width = 128  # Transformer Encoderのレイヤーの厚さ
patch_size = 48  # パッチサイズ
dim = 128  # 線型射影後のサイズ
steps_per_epoch = X_train.shape[0] // batch_size  # 1エポックあたりのステップ数
validation_steps = X_test.shape[0] // batch_size  # 1エポックあたりのステップ数(テスト)
def create_bench_model():
    source_target = Input(shape = (32, 32, 3))
    st_patch = tf.reshape(source_target,[batch_size, 3072//patch_size,patch_size])  # パッチに分割する
    st_vec = Dense(dim)(st_patch)  # 線型射影
    st_encoding = Embedding(input_dim=3072//patch_size,output_dim = dim)(tf.range(0,3072//patch_size))+st_vec  # 位置エンコーディング
    source_target_norm = tf.expand_dims(LayerNormalization()(st_encoding),0)
    layer = MultiHeadAttention(num_heads=num_heads, key_dim=3)  # MultiHeadAttention
    output_tensor, weights = layer(source_target_norm, source_target_norm, return_attention_scores=True)
    output_tensor = tf.expand_dims(st_encoding,0)+output_tensor  # 残差接続
    mlp_input = LayerNormalization()(output_tensor[0])
    # MLP
    mlp_hidden = Dense(layer_width)(mlp_input)
    mlp_hidden = gelu(mlp_hidden)
    mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
    mlp_hidden = Dense(layer_width)(mlp_hidden)
    mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
    mlp_hidden = mlp_hidden+Dense(layer_width)(output_tensor[0])
    
    for i in range(layer_number-1):
        mlp_hidden = tf.expand_dims(LayerNormalization()(mlp_hidden),0)
        layer2 = MultiHeadAttention(num_heads=num_heads, key_dim=2)
        output_tensor, weights = layer2(mlp_hidden, mlp_hidden, return_attention_scores=True)
        output_tensor = mlp_hidden+output_tensor
        mlp_input = LayerNormalization()(output_tensor[0])
        mlp_hidden = Dense(layer_width)(mlp_input)
        mlp_hidden = gelu(mlp_hidden)
        mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
        mlp_hidden = Dense(layer_width)(mlp_hidden)
        mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
        mlp_hidden = mlp_hidden+Dense(layer_width)(output_tensor[0])
    
    # MLP Head
    mlp_head_input = LayerNormalization()(mlp_hidden)
    mlp_head_input = Flatten()(mlp_head_input)
    mlp_head_output = Dense(10, activation = "softmax")(mlp_head_input)
    return Model(inputs=source_target, outputs=mlp_head_output)

model = create_bench_model()
model.summary()
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
model.compile(loss=loss_fn, optimizer=Adam(), metrics=["accuracy"])
val_gen = ImageDataGenerator().flow(X_test, y_test, batch_size=batch_size)
train_gen = ImageDataGenerator().flow(X_train, y_train, batch_size=batch_size)

history = model.fit(train_gen, epochs=epochs, validation_data=val_gen, 
                        steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, callbacks=[lr_decay])

まず準備として学習率のスケジュールを設定します。

learningratescheduler.py
from tensorflow.keras.callbacks import LearningRateScheduler
# 学習率のスケジュールを定義する
def step_decay(epoch):
    x = 0.001
    if epoch > 0:
        x /= 2
    if epoch > 5:
        x /= 2
    if epoch > 10:
        x /= 2
    if epoch > 15:
        x /= 2
    if epoch > 20:
        x /= 2
    return x
lr_decay = LearningRateScheduler(step_decay)

次にcifar10をロードし、ラベルをone-hotベクトル化(一つの要素が1でそれ以外の要素が0であるようなベクトル)します。

load_cifar10.py
from tensorflow.keras.datasets import cifar10
from tensorflow.keras import utils
import tensorflow as tf
# cifar10をロードし、ラベルをone-hotベクトル化する
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.reshape(50000, 32, 32, 3).astype('float32')
X_test = X_test.reshape(10000, 32, 32, 3).astype('float32')
X_train /= 255
X_test /= 255

y_train = utils.to_categorical(y_train, 10)
y_test = utils.to_categorical(y_test, 10)

次に各種変数を定義します。

variable.py
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Flatten,  MultiHeadAttention, LayerNormalization, BatchNormalization, Embedding,  Dropout
from tensorflow.keras.activations import gelu
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 各種変数を定義する
batch_size = 10  # バッチサイズ
epochs = 25  # エポック数
num_heads = 4  # Attention Headの数
layer_number = 4  # Transformer Encoderのレイヤー数
layer_width = 128  # Transformer Encoderのレイヤーの厚さ
patch_size = 48  # パッチサイズ
dim = 128  # 線型射影後のサイズ
steps_per_epoch = X_train.shape[0] // batch_size  # 1エポックあたりのステップ数
validation_steps = X_test.shape[0] // batch_size  # 1エポックあたりのステップ数(テスト)

そしてモデルを定義していきます。まず画像をパッチに分解し、線型射影します。

patch.py
def create_bench_model():
    source_target = Input(shape = (32, 32, 3))
    st_patch = tf.reshape(source_target,[batch_size, 3072//patch_size,patch_size])  # パッチに分割する
    st_vec = Dense(dim)(st_patch)  # 線型射影

次に位置エンコーディングを行います。

encoding.py
st_encoding = Embedding(input_dim=3072//patch_size,output_dim = dim)(tf.range(0,3072//patch_size))+st_vec  # 位置エンコーディング

次にMaltiHeadAttentionに入力します。ここで次元の問題で次元拡張を行っています。また残差接続も行っています

MultiHeadAttention.py
source_target_norm = tf.expand_dims(LayerNormalization()(st_encoding),0)
    layer = MultiHeadAttention(num_heads=num_heads, key_dim=3)  # MultiHeadAttention
    output_tensor, weights = layer(source_target_norm, source_target_norm, return_attention_scores=True)
    output_tensor = tf.expand_dims(st_encoding,0)+output_tensor  # 残差接続

次にMLPに入力します。元の論文とは異なりBatchNormalizationrとDropoutを挿入しています。これによりより高い精度で回答することが可能になります。

MLP.py
mlp_input = LayerNormalization()(output_tensor[0])
    # MLP
    mlp_hidden = Dense(layer_width)(mlp_input)
    mlp_hidden = gelu(mlp_hidden)
    mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
    mlp_hidden = Dense(layer_width)(mlp_hidden)
    mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
    mlp_hidden = mlp_hidden+Dense(layer_width)(output_tensor[0])

以上の操作をレイヤー数繰り返します。

Layer_rep.py
for i in range(layer_number-1):
        mlp_hidden = tf.expand_dims(LayerNormalization()(mlp_hidden),0)
        layer2 = MultiHeadAttention(num_heads=num_heads, key_dim=2)
        output_tensor, weights = layer2(mlp_hidden, mlp_hidden, return_attention_scores=True)
        output_tensor = mlp_hidden+output_tensor
        mlp_input = LayerNormalization()(output_tensor[0])
        mlp_hidden = Dense(layer_width)(mlp_input)
        mlp_hidden = gelu(mlp_hidden)
        mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
        mlp_hidden = Dense(layer_width)(mlp_hidden)
        mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
        mlp_hidden = mlp_hidden+Dense(layer_width)(output_tensor[0])

最後にMLP Headを実装します。

MLP_Head.py
# MLP Head
    mlp_head_input = LayerNormalization()(mlp_hidden)
    mlp_head_input = Flatten()(mlp_head_input)
    mlp_head_output = Dense(10, activation = "softmax")(mlp_head_input)
    return Model(inputs=source_target, outputs=mlp_head_output)

モデルの学習を行います。

model_learn.py
model = create_bench_model()
model.summary()
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
model.compile(loss=loss_fn, optimizer=Adam(), metrics=["accuracy"])
val_gen = ImageDataGenerator().flow(X_test, y_test, batch_size=batch_size)
train_gen = ImageDataGenerator().flow(X_train, y_train, batch_size=batch_size)

history = model.fit(train_gen, epochs=epochs, validation_data=val_gen, 
                        steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, callbacks=[lr_decay])

終わりに

ここまでVision Transformerを実装する方法について書いてきました。この記事が皆さんのお力になれば幸いです。
では、ばいにゃん〜。

参考

Vision Transformerの論文

omiitさんの記事

7
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?