LoginSignup
3
4

More than 3 years have passed since last update.

Non-Local(Self-Attention)の亜種

Posted at

目的

Non-Local(Self-Attention)の亜種についてどのような立場から軽量化しているかいくつかまとめてみる。
適当に論文から図を引っ張ってくるだけなので内容の正確性に関しては保証しません。

1.Self-Attention

Q Q Q K K K V V V
header row column header row column header row column
Self-attention B len emb B len emb B len emb
multi-head Self-attention B×h len emb/h B×h len emb/h B×h len emb/h
QK QK QK QKV QKV QKV
header row column header row column
Self-attention B len len B len emb
multi-head Self-attention B×h len len B×h len emb/h

自然言語系のデータ並びを$(B,len,emb)$とすれば、$Q=(B,len,emb),K^T=(B,emb,len)$であるからAttentionマップ$(QK^T)$は$(B,len,len)$の形で、各単語間の相関を示す。これに$V=(B,len,emb)$に掛ければ$QK^TV$は元の$(B,len,emb)$と同じ形になる。
一般に自然言語系において文章の長さ$len$は高々知れており$(len<emb)$、計算量は大したことない。
multi-head Self-attentionはheaderで括る1header当たりの計算量は1/hになる代わりに、headerの数はh倍になるのでAttention計算のmulti-headによる演算総量は変わりない。headerで括る方が結果が良くなるのが知られている。

2.Non-local

Q Q Q K K K V V V
header row column header row column header row column
Non-local B H×W C B H×W C B H×W C
multi-head Non-local B×h H×W C/h B×h H×W C/h B×h H×W C/h
QK QK QK QKV QKV QKV
header row column header row column
Non-local B H×W H×W B H×W C
multi-head Non-local B×h H×W H×W B×h H×W C/h

Non-localではデータを$(B,H,W,C)$とすれば、Attentionマップ$(QK^T)$は$(B,H×W,H×W)$の形で空間相関を示す。これに$V$に掛ければ$QK^TV$は元の$(B,H,W,C)$と同じ形になる。
しかし、Self-Attentionと比較した場合、画像のサイズが大きいと$(B,H×W,H×W)$が非常に大きな値になる問題がある。
この為、U-netの谷位置などの画像サイズが小さい場所$(H×W< C)$にしかNon-localを挿入することができない。画像サイズが大きい場所では$(H×W>>C)$Non-localのAttentionマップが大きすぎる。
Non-local Neural Networks

3.Q,Kのチャンネル数を減らす

Q Q Q K K K V V V
header row column header row column header row column
Non-local B H×W C B H×W C B H×W C
reduce channel Non-local B H×W C/g B H×W C/g B H×W C
QK QK QK QKV QKV QKV
header row column header row column
Non-local B H×W H×W B H×W C
reduce channel Non-local B H×W H×W B×h H×W C

最も簡単な計算量を減らす処理は全結合処理において$Q,K$のチャンネル数を減らす処理である。$QK^T$を計算する場合、$Q$と$K$のチャンネル数は等しい必要があるので等しく減らす必要がある。
また、$V$のチャンネル数は減らす必要がない。

4.KVを先に計算する(H,WとCを入れ替える)

Q Q Q K K K V V V
header row column header row column header row column
channel non-local B H×W C B H×W C B H×W C
Double Attention B H×W N B H×W N B H×W M
KV KV KV QKV QKV QKV
header row column header row column
channel non-local B C C B H×W C
Double Attention B N M B H×W M

Non-localでは$QK^T$を計算して、$softmax(QK^T)$を求め、$softmax(QK^T)V$を計算する。
しかし、スケーリングの$softmax$を$1/N$などで代替できるなら$QK^TV$の演算は先に$K^TV$を求めて、これに$Q$を掛けて$QK^TV$と計算してもよい。
以下の論文はnon-localの$QK^T$にsoftmaxを掛けないので、演算の順番を変えることができると考えられる。$(H×W>>C)$でも計算量は多くならない。
また、これは$Q,K,V$の$C$と$H×W$を入れ替えたと見なしても等価である。

Efficient Attention
image.png

Dual Attention
image.png

Double Attention
image.png
image.png

5.HかWのみのAttentionを得る

Q Q Q K K K V V V
header row column header row column header row column
CrissCrossAttention1 B×W H C/8 B×W H C/8 B×W H C
CrissCrossAttention2 B×H W C/8 B×H W C/8 B×H W C
AxialAttention1 B×W×g H C/g/2 B×W×g H C/g/2 B×W×g H C/g
AxialAttention2 B×H×g W C/g/2 B×H×g W C/g/2 B×H×g W C/g
QK QK QK QKV QKV QKV
header row column header row column
CrissCrossAttention1 B×W H H B×W H C
CrissCrossAttention2 B×H W W B×H W C
AxialAttention1 B×W×g H H B×W×g H C/g
AxialAttention2 B×H×g W W B×H×g W C/g

CrissCrossAttentionとAxialAttentionは非常によく似ている。
つまりH×W=>HかWにして二通りのnon-localを求めている。
この場合のAttentionマップは$(B×W,H,H)$と$(B×H,W,W)$である。
違いは二個のAttentionの掛け方であり、CrissCrossAttentionは二個のAttentionを並列に足し、AxialAttentionは二個のAttentionを直列に掛ける。
CrissCrossAttention
AxialAttention

自分のイメージ:
image.png

6.K,Vに同じpooling関数を掛ける

Asymmetric Non-localは$K,V$にpyramid poolingというDown sampleが入っている。例えば出力形状を1×1にするのはglobal poolingである。おそらくglobal poolingだけだと空間相関を伝えるには情報量が少なすぎるので、複数のpooling処理を使ったNon-localが考えられる。
Asymmetric Non-localは以下のAsymmetric Non-local1~4の和である。

Q Q Q K K K V V V
header row column header row column header row column
Asymmetric Non-local1 B H×W C B 1×1 C B 1×1 C
Asymmetric Non-local2 B H×W C B 3×3 C B 3×3 C
Asymmetric Non-local3 B H×W C B 6×6 C B 6×6 C
Asymmetric Non-local4 B H×W C B 8×8 C B 8×8 C
QK QK QK QKV QKV QKV
header row column header row column
Asymmetric Non-local1 B H×W 1×1 B H×W C
Asymmetric Non-local2 B H×W 3×3 B H×W C
Asymmetric Non-local3 B H×W 6×6 B H×W C
Asymmetric Non-local4 B H×W 8×8 B H×W C

Asymmetric Non-local
image.png
image.png

7.1.Self-Attention (restricted)

TransformerではAttentionマップは$(len,len)$であるが、$len$が長くなってくると実際にはそんなに長距離の単語の相関を求める必要があるか?という疑問が生まれる。あまりに離れた単語間の相関は通常は非常に薄いからである。文章の単語の相関を単語から$k$長さに制限するならAttentionマップは$(len,k)$次元でもいいのではという発想が生じる。
image.png
"Attention Is All You Need"にも以下のテーブルがある。
Self-Attention (restricted)は計算する相関距離を制限したものと考えられる。
image.png
(ただ、このテーブルからSelf-Attention (restricted)がConvolutionより優れていると決めつけることは出来ない。何故ならDepthwiseConvは$Ο(k\cdot n \cdot d)$であるからである)

7.2.Unfold関数を使う

Unfold関数(im2col関数)を$(B,H,W,C1)$に使うとフィルターサイズが$k=3$なら出力は$(B,H,W,9×C1)$となる。Unfold関数自体は単なる処理関数(サンプリング関数)で重みの定義は必要ない。これに全結合(1×1畳み込み)掛ければ$(9×C1,C2)$の重み行列を掛けて出力は$(B,H,W,C2)$となる。
これは$3×3$の畳み込みフィルターが$(C1,C2)$個あるのと等しい。
すなわちUnfold関数(im2col関数)+全結合はConv2D関数と等価であることが分かる。
(ちなみにUnfold関数はpytorchでのim2col関数の呼び方。tensorflowではtf.image.extract_image_patches関数)

このUnfold関数をNon-localに使用する。
入力$(B,H,W,C)$にフィルターサイズ$k$のUnfold関数を掛けた場合、出力は$(B,H,W,k×k×C)$という形状になる。$K,V$の形状を変形させてNon-localを考えるとこれは、$K,V$の$H×W$が$k×k$と制限されるNon-localを考えるのと等しくなるだろう。

Q Q Q K K K V V V
header row column header row column header row column
Non-local B H×W C B H×W C B H×W C
Non-Local Recurrent Network B×H×W 1 C/m B×H×W k×k C/m B×H×W k×k C
QK QK QK QKV QKV QKV
header row column header row column
Non-local B H×W H×W B H×W C
Non-Local Recurrent Network B×H×W 1 k×k B×H×W 1 C

Non-Local Recurrent Network
image.png
Local Relation Networks
image.png

8.Aggregation

Aggregationの定義
実装の仕様を見た場合、入力$(B,C,H,W)$、重み$(B,1,m,k×k,H×W)$とした時、
Unfold関数で入力$(B,C,H,W)$を$(B,C/m,m,k×k,H×W)$に変換して、重み$(B,1,m,k×k,H×W)$を掛けて$k×k$の軸を足して出力は$(B,C,H,W)$となる。
ここで重み$(B,1,m,k×k,H×W)$は一般に$QK^T$の結果であるから以下の変換になる。

Aggregation(x,w)=w*Unfold(x)=(QK^T)*Unfold(V)

また、$K$にもUnfold関数は掛かっているから、このようなNon-localは$K,V$に同じUnfold関数を掛けるパターンと見なせる。

    x = torch.randn(n, c_x, in_height, in_width, requires_grad=True).double().cuda()
    w = torch.randn(n, c_w, pow(kernel_size, 2), out_height * out_width, requires_grad=True).double().cuda()

    y1 = aggregation_zeropad(x, w, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
    unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
    x2 = unfold_j(x).view(n, c_x // c_w, c_w, pow(kernel_size, 2), out_height * out_width)
    y2 = (w.unsqueeze(1) * x2).sum(-2).view(n, c_x, out_height, out_width)
    assert (y1 - y2).abs().max() < 1e-9

image.png
Exploring Self-attention for Image Recognition

9.パッチ分割(Vision Transformer)

Vision Transformerはパッチ分割という処理が入る。
これは$(B,H,W,C)$を$(B,H/p,p,W/p,p,C)$から$(B,H×W/p/p,C×p×p)$とするReshapeとtransposeの恒等変換と等しい。
また、パッチ分割はフィルターサイズとストライドサイズが$p$であるUnfold関数を掛ける処理とも等しい。Vision Transformerでは一番最初にパッチ分割が行われ、途中のNon-local処理に対して入らない。
Unfold関数+全結合がConv2D関数と等価であることを思い出せば、一番最初のパッチ分割+全結合はフィルターサイズ、ストライドサイズが大きいConv2Dが掛けられるのと等しいようにも考えられる。

Q Q Q K K K V V V
header row column header row column header row column
Non-local B H×W C B H×W C B H×W C
ViT B×h H×W/p/p C×p×p/h B×h H×W/p/p C×p×p/h B×h H×W/p/p C×p×p/h
QK QK QK QKV QKV QKV
header row column header row column
Non-local B H×W H×W B H×W C
ViT B×h H×W/p/p H×W/p/p B×h H×W/p/p C×p×p/h

Vision Transformer
image.png

10.Simplified NL

Simplified NLはシンプルなNon-localであるが、これがNon-localの一種であるのかはよく分からない。
relationが$concat[q,k]$な場合か、後述の11.position embeddingsに近いのかと解釈できないかとも考えてみたが微妙に違ってる気がする為、Simplified NLがどういう立ち位置なのかうまく表現できない。
image.png
GCNet

11.位置の埋め込み(position embeddings)

Transformerでは位置情報はsin,cos的なembeddings方法だったが、画像系においては相対位置情報が良く使われる。
これの扱いは論文によって異なる。以下の論文によれば$q^T r$だったり$\mu ^T k$だったり$E_{nm}V^T$だったりする。
いずれにしても$QK^T$(lambda layerでは$KV^T$)と並列に足し合わせる重みであるという事は共通している。

Stand-Alone Self-Attention in Vision Models
Disentangled Non-Local
Lambda Network
image.png
image.png
image.png

Local Relation NetworksやAxialAttention、Patchwise SANでは$q^T r$ではなく更に異なるが、それについては以下の記事が詳しい。
画像認識でもConvolutionの代わりにAttentionが使われ始めたので、論文まとめ
Local Relation Networksでは$(\Phi (q,k)+r)v$、
AxialAttentionでは$(qk^T + qr_k^T + r_qk^T)(v+r_v)$となる。

一方でViTでは最初の入力にposition embeddingを行う代わりにNon-local計算部分では$q^Tr$を考えない。これは$q,k,v$それぞれにposition embeddingを行う結果の$((q+r)^T(k+r))(v+r)$を包含していると自分は考えている。
従来の$q^Tr$は$k$のみposition embeddingを行う結果の$(q^T(k+r))v$と考えられる。

12.Lambda Network

Lambda NetworkではNon-Local重み$\lambda c$は$K^TV$で、これと並列に$\lambda p$は$Conv2D(V)$または$EV^T$を足している。
後者の場合はposition embeddingsの項がNon-Localと並列に挟まるのは11.で説明したが、$EV^T$を$Conv2D(V)$に置き換えて計算できる理由は理解していない。
Lambda Network

class LambdaLayer(nn.Module):
    def __init__(self, dim, *, dim_k, n=None, r=None, heads=4, dim_out=None, dim_u=1):
        ...
        if self.local_contexts:
            self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
        else:
            self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))

    def forward(self, x):
        ...
        λc = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b n h v', q, λc)

        if self.local_contexts:
            v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
            λp = self.pos_conv(v)
            Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
        else:
            λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
        Y = Yc + Yp

https://gist.github.com/PistonY/ad33ab9e3d5f9a6a38345eb184e68cb4
image.png

13.AutoNL

4.で示したようにsoftmaxが掛からない場合は$QK^TV$を$(QK^T)$に$V$を掛けても$(K^TV)$に$Q$を計算してもよい。計算の順によって計算量が変わるがどちらの方が小さくなるかは条件による。

AutoNL
image.png
入力が$(B,H×W,C)$の場合、
$(QK^T)V$は演算量が$Ο(2*(H×W)^2C)$
$Q(K^TV)$は演算量が$Ο(2*H×W(C)^2)$である。

$Q,K$のチャンネル数を$C/g$に削減し、$stride=s$によって$K,V$の画像サイズを小さくすれば
$Q=(B,H×W,C/g),K=(B,H×W/s/s,C/g),V=(B,H×W/s/s,C)$の場合

$(QK^T)V$は演算量が$Ο((H×W)(H×W/s/s)C/g+(H×W)(H×W/s/s)C)$
$Q(K^TV)$は演算量が$Ο((H×W/s/s)(C)(C/g)+(H×W)(C)(C/g))$である。
AutoNLは演算量によって$(QK^T)V$か$Q(K^TV)$のどちらから計算するか選択する。

    if (H * W) * reduced_HW * n_in * (1 + nl_ratio) < (
            H * W) * n_in ** 2 * nl_ratio + reduced_HW * n_in ** 2 * nl_ratio or softmax:
        f = tf.einsum('nabi,ncdi->nabcd', theta, phi)
        f = tf.einsum('nabcd,ncdi->nabi', f, g)
        macs = (H * W) * reduced_HW * n_in * (1 + nl_ratio)
    else:
        f = tf.einsum('nhwi,nhwj->nij', phi, g)
        f = tf.einsum('nij,nhwi->nhwj', f, theta)
        macs = (H * W) * n_in ** 2 * nl_ratio + reduced_HW * n_in ** 2 * nl_ratio

まとめ:

Non-local(Self-Attention)の亜種の整理をした。これら以外にも多く存在する。
Transformerが自然言語方面で成功し、Vision Transformerが当時のSoTAを更新し、LambdaNetはEfficientNetよりも高速だったことからNon-local(Self-Attention)が注目された。

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4