LoginSignup
2
2

LLMを3倍高速にする手法「FlashAttention」を解説!

Posted at

概要

FLashAttentionはLLMの学習スピードを3倍も高速にすることができると話題のようです。その後もFlashAttentionを改良したFlashAttention2が出てきたり、FlashDecodingが出てきたり、これからますます注目が集まると思います。しかし、日本語で詳しく解説されている記事は見つからず、、、自分で色々調べてたので、その内容をまとめてみます。

簡単なまとめ

  • 最近のGPUでAttentionを計算する際のボトルネックはGPUメモリへのアクセス
  • 上記問題を解決するためにAttentionのアルゴリズムを2つの方法で改良
  • 1つ目はTileing。Q,K,Vの行列を分割して順番に計算
  • 2つ目はRecomputation, backpropagationのためにQKの行列積をGPUメモリに保存する代わりに、backpropagationの時に再度計算する
  • その結果、GPT-2の学習時間が3倍高速に
  • その後も改良版が提案されたり、llamaの学習に使われたり、ますます重要な存在に

Attention計算時のメモリアクセス

SparseAttention, ReFormer, PerFormerなどAttentionの計算を高速化する手法が存在しています。これらの方式はAttentionの計算量を減らす手法となっています。しかし、最近の非常に計算が高速なGPUでは、計算そのものではなくて、メモリアクセスがボトルネックとなります。

Attentionの計算は次に示すとおりです。 :の後に[]で囲って書いてあるのは各行列のサイズです。sはsequence length, dは特徴量のサイズを示します。

\begin{align}
& S = QK^T 
\quad : \quad
[s, s] = [s,d] * [d,s]
\\

& A = Softmax(S)
\quad : \quad
[s, s] = [s, s]
\\

& O = AV
\quad : \quad
[s,d] = [s,s] * [s,d]
\\
\end{align}

この計算の途中の結果であるSとAがsequence length * sequence lengthサイズの行列です。この行列をメモリ読み書きするのがボトルネックになるようです。

もちろん、sequence lengthが短い時はボトルネックにならないですが、sequence lengthを増やすのは最近のトレンド(GPT4は32k, Claude2は100k)でもあるし、今後ますますボトルネックになっていくようです。

FlashAttention

FlashAttentionは、TilingとRecomputationという二つの手法を使って、$S$と$A$をメモリに読み書きしなくて良いように改良されたAttentionの計算アルゴリズムです。

それぞれ、独立した手法であり、特に重要なのはTilingの方なので、まずはTilingから説明していきます。

Tiling

Tilingでは$QKV$を小さいブロックに分割して、Attentionの計算($Softmax(QK^T)V$)を行います。小さいブロックにすることで、$S$や$A$をメモリに保存することなく、SRAMに保持したまま計算をすることが可能です。

はじめに、$QKV$をそれぞれ、行方向(sequence length)で分割します。ここでは、簡単のために2分割にしています。

\begin{align}
& Q = 
\begin{bmatrix}
 Q_1\\ Q_2
\end{bmatrix}
\quad:\quad
[s,d] = 
\begin{bmatrix}
 [s/2,d]\\ [s/2,d]
\end{bmatrix}
\\
& K = 
\begin{bmatrix}
 K_1\\ K_2
\end{bmatrix}
\quad:\quad
[s,d] = 
\begin{bmatrix}
 [s/2,d]\\ [s/2,d]
\end{bmatrix}
\\& V = 
\begin{bmatrix}
 V_1\\ V_2
\end{bmatrix}
\quad:\quad
[s,d] = 
\begin{bmatrix}
 [s/2,d]\\ [s/2,d]
\end{bmatrix}
\\
\end{align}

次に、$Q$と$KV$の組み合わせごとに$S$と$A$と$O$をそれぞれ計算していきます。ここでは$K_1, V_1$に対する$Q_1, Q_2$の計算例のみ載せます。

\begin{align}
& S_{11} = Q_1 K_1^T
\quad : \quad
[s/2, s/2] = [s/2, d]*[d,s/2]
\\

& A_{11} = Softmax(S_{11})
\quad : \quad
[s/2, s/2] = [s/2, s/2]
\\

& O_{11} = A_{11} * V_{11}
\quad : \quad
[s/2, d] = [s/2, s/2]*[s/2,d]
\\
\end{align}

これを$O_{11}$だけでなく、$O_{12}, O_{21}, O_{22}$についても同様に計算してあげます。その後、同じ$Q$を用いて計算された$O$同士を足してあげます。

\begin{align}
& O_1 = O_{11} + O_{12}
\quad : \quad
[s/2, d] = [s/2, d] + [s/2,d]
\\

& O_2 = O_{21} + O_{22}
\quad : \quad
[s/2, d] = [s/2, d] + [s/2,d]
\\
\end{align}

最後に$O_1$と$O_2$を行方向にくっつけてあげれば完了です。

\begin{align}
& O = stack(O_1, O_2) 
\quad : \quad
[s, d] = stack([s/2, d], [s/2,d])
\\
\end{align}

これで完了と思いきや、実はこの計算は間違っています。それは、Softmaxの箇所です。Softmaxは各行に対して実行されるので、$Softmax(Q_1K_1^T)$のように一部の列しかない状態では適用できず、各行の全ての列に足しして$QK^T$の結果が揃わないと本来は適用できないはずです。

この問題を解決するのがOnline Softmaxという手法です。
Online softmaxでは各ブロックはそのブロック内のみでsoftmaxを適用します。

\begin{align}
& S_{11} = Q_1 K_1^T\\
& l_{11} = \sum_j exp(S_{1j})\\
& A_{11} = \frac{exp(S_{11})}{l_{11}}\\
& O_{11} = A_{11} * V_{11}\\
\end{align}

その後同じ$Q$のブロックごとに足し合わせる時に、$O$の値を$l$値でスケールしてから足してあげます。

\begin{align}
& O_1 = O_{11}*\frac{l_{11}}{l_{22}} + O_{12}
\end{align}

こうすることで、途中の計算結果である$S$や$A$を保存せず、ただのスカラーである$l$を保存して置くだけで、Attentionの計算ができるようになります。

下記に、FlashAttention2の論文に載っている上記の内容について説明した図を貼っておきますので、参考にしてください。この図では簡略化のために、$Q$を分割していないので注意してください。
スクリーンショット 2023-10-21 16.01.16.png

Recomputation

Recomputationは全く難しくないです。通常DNNのモデルをトレーニングするときはbackpropagationの時に使用するために、各層の出力を保存しておきす。これはAttentionでも例外ではなく、SとAも保存されます。しかし, せっかくTilingでSとはAをメモリに書き込まなくて良くなったのに、backpropagationのためにわざわざ保存するのはアホらしいです。そこで単純にSとAは保存せずに、backpropagationの時にもう一度計算しようというのがRecomputationです。もう一度計算するのは遅くなりそうな気がしますが、メモリに大きな行列を何回も読み書きするよりはマシなようです。

評価結果

簡単に評価結果を紹介します。

スクリーンショット 2023-10-20 17.11.26.png
こちらの表は、GPT2モデルの学習がHuggingfaceのTransformersやMegatorn-LMに比べて高速になっていることを示しています。3倍だと一瞬しょぼいような気もしますが、学習に3週間かかるのが、1週間になると考えるととても早くなっていると思います。

スクリーンショット 2023-10-20 17.12.12.png
こちらの表はsequence lengthを変えた時の結果ですね。sequence lengthが1kの時はMegaton-LM
に比べて、1.7倍程度ですが、それが2k,4kと増えた時にもなお1kのMegatronよりも早くなっています。つまりsequence lengthが伸びても学習がそれほど遅くならないということですね。個人的にはMegatron-LMの2kと4kの時の結果も載せて欲しかったです。載せてないと実はMegatron-LMはsequece lenghtが伸びると高速化したりするのかなとか思ってしまします。

最後に

最後まで読んでいただいてありがとうございます。間違いのご指摘やわからないことへの質問、もっとこう説明し方が良いなどのアドバイスなど、なんでも気軽にコメントしていただければと思います。解説してほしい論文も募集中です。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2