3
3

More than 1 year has passed since last update.

世界最先端の画像処理モデルCoAtNETを自作してcifar10を解いてみた

Last updated at Posted at 2023-07-27

こんにちにゃんです。
水色桜(みずいろさくら)です。
今回は現在(2023年7月27日)、世界最先端の画像処理モデルであるCoAtNETを自作してみようと思います。
記事中で何か不明点・間違いなどありましたら、コメントまたはTwitterまでお寄せいただけると嬉しいです(≧▽≦)

はじめに

まず今回作成したモデルの精度を示します。

Epoch 1/25
loss: 2.0099 - accuracy: 0.2414 - val_loss: 1.7703 - val_accuracy: 0.3466
Epoch 2/25
loss: 1.6918 - accuracy: 0.3677 - val_loss: 1.5752 - val_accuracy: 0.4208
Epoch 3/25
loss: 1.5531 - accuracy: 0.4261 - val_loss: 1.4964 - val_accuracy: 0.4468
Epoch 4/25
loss: 1.4712 - accuracy: 0.4577 - val_loss: 1.4931 - val_accuracy: 0.4606
Epoch 5/25
loss: 1.4146 - accuracy: 0.4808 - val_loss: 1.3975 - val_accuracy: 0.4918
Epoch 6/25
loss: 1.3647 - accuracy: 0.5006 - val_loss: 1.3852 - val_accuracy: 0.4948
Epoch 7/25
loss: 1.3264 - accuracy: 0.5162 - val_loss: 1.3907 - val_accuracy: 0.4892
Epoch 8/25
loss: 1.2816 - accuracy: 0.5337 - val_loss: 1.3266 - val_accuracy: 0.5152
Epoch 9/25
loss: 1.2508 - accuracy: 0.5464 - val_loss: 1.2732 - val_accuracy: 0.5422
Epoch 10/25
loss: 1.2196 - accuracy: 0.5582 - val_loss: 1.2856 - val_accuracy: 0.5399
Epoch 11/25
loss: 1.1904 - accuracy: 0.5707 - val_loss: 1.2138 - val_accuracy: 0.5617
Epoch 12/25
loss: 1.1599 - accuracy: 0.5820 - val_loss: 1.2302 - val_accuracy: 0.5568
Epoch 13/25
loss: 1.1387 - accuracy: 0.5916 - val_loss: 1.2757 - val_accuracy: 0.5485
Epoch 14/25
loss: 1.1146 - accuracy: 0.5981 - val_loss: 1.1763 - val_accuracy: 0.5812
Epoch 15/25
loss: 1.0946 - accuracy: 0.6094 - val_loss: 1.1786 - val_accuracy: 0.5816
Epoch 16/25
loss: 1.0697 - accuracy: 0.6174 - val_loss: 1.1557 - val_accuracy: 0.5900
Epoch 17/25
loss: 1.0533 - accuracy: 0.6240 - val_loss: 1.1711 - val_accuracy: 0.5844
Epoch 18/25
loss: 1.0328 - accuracy: 0.6332 - val_loss: 1.1556 - val_accuracy: 0.5946
Epoch 19/25
loss: 1.0139 - accuracy: 0.6414 - val_loss: 1.1116 - val_accuracy: 0.6072
Epoch 20/25
loss: 0.9952 - accuracy: 0.6471 - val_loss: 1.1222 - val_accuracy: 0.6067
Epoch 21/25
loss: 0.9785 - accuracy: 0.6527 - val_loss: 1.0959 - val_accuracy: 0.6117
Epoch 22/25
loss: 0.9639 - accuracy: 0.6590 - val_loss: 1.1171 - val_accuracy: 0.6071
Epoch 23/25
loss: 0.9496 - accuracy: 0.6635 - val_loss: 1.1461 - val_accuracy: 0.5996
Epoch 24/25
loss: 0.9336 - accuracy: 0.6676 - val_loss: 1.1017 - val_accuracy: 0.6181
Epoch 25/25
loss: 0.9203 - accuracy: 0.6748 - val_loss: 1.0789 - val_accuracy: 0.6257

25回の学習で62%の正解率となっています。Vision Transformer(ViT)と同様に本来の能力を発揮させるためにはImagenetなどで事前学習を行う必要があるため、事前学習なしではこのくらいの精度しか出ません。
今回作成したモデルのサイズは以下のような感じです。

Model: "model"
_______________________________________________________________________
Total params: 1462538 (5.58 MB)
Trainable params: 1458442 (5.56 MB)
Non-trainable params: 4096 (16.00 KB)

では早速CoAtNETについて解説していきます。

CoAtNET

CoAtNETは畳み込みとtransformer(self-attention)を組み合わせ、いいとこ取りをしたモデルです。"CoAtNet: Marrying Convolution and Attention for All Data Sizes", Dai, Z., et al., (2021)という論文で発表されました。発表当時、ImageNetでSoTA(Top-1精度:90.88%)を達成しました。CoAtNETの構造は下の図に示す通りです。

image.png

image.png

CoAtNETの論文より引用

CoAtNETの主な特徴は以下の2つです。

  1. 前段に畳み込み層、後段にattention層を配置することで、局所的な特徴と、大域的な特徴を把握できるようにした点。
  2. DepthWiseConvolution(フィルターが1枚の畳み込み演算)を用いたMBConv-Blockを用いることで計算量を削減している点。

Attentionとは

Attentionは「深層学習モデルに入力されたデータのどの部分に注目するのかを学習し、利用する仕組み」のことです。Seq2Seqなどの従来手法では入力全体を最終的に1つの固定長ベクトルに詰め込んで表現するため、入力が長くなると内容を伝えるのが難しくなるという問題がありました。それに対して、Attentionはデコーダーにおいて、入力系列の情報を直接参照できるようにすることで、入力が長くなっても適切に内容を伝えることができるようにしました。入力全体の内容に加えて、パッチを1つ1つ出力する際に毎回、対応する入力系列のパッチを逐次的に考慮しながら変換します。また、注目度も含めて深層学習の誤差逆伝播によって学習できます。

Attentionの利点(こちらを参考にしました)

  1. 高い性能が期待できる(現在の世界最高精度クラスのモデルの多くはAttentionを用いている)
  2. 高速に学習できる(RNN は時刻tの計算が終わるまで時刻t+1の計算をできず、GPU をフルに使えません。Transformer は推論時の Decoder を除いて、すべての時刻の計算を同時に行えるため GPU をフルに使いやすいです。)
  3. 構造が単純

Attentionの構造(こちらを参考にしました)

image.png
Attentionの基本はqueryとmemory(key, value) です。Attentionとはqueryを用いてmemoryから必要な情報を選択的に抽出する仕組みです。memoryから情報を抽出する際、queryはkeyによって取得するmemoryを決定し、対応するvalueを取得します。

image.png

例えば、食べ物というQueryに対して、「『いちご』が80%、『が』が5%、『好き』が15%くらい」という風に、Keyはどこにどれくらい注目するのかを決定します。
具体的な計算はQueryとKeyの行列積をとります。
行列積をとった後、softmax関数にかけることでAttention_weightが得られます。
Attention_weightはvalueから情報を取得する際に、どこにどれだけ注目するのかを示しています。
Attention_weightとvalueの行列積を計算することで、Inputのどこにどれだけ注目すればいいかという情報を持ったoutputを得ることができます。

CoAtNETの実装

ではさっそくCoAtNETを実装していきたいと思います。
まず今回作成したコード全体を示します。

CoAtNET.py
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras import utils

# cifar10をロードし、ラベルをone-hotベクトル化する
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.reshape(50000, 32, 32, 3).astype('float32')
X_test = X_test.reshape(10000, 32, 32, 3).astype('float32')
X_train /= 255
X_test /= 255

y_train = utils.to_categorical(y_train, 10)
y_test = utils.to_categorical(y_test, 10)

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Flatten, MaxPooling2D, Conv2D, LayerNormalization, Dropout, DepthwiseConv2D, BatchNormalization
from tensorflow.keras.activations import gelu
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from official.nlp.modeling.layers import MultiHeadRelativeAttention, RelativePositionEmbedding

# 各種変数を定義する
batch_size = 10  # バッチサイズ
epochs = 25  # エポック数
num_heads = 8  # Attention Headの数
At_layer_number = 8  # Rel Attention層のレイヤー数
layer_width = 128  # Rel Attention層のレイヤーの厚さ
dim = 2048  # 相対位置エンコーディング後のサイズ
conv_layer = 16  # 畳み込み層のレイヤー数
steps_per_epoch = X_train.shape[0] // batch_size  # 1エポックあたりのステップ数
validation_steps = X_test.shape[0] // batch_size  # 1エポックあたりのステップ数(テスト)
def create_bench_model():
    inputs = Input(shape = (32, 32, 3))
    conv_hidden = Conv2D(filters=32, kernel_size=3, activation='gelu')(inputs)
    conv_hidden = Conv2D(filters=32, kernel_size=3, activation='gelu')(conv_hidden)
    conv_hidden = MaxPooling2D()(conv_hidden)
    conv_hidden = LayerNormalization()(conv_hidden)
    for i in range(conv_layer):
        conv_hidden_DW = Conv2D(filters=32, kernel_size=1,activation = 'gelu',padding='same')(conv_hidden)
        conv_hidden_DW = DepthwiseConv2D(kernel_size=3, strides=1, padding='same', activation = 'gelu')(conv_hidden_DW)
        conv_hidden_DW = Conv2D(filters=32, kernel_size=1, activation = 'gelu',padding='same')(conv_hidden_DW)
        conv_hidden_DW = LayerNormalization()(conv_hidden_DW)
        conv_hidden = conv_hidden_DW+conv_hidden
    conv_hidden = MaxPooling2D()(conv_hidden)
    for i in range(At_layer_number):
        conv_hidden = tf.reshape(conv_hidden, [batch_size, conv_hidden.shape[1]*conv_hidden.shape[2]*conv_hidden.shape[3]])
        relative_position_encoding=RelativePositionEmbedding(hidden_size=dim,min_timescale=1.0,max_timescale=10000)(conv_hidden)
        conv_hidden = tf.expand_dims(conv_hidden,0)
        layer = MultiHeadRelativeAttention(num_heads=num_heads, key_dim=3)  # MultiHeadAttention
        output_tensor = layer(conv_hidden, conv_hidden, content_attention_bias = 0.1, positional_attention_bias = 0.1, relative_position_encoding=tf.expand_dims(relative_position_encoding,0))
        output_tensor = conv_hidden+output_tensor  # 残差接続
        mlp_input = LayerNormalization()(output_tensor)
        # MLP
        mlp_hidden = Dense(layer_width)(mlp_input)
        mlp_hidden = gelu(mlp_hidden)
        mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
        mlp_hidden = Dense(layer_width)(mlp_hidden)
        mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
        mlp_hidden = mlp_hidden+Dense(layer_width)(output_tensor[0])
        conv_hidden = tf.squeeze(mlp_hidden)
        conv_hidden = tf.expand_dims(tf.expand_dims(conv_hidden,-1),-1)
    
    # MLP Head
    conv_hidden = tf.squeeze(conv_hidden)
    mlp_head_input = LayerNormalization()(conv_hidden)
    mlp_head_input = Flatten()(mlp_head_input)
    mlp_head_output = Dense(10, activation = "softmax")(mlp_head_input)
    
    return Model(inputs, mlp_head_output)

model = create_bench_model()
model.summary()
loss_fn = tf.keras.losses.CategoricalCrossentropy()
model.compile(loss=loss_fn, optimizer=Adam(lr=0.00001), metrics=["accuracy"])
val_gen = ImageDataGenerator().flow(X_test, y_test, batch_size=batch_size)
train_gen = ImageDataGenerator().flow(X_train, y_train, batch_size=batch_size)

history = model.fit(train_gen, epochs=epochs, validation_data=val_gen, 
                        steps_per_epoch=steps_per_epoch, validation_steps=validation_steps)

ここからはコードの解説をしていきます。
まずcifar10のロードと、各種変数の定義を行います。

preparation.py
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras import utils

# cifar10をロードし、ラベルをone-hotベクトル化する
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.reshape(50000, 32, 32, 3).astype('float32')
X_test = X_test.reshape(10000, 32, 32, 3).astype('float32')
X_train /= 255
X_test /= 255

y_train = utils.to_categorical(y_train, 10)
y_test = utils.to_categorical(y_test, 10)

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Flatten, MaxPooling2D, Conv2D,  LayerNormalization, Dropout, DepthwiseConv2D, BatchNormalization
from tensorflow.keras.activations import gelu
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from official.nlp.modeling.layers import MultiHeadRelativeAttention, RelativePositionEmbedding

# 各種変数を定義する
batch_size = 10  # バッチサイズ
epochs = 25  # エポック数
num_heads = 8  # Attention Headの数
At_layer_number = 8  # Rel Attention層のレイヤー数
layer_width = 128  # Rel Attention層のレイヤーの厚さ
dim = 2048  # 相対位置エンコーディング後のサイズ
conv_layer = 16  # 畳み込み層のレイヤー数
steps_per_epoch = X_train.shape[0] // batch_size  # 1エポックあたりのステップ数
validation_steps = X_test.shape[0] // batch_size  # 1エポックあたりのステップ数(テスト)

つぎにモデルを作成していきます。

image.png

まずS0:Stem stageを実装します。

Stem_stage.py
def create_bench_model():
    inputs = Input(shape = (32, 32, 3))
    conv_hidden = Conv2D(filters=32, kernel_size=3, activation='gelu')(inputs)
    conv_hidden = Conv2D(filters=32, kernel_size=3, activation='gelu')(conv_hidden)
    conv_hidden = MaxPooling2D()(conv_hidden)
    conv_hidden = LayerNormalization()(conv_hidden)

つぎにMBConv-BlockであるS1とS2を実装します。

MBConv.py
for i in range(conv_layer):
        conv_hidden_DW = Conv2D(filters=32, kernel_size=1,activation = 'gelu',padding='same')(conv_hidden)
        conv_hidden_DW = DepthwiseConv2D(kernel_size=3, strides=1, padding='same', activation = 'gelu')(conv_hidden_DW)
        conv_hidden_DW = Conv2D(filters=32, kernel_size=1, activation = 'gelu',padding='same')(conv_hidden_DW)
        conv_hidden_DW = LayerNormalization()(conv_hidden_DW)
        conv_hidden = conv_hidden_DW+conv_hidden
    conv_hidden = MaxPooling2D()(conv_hidden)

つぎに、MultiHeadRelativeAttentionを用いて、S3、S4層を実装します。

attention.py
 for i in range(At_layer_number):
        conv_hidden = tf.reshape(conv_hidden, [batch_size, conv_hidden.shape[1]*conv_hidden.shape[2]*conv_hidden.shape[3]])
        relative_position_encoding=RelativePositionEmbedding(hidden_size=dim,min_timescale=1.0,max_timescale=10000)(conv_hidden)
        conv_hidden = tf.expand_dims(conv_hidden,0)
        layer = MultiHeadRelativeAttention(num_heads=num_heads, key_dim=3)  # MultiHeadAttention
        output_tensor = layer(conv_hidden, conv_hidden, content_attention_bias = 0.1, positional_attention_bias = 0.1, relative_position_encoding=tf.expand_dims(relative_position_encoding,0))
        output_tensor = conv_hidden+output_tensor  # 残差接続
        mlp_input = LayerNormalization()(output_tensor)
        # MLP
        mlp_hidden = Dense(layer_width)(mlp_input)
        mlp_hidden = gelu(mlp_hidden)
        mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
        mlp_hidden = Dense(layer_width)(mlp_hidden)
        mlp_hidden = Dropout(0.25)(BatchNormalization()(mlp_hidden))
        mlp_hidden = mlp_hidden+Dense(layer_width)(output_tensor[0])
        conv_hidden = tf.squeeze(mlp_hidden)
        conv_hidden = tf.expand_dims(tf.expand_dims(conv_hidden,-1),-1)

最後にFC層を実装します。

FC.py
# MLP Head
    conv_hidden = tf.squeeze(conv_hidden)
    mlp_head_input = LayerNormalization()(conv_hidden)
    mlp_head_input = Flatten()(mlp_head_input)
    mlp_head_output = Dense(10, activation = "softmax")(mlp_head_input)
    
    return Model(inputs, mlp_head_output)

終わりに

今回はCoAtNETを自作して、cifar10を解いてみました。
ぜひ参考にしてもらえると嬉しいです。
では、ばいにゃん~。

参考

Omiitaさんの記事

CoAtNETの論文

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3