0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Sliding Tile Attentionについて

Posted at

以下の論文を読む。僅かな性能低下でHunyuanVideoのAttn高速化が述べられている。
論文にあるSliding Tile AttentionのAttention mapを描いてみたい。

3D Attentionの冗長性

Videoを作るHunyuanVideoモデルでは動画時間(生成フレーム)が長くなるとAttentionの計算時間がtoken数に対して2次で増えるのでAttentionの計算時間が支配的になる。例えば10kのtokenの計算時間を$100$とした時、115kの計算時間は$5400.4$となり、この5400.4のうちFFNが$43.2\cdot \frac{115}{10}=496.8$、QKVが$21.6\cdot \frac{115}{10}=248.4$、ATTNが$35.2\cdot (\frac{115}{10})^2=4655.2$となる。
ここで10kとはlatent次元で50x50x4(動画次元で400x400x17)、115kとはlatent次元で60x80x24(動画次元で480x640x97)であろう。

しかし、実のところ時間的や空間的に隣接するToken間でのAttention項は意味があるが、遠く離れたToken間でのAttentionを計算してもほとんど意味をなさない。
要するにText2Videoというタスクにおいてはこれらの遠く離れたAttention計算は計算してもその影響度はほとんどゼロに近くこれを省略しようが、性能の劣化はあまりないらしい。

image.png

token_numに対してFFNとQKVは一次関数的に増え、ATTNは二次関数的に増えるとし、10kの時の計算量を100とする。動画のサイズがある程度以上になるとATTN計算が支配的になる。

token_num FFN QKV ATTN 合計 latent_size videosize
10k 43.2
(43.2%)
21.6
(21.6%)
35.2
(35.2%)
100 50x50x4 400x400x17
15k 64.8
(36.7%)
32.4
(18.4%)
79.2
(44.9%)
176.4 50x75x4 400x600x17
19.2k 82.9
(32.6%)
41.5
(16.3%)
129.8
(51.1%)
254.2 60x80x4 480x640x17
20k 86.4
(32.0%)
43.2
(16.0%)
140.8
(52.1%)
270.4 50x50x8 400x400x33
38.4k 165.9
(21.6%)
82.9
(10.8%)
519.0
(67.6%)
767.9 60x80x8 480x640x33
57.6k 248.8
(16.1%)
124.4
(8.1%)
1167.9
(75.8%)
1541.1 60x80x12 480x640x49
76.8k 331.8
(12.9%)
165.9
(6.4%)
2076.2
(80.7%)
2573.8 60x80x16 480x640x65
115k 496.8
(9.2%)
248.4
(4.6%)
4655.2
(86.2%)
5400.4 60x80x24 480x640x97
153.6k 663.6
(7.1%)
331.8
(3.6%)
8304.7
(89.3%)
9300.0 60x80x32 480x640x129

NATTEN

このようなAttentionの冗長性を取り除くのにNeighborhood Attention(NATTEN)があり、以下のような図が知られている。しかし、この図を見ただけでは背景がよく分からないのでどうやってこの図を描いたのかがいま一つピンとこない。プログラム的にこの図をどうやって得ているのかを考えたい。

image.png

まず最初に画像サイズ(latentサイズ)が24x24とし、Attentionが有効となる距離を12x12とする。以下のようなプログラムを書けばNATTENのマップは書ける。

image.png

import numpy as np
import matplotlib.pyplot as plt

index = np.arange(24*24).reshape(24,24)
map = np.zeros((24*24,24*24))

for i in range(24*24):
    x, y = i%24, i//24
    x = np.clip(x, 6, 18)
    y = np.clip(y, 6, 18)
    window_index = index[y-6:y+6,x-6:x+6].flatten()
    map[i,window_index] = 1.0

u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()

Morton curve

次にFlexAttentionの図を見るとMorton Curveなる単語が見られる。
これは上述プログラムのindexカウントを改良するためのものだろう。現行のindexは画像端まで数えて折り返してるのでblock_window上のAttention範囲を考えると数字の範囲が飛びがちである。

image.png

参考までにspace-filling curveの資料を示しておくがZ-order_curve(Morton)は24x24をindex = np.arange(24*24).reshape(3,3,2,2,2,2,2,2).transpose(0,2,4,6,1,3,5,7).reshape(24,24)とするのに等しい。このように二次元空間的indexと空間距離を縮める並べ方の工夫について調べたが、別にTokenの並べ方で解決している訳ではなさそうである。(参考程度)

import numpy as np
index = np.arange(24*24).reshape(3,3,2,2,2,2,2,2).transpose(0,2,4,6,1,3,5,7).reshape(24,24)
print(index)
--------------------------------------------------
[[  0   1   4   5  16  17  20  21  64  65  68  69  80  81  84  85 128 129
  132 133 144 145 148 149]
 [  2   3   6   7  18  19  22  23  66  67  70  71  82  83  86  87 130 131
  134 135 146 147 150 151]
 [  8   9  12  13  24  25  28  29  72  73  76  77  88  89  92  93 136 137
  140 141 152 153 156 157]
 [ 10  11  14  15  26  27  30  31  74  75  78  79  90  91  94  95 138 139
  142 143 154 155 158 159]
 [ 32  33  36  37  48  49  52  53  96  97 100 101 112 113 116 117 160 161
  164 165 176 177 180 181]
 [ 34  35  38  39  50  51  54  55  98  99 102 103 114 115 118 119 162 163
  166 167 178 179 182 183]
 [ 40  41  44  45  56  57  60  61 104 105 108 109 120 121 124 125 168 169
  172 173 184 185 188 189]
 [ 42  43  46  47  58  59  62  63 106 107 110 111 122 123 126 127 170 171
  174 175 186 187 190 191]
  ...

このZ-order_curve(Morton curve)の時のAttention Maskを参考までに描けば、以下のようになる。

image.png

import numpy as np
import matplotlib.pyplot as plt

index = np.arange(24*24).reshape(3,3,2,2,2,2,2,2).transpose(0,2,4,6,1,3,5,7).reshape(24,24)
map = np.zeros((24*24,24*24))

for i in range(24*24):
    min_index = np.argmin(np.abs(index-i))
    x = min_index%24
    y = min_index//24
    x = np.clip(x, 6, 18)
    y = np.clip(y, 6, 18)
    window_index = index[y-6:y+6,x-6:x+6].flatten()
    map[i,window_index] = 1.0

u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()

Tiled NATTEN

とりあえずindexを4x4のタイルを考え、その4x4タイルを6x6で敷き詰めた図で12x12のblock_windowを考えたmapが以下である。この図がTiled NATTENと等しいのが確認できる。

image.png

import numpy as np
import matplotlib.pyplot as plt

index = np.arange(24*24).reshape(6,6,4,4).transpose(0,2,1,3).reshape(24,24)
map = np.zeros((24*24,24*24))

for i in range(24*24):
    block_index = i//16
    block_inner_index = i%16
    inner_x, inner_y = block_inner_index%4, block_inner_index//4
    block_x, block_y = 4*(block_index%6), 4*(block_index//6)
    x = block_x + inner_x
    y = block_y + inner_y
    x = np.clip(x, 6, 18)
    y = np.clip(y, 6, 18)
    window_index = index[y-6:y+6,x-6:x+6].flatten()
    map[i,window_index] = 1.0

u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()

STA(Sliding Tile Attention)

ここまで考えたらSTAが何なのかは何となく理解してきた。要するに前述のコードのinner_xとinner_yを無視すれば綺麗にAttention領域をblock状にして計算できる筈だ。(4以下のwindow幅への正確性はある程度減るが)。適当に4の倍数window取り出し範囲を-4~+8に変えれば以下のような図が描ける。

image.png

import numpy as np
import matplotlib.pyplot as plt

index = np.arange(24*24).reshape(6,6,4,4).transpose(0,2,1,3).reshape(24,24)
map = np.zeros((24*24,24*24))

for i in range(24*24):
    block_index = i//16
    block_x, block_y = 4*(block_index%6), 4*(block_index//6)
    x = block_x
    y = block_y
    x = np.clip(x, 4, 16)
    y = np.clip(y, 4, 16)
    window_index = index[y-4:y+8,x-4:x+8].flatten()
    map[i,window_index] = 1.0

u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()

多段STA

4x4のタイルが3x3に並んでいてこれが更に2x2で並んでいるとする。
周期的ではないが多少無視すれば簡単な図形に近似しやすく見える。後述のpatchfyはこれに近いのかと思う。
image.png

import numpy as np
import matplotlib.pyplot as plt

index = np.arange(24*24).reshape(2,2,3,3,4,4).transpose(0,2,4,1,3,5).reshape(24,24)
map = np.zeros((24*24,24*24))

for i in range(24*24):
    min_index = np.argmin(np.abs(index-i))
    x = min_index%24
    y = min_index//24
    x = 4*(x//4)
    y = 4*(y//4)
    x = np.clip(x, 4, 16)
    y = np.clip(y, 4, 16)
    window_index = index[y-4:y+8,x-4:x+8].flatten()
    map[i,window_index] = 1.0

u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()

Hilbert curve

Hilbert curveでAttention Mapを考えてみる。Hilbeltは$2^N$である必要があり、24x24を考えることが出来ないのでここでは32x32のindexの並び順を操作する。windowサイズを16とした時、このコードは以下のように書ける。STAのようにここでは4以下の空間座標を省略する。

image.png

import numpy as np
import matplotlib.pyplot as plt
from hilbertcurve.hilbertcurve import HilbertCurve

index = np.arange(32*32).reshape(32,32)
hilbert_curve = HilbertCurve(5, 2)
distances = list(range(32*32))
points = hilbert_curve.points_from_distances(distances)
for point, dist in zip(points, distances):
    index[point[0], point[1]] = int(dist)
print(index)
map = np.zeros((32*32,32*32))

for i in range(32*32):
    min_index = np.argmin(np.abs(index-i))
    x = min_index%32
    y = min_index//32
    x = 4*(x//4)
    y = 4*(y//4)
    x = np.clip(x, 8, 24)
    y = np.clip(y, 8, 24)
    window_index = index[y-8:y+8,x-8:x+8].flatten()
    map[i,window_index] = 1.0

u = np.arange(32*32)
v = np.arange(32*32)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()
-----------------------------------
[[   0    1   14 ...  339  340  341]
 [   3    2   13 ...  338  343  342]
 [   4    7    8 ...  349  344  345]
 ...
 [1019 1016 1015 ...  674  679  678]
 [1020 1021 1010 ...  685  680  681]
 [1023 1022 1009 ...  684  683  682]]

x = 4*(x//4),y = 4*(y//4)をコメントアウトすれば以下である。フラクタル的なAttentionMapが見れる。このHilbert curveは対角近傍のAttentionマップに近いので入力Tokenの適切な並べ替えのみで対角近傍計算でよい効率的なAttentionMapを生成する道もあるかもしれない。

image.png

image.png

patchfy(Lumina-video)

複雑なAttentionマップを考えずともLumina-videoではもっと簡単なpatchfyでAttentionの及ぶ長さを調整しているのかと思う。多分だけれども。

image.png
image.png

推論stepによる依存性

image.pngimage.png

Self-Attention(spatial,temporal Attention)の重要度は推論step(t)においても変動するのでこれも参考に設定を変えるべきなのかもしれない。

その他参考:

まとめ

動画生成モデルの論文を眺めた時、その生成のAttn計算を高速化する論文を幾つか見た。
FlashAttnetion2,3とかsdpa(Scaled Dot-Product Attention)とかsage attentionのAttention自体の高速化とは別口で、Attentionの計算領域を空間的や時間的に制限してそれを抑えるようだ。

その論文にある図が自分には直感的には理解できなかったので理解するためにAttntion Mapを自分で描いてみた。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?