10
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

一番簡単(で一番適当な)なVision Transformer実装例

Last updated at Posted at 2022-07-30

初めに

ICLR2021にてViTのポスター発表ありましたね。 
なので遅ればせながらViTの解説とその実装をします。

色々実装例を見たところスクラッチから書いてる例かViT専用のライブラリを使ってる例しか見当たりませんでした。
やっぱりプログラムは少数の有名ライブラリを上手く使って書くものだと思うんですよね。

というわけで以下のレギュレーションに則りプログラムを書こうということです。

  • pytorch以外禁止
  • pytorchのライブラリで代用できる処理は代用する

あ、組んだコードは以下です。

ViTの特徴

実装に入る前にViTの特徴について話したいと思います。

特徴としては何よりトランスフォーマーを使っていることでしょうか。
トランスフォーマーはもともと自然言語処理とかのシーケンスに使われていた技術でした。

当たり前ながらトランスフォーマーには畳み込みは使いません。
しかしこれが画期的な点になってきます。

ViTが出る前は畳み込みを使わない手はないくらい畳み込みの時代でした。
それなのにViTは畳み込みを使わずに既存の記録を多数塗り替えてしまいました。
界隈がひっくり返る大発見ですよほんとに。

しかもトランスフォーマーは畳み込み処理より計算コストが安いおまけつき。

数年経った今でもViTの仕組みを応用した論文出てきてます。
例えばAttention機構をmlpにした奴とかmlpじゃなくて単なるプーリング層にしたろとかが個人的に面白い論文でした。

それだけ画期的な技術だったということですね。

ViTの構造

こちらのFigure1がわかりやすいです。

ViTは自然言語処理としてのトランスフォーマーを意識しています。
実際ViT作者も既存のトランスフォーマーを壊したくないみたいなことを言ってました。

we wanted the model to be "exactly Transformer, but on image patches"

なのでBERTを学んでからViTをやったほうがわかりやすいと思います。

BERT未履修でも何とかなるようにはするつもりですが...

ViTには5つの段階があります。

  1. 画像パッチ化処理
  2. CLSトークン埋め込み
  3. ポジション埋め込み
  4. トランスフォーマー入力
  5. mlp入力

これら5つについて、実装しながら説明したいと思います。

実装してみよう

上述の5つのステップを順に実装してみましょう。
まずViTのクラスから、

vit.py
import torch
import torch.nn as nn
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer
class ViT(nn.Module):
    def __init__(self):
        super(ViT, self).__init__()
    def forward(self,x):
        return x

importは以降使うクラスを呼び出してます。

画像のパッチ化処理

トランスフォーマーはそもそも自然言語処理用のモデルでした。
つまり言語っぽい扱いをしてあげなきゃいけません。

やってることは以下2つ、

  • 画像を任意の数に当分
  • それぞれを単語とみなす

これで無理やりトランスフォーマーに入力している感じです。

今回cifar10(3x32x32の画像)を使おうと思ってるので3x8x8の画像に16分割しましょうか。
patchifyでグリッド状に画像を分けて、flattenで画像を潰しました。

vit.py
#import 省略
#以下4行追加
imageWH = 32
patchWH = 8
splitRow=imageWH//patchWH #32 / 8 = 4
splitCol=imageWH//patchWH #32 / 8 = 4
#以上
class ViT(nn.Module):
    def __init__(self):
        super(ViT, self).__init__()
    #以下5行追加
    def patchify(img):
        horizontal = torch.stack(torch.chunk(img,splitRow,dim=2),dim=1)
        patches = torch.cat(torch.chunk(horizontal,splitCol,dim=4),dim=1)
        return patches
    #以上
    def forward(self,x):
        x=self.patchify(x) # 追加 [batch_size, 16, 3, 8, 8]
        x=torch.flatten(x,start_dim=2) # 追加 [batch_size, 16, 3x8x8]
        return x

潰した画像はまだ生の値で、特徴量ではありません。
とりあえず全結合層に入れてトランスフォーマーへねじ込めるようにしましょう。

vit.py
#省略
patchWH = 8
#以下3行追加
channel=3
patchVectorLen=channel*(patchWH**2) #3 * (8 ** 8) = 192
embedVectorLen=int(patchVectorLen/2) #192 / 2 = 96
#以上
class ViT(nn.Module):
    def __init__(self):
        super(ViT, self).__init__()
        self.patchEmbedding = nn.Linear(patchVectorLen,embedVectorLen) #追加
    def patchify(img):
        #省略
    def forward(self,x):
        x=self.patchify(x)
        x=torch.flatten(x,start_dim=2)
        x=self.patchEmbedding(x) #追加 [batch_size, 16, embedVectorLen(3x8x8/2)]
        return x

CLSトークン埋め込み

ここはBERTやってない人がつまずきやすいので少し長めに説明します

CLSとはクラスの略です。
画像全体の特徴量を集める役割を持たせます。

するとCLSの出力のみを使って画像分類ができるようになりました。
やったね!

実装方法としては学習可能な変数をCLSトークンとして定義します。
このトークンをベクトル化されたパッチ列の先頭に配置します。

vit.py
#省略
class ViT(nn.Module):
    def __init__(self):
        super(ViT, self).__init__()
        self.patchEmbedding = nn.Linear(patchVectorLen,embedVectorLen)
        self.cls = nn.Parameter(torch.zeros(1, 1, embedVectorLen)) #追加
    def patchify(img):
        #省略
    def forward(self,x):
        x=self.patchify(x)
        x=torch.flatten(x,start_dim=2)
        x=self.patchEmbedding(x)
        #以下2行追加
        clsToken = self.cls.repeat_interleave(x.shape[0],dim=0) #[batch_size, 1, embedVectorLen(3x8x8/2)]
        x=torch.cat((clsToken,x),dim=1) #[batch_size, 1+16, embedVectorLen(3x8x8/2)]
        #以上
        return x

特徴量を集めるとか大層な気がしますがこれは逆で集まるのは必然的なんですよね。

ViTの最終段階(以下mlpとかmlpヘッダー)への入力はCLSトークンのみです。
するとViT全体がCLSトークンに特徴が集まるように学習してしまうのは自明ですね。

ここは重要な点ですが、どのような入力にもトランスフォーマー第1層のCLSトークンは一定です。
「トランスフォーマー最終層のCLSトークンに特徴量が一番集まるように第1層のCLSトークン(定数)を決める」
と言えばわかりやすいでしょうか。

余談1

先ほど「CLSトークンのみをmlpへ入力しているので、CLSに特徴が集まる」と言いました。
これは逆に「CLSトークン以外をmlpへ入力にすればCLSはいらない」ということでもあります。

というわけでCLSトークンが気持ち悪い諸兄への朗報です。
CLSを使わなくてもViTはViT足りえます。

ViT作者もCLSはViTにそこまで関係しないと明言していますね。

Different from the common ways to use feature maps to obtain classifcation prediction (with fc or GAP layers), VIT employs an extra class embedding to do this without using feature maps explicitly. Wonder the meanings of this unusual design?

Great question. It is not really important. However, we wanted the model to be "exactly Transformer, but on image patches", so we kept this design from Transformer, where a token is always used.

できるだけNLP用のトランスフォーマーをそのまま使いたいため、CLSを採用しているみたいです。

CLSを使わない方式は余談2で紹介しています。

ポジション埋め込み

トランスフォーマーは入力順という概念がありません。
不便ですがどの入力が画像のどこに対応するか明示する必要があります。

元論文を見ると1次元の学習可能な変数を加算でいいらしいですね。

vit.py
#省略
imageWH = 32
patchWH = 8
patchTotal=(imageWH//patchWH)**2 #追加 (32 / 8)^2 = 16
class ViT(nn.Module):
    def __init__(self):
        # 省略
        self.positionEmbedding = nn.Parameter(torch.zeros(1, patchTotal + 1, embedVectorLen)) #追加
    def patchify(img):
        #省略
    def forward(self,x):
        # 省略
        x=torch.cat((clsToken,x),dim=1)
        x+=self.positionEmbedding #追加 [batch_size, 17, embedVectorLen(3x8x8/2)]
        return x

self.positionEmbeddingのpatchTotalに1が加算されているのはCLSトークンが入っているためです。

2次元の正弦関数とか高度なの使ったみたいですけど言うて結果は良くならなかったみたいですね。

We use standard learnable 1D position embeddings, since we have not observed significant performance gains from using more advanced 2D-aware position embeddings

トランスフォーマー入力

ここは腐るほど説明されていると思うのですぐ終えます。

pre-norm、GELU関数を採用していることに注意。

vit.py
#省略
#以下4行追加
head=12
dim_feedforward=embedVectorLen
activation="gelu"
layers=12
#以上
class ViT(nn.Module):
    def __init__(self):
        # 省略
        self.positionEmbedding = nn.Parameter(torch.zeros(1, patchTotal + 1, embedVectorLen))
        #以下9行追加
        encoderLayer = TransformerEncoderLayer(
            d_model=embedVectorLen,
            nhead=head,
            dim_feedforward=dim_feedforward,
            activation=activation,
            batch_first=True,
            norm_first=True
        )
        self.transformerEncoder = TransformerEncoder(encoderLayer,layers)
        #以上
    def patchify(img):
        #省略
    def forward(self,x):
        # 省略
        x+=self.positionEmbedding
        x=self.transformerEncoder(x) #追加 [batch_size, 17, embedVectorLen(3x8x8/2)]
        return x

mlp入力

10次元出力のmlpを作ってトランスフォーマー最終層のCLSトークンを入力します。

vit.py
#省略
class ViT(nn.Module):
    def __init__(self):
        # 省略
        self.transformerEncoder = TransformerEncoder(encoderLayer,layers)
        self.mlpHead=nn.Linear(embedVectorLen,10) #追加
    def patchify(img):
        #省略
    def forward(self,x):
        # 省略
        x=self.transformerEncoder(x)
        x=self.mlpHead(x[:,0,:]) # 追加[batch_size, 17, 10]
        return x

余談2

余談1で言いましたCLSを使わない手法について検討します。

まずCLSトークンが存在しない世界線のViTを召喚します。

vit_noCLS.py
#省略
class ViT(nn.Module):
    def __init__(self):
        super(ViT, self).__init__()
        self.patchEmbedding = nn.Linear(patchVectorLen,embedVectorLen)
        # self.clsを削除
        self.positionEmbedding = nn.Parameter(torch.zeros(1, patchTotal, embedVectorLen)) #ポジション埋め込みの次元を[batch_size, 16, embedVectorLen(3x8x8/2)]に変更
        encoderLayer = TransformerEncoderLayer(
            d_model=embedVectorLen,
            nhead=head,
            dim_feedforward=dim_feedforward,
            activation=activation,
            batch_first=True,
            norm_first=True
        )
        self.transformerEncoder = TransformerEncoder(encoderLayer,layers)
        self.mlpHead=nn.Linear(embedVectorLen,10)
    def patchify(img):
        #省略
    def forward(self,x):
        # 省略
        #clsToken 削除
        #torch.cat((clsToken,x),dim=1) 削除
        x+=self.positionEmbedding
        x=self.transformerEncoder(x)
        x=self.mlpHead(x[:,0,:]) #ここに注目
        return x

正味これだけでも何とかなりはします。
mlpの入力を色々変えてみても良いですね。

例えば

vit_noCLS.py
class ViT(nn.Module):
    #省略
    def forward(self,x):
        # 省略
        x=self.transformerEncoder(x)
        x=self.mlpHead(x.mean(dim=1)) # 変更
        return x

とかでもよろしいかと。

あとはself.mlpHeadを単純な全結合じゃなくてLSTMとかに変更も面白そうです。

できた!!

というわけで全体像ぽん

vit.py
import torch
import torch.nn as nn
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer

#image param
imageWH = 32
channel=3

#vit hyperparam
patchWH=8
splitRow=imageWH//8
splitCol=imageWH//8
patchTotal=(imageWH//patchWH)**2 #(32 / 8)^2 = 16
patchVectorLen=channel*(patchWH**2) #3 * 64 = 192
embedVectorLen=int(patchVectorLen/2)

#transformer layer hyperparam
head=12
dim_feedforward=embedVectorLen
activation="gelu"
layers=12

class ViT(nn.Module):
    def __init__(self):
        super(ViT, self).__init__()
        self.patchEmbedding = nn.Linear(patchVectorLen,embedVectorLen)
        self.cls = nn.Parameter(torch.zeros(1, 1, embedVectorLen))
        self.positionEmbedding = nn.Parameter(torch.zeros(1, patchTotal + 1, embedVectorLen))
        encoderLayer = TransformerEncoderLayer(
            d_model=embedVectorLen,
            nhead=head,
            dim_feedforward=dim_feedforward,
            activation=activation,
            batch_first=True,
            norm_first=True
        )
        self.transformerEncoder = TransformerEncoder(encoderLayer,layers)
        self.mlpHead=nn.Linear(embedVectorLen,10)

    def patchify(self,img):
        horizontal = torch.stack(torch.chunk(img,splitRow,dim=2),dim=1)
        patches = torch.cat(torch.chunk(horizontal,splitCol,dim=4),dim=1)
        return patches

    def forward(self,x):
        x=self.patchify(x)
        x=torch.flatten(x,start_dim=2)
        x=self.patchEmbedding(x)
        clsToken = self.cls.repeat_interleave(x.shape[0],dim=0)
        x=torch.cat((clsToken,x),dim=1)
        x+=self.positionEmbedding
        x=self.transformerEncoder(x)
        x=self.mlpHead(x[:,0,:])
        return x

テスト

cifar10でやってみます。比較対象として適当に作ったCNNも学習させます。
image.png

image.png

CNNもViTも特別な操作をしていないのですぐに頭打ちになっちゃいますね。

とはいえやっぱりCNNのほうがいい結果を出します。
それもそのはずデータを食わせば食わせるだけ精度が上がるのがトランスフォーマーです。
SOTAに届くには億単位のデータが必用です。我々には手の届かない世界ですね...

ということでどんでん返しになっちゃいますが、大人しく既存のライブラリと事前学習データを使うことをお勧めします。

終わりに

やっぱり自分で書いてみると勉強になりますね。

なんか腑抜けた感じに終わりましたが、ViTへの理解の助けになったら幸いです。

参考サイト

10
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?