14
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

論文の勉強17 「Vision Transformer(ViT)」

Last updated at Posted at 2023-03-25

背景などは説明はしません。
以下の論文を読んでいきます。
途中GoogleColabのユニット(?)を使い切ってしまい。動くことは確認していますが。
一部出力結果がありません。

自然言語処理で事実上の標準技術となっているtransformerを画像に適用したものです。
大規模データセットで学習させることでCNNベースの精度を超えることが可能となりました。
実装自体は通常のtransformerと変わらないので、比較的簡単ではないかと思います。
そのためか実装したという記事も多数あります。こちらは解説というより勉強のメモ書きなので、学習される方は他の記事を参考にしてください。

Method

Vision Transformer(ViT)

概要を図に示す。

image.png

標準的なTransformerはトークンembeddingの1次元の系列を入力として受け取ります。
2次元の画像を扱うために、画像を$\boldsymbol{x}\in R^{H×W×C}$から2次元のパッチの系列$\boldsymbol{x}_p\in R^{N×(P^2・C)}$に変形します。ここで、$(H,W)$は元の画像の解像度、$C$はチャネル数、$(P,P)$は各パッチの解像度です。そして$N=HW/P^2$はパッチの数であり、Transformerへの入力の系列長となります。Transformerは」$D$次元の潜在ベクトルを使用するため、パッチを平坦化して学習可能な線形返還により$D$次元に変換します。この出力をパッチembeddingと呼びます。
BERTのclassトークンのように、embeddingされたパッチの系列に学習可能なembeddingを追加します($\boldsymbol{z}_0^0=\boldsymbol{x}_c$)。$\boldsymbol{z}_0^0=\boldsymbol{x}_c$のTransformerからの出力($\boldsymbol{z}_0^0=\boldsymbol{z}_L^0$)は画像の表現$\boldsymbol{y}$となります。事前学習とファインチューニングの両方で、$\boldsymbol{z}_L^0$に分類ヘッドが付けられます。分類ヘッドは事前学習の時は1つの隠れ層を持つMLPであり、ファインチューニングのときは単一の線形変換の層となります。
位置埋め込みは、位置情報を保持するためにパッチ埋め込みに追加されます。ここでは、学習可能な 1D 位置埋め込みを使用します。
Transformer encoderはmultihead self-attention(MSA)ブロックとMLPブロックの交互レイヤで構成されます。Layernorm(LM)が各ブロックの前に、残差結合が各ブロックの後に適用されます。MLPは2層構造で活性化関数としてGELUを使用します。
式で表すと次のようになります。

\boldsymbol{z}_0=[\boldsymbol{x}_{c};\boldsymbol{x}_p^1\boldsymbol{E};\boldsymbol{x}_p^2\boldsymbol{E}; \cdots;\boldsymbol{x}_p^N\boldsymbol{E}]+\boldsymbol{E}_{pos},\ \boldsymbol{E}\in\mathbb{R}^{(P^2\cdot C)D,\ \boldsymbol{E}_{pos}\in\mathbb{R}^{(N+1)D}}\\
\boldsymbol{z}'_l=MSA(LN(\boldsymbol{z}_{l-1}))+\boldsymbol{z}_{l-1},\ l=1,\cdots,L\\
\boldsymbol{z}_l=MLP(LN(\boldsymbol{z}'_l))+\boldsymbol{z}'_l\\
\boldsymbol{y}=LN(\boldsymbol{z}_L^0)
Hybrid Architecture

生のパッチ画像の代わりに、CNNの特徴マップから入力系列を構成することができます。hybridモデルでは、パッチembeddingの射影$\boldsymbol{E}$はCNNの特徴マップから抽出されたパッチに適用されます。特別な場合として、パッチは1×1のサイズをとることができます。これは、特徴マップが平坦化されTarnsformerの次元に射影されることを意味します。

FINE-TUNING AND HIGHER RESOLUTION

通常、大規模なデータセットでViTを事前トレーニングし、より小さなタスクに合わせてファインチューニングします。
このために、事前トレーニング済みの予測ヘッドを削除し、ゼロで初期化された$D ×K$のフィードフォワード層を追加します。ここで、$K$はクラス数です。事前学習より高解像度にすることで有益になることがあります。高解像度の画像を扱う場合、パッチサイズを同じにするためには系列長が長くなります。ViTは任意の系列長を処理できますが、事前に学習されたpositional embeddingは意味をなさない可能性があります。そのため、基の画像の位置に合わせるため、事前に学習されたporsitional embeddingを2次元補完します。

Model

image.png

ViTの構造はBERTの構造をもとにしています。
BaseとLargeはBERTから直接採用され、ここではさらに大きなHugeを追加しています。
モデルを表す表記として、たとえばViT-L/16は、16×16のパッチサイズのLargeモデルを表します。
系列長はパッチサイズの2乗に反比例するため、パッチサイズが小さいモデルは計算コストが高くなります。

Training

最適化の手法としてはAdamを使用して、パラメータは$\beta_10.9,\beta_2=0.999$、weight decayは0.1とします。

実装(Keras)

データはKaggleのDogvsCatの一部を使用します。
kerasでの実装を行います。
MultiHeadAttentionなどは実装されたものが提供されていますがここでは使用しません。

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Layer, Input, Dense, Conv1D, Conv2D, Activation, Dropout, LayerNormalization, Reshape, RepeatVector, Concatenate
from tensorflow.keras import activations
from tensorflow.keras.optimizers import RMSprop, Adagrad, Adam
from tensorflow.keras.callbacks import LearningRateScheduler, ReduceLROnPlateau, ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator
 
from keras import backend as K

import numpy as np
import pandas as pd
import math
import time

import matplotlib.pyplot as plt
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import cv2
class patch_embbeding(Layer):
    """
    patch embeddingレイヤ
    
    画像をパッチに分割し線形変換を行いtransformerへの入力とする
    [cls]トークンをパッチの系列の先頭に追加し、これがencodingさたものを全結合層の入力とする
    学習可能な重みでpositional encoding
    """
    def __init__(self, img_size, patch_size=4, hidden_dim=8):
        """
        img_size : 画像のサイズ(height, width)
        patch_size : パッチサイズ
        hidden_dim : embedding後の次元
        """
        super().__init__()
        self.D = hidden_dim
        
        # 畳み込み層のkernelとstrideをpatch_sizeとすることで分割と線形変換を同時に行う
        self.patch_conv = Conv2D(filters = hidden_dim, kernel_size = (patch_size,patch_size), strides=patch_size, padding = 'same')
        
        # [class](クラストークン)追加
        # callの中でbatch_sizeに拡張
        self.cls_token = self.add_weight(
            shape=(1,1,hidden_dim), initializer="random_normal", trainable=True, name='class', dtype=tf.float32
        )

        # position encoding
        # クラストークンの分も入れたshapeを指定
        N = int(img_size[0]*img_size[1]/(patch_size*patch_size))+1 # パッチの数+1(クラストークン)
        self.position = self.add_weight(
            shape=(N,self.D), initializer="random_normal", trainable=True, name='position', dtype=tf.float32
        )
        
    def call(self, inputs):
        """
        inputs: [batch_size,height,width,channel]
        """
        # batch_sizeの取得
        batch_size = tf.shape(inputs)[0]
        
        # パッチへの分割→線形変換
        # 畳み込み層のkernelとstrideをpatch_sizeとすることで同時に行う
        # [batch_size,height,width,channel]→[batch_size,height/patch_size,width/patch_size,hidden_dim]
        out = self.patch_conv(inputs)
        
        # [batch_size,height/patch_size,width/patch_size,hidden_dim]→[batch_size,height*width/patch_size^2,hidden_dim]
        out = Reshape((-1,self.D))(out)
        
        # クラストークンの拡張
        # [1,1,hidden_dim]→[batch_size,1,hidden_dim]
        cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, self.D])
        
        # クラストークンをパッチ系列へ追加
        # [batch_size,height*width/patch_size^2,hidden_dim] + [batch_size,1,hidden_dim]
        #      → [batch_size,height*width/patch_size^2+1,hidden_dim]
        out = Concatenate(axis=1)([cls_token, out])
        
        # positional encoding
        out = out+self.position 
        return out

class SelfMultiHeadAttention(Layer):
    '''
    Multi-Head Attentionレイヤ

    model = MultiheadAttention(
        hidden_dim = 512,
        head_num = 8,
        drop_rate = 0.5
    )
    '''
    def __init__(self, hidden_dim, heads_num, drop_rate=0.5):
        '''
        Multi-Head Attentionレイヤ
    
        hidden_dim : Embeddingされた単語ベクトルの長さ
        heads_num : マルチヘッドAttentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        '''

        super(SelfMultiHeadAttention, self).__init__()
        # 入力の線形変換
        # 重み行列は[hidden_dim, hidden_dim]
        self.query = Conv1D(hidden_dim, kernel_size=1)
        self.key   = Conv1D(hidden_dim, kernel_size=1)
        self.value = Conv1D(hidden_dim, kernel_size=1)
        
        # 出力の線形変換
        self.projection = Conv1D(hidden_dim, kernel_size=1)
        
        # 出力のDropout
        self.drop = Dropout(drop_rate)
        
        self.nf = hidden_dim
        self.nh = heads_num
    
    def atten(self, query, key, value, training):
        """
        Attention
        
        query, key, value : クエリ、キー、バリュー
            query [batch_size, head_num, q_length, hidden_dim//head_num]
            key, value [batch_size, head_num, m_length, hidden_dim//head_num]
            ただし、encoder:q_length=m_length
        """
        # 各値を取得
        shape = query.shape.as_list() # batch_size, head_num, q_length, hidden_dim//head_num
        batch_size = -1 if shape[0] is None else shape[0]
        token_num = shape[2] # トークン列数(q_length)
        hidden_dim = shape[1]*shape[3] # 特徴ベクトルの長さ(head_num × hidden_dim//head_num = hidden_dim)
        
        # ここで q と k の内積を取ることで、query と key の単語間の関連度のようなものを計算します。
        # tf.matmulで最後の2成分について積を計算(それ以外は形がそろっている必要あり)
        # transpose_bで転置
        # [batch_size, head_num, q_length, hidden_dim/head_num] @ [batch_size, head_num, hidden_dim/head_num, m_length] = [batch_size, head_num, q_length, m_length]
        scores = tf.matmul(query, key, transpose_b=True)
        
        # scoreをhidden_dimの平方根割る
        scores = tf.multiply(scores, tf.math.rsqrt(tf.cast(hidden_dim, tf.float32)))
        
        # softmax を取ることで正規化します
        # input(query) の各単語に対して memory(key) の各単語のどこから情報を引いてくるかの重み
        atten_weight = tf.nn.softmax(scores, axis=-1)
        
        # 重みに従って value から情報を引いてきます
        # [batch_size, head_num, q_length, m_length] @ [batch_size, head_num, m_length, hidden_dim/head_num] = [batch_size, head_num, q_length, hidden_dim/head_num]
        # input(query) の単語ごとに memory(value)の各単語 に attention_weight を掛け合わせて足し合わせた ベクトル(分散表現の重み付き和)を計算
        context = tf.matmul(atten_weight, value)
        
        # 各ヘッドの結合(reshape)
        # 入力と同じ形に変換する
        # [batch_size, head_num, q_length, hidden_dim/head_num] -> [batch_size, q_length, head_num, hidden_dim/head_num]
        context = tf.transpose(context, [0, 2, 1, 3])
        # [batch_size, q_length, head_num, hidden_dim/head_num] -> [batch_size, q_length, hidden_dim]
        context = tf.reshape(context, (batch_size, token_num, hidden_dim))
        
        # 線形変換
        context = self.projection(context)
        
        return self.drop(context, training=training), atten_weight

    def _split(self, x):
        """
        query, key, valueを分割する
        
        入力 shape: [batch_size, length, hidden_dim]
        出力 shape: [batch_size, head_num, length, hidden_dim//head_num]
        """
        # 各値を取得
        hidden_dim = self.nf
        heads_num = self.nh
        shape = x.shape.as_list()
        batch_size = -1 if shape[0] is None else shape[0]
        token_num = shape[1] # トークン列数
        
        # [batch_size, (q|m)_length, hidden_dim] -> [batch_size, (q|m)_length, head_num, hidden_dim/head_num]
        # splitだが実際は次元を拡張する処理
        x = tf.reshape(x, (batch_size, token_num, heads_num, int(hidden_dim/heads_num)))
        
        # [batch_size, (q|m)_length, head_num, hidden_dim/head_num] -> [batch_size, head_num, (q|m)_length, hidden_dim/head_num]
        x = tf.transpose(x, [0, 2, 1, 3])
        return x
    
    def call(self, x, training, memory=None, return_attention_scores=False):
        """
        モデルの実行
        
        input : 入力(query) [batch_size, length, hidden_dim]
        memory : 入力(key, value) [batch_size, length, hidden_dim]
         ※memory(key, value)についてはqueryのtoken_numと異なる場合がある
        return_attention_scores : attention weightを出力するか
        """
        # memoryが入力されない場合、memory=input(Self Attention)とする
        if memory is None:
            memory = x

        # input -> query
        # memory -> key, value
        # [batch_size, (q|m)_length, hidden_dim] @ [hidden_dim, hidden_dim] -> [batch_size, (q|m)_length, hidden_dim] 
        query = self.query(x)
        key = self.key(memory)
        value = self.value(memory)
        
        # ヘッド数に分割する
        # 実際はreshapeで次数を1つ増やす
        # [batch_size, (q|m)_length, hidden_dim] -> [batch_size, head_num, (q|m)_length, hidden_dim/head_num]
        query = self._split(query)
        key = self._split(key)
        value = self._split(value)
        
        # attention
        # 入力と同じ形の出力
        # context: [batch_size, q_length, hidden_dim]
        context, attn_weights = self.atten(query, key, value, training)
        if not return_attention_scores:
            return context
        else:
            return context, attn_weights

class FeedForwardNetwork(Layer):
    '''
    Position-wise Feedforward Neural Network
    transformer blockで使用される全結合層
    '''
    def __init__(self, hidden_dim, drop_rate):
        '''
        hidden_dim : Embeddingされた単語ベクトルの長さ
        drop_rate : 出力のDropout率
        '''
        super().__init__()
        # 2層構造
        # 1層目:チャンネル数を増加させる
        self.filter_dense_layer = Dense(hidden_dim * 4, use_bias=True, activation='gelu')
        
        # 2層目:元のチャンネル数に戻す
        self.output_dense_layer = Dense(hidden_dim, use_bias=True)
        self.drop = Dropout(drop_rate)

    def call(self, x, training):
        '''
        入力と出力で形が変わらない
        x : 入力 [batch_size, length, hidden_dim]
        '''
        
        # [batch_size, (q|m)_length, hidden_dim] -> [batch_size, (q|m)_length, 4*hidden_dim]
        x = self.filter_dense_layer(x)
        x = self.drop(x, training=training)
        
        # [batch_size, (q|m)_length, 4*hidden_dim] -> [batch_size, (q|m)_length, hidden_dim]
        return self.output_dense_layer(x)

class ResidualNormalizationWrapper(Layer):
    '''
    残差接続
    output: input + SubLayer(input)
    '''
    def __init__(self, layer, drop_rate):
        '''
        layer : 残渣接続したいレイヤ(MultiHeadAttentionかFeedForwardNetwork)に適用
        drop_rate : 出力のDropout率
        '''
        super().__init__()
        self.layer = layer # SubLayer : ここではAttentionかFFN
        self.layer_normalization = LayerNormalization()
        self.drop = Dropout(drop_rate)

    def call(self, x, training, memory=None, return_attention_scores=None):
        """
        モデルの実行
        
        memory : 入力(key, value) [batch_size, length, hidden_dim]
         ※memory(key, value)についてはqueryのlengthと異なる場合がある
        return_attention_scores : attention weightを出力するか

        AttentionもFFNも入力と出力で形が変わらない
        output : [batch_size, length, hidden_dim]
        """
        
        params = {}
        if memory is not None:
            params['memory'] = memory
        if return_attention_scores:
            params['return_attention_scores'] = return_attention_scores
        
        out = self.layer_normalization(x)
        if return_attention_scores:
            # attention weightを返す
            out, attn_weights = self.layer(out, training, **params)
            out = self.drop(out, training=training)
            return x + out, attn_weights
        else:
            # attention weightを返さない
            out = self.layer(out, training, **params)
            out = self.drop(out, training=training)
            return x + out

class EncoderLayer(Layer):
    """
    Encoderレイヤ
     MultiHeadAttentionとFeedForwardNetworkの組み合わせ
      それぞれ残差接続されている
    """
    def __init__(self, hidden_dim, heads_num, drop_rate=0.2):
        """
        hidden_dim : Embeddingされた単語ベクトルの長さ
        heads_num : Multi-head Attentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        """
        super().__init__()
        # Multi-head attention
        self.atten = ResidualNormalizationWrapper(
            layer = SelfMultiHeadAttention(hidden_dim = hidden_dim,
                                           heads_num = heads_num,
                                           drop_rate = drop_rate),
            drop_rate = drop_rate)
        
        # Feed Forward Network
        self.ffn = ResidualNormalizationWrapper(
            layer = FeedForwardNetwork(hidden_dim = hidden_dim,
                                       drop_rate = drop_rate),
            drop_rate = drop_rate)
    
    def call(self, input, training, memory=None, return_attention_scores=False):
        """
        x : 入力(query) [batch_size, length, hidden_dim]
        memory : 入力(key, value) [batch_size, length, hidden_dim]
         ※memory(key, value)についてはqueryのtoken_numと異なる場合がある
        return_attention_scores : attention weightを出力するか

        AttentionもFFNも入力と出力で形が変わらない
        output : [batch_size, length, hidden_dim]
        
        入力と出力で形式が変わらない
        output : [batch_size, length, hidden_dim]
        """
        if return_attention_scores:
            x, attn_weights = self.atten(input,training, memory, return_attention_scores)
            x = self.ffn(x)
            return x, attn_weights
        else:
            x = self.atten(input, training, memory, return_attention_scores)
            x = self.ffn(x)
            return x

class Encoder(Layer):
    '''
    TransformerのEncoder
    '''
    def __init__(self, img_size, patch_size, hopping_num, heads_num, hidden_dim, drop_rate):
        '''
        img_size : 画像のサイズ
        patch_size : 画像を分割するサイズ
        hopping_num : Multi-head Attentionの繰り返し数
        hidden_dim : Embeddingされた特徴ベクトルの長さ
        heads_num : Multi-head Attentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        '''
        super().__init__()
        self.hopping_num = hopping_num
        
        # patch Embedding
        self.patch_embbeding = patch_embbeding(img_size=img_size, patch_size=patch_size, hidden_dim=hidden_dim)
        self.input_dropout_layer = Dropout(drop_rate)

        # Multi-head Attentionの繰り返し(hopping)のリスト
        self.attention_block_list = [EncoderLayer(hidden_dim, heads_num) for _ in range(hopping_num)]
        self.output_normalization = LayerNormalization()

    def call(self, input, training, return_attention_scores=False):
        '''
        input: 入力 [batch_size,height,width,channel]
        return_attention_scores : attention weightを出力するか
        出力 [batch_size, q_length, hidden_dim]
        '''
        # patch Embedding
        # [batch_size,height,width,channel] → [batch_size, q_length, hidden_dim]
        embedded_input = self.patch_embbeding(input)
        query = self.input_dropout_layer(embedded_input, training=training)

        # Encoderレイヤを繰り返し適用
        if return_attention_scores:
            for i in range(self.hopping_num):
                query, atten_weights = self.attention_block_list[i](query, training, query, return_attention_scores)

            # [batch_size, q_length, hidden_dim]
            return self.output_normalization(query), atten_weights
        else:
            for i in range(self.hopping_num):
                query = self.attention_block_list[i](query, training, query, return_attention_scores)
            # [batch_size, q_length, hidden_dim]
            return self.output_normalization(query)

class VisionTransformer(Model):
    """
    Vision Transformer
    
    """
    def __init__(self, img_size, patch_size, hopping_num, heads_num, hidden_dim, drop_rate):
        '''
        patch_size : 画像を分割するサイズ
        hopping_num : Multi-head Attentionの繰り返し数
        hidden_dim : Embeddingされた特徴ベクトルの長さ
        heads_num : Multi-head Attentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        '''
        super().__init__()
        self.encoder = Encoder(img_size, patch_size, hopping_num, heads_num, hidden_dim, drop_rate)
        
        # 全結合層
        self.fc1 = Dense(16, activation='tanh')
        self.dropout1 = Dropout(drop_rate)
   
        self.final_layer = Dense(1, activation='sigmoid')

    def call(self, inputs, return_attention_scores=False, training=False):
        '''
        inputs: 入力(encoder, decoder)
        return_attention_scores : attention weightを出力するか
        '''

        # enc_input : [batch_size,height,width,channel]
        if return_attention_scores:
            enc_output, enc_atten_weights = self.encoder(inputs, training, return_attention_scores=return_attention_scores)
        else:
            enc_output = self.encoder(inputs, training, return_attention_scores=return_attention_scores)
         
        # クラストークン部分のみ使用
        # [batch_size, enc_length, hidden_dim] -> [batch_size, hidden_dim]
        class_output = enc_output[:,0,:]
        
        fc_output = self.fc1(class_output)
        fc_output = self.dropout1(fc_output)
        final_output = self.final_layer(fc_output)

        if return_attention_scores:
            return final_output, enc_atten_weights
        else:
            return final_output

ViT-B/16を定義します。
パラメータ数は論文通りおよそ86Mとなりました。

model = VisionTransformer(img_size=(224,224), patch_size=16, hopping_num=12, heads_num=12, hidden_dim=768, drop_rate=0.1)
model.build((None, 224, 224, 3))  # build with input shape.
dummy_input = Input(shape=(224, 224, 3))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=model.call(dummy_input))
model_summary.summary()
Model: "model_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 encoder_4 (Encoder)         (None, 197, 768)          85798656  
                                                                 
 tf.__operators__.getitem_4   (None, 768)              0         
 (SlicingOpLambda)                                               
                                                                 
 dense_104 (Dense)           (None, 16)                12304     
                                                                 
 dropout_201 (Dropout)       (None, 16)                0         
                                                                 
 dense_105 (Dense)           (None, 1)                 17        
                                                                 
=================================================================
Total params: 85,810,977
Trainable params: 85,810,977
Non-trainable params: 0
_________________________________________________________________

次の画像(kaggleのdog vs cat)を学習させます。多クラス分類でうまくいかなかったため2値分類としました。
モデルについてはかなり小さくしたものを使用します。(こういったことができるのは自分で実装するメリットかと思います。)

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

batch_size=25
img_height = 224
img_width = 224

epochs = 120
initial_lrate = 0.001

classes = ['Dog', 'Cat']

train_datagen = ImageDataGenerator(rescale=1./255,
                                   horizontal_flip=True,
                                   channel_shift_range=True,
                                   zoom_range=[0.5, 2.0],
                                   rotation_range=10,
                                   height_shift_range = 0.1,
                                   width_shift_range = 0.1,
                                   fill_mode='nearest')

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        '/content/drive/MyDrive/dogcat_mini/train',
        target_size=(img_height, img_width),
        color_mode = 'rgb', #グレー:'grayscale'
        batch_size=batch_size,
        classes = classes, 
        class_mode='binary',#2つ'binary' 3つ以上:'categorical'
        save_format='jpeg'
)

validation_generator = test_datagen.flow_from_directory(
        '/content/drive/MyDrive/dogcat_mini/val',
        target_size=(img_height, img_width),
        batch_size=batch_size,
        classes = classes, 
        class_mode='binary',
        save_format='jpeg')

# ベストのモデルのみ保存
modelCheckpoint = ModelCheckpoint(filepath = "/content/drive/MyDrive/dogcat_mini/",
                                  monitor='val_accuracy',
                                  verbose=1,
                                  save_best_only=True,
                                  save_weights_only=False, # 重みのみ保存
                                  mode='max', # val_accuracyの場合
                                  period=1)

def decay(epoch, steps=50):
    initial_lrate = 0.001
    drop = 0.1
    epochs_drop = 50
    lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
    if lrate <= 1e-6:
        lrate=1e-6
    return lrate

adam = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, weight_decay=0.1, amsgrad=True)

lr_sc = LearningRateScheduler(decay, verbose=1)

model = VisionTransformer(img_size=(224,224), patch_size=4, hopping_num=4, heads_num=2, hidden_dim=32, drop_rate=0.1)
model.compile(loss=['binary_crossentropy'], optimizer=adam,  metrics=['accuracy'])

学習の実行

history=model.fit(train_generator,
                  epochs=epochs,
                  validation_data=validation_generator,
                  callbacks = [modelCheckpoint, lr_sc]
                  )

image.png

validationデータでaccuracyは0.7程度でした。google colab proを利用しましたが、1epoch1分半くらいでした。
結果の可視化を行います。

from tensorflow.keras.models import load_model

model = VisionTransformer(img_size=(224,224), patch_size=8, hopping_num=4, heads_num=2, hidden_dim=32, drop_rate=0.1)
model.build((None, 224, 224, 3))  # build with input shape.
model.load_weights("/content/drive/MyDrive/dogcat_mini/model1/model.hdf5")
import matplotlib.pyplot as plt
import cv2

imgs = next(validation_generator)
result = model.call(inputs=imgs[0], return_attention_scores=True)

fig, ax = plt.subplots(1, 2, figsize=(10,4))

ax[0].imshow(imgs[0][6])
ax[1].imshow(cv2.resize(result[1][6,:,1:,0].numpy().mean(axis=0).reshape((int(224/8),int(224/8))), (224, 224)))

image.png

実際には学習済みモデルを使用することになると思います。kerasで言えば、tensorflowhubかvit_kerasというライブラリなどからの利用となります。

import tensorflow_hub as hub

model2 = tf.keras.Sequential([
                             hub.KerasLayer("https://tfhub.dev/sayakpaul/vit_b16_fe/1", trainable=False),
                             Dense(16, activation='tanh'),
                             Dropout(0.5),
                             Dense(2, activation='sigmoid')
                             ])
model2.build((None, 224, 224, 3))  # build with input shape.
dummy_input = Input(shape=(224, 224, 3))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=model2.call(dummy_input))
model_summary.summary()
Model: "model_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_7 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 keras_layer_2 (KerasLayer)  (None, 768)               85798656  
                                                                 
 dense_106 (Dense)           (None, 16)                12304     
                                                                 
 dropout_202 (Dropout)       (None, 16)                0         
                                                                 
 dense_107 (Dense)           (None, 2)                 34        
                                                                 
=================================================================
Total params: 85,810,994
Trainable params: 85,810,994
Non-trainable params: 0
_________________________________________________________________
model2.compile(loss=['binary_crossentropy'], optimizer=sgd,  metrics=['accuracy'])

history2=model2.fit(train_generator,
                    epochs=epochs,
                    validation_data=validation_generator,
                    callbacks = [modelCheckpoint, lr_sc]
                   )

image.png

20epochですが0.99近くとなりました。

# pip install vit-keras
from vit_keras import vit

image_size = 224
model3 = vit.vit_b16(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=False,
    classes=2
)

# 学習させない
for layer in base_model.layers[:6]:
    layer.trainable = False

model3.compile(loss=['binary_crossentropy'], optimizer=sgd,  metrics=['accuracy'])
model3.build((None, 224, 224, 3))  # build with input shape.
dummy_input = Input(shape=(224, 224, 3))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=model3.call(dummy_input))
model_summary.summary()
Model: "model_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================    
 
略                                                                

=================================================================
Total params: 85,800,194
Trainable params: 85,800,194
Non-trainable params: 0
_________________________________________________________________
history3=model3.fit(train_generator,
                  epochs=epochs,
                  validation_data=validation_generator,
                  callbacks = [modelCheckpoint, lr_sc]
                  )

image.png

こちらも同様に少ないepoch数で精度を高くだせました。
可視化の関数もありますがうまくだせませんでした。調査中。

from vit_keras import vit, utils, visualize

imgs = next(validation_generator)

attention_map = visualize.attention_map(model=model3, image=imgs[0][6])

# Plot results
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(imgs[0][5])
_ = ax2.imshow(attention_map)

実装(pytorch)

!pip install pytorch_lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimizers
import pytorch_lightning as pl
from torchmetrics import Accuracy as accuracy
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

import numpy as np
import math
class patch_embbeding(nn.Module):
    """
    patch embeddingレイヤ
    
    画像をパッチに分割し線形変換を行いtransformerへの入力とする
    [cls]トークンをパッチの系列の先頭に追加し、これがencodingさたものを全結合層の入力とする
    学習可能な重みでpositional encoding
    """
    def __init__(self, img_size, patch_size=4, hidden_dim=8):
        """
        img_size : 画像のサイズ(height, width)
        patch_size : パッチサイズ
        hidden_dim : embedding後の次元
        """
        super().__init__()
        self.D = hidden_dim
        
        # 畳み込み層のkernelとstrideをpatch_sizeとすることで分割と線形変換を同時に行う
        self.patch_conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size = patch_size, stride=patch_size)
        
        # [class](クラストークン)追加
        # callの中でbatch_sizeに拡張
        self.cls_token = nn.Parameter(torch.randn(1,1,hidden_dim))

        # position encoding
        # クラストークンの分も入れたshapeを指定
        N = int(img_size[0]*img_size[1]/(patch_size*patch_size))+1 # パッチの数+1(クラストークン)
        self.position = nn.Parameter(torch.randn(1,N,hidden_dim))
        
    def forward(self, inputs):
        """
        inputs: [batch_size,channel,height,width]
        """
        # batch_sizeの取得
        batch_size = inputs.size()[0]
        
        # パッチへの分割→線形変換
        # 畳み込み層のkernelとstrideをpatch_sizeとすることで同時に行う
        # [batch_size,channel,height,width]→[batch_size,hidden_dim,height/patch_size,width/patch_size]
        out = self.patch_conv(inputs)
        
        # [batch_size,hidden_dim,height/patch_size,width/patch_size]
        #      →[batch_size,hidden_dim,height*width/patch_size^2]
        #           →[batch_size,height*width/patch_size^2,hidden_dim]
        out = out.flatten(2).transpose(1,2)
        
        # クラストークンの拡張
        # [1,1,hidden_dim]→[batch_size,1,hidden_dim]
        cls_token = self.cls_token.repeat(repeats=(batch_size,1,1))
        
        # クラストークンをパッチ系列へ追加
        # [batch_size,height*width/patch_size^2,hidden_dim] + [batch_size,1,hidden_dim]
        #      → [batch_size,height*width/patch_size^2+1,hidden_dim]
        out = torch.cat([cls_token, out], dim=1)
        
        # positional encoding
        out = out+self.position 
        return out


class SelfMultiHeadAttention(nn.Module):
    '''
    Multi-Head Attentionレイヤ

    model = MultiheadAttention(
        hidden_dim = 512,
        head_num = 8,
        drop_rate = 0.5
    )
    '''
    def __init__(self, hidden_dim, heads_num, drop_rate=0.5):
        '''
        Multi-Head Attentionレイヤ
    
        hidden_dim : Embeddingされた単語ベクトルの長さ
        heads_num : マルチヘッドAttentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        '''
        super(SelfMultiHeadAttention, self).__init__()
        # 入力の線形変換
        # 重み行列は[hidden_dim, hidden_dim]
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key   = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        
        # 出力の線形変換
        self.projection = nn.Linear(hidden_dim, hidden_dim)
        
        # 出力のDropout
        self.drop = nn.Dropout(drop_rate)
        
        self.nf = hidden_dim
        self.nh = heads_num
    
    def atten(self, query, key, value):
        """
        Attention
        
        query, key, value : クエリ、キー、バリュー
            query [batch_size, head_num, q_length, hidden_dim//head_num]
            key, value [batch_size, head_num, m_length, hidden_dim//head_num]
            ただし、encoder:q_length=m_length
        """
        # 各値を取得
        shape = query.shape
        batch_size = -1 if shape[0] is None else shape[0]
        token_num = shape[2] # 系列長
        hidden_dim = shape[1]*shape[3] # 入力チャンネル数
        
        # ここで q と k の内積を取ることで、query と key の単語間の関連度のようなものを計算します。
        # tf.matmulで最後の2成分について積を計算(それ以外は形がそろっている必要あり)
        # transpose_bで転置
        # [batch_size, head_num, q_length, hidden_dim/head_num] @ [batch_size, head_num, hidden_dim/head_num, m_length] = [batch_size, head_num, q_length, m_length]
        scores = torch.matmul(query, key.transpose(-2, -1))
        
        # scoreをhidden_dimの平方根割る
        scores = scores / math.sqrt(hidden_dim)

        # softmax を取ることで正規化します
        # input(query) の各単語に対して memory(key) の各単語のどこから情報を引いてくるかの重み
        atten_weight = F.softmax(scores, dim = -1)
        #atten_weight = scores / torch.sum(scores, dim=-1, keepdim=True)
        
        # 重みに従って value から情報を引いてきます
        # [batch_size, head_num, q_length, m_length] @ [batch_size, head_num, m_length, hidden_dim/head_num] = [batch_size, head_num, q_length, hidden_dim/head_num]
        # input(query) の単語ごとに memory(value)の各単語 に attention_weight を掛け合わせて足し合わせた ベクトル(分散表現の重み付き和)を計算
        context = torch.matmul(atten_weight, value)
        
        # 各ヘッドの結合(reshape)
        # 入力と同じ形に変換する
        # [batch_size, head_num, q_length, hidden_dim/head_num] -> [batch_size, q_length, head_num, hidden_dim/head_num]
        context = context.transpose(1, 2).contiguous()
        # [batch_size, q_length, head_num, hidden_dim/head_num] -> [batch_size, q_length, hidden_dim]
        context = context.view(batch_size, token_num, hidden_dim)
        
        # 線形変換
        context = self.projection(context)
        
        return self.drop(context), atten_weight

    def _split(self, x):
        """
        query, key, valueを分割する
        
        入力 shape: [batch_size, length, hidden_dim] の時
        出力 shape: [batch_size, head_num, length, hidden_dim//head_num]
        """
        # 各値を取得
        hidden_dim = self.nf
        heads_num = self.nh
        shape = x.shape
        batch_size = -1 if shape[0] is None else shape[0]
        length = shape[1] # 系列長
        
        # [batch_size, (q|m)_length, hidden_dim] -> [batch_size, (q|m)_length, head_num, hidden_dim/head_num]
        # splitだが実際は次元を拡張する処理
        x = x.view(batch_size, length, heads_num, int(hidden_dim/heads_num))
        
        # [batch_size, (q|m)_length, head_num, hidden_dim/head_num] -> [batch_size, head_num, (q|m)_length, hidden_dim/head_num]
        x = x.transpose(1, 2)
        return x
    
    def forward(self, x, memory=None, return_attention_scores=False):
        """
        モデルの実行
        
        input : 入力(query) [batch_size, length, hidden_dim]
        memory : 入力(key, value) [batch_size, length, hidden_dim]
         ※memory(key, value)についてはqueryのtoken_numと異なる場合がある
        return_attention_scores : attention weightを出力するか
        """
        # memoryが入力されない場合、memory=input(Self Attention)とする
        if memory is None:
            memory = x
        
        # input -> query
        # memory -> key, value
        # [batch_size, (q|m)_length, hidden_dim] @ [hidden_dim, hidden_dim] -> [batch_size, (q|m)_length, hidden_dim] 
        query = self.query(x)
        key = self.key(memory)
        value = self.value(memory)
        
        # ヘッド数に分割する
        # 実際はreshapeで次数を1つ増やす
        # [batch_size, (q|m)_length, hidden_dim] -> [batch_size, head_num, (q|m)_length, hidden_dim/head_num]
        query = self._split(query)
        key = self._split(key)
        value = self._split(value)
        
        # attention
        # 入力と同じ形の出力
        # context: [batch_size, q_length, hidden_dim]
        context, atten_weight = self.atten(query, key, value)
        
        if return_attention_scores:
            return context, atten_weight
        else:
            return context

class FeedForwardNetwork(nn.Module):
    '''
    Position-wise Feedforward Neural Network
    transformer blockで使用される全結合層
    '''
    def __init__(self, hidden_dim, drop_rate=0.1):
        '''
        hidden_dim : Embeddingされた単語ベクトルの長さ
        drop_rate : 出力のDropout率
        '''
        super().__init__()
        # 2層構造
        # 1層目:チャンネル数を増加させる
        self.filter_dense_layer = nn.Linear(hidden_dim, hidden_dim * 4)
        self.gelu = nn.GELU()
        
        # 2層目:元のチャンネル数に戻す
        self.output_dense_layer = nn.Linear(hidden_dim * 4, hidden_dim)
        self.drop = nn.Dropout(drop_rate)

    def forward(self, x):
        '''
        入力と出力で形が変わらない
        x : 入力 [batch_size, length, hidden_dim]
        '''
        
        # [batch_size, (q|m)_length, hidden_dim] -> [batch_size, (q|m)_length, 4*hidden_dim]
        x = self.filter_dense_layer(x)
        x = self.gelu(x)
        x = self.drop(x)
        
        # [batch_size, (q|m)_length, 4*hidden_dim] -> [batch_size, (q|m)_length, hidden_dim]
        return self.output_dense_layer(x)

class ResidualNormalizationWrapper(nn.Module):
    '''
    残差接続
    output: input + SubLayer(input)
    '''
    def __init__(self, layer, hidden_dim, drop_rate=0.1):
        '''
        layer : 残渣接続したいレイヤ(MultiHeadAttentionかFeedForwardNetwork)に適用
        drop_rate : 出力のDropout率
        '''
        super().__init__()
        self.layer = layer # SubLayer : ここではAttentionかFFN
        self.layer_normalization = nn.LayerNorm(hidden_dim)
        self.drop = nn.Dropout(drop_rate)

    def forward(self, x, memory=None, return_attention_scores=False):
        """
        モデルの実行
        
        memory : 入力(key, value) [batch_size, length, hidden_dim]
         ※memory(key, value)についてはqueryのlengthと異なる場合がある
        return_attention_scores : attention weightを出力するか

        AttentionもFFNも入力と出力で形が変わらない
        output : [batch_size, length, hidden_dim]
        """
        
        params = {}
        if memory is not None:
            params['memory'] = memory
        if return_attention_scores:
            params['return_attention_scores'] = return_attention_scores
        
        out = self.layer_normalization(x)
        if return_attention_scores:
            # attention weightを返す
            out, attn_weights = self.layer(out,**params)
            out = self.drop(out)
            return x + out, attn_weights
        else:
            # attention weightを返さない
            out = self.layer(out,**params)
            out = self.drop(out)
            return x + out

class EncoderLayer(nn.Module):
    """
    Encoderレイヤ
     MultiHeadAttentionとFeedForwardNetworkの組み合わせ
      それぞれ残差接続されている
    """
    def __init__(self, hidden_dim, heads_num, drop_rate=0.1):
        """
        hidden_dim : Embeddingされた単語ベクトルの長さ
        heads_num : Multi-head Attentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        """
        super().__init__()
        # Multi-head attention
        self.atten = ResidualNormalizationWrapper(
            hidden_dim=hidden_dim,
            layer = SelfMultiHeadAttention(hidden_dim = hidden_dim,
                                           heads_num = heads_num,
                                           drop_rate = drop_rate),
            drop_rate = drop_rate)
        
        # Feed Forward Network
        self.ffn = ResidualNormalizationWrapper(
            hidden_dim=hidden_dim,
            layer = FeedForwardNetwork(hidden_dim = hidden_dim,
                                       drop_rate = drop_rate),
            drop_rate = drop_rate)
    
    def forward(self, input, memory=None, return_attention_scores=False):
        """
        x : 入力(query) [batch_size, length, hidden_dim]
        memory : 入力(key, value) [batch_size, length, hidden_dim]
         ※memory(key, value)についてはqueryのtoken_numと異なる場合がある
        return_attention_scores : attention weightを出力するか

        AttentionもFFNも入力と出力で形が変わらない
        output : [batch_size, length, hidden_dim]
        
        入力と出力で形式が変わらない
        output : [batch_size, length, hidden_dim]
        """
        if return_attention_scores:
            x, attn_weights = self.atten(input, memory, return_attention_scores)
            x = self.ffn(x)
            return x, attn_weights
        else:
            x = self.atten(input, memory, return_attention_scores)
            x = self.ffn(x)
            return x

class Encoder(nn.Module):
    '''
    TransformerのEncoder
    '''
    def __init__(self, img_size, patch_size, hopping_num, heads_num, hidden_dim, drop_rate):
        '''
        img_size : 画像のサイズ
        patch_size : 画像を分割するサイズ
        hopping_num : Multi-head Attentionの繰り返し数
        hidden_dim : Embeddingされた特徴ベクトルの長さ
        heads_num : Multi-head Attentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        '''
        super().__init__()
        self.hopping_num = hopping_num
        
        # patch Embedding
        self.patch_embedding = patch_embbeding(img_size=img_size, patch_size=patch_size, hidden_dim=hidden_dim)
        self.input_dropout_layer = nn.Dropout(drop_rate)

        # Multi-head Attentionの繰り返し(hopping)のリスト
        self.attention_block_list = nn.ModuleList([EncoderLayer(hidden_dim, heads_num) for _ in range(hopping_num)])
        self.output_normalization = nn.LayerNorm(hidden_dim)

    def forward(self, input, return_attention_scores=False):
        '''
        input: 入力 [batch_size,height,width,channel]
        return_attention_scores : attention weightを出力するか
        出力 [batch_size, q_length, hidden_dim]
        '''
        # patch Embedding
        # [batch_size,height,width,channel] → [batch_size, q_length, hidden_dim]
        embedded_input = self.patch_embedding(input)
        query = self.input_dropout_layer(embedded_input)
        
        # Encoderレイヤを繰り返し適用
        if return_attention_scores:
            for i in range(self.hopping_num):
                query, atten_weights = self.attention_block_list[i](query, return_attention_scores=return_attention_scores)

            #  [batch_size, q_length, hidden_dim]
            return self.output_normalization(query), atten_weights
        else:
            for i in range(self.hopping_num):
                query = self.attention_block_list[i](query, return_attention_scores=return_attention_scores)
            # [batch_size, q_length, hidden_dim]
            return self.output_normalization(query)

class VisionTransformer(nn.Module):
    """
    Vision Transformer
    
    """
    def __init__(self, img_size, patch_size, hopping_num, heads_num, hidden_dim, drop_rate):
        '''
        patch_size : 画像を分割するサイズ
        hopping_num : Multi-head Attentionの繰り返し数
        hidden_dim : Embeddingされた特徴ベクトルの長さ
        heads_num : Multi-head Attentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        '''
        super().__init__()        
        self.encoder = Encoder(img_size, patch_size, hopping_num, heads_num, hidden_dim, drop_rate)
        self.dense1 = nn.Linear(hidden_dim, 16)
        self.act1 = nn.Tanh()
        self.dropout1 = nn.Dropout(drop_rate)   
        self.final_layer = nn.Linear(16, 2)
        
        nn.init.normal_(self.dense1.weight, std=0.02)
        nn.init.normal_(self.dense1.bias, std=0)
        nn.init.normal_(self.final_layer.weight, std=0.02)
        nn.init.normal_(self.final_layer.bias, std=0)

    def forward(self, inputs, return_attention_scores=False):
        '''
        inputs: 入力(encoder, decoder)
        return_attention_scores : attention weightを出力するか
        '''
        # enc_input : [batch_size,height,width,channel]
        if return_attention_scores:
            enc_output, atten_weights = self.encoder(inputs, return_attention_scores=return_attention_scores)
        else:
            enc_output = self.encoder(inputs, return_attention_scores=return_attention_scores)
        
        # クラストークン部分のみ使用
        # [batch_size, enc_length, hidden_dim] -> [batch_size, hidden_dim]
        enc_output = self.dense1(enc_output[:, 0, :])
        enc_output = self.act1(enc_output)
        enc_output = self.dropout1(enc_output)
        final_output = self.final_layer(enc_output)

        if return_attention_scores:
            return final_output, atten_weights
        else:
            return final_output
from torchsummary import summary

model = VisionTransformer(img_size=(224,224), patch_size=16, hopping_num=12, heads_num=12, hidden_dim=768, drop_rate=0.1)
summary(model, (3,224,224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 768, 14, 14]         590,592
   patch_embbeding-2             [-1, 197, 768]               0
 略
          Linear-225                    [-1, 2]              34
================================================================
Total params: 85,658,930
Trainable params: 85,658,930
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 379.76
Params size (MB): 326.76
Estimated Total Size (MB): 707.10
----------------------------------------------------------------
class MyDataModule(pl.LightningDataModule):
    def __init__(self, train_path=None, val_path=None, test_path=None, batch_size=32):
        super().__init__()
        self.batch_size=batch_size
        self.train_path = train_path
        self.val_path = val_path
        self.test_path = test_path
        
        self.train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3)),
            transforms.RandomRotation([-10, 10]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        self.val_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    def setup(self, stage=None):
        if stage=='fit':
            self.train_dataset = datasets.ImageFolder(self.train_path, self.train_transform)
            self.val_dataset  = datasets.ImageFolder(self.val_path, self.val_transform)
        else:
            self.test_dataset  = datasets.ImageFolder(self.test_path, self.val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,num_workers=3, shuffle=True)
 
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,num_workers=3, shuffle=True)
 
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size,num_workers=3, shuffle=True)
class ViTTrainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        model = VisionTransformer(img_size=(224,224), patch_size=8, hopping_num=4, heads_num=2, hidden_dim=32, drop_rate=0.1)
        model.to(device)
        self.model = model
        self.training_step_outputs = []
        self.validation_step_outputs = []
        self.test_step_outputs = []
        
    def forward(self, x):
        x = x.to(device)
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        #loss = nn.BCEWithLogitsLoss()(y_hat, y)
        self.training_step_outputs.append({'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)})
        return {'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        #loss = nn.BCEWithLogitsLoss()(y_hat, y)
        self.validation_step_outputs.append({'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)})
        return {'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def test_step(self, batch, batch_nb):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        #loss = nn.BCEWithLogitsLoss()(y_hat, y)
        y_label = torch.argmax(y_hat, dim=1)
        acc = accuracy(task="binary")(y_label, y)
        self.test_step_outputs.append({'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)})
        return {'test_loss': loss, 'test_acc': acc}
    
    def on_train_epoch_end(self):
        y_hat = torch.cat([val['y_hat'] for val in self.training_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in self.training_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in self.training_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        preds = preds.cpu()
        y = y.cpu()
        acc = accuracy(task="binary")(preds, y)
        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)
        
        print('---------- Current Epoch {} ----------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))
    
    def on_validation_epoch_end(self):
        y_hat = torch.cat([val['y_hat'] for val in self.validation_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in self.validation_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in self.validation_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)

        preds = preds.cpu()
        y = y.cpu()
        acc = accuracy(task="binary")(preds, y)
        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)
        
        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
    
    # New: テストデータに対するエポックごとの処理
    def on_test_epoch_end(self, test_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in self.test_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in self.test_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in self.test_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        preds = preds.cpu()
        y = y.cpu()
        acc = accuracy(task="binary")(preds, y)
        self.log('test_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('test_acc', acc, prog_bar=True, on_epoch=True)
        
        print('test Loss: {:.4f} test Acc: {:.4f}'.format(epoch_loss, acc))
        
    def configure_optimizers(self):
        learning_rate = 1e-3
        #optimizer = optimizers.Adam(self.model.parameters(), weight_decay=0.1)
        optimizer = optimizers.Adam(net.parameters(), lr=learning_rate, amsgrad=True, eps=1e-07)
        return optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dm = MyDataModule(train_path='/content/drive/MyDrive/dogcat_mini/train',
                  val_path='/content/drive/MyDrive/dogcat_mini/val',
                  test_path='/content/drive/MyDrive/dogcat_mini/test',
                  batch_size=25)
dm.setup()

net = ViTTrainer().to(device)

trainer = pl.Trainer(devices=1,max_epochs=200, accelerator="gpu")
trainer.fit(net, dm)

image.png

accuracyはkeras同様0.7程度でした。
こちらでも可視化を行います。GoogleColabのユニットを使いきってしまいだせませんでした。

import matplotlib.pyplot as plt
import cv2

fig, ax = plt.subplots(1,2,figsize=(10,5))

ax[0].imshow(imgs[20].detach().numpy().transpose(1,2,0))
ax[1].imshow(cv2.resize(attentino_map[20,:,1:,0].detach().numpy().mean(axis=0).reshape((int(224/8),int(224/8))), (224, 224)))

pytorchでもいくつか学習済みモデルが提供されています。
ここでは、torchvisionとtransformersで利用できるものを実装します。

!pip install transformers
import torchvision.models as models
from transformers import ViTForImageClassification

class CustomVisionTransformer(nn.Module):
    """
    Vision Transformer
    
    """
    def __init__(self, drop_rate=0.1, model_type='torchvision'):
        '''
        patch_size : 画像を分割するサイズ
        hopping_num : Multi-head Attentionの繰り返し数
        hidden_dim : Embeddingされた特徴ベクトルの長さ
        heads_num : Multi-head Attentionのヘッド数
           ※hidden_numはheads_numで割り切れえる値とすること
        drop_rate : 出力のDropout率
        '''
        super().__init__()
        fc=lambda hidden_dim: nn.Sequential(
          nn.Linear(hidden_dim, 16),
          nn.Tanh(),
          nn.Dropout(drop_rate),
          nn.Linear(16, 2)
          )

        if model_type=="torchvision":
            self.encoder = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
            for param in self.encoder.parameters():
                param.requires_grad = False
            hidden_dim = self.encoder.heads.head.in_features
            self.encoder.heads = fc(hidden_dim)
        else:
            self.encoder = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224",output_attentions=True)
            for param in self.encoder.parameters():
                param.requires_grad = False
            hidden_dim = self.encoder.classifier.in_features
            self.encoder.classifier = fc(hidden_dim)

    def forward(self, inputs):
        '''
        inputs: 入力(encoder, decoder)
        return_attention_scores : attention weightを出力するか
        '''
        final_output = self.encoder(inputs)
        return final_output
class ViTTrainer2(pl.LightningModule):
    def __init__(self, model_type='torchvision'):
        super().__init__()
        model = CustomVisionTransformer(model_type=model_type)
        model.to(device)
        self.model = model
        self.training_step_outputs = []
        self.validation_step_outputs = []
        self.test_step_outputs = []
        
    def forward(self, x):
        x = x.to(device)
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        #loss = nn.BCEWithLogitsLoss()(y_hat, y)
        self.training_step_outputs.append({'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)})
        return {'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        #loss = nn.BCEWithLogitsLoss()(y_hat, y)
        self.validation_step_outputs.append({'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)})
        return {'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def test_step(self, batch, batch_nb):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        #loss = nn.BCEWithLogitsLoss()(y_hat, y)
        y_label = torch.argmax(y_hat, dim=1)
        acc = accuracy(task="binary")(y_label, y)
        self.test_step_outputs.append({'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)})
        return {'test_loss': loss, 'test_acc': acc}
    
    def on_train_epoch_end(self):
        y_hat = torch.cat([val['y_hat'] for val in self.training_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in self.training_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in self.training_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        preds = preds.cpu()
        y = y.cpu()
        acc = accuracy(task="binary")(preds, y)
        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)
        
        print('---------- Current Epoch {} ----------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))
    
    def on_validation_epoch_end(self):
        y_hat = torch.cat([val['y_hat'] for val in self.validation_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in self.validation_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in self.validation_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)

        preds = preds.cpu()
        y = y.cpu()
        acc = accuracy(task="binary")(preds, y)
        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)
        
        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
    
    # New: テストデータに対するエポックごとの処理
    def on_test_epoch_end(self, test_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in self.test_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in self.test_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in self.test_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        preds = preds.cpu()
        y = y.cpu()
        acc = accuracy(task="binary")(preds, y)
        self.log('test_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('test_acc', acc, prog_bar=True, on_epoch=True)
        
        print('test Loss: {:.4f} test Acc: {:.4f}'.format(epoch_loss, acc))
        
    def configure_optimizers(self):
        learning_rate = 1e-3
        #optimizer = optimizers.Adam(self.model.parameters(), weight_decay=0.1)
        optimizer = optimizers.Adam(net.parameters(), lr=learning_rate, amsgrad=True, eps=1e-07)
        return optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dm = MyDataModule(train_path='/content/drive/MyDrive/dogcat_mini/train',
                  val_path='/content/drive/MyDrive/dogcat_mini/val',
                  test_path='/content/drive/MyDrive/dogcat_mini/test',
                  batch_size=25)
dm.setup()

# torchvision
net = ViTTrainer2().to(device)

trainer = pl.Trainer(devices=1,max_epochs=5, accelerator="cpu")
trainer.fit(net, dm)
# transformers
net = ViTTrainer2(model_type='huggingface').to(device)

trainer = pl.Trainer(devices=1,max_epochs=10, accelerator="gpu")
trainer.fit(net, dm)
import matplotlib.pyplot as plt
import cv2

result = net.model.forward(imgs)

fig, ax = plt.subplots(1,2,figsize=(10,5)))

ax[0].imshow(imgs[1].detach().numpy().transpose(1,2,0))
ax[1].imshow(cv2.resize(result.attentions[0][1,:,1:,0].detach().numpy().mean(axis=0).reshape((int(224/16),int(224/16))), (224, 224)))

image.png

以上となります。
Chat-GPTつかえば簡単かもしれませんが勉強は続けています。

14
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?