LoginSignup
12
12

More than 1 year has passed since last update.

MLP-Mixerについての考察

Posted at

最近は深層モデルにおいてMLP-Mixerというものが流行っているらしい。
これはパッチとチャンネルの次元を入れ替えてMLPを行うというものである。ここでMLPとは単なる全結合を2回行うだけのものを指している。(MLP…。Multi-layer perceptron(多層パーセプトロン)…?CNNが流行る前の単純な全結合モデルを表す随分昔の名前じゃないか。百歩譲ってMLP-Mixerはそれでいいとして、後続モデルが全部ResMLPやらS2-MLPやら、MixerではなくMLPの呼称のみなのはどうなんだろう。最新ゲーム機全てに〇〇ファミコンって名前がついてるような違和感を感じる)

目的

ViTにおける「パッチ分割+全結合」が「strideを使った畳み込み」に等しいというのは以前示したことがあるがMLP-Mixerにおける「パッチ方向の全結合」(token-mixing MLPs)についてこれが一体何者なのか考える。

image.png

KerasでMLP-Mixerのモデルを考えてみる。これはChannels lastである。
最初にモデルの入力は$(batch,224,224,3)$である。
これをパッチ分割すると$(batch,14*14,16*16*3)=(batch,196,768)$
これをパッチ方向の全結合を考えた場合、重み行列の次元は$(196,196)$
これをチャンネル方向の全結合を考えた場合、重み行列の次元は$(768,768)$
パッチ方向の全結合(token-mixing MLPs)とチャンネル方向の全結合(channel-mixing MLPs)を交互に繰り返していくというのがMLP-Mixerである。
典型的なモデルは以下の様に書ける。skip-connectionと活性化関数geluについては省略した。

    # MLP-Mixer like model
    input = Input(shape=(224, 224, 3))
    x = Reshape((14,16,14,16,3))(input) #(224,224,3)=>(14,16,14,16,3)
    x = Permute((1,3,2,4,5))(x)         #(14,16,14,16,3)=>(14,14,16,16,3)
    x = Reshape((14*14,16*16*3))(x)     #(14,14,16,16,3)=>(14*14,16*16*3)=>(196,768)
    x = Dense(768, activation='relu')(x)
    for i in range(6):
        x = LayerNormalization()(x)
        x = Permute((2,1))(x)           #(196,768)=>(768,196)
        x = Dense(196, activation='relu')(x)
        x = Dense(196, activation='relu')(x)
        x = Permute((2,1))(x)           #(768,196)=>(196,768)
        x = LayerNormalization()(x)
        x = Dense(768, activation='relu')(x)
        x = Dense(768, activation='relu')(x)
    x = GlobalAveragePooling1D()(x)
    y = Dense(10, activation='softmax')(x)
    model = Model(input, y)

説①:Self-Attention重み

MLP-Mixerの前身であるViT(Vision Transformer)を考えた時、パッチ分割した後のデータ$(batch,196,768)$とすれば、Attention重みの次元は$(batch,196,196)$で与えられることになる。
さて、MLP-Mixerにおけるパッチ方向の全結合の重みは$(196,196)$であるから、これはViTのAttention重み$(batch,196,196)$と次元が似ている。

QK^T=(batch,196,768)*(batch,768,196)=AttentionWeight=(batch,196,196) \\
AttentionWeight*V=(batch,196,196)*(batch,196,768)=(batch,196,768)
PatchwiseFullyconnectWeight=(196,196)\\
V^T * PatchwiseFullyconnectWeight=(batch,768,196)*(196,196)=(batch,768,196)

実際にはViTのAttention重みは入力に依存する重み行列であり、MLP-Mixerにおけるパッチ方向の全結合の重みは入力に依存しない重みに過ぎない。しかし、MLP-MixerはViTからAttention計算が除かれているのを考えるとパッチ方向の全結合はViTのSelf-Attention重みの代替なのではないかという考えが浮かぶ。
MLP-MixerにおいてAttentionはさほど重要でないので消えたという意見もあるが、パッチ方向の全結合の重みがAttention重みの代替だと考えるなら、簡略化されたAttention構造はパッチ方向の全結合としてまだ残っているという考えを持つことも可能である。

説②:ShuffleNet

MNISTの(従来の意味での)MLPではflatten関数によって入力$(batch,28,28)$を$(batch,784)$次元にして全結合を計算していく。これの入出力が同じの全結合重みは$(784,784)$次元なのでパラメータ数はMLPにおいても高々知れている。
一方、入力$(batch,224,224,3)$のMLPではflatten関数によって$(batch,150528)$にして入出力が同じの全結合を考えると$(150528,150528)$の次元になり1パラメータ1バイトとしても$150528*150528=22,658,678,784=21.1GB$となってメモリに全結合の重みを保持できない。

そこで$(batch,196,768)$と分割して、パッチ方向の全結合の重みは$(196,196)$、チャンネル方向の全結合の重みは$(768,768)$と分割すればメモリの使用量を抑えることができる。全結合を行う次元にいわゆるグループ化して、異なるグループ内で全結合を行うことで全結合のサイズを小さくしてパラメータを節約する。

これに既視感がある。ShuffleNetである。
例えば$(batch,H,W,144)$を全結合(1x1Conv)すると全結合重みは$(144,144)$必要だが、$(batch,H,W,4,36)$と変形して全結合重み$(36,144)$で全結合計算(グループ数$4$)。次に次元を入れ替えて$(batch,H,W,36,4)$として全結合重み$(4,144)$で全結合計算(グループ数$36$)とした場合、必要となるパラメータ数が減る。
ShuffleNetにおけるチャンネルシャッフルは$(batch,H,W,4,36)$次元に並んでるデータを$(batch,H,W,36,4)$次元に並べ替える処理に等しく、グループ畳み込みはそれぞれの畳み込み(1x1の場合は全結合)の数をグループの数分の1に減らす。

image.png
さて、この次元を分割、全結合、次元をシャッフル、シャッフル後次元で全結合という手順はShuffleNetとMLP-Mixerの類似性を思い浮かべさせる。ShuffleNetはあくまでチャンネル方向の全結合の計算量を減らすためで、空間方向の全結合には進まなかったが、次元を落とした空間方向に全結合を行うのがMLP-Mixerと考えることができる。しかし、ShuffleNetのパラメータ減らし方が本当に万能であればもっと後続のモデルが出来たはずだが、…現実には後追いモデルは作られなかった。
チャンネルシャッフルが効果を発揮する状況は限られる、という事である。

説③:DepthwiseConv2D

ConvMixerではパッチ方向の全結合の代わりにDepthwiseConv2Dが使用される。
$(batch,224,224,3)$を$(batch,14,14,768)$にパッチ分割した状態を考える。
ここでカーネルサイズ$(14,14)$の入力チャンネル数$768$、出力チャンネル数$768$の通常のConv2Dを考えた場合、パラメータ数は$196*768*768$となる。
一方、カーネルサイズ$(14,14)$のDepthwiseConv2Dを考えた場合、パラメータ数は$196*768$となる。
カーネルサイズ$(9,9)$のDepthwiseConv2Dを考えた場合、パラメータ数は$81*768$となる。
MLP-Mixerのパッチ方向の全結合のパラメータ数は$196*196$である。
これはパッチ方向の全結合はDepthwiseConv2Dのパラメータ数とは次元が異なるが、Conv2DよりはDepthwiseConv2Dの方が近くなる。従ってこれをDepthwiseConv2Dで代替しようというのがConvMixerであると思われる。
image.png
ConvMixerではパッチ方向の全結合の代わりに、DepthwiseConv2Dを用いているが、これは処理的に等しくはないが、パラメータ数の大きさ的に同じぐらいなので代替として使っても問題ないのかもしれない。
image.png
ところでConvMixerは画像サイズ約1/8までパッチ分割で小さくしてから以降、DepthwiseConv2DとPointwiseConv2Dを交互に用いるモデルである事が示されている。さて、この構造を見て丁度思い出すモデルがEfficientNetV2である。
入力から1/8までstrideありのConv2Dと全結合(1x1畳み込み)のみで作られ、1/8サイズ以降はDepthwiseConv2DとPointwiseConv2Dの交互の繰り返しであるというのはConvMixerとEfficientNetV2の両方に言える事である(最初のパッチ分割がstride=7、カーネルサイズ=(7,7)の畳み込みと見なせる)。ConvMixerはEfficientNetV2を参考にはしていないのだろうが、結果的にはEfficientNetV2の構造に似ているように思う。逆に言えば、EfficientNetV2の「DepthwiseConv2D」を「パッチ方向の全結合」に置き換えればMLP-Mixerに近くなるということである。
image.png
image.png

説④:入力チャンネル1のConv2D

チャンネル方向の大きさを1とした行列$(196,768,1)$とした場合、パッチ方向の全結合を実装するConv2Dを考えると以下の様になる。この計算によって$(196,768,1)$が$(1,768,196)$となり、この時のパラメータ数は$196*196$である。

    x = Reshape((14*14,16*16*3,1))(x)   #(14,14,16,16,3)=>(196,768,1)
    x = Conv2D(196, (196,1), activation='relu')(x)    # (196,768,1)=>(1,768,196)
    x = Permute((3,2,1))(x)                           # (1,768,196)=>(196,768,1)
    x = Conv2D(196, (196,1), activation='relu')(x)    # (196,768,1)=>(1,768,196)

また、MLP-Mixerの論文中に..., and single-channel depth-wise convolutions of a full receptive field and parameter sharing for token mixing.と書かれているのだが、自分にはパッチ方向の全結合と等価なレイヤーを作成するのにわざわざDepthwiseConv2Dを用いる必要性が分からなかった。DepthwiseConv2Dならチャンネル数の変化はなく、入力チャンネルが1なら出力チャンネルも1となる筈で全結合を示すには多数のレイヤーの結合が必要になってしまう。
上に示すようにチャンネル1の入力ならばConv2D演算の方が全結合と等価なレイヤーになる筈である。

まとめ

「パッチ方向の全結合」(token-mixing MLPs)について、Self-Attention重み、チャンネルシャッフル、DepthwiseConv2Dなどの関連性について考察した。
MLP-Mixerの論文にそう書いてあるわけではなくあくまで自分の勝手な考察である。

参考

12
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
12