はじめに
しばしばVision Transformerの解説でVision Transformerには全く畳み込みが使われていなくて凄いんだというような解説を見かける事があります。一方、自分としては「パッチ分割+全結合」は「畳み込み」に等しいんだから、**Vision Transformerも実質畳み込みを使っているように考えられるので、別にそこまで凄いわけではなくない?**と思う事があります。
自分の過去記事などで、その事はそれとなく書いたりしていたわけですが、あんまり自分の主張が伝わっている気がしないので記事に起こしてみます。
畳み込みの分割
畳み込みは主に二つの処理、im2col関数(pytorchではUnfold関数、tensorflowではextract_image_patches関数)と全結合(pytorchではnn.Linear、tensorflowではDense)に分割する事が出来ます。
例えば3x3畳み込みを考えた場合、im2col関数は$(Batch,H,W,C_{in})$を$(Batch,H,W,9×C_{in})$に変換する関数です。これは単なる処理関数で保持する重みなどはありません。これに全結合として$(9×C_{in}, C_{out})$の行列を掛けると$(Batch,H,W,C_{out})$が出力になります。
Vision Transformerで用いられるパッチ分割はカーネルサイズとストライドサイズが等しいim2col関数と見なすことができますからカーネルサイズとストライドサイズが$16$とした場合、$(Batch,H,W,C_{in})$にim2col関数を掛けると$(Batch,H/16,W/16,256×C_{in})$になり、これに全結合として$(256×C_{in}, C_{out})$を掛けると$(Batch,H/16,W/16,C_{out})$が出力になります。
これはカーネルサイズとストライドサイズが等しい畳み込み関数と等しいと見なすことができます。また、このカーネルサイズとストライドサイズが等しいパッチ分割はデータの大きさは変わらないのでデータのReshapeとTransposeのみでも表すことができます。
「畳み込み」と「パッチ分割+全結合」の比較
以下の様に3通りのモデルを作ってみます。
1番目のモデルはカーネルサイズとストライドサイズ$16$のConv2Dの畳み込みを行うモデル
2番目のモデルはパッチ分割のim2col関数を掛け、全結合を行うモデル
3番目のモデルはパッチ分割を形状変形と軸入れ替えで行い、全結合を行うモデル
作成したモデルの畳み込みの重みを全結合の重み変換して移しておきます。
from tensorflow.keras import Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Reshape, Permute
import tensorflow as tf
import numpy as np
def make_model():
#1. Conv2D model
input = Input(shape=(224, 224, 3))
x1 = Conv2D(32, (16,16), strides=(16,16), use_bias=False)(input)
model1 = Model(input, x1)
#2. im2col model
x2 = tf.compat.v1.extract_image_patches(input,
ksizes=(1,16,16,1), strides=(1,16,16,1), rates=(1,1,1,1), padding="VALID")
x2 = Dense(32, use_bias=False)(x2)
model2 = Model(input, x2)
#3. reshape & transpose model
x3 = Reshape((14,16,14,16,3))(input) #(224,224,3)=>(14,16,14,16,3)
x3 = Permute((1,3,2,4,5))(x3) #(14,16,14,16,3)=>(14,14,16,16,3)
x3 = Reshape((14,14,16*16*3))(x3) #(14,14,16,16,3)=>(14,14,16*16*3)
x3 = Dense(32, use_bias=False)(x3)
model3 = Model(input, x3)
return model1, model2, model3
model1, model2, model3 = make_model()
model1_weight = model1.layers[1].weights[0]
model2_weight = model2.layers[2].weights[0]
model3_weight = model3.layers[4].weights[0]
print("model1_weight.shape=", model1_weight.shape)
print("model2_weight.shape=", model2_weight.shape)
print("model3_weight.shape=", model3_weight.shape)
model1_weight = np.reshape(model1_weight, (768, 32))
model2.layers[2].set_weights([model1_weight])
model3.layers[4].set_weights([model1_weight])
x_train = np.random.randn(10, 224, 224, 3)
print(model1.predict(x_train)[0,1,1,:6])
print(model2.predict(x_train)[0,1,1,:6])
print(model3.predict(x_train)[0,1,1,:6])
上記を実行した結果は下記のようになりました。
モデルの出力値が等しい事から「畳み込み」と「パッチ分割+全結合」が等価であることを示せました。また、パッチ分割はim2col関数でも実装できるし、ReshapeとTransposeでも代替できることが分かります。
model1_weight.shape= (16, 16, 3, 32)
model2_weight.shape= (768, 32)
model3_weight.shape= (768, 32)
[-0.6510169 0.38975465 0.05311207 -0.4468006 -0.07156172 0.6943207 ]
[-0.6510169 0.38975465 0.05311207 -0.4468006 -0.07156172 0.6943207 ]
[-0.6510169 0.38975465 0.05311207 -0.4468006 -0.07156172 0.6943207 ]
まとめ
Vision Transformerの「パッチ分割+全結合」は「畳み込み」と等価であることを示した。
例えばViTの実装例では一番最初にConv2Dが使われている例が見受けられるが、これは上記が等価である事を利用したものである。
ViTのパッチ分割部分はカーネルサイズの大きい畳み込みを掛けていると見なせる。