以下の論文を読む。僅かな性能低下でHunyuanVideoのAttn高速化が述べられている。
論文にあるSliding Tile AttentionのAttention mapを描いてみたい。
3D Attentionの冗長性
Videoを作るHunyuanVideoモデルでは動画時間(生成フレーム)が長くなるとAttentionの計算時間がtoken数に対して2次で増えるのでAttentionの計算時間が支配的になる。例えば10kのtokenの計算時間を$100$とした時、115kの計算時間は$5400.4$となり、この5400.4のうちFFNが$43.2\cdot \frac{115}{10}=496.8$、QKVが$21.6\cdot \frac{115}{10}=248.4$、ATTNが$35.2\cdot (\frac{115}{10})^2=4655.2$となる。
ここで10kとはlatent次元で50x50x4(動画次元で400x400x17)、115kとはlatent次元で60x80x24(動画次元で480x640x97)であろう。
しかし、実のところ時間的や空間的に隣接するToken間でのAttention項は意味があるが、遠く離れたToken間でのAttentionを計算してもほとんど意味をなさない。
要するにText2Videoというタスクにおいてはこれらの遠く離れたAttention計算は計算してもその影響度はほとんどゼロに近くこれを省略しようが、性能の劣化はあまりないらしい。
token_numに対してFFNとQKVは一次関数的に増え、ATTNは二次関数的に増えるとし、10kの時の計算量を100とする。動画のサイズがある程度以上になるとATTN計算が支配的になる。
token_num | FFN | QKV | ATTN | 合計 | latent_size | videosize |
---|---|---|---|---|---|---|
10k | 43.2 (43.2%) |
21.6 (21.6%) |
35.2 (35.2%) |
100 | 50x50x4 | 400x400x17 |
15k | 64.8 (36.7%) |
32.4 (18.4%) |
79.2 (44.9%) |
176.4 | 50x75x4 | 400x600x17 |
19.2k | 82.9 (32.6%) |
41.5 (16.3%) |
129.8 (51.1%) |
254.2 | 60x80x4 | 480x640x17 |
20k | 86.4 (32.0%) |
43.2 (16.0%) |
140.8 (52.1%) |
270.4 | 50x50x8 | 400x400x33 |
38.4k | 165.9 (21.6%) |
82.9 (10.8%) |
519.0 (67.6%) |
767.9 | 60x80x8 | 480x640x33 |
57.6k | 248.8 (16.1%) |
124.4 (8.1%) |
1167.9 (75.8%) |
1541.1 | 60x80x12 | 480x640x49 |
76.8k | 331.8 (12.9%) |
165.9 (6.4%) |
2076.2 (80.7%) |
2573.8 | 60x80x16 | 480x640x65 |
115k | 496.8 (9.2%) |
248.4 (4.6%) |
4655.2 (86.2%) |
5400.4 | 60x80x24 | 480x640x97 |
153.6k | 663.6 (7.1%) |
331.8 (3.6%) |
8304.7 (89.3%) |
9300.0 | 60x80x32 | 480x640x129 |
NATTEN
このようなAttentionの冗長性を取り除くのにNeighborhood Attention(NATTEN)があり、以下のような図が知られている。しかし、この図を見ただけでは背景がよく分からないのでどうやってこの図を描いたのかがいま一つピンとこない。プログラム的にこの図をどうやって得ているのかを考えたい。
まず最初に画像サイズ(latentサイズ)が24x24とし、Attentionが有効となる距離を12x12とする。以下のようなプログラムを書けばNATTENのマップは書ける。
import numpy as np
import matplotlib.pyplot as plt
index = np.arange(24*24).reshape(24,24)
map = np.zeros((24*24,24*24))
for i in range(24*24):
x, y = i%24, i//24
x = np.clip(x, 6, 18)
y = np.clip(y, 6, 18)
window_index = index[y-6:y+6,x-6:x+6].flatten()
map[i,window_index] = 1.0
u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()
Morton curve
次にFlexAttentionの図を見るとMorton Curveなる単語が見られる。
これは上述プログラムのindexカウントを改良するためのものだろう。現行のindexは画像端まで数えて折り返してるのでblock_window上のAttention範囲を考えると数字の範囲が飛びがちである。
参考までにspace-filling curveの資料を示しておくがZ-order_curve(Morton)は24x24をindex = np.arange(24*24).reshape(3,3,2,2,2,2,2,2).transpose(0,2,4,6,1,3,5,7).reshape(24,24)
とするのに等しい。このように二次元空間的indexと空間距離を縮める並べ方の工夫について調べたが、別にTokenの並べ方で解決している訳ではなさそうである。(参考程度)
import numpy as np
index = np.arange(24*24).reshape(3,3,2,2,2,2,2,2).transpose(0,2,4,6,1,3,5,7).reshape(24,24)
print(index)
--------------------------------------------------
[[ 0 1 4 5 16 17 20 21 64 65 68 69 80 81 84 85 128 129
132 133 144 145 148 149]
[ 2 3 6 7 18 19 22 23 66 67 70 71 82 83 86 87 130 131
134 135 146 147 150 151]
[ 8 9 12 13 24 25 28 29 72 73 76 77 88 89 92 93 136 137
140 141 152 153 156 157]
[ 10 11 14 15 26 27 30 31 74 75 78 79 90 91 94 95 138 139
142 143 154 155 158 159]
[ 32 33 36 37 48 49 52 53 96 97 100 101 112 113 116 117 160 161
164 165 176 177 180 181]
[ 34 35 38 39 50 51 54 55 98 99 102 103 114 115 118 119 162 163
166 167 178 179 182 183]
[ 40 41 44 45 56 57 60 61 104 105 108 109 120 121 124 125 168 169
172 173 184 185 188 189]
[ 42 43 46 47 58 59 62 63 106 107 110 111 122 123 126 127 170 171
174 175 186 187 190 191]
...
このZ-order_curve(Morton curve)の時のAttention Maskを参考までに描けば、以下のようになる。
import numpy as np
import matplotlib.pyplot as plt
index = np.arange(24*24).reshape(3,3,2,2,2,2,2,2).transpose(0,2,4,6,1,3,5,7).reshape(24,24)
map = np.zeros((24*24,24*24))
for i in range(24*24):
min_index = np.argmin(np.abs(index-i))
x = min_index%24
y = min_index//24
x = np.clip(x, 6, 18)
y = np.clip(y, 6, 18)
window_index = index[y-6:y+6,x-6:x+6].flatten()
map[i,window_index] = 1.0
u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()
Tiled NATTEN
とりあえずindexを4x4のタイルを考え、その4x4タイルを6x6で敷き詰めた図で12x12のblock_windowを考えたmapが以下である。この図がTiled NATTENと等しいのが確認できる。
import numpy as np
import matplotlib.pyplot as plt
index = np.arange(24*24).reshape(6,6,4,4).transpose(0,2,1,3).reshape(24,24)
map = np.zeros((24*24,24*24))
for i in range(24*24):
block_index = i//16
block_inner_index = i%16
inner_x, inner_y = block_inner_index%4, block_inner_index//4
block_x, block_y = 4*(block_index%6), 4*(block_index//6)
x = block_x + inner_x
y = block_y + inner_y
x = np.clip(x, 6, 18)
y = np.clip(y, 6, 18)
window_index = index[y-6:y+6,x-6:x+6].flatten()
map[i,window_index] = 1.0
u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()
STA(Sliding Tile Attention)
ここまで考えたらSTAが何なのかは何となく理解してきた。要するに前述のコードのinner_xとinner_yを無視すれば綺麗にAttention領域をblock状にして計算できる筈だ。(4以下のwindow幅への正確性はある程度減るが)。適当に4の倍数window取り出し範囲を-4~+8に変えれば以下のような図が描ける。
import numpy as np
import matplotlib.pyplot as plt
index = np.arange(24*24).reshape(6,6,4,4).transpose(0,2,1,3).reshape(24,24)
map = np.zeros((24*24,24*24))
for i in range(24*24):
block_index = i//16
block_x, block_y = 4*(block_index%6), 4*(block_index//6)
x = block_x
y = block_y
x = np.clip(x, 4, 16)
y = np.clip(y, 4, 16)
window_index = index[y-4:y+8,x-4:x+8].flatten()
map[i,window_index] = 1.0
u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()
多段STA
4x4のタイルが3x3に並んでいてこれが更に2x2で並んでいるとする。
周期的ではないが多少無視すれば簡単な図形に近似しやすく見える。後述のpatchfyはこれに近いのかと思う。
import numpy as np
import matplotlib.pyplot as plt
index = np.arange(24*24).reshape(2,2,3,3,4,4).transpose(0,2,4,1,3,5).reshape(24,24)
map = np.zeros((24*24,24*24))
for i in range(24*24):
min_index = np.argmin(np.abs(index-i))
x = min_index%24
y = min_index//24
x = 4*(x//4)
y = 4*(y//4)
x = np.clip(x, 4, 16)
y = np.clip(y, 4, 16)
window_index = index[y-4:y+8,x-4:x+8].flatten()
map[i,window_index] = 1.0
u = np.arange(24*24)
v = np.arange(24*24)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()
Hilbert curve
Hilbert curveでAttention Mapを考えてみる。Hilbeltは$2^N$である必要があり、24x24を考えることが出来ないのでここでは32x32のindexの並び順を操作する。windowサイズを16とした時、このコードは以下のように書ける。STAのようにここでは4以下の空間座標を省略する。
import numpy as np
import matplotlib.pyplot as plt
from hilbertcurve.hilbertcurve import HilbertCurve
index = np.arange(32*32).reshape(32,32)
hilbert_curve = HilbertCurve(5, 2)
distances = list(range(32*32))
points = hilbert_curve.points_from_distances(distances)
for point, dist in zip(points, distances):
index[point[0], point[1]] = int(dist)
print(index)
map = np.zeros((32*32,32*32))
for i in range(32*32):
min_index = np.argmin(np.abs(index-i))
x = min_index%32
y = min_index//32
x = 4*(x//4)
y = 4*(y//4)
x = np.clip(x, 8, 24)
y = np.clip(y, 8, 24)
window_index = index[y-8:y+8,x-8:x+8].flatten()
map[i,window_index] = 1.0
u = np.arange(32*32)
v = np.arange(32*32)
X, Y = np.meshgrid(u, v)
plt.pcolormesh(X,Y,map)
plt.show()
-----------------------------------
[[ 0 1 14 ... 339 340 341]
[ 3 2 13 ... 338 343 342]
[ 4 7 8 ... 349 344 345]
...
[1019 1016 1015 ... 674 679 678]
[1020 1021 1010 ... 685 680 681]
[1023 1022 1009 ... 684 683 682]]
x = 4*(x//4),y = 4*(y//4)
をコメントアウトすれば以下である。フラクタル的なAttentionMapが見れる。このHilbert curveは対角近傍のAttentionマップに近いので入力Tokenの適切な並べ替えのみで対角近傍計算でよい効率的なAttentionMapを生成する道もあるかもしれない。
patchfy(Lumina-video)
複雑なAttentionマップを考えずともLumina-videoではもっと簡単なpatchfyでAttentionの及ぶ長さを調整しているのかと思う。多分だけれども。
推論stepによる依存性
Self-Attention(spatial,temporal Attention)の重要度は推論step(t)においても変動するのでこれも参考に設定を変えるべきなのかもしれない。
その他参考:
まとめ
動画生成モデルの論文を眺めた時、その生成のAttn計算を高速化する論文を幾つか見た。
FlashAttnetion2,3とかsdpa(Scaled Dot-Product Attention)とかsage attentionのAttention自体の高速化とは別口で、Attentionの計算領域を空間的や時間的に制限してそれを抑えるようだ。
その論文にある図が自分には直感的には理解できなかったので理解するためにAttntion Mapを自分で描いてみた。