はじめに
LLM推論の高速化において、FlashAttentionは広く用いられる重要な最適化手法の一つとなっています。一般に「IO-awareな設計によってメモリアクセスを削減し、softmax attentionを高速に計算する」と説明されることが多いですが、実際の推論性能を考えるうえでは、Key-Value Cache(KVキャッシュ)の存在は無視できません。
FlashAttentionの自己回帰型の推論では、過去トークンのKey/Valueを保持・参照するKVキャッシュの読み書きが支配的なコストとなる場合があります。特に、K/Vのメモリレイアウトやアクセスパターンは、メモリ帯域利用効率やキャッシュヒット率に直接影響し、最終的なスループットに大きく関係します。
本記事では、「KVキャッシュとは何か」と「なぜレイアウト最適化が重要なのか」をCPUでの例に基づいて考え、記載します。
個人的な学習と整理を兼ねて書いていますが、よろしくお願いします。
KVキャッシュとは何か
デコーダ型 Transformer では、自己回帰的に次のトークンを1つずつ生成します。このとき各ステップで、現在のトークンは過去全てのトークンに対して attention を計算します。
具体的には各層において、
- $ Query(Q) $: 現在のトークンから計算される
- $ Key(K) $と$ Value(V) $: 過去を含む各トークンに対して必要となる
という計算が行われます。
ここで、ステップ$ t $において長さ$ t $の系列全体を入力として各層の$ K, V $を再計算する場合、全体の計算量は生成長に対して$ O(n^2) $となります。このため、シーケンスが長くなるほど推論時間は大きく増加します。
そこで
そこで以下のような、KVキャッシュと呼ばれる手法が使われます。
- 各層における$ Key(K) $と$ Value(V) $を、生成済みトークンごとにキャッシュとして保存する
- 新しいトークンを生成する際には、そのトークンに対応する$ Q $(および正規の$ K, V $)のみを計算し、過去の$ K, V $はキャッシュから再利用する
これにより、各ステップで必要な計算は、「新規トークンに対する計算」と「キャッシュされた$ K, V $との内積計算」のみに限定されます。結果として、生成長$ n $に対する全体の計算量は$ O(n) $のままですが、各ステップでの再計算が不要になるため、実際の計算量を$ O(n) $まで削減でき、長いシーケンスを高速に推論できます。
メモリ帯域のボトルネック
KVキャッシュは通常、以下のようなレイアウトで保存されます。
[batch, seq_len, num_heads, head_dim]
推論では毎ステップ、過去すべての$ K, V $を読み出す必要があるため、メモリ帯域・キャッシュ効率・アクセスパターンが性能を大きく左右します。特に、連続アクセスできるかどうかが速度に直結しますので、レイアウト最適化が重要になります。
FlashAttentionとレイアウト最適化
推論では、次のように$ K, V $を読み出します。
for seq in 0..n-1:
load K[seq, :, :]
load V[seq, :, :]
これは、一見すると問題なさそうに見えますが、実際には「head方向がバラバラで連続していない」という問題があります。例えば、KVキャッシュのレイアウト
[batch, seq_len, num_heads, head_dim]
において、head_dim=128, num_heads=32の場合、1トークン分の$ K $は32個のヘッドに分かれており、一般的な実装では、headごとにメモリが飛び飛びの値になります。結果として、SIMDロードが非効率になります。FlashAttentionは、計算よりもメモリI/Oを最適化することに特化したアルゴリズムです。そのため、K/Vのレイアウトが最適化されているかどうかが重要になります。
FlashAttentionは、以下のようにQ/K/Vをチャンク(ブロック)単位で処理します。
Q: [seq_q_block, head_dim]
K: [seq_k_block, head_dim]
V: [seq_k_block, head_dim]
GPUの shared memory (SRAM) に
- Qのブロック
- K/Vのブロック
を載せて、softmax attentionを部分的に計算し、結果を累積します。そのため、K/Vをseq方向に連続したブロックとして読める、より良いレイアウトになります。
[batch, num_heads, seq_len, head_dim]
なぜこのレイアウトが最適なのかというと、FlashAttention内部では以下のようなループが走るからです。
for k_block in range(0, seq_len, BLOCK_SIZE):
load K[k_block: k_block + BLOCK_SIZE]
load V[k_block: k_block + BLOCK_SIZE]
compute attention(Q_block, K_block)
このとき、
- seq方向が連続していれば、1回のロードで済む
- head_dimも連続していればvectorized loadが効く
- ブロックサイズ(64など)に合わせて最適化できる
つまり、帯域効果が最大化されます。
また、FlashAttentionはheadごとに独立してattentionを計算するため、head方向が最も外側にあるとアクセスが規則的になります。
[batch, head, seq, dim]
のように、headを外側に置くことで
- headごとに連続したK/Vブロックで読み込める
- warp/threadの並列化がしやすい
- shared memoryの配列がシンプルになる
というメリットがあります。
ただし、この内側ループが連続となるレイアウトが有利なのはCPUの場合で、GPUでは「warp内のスレッドが連続アドレスにアクセスできるか(coalescing)」というスレッドとの組み合わせが重要となります。そのため、必ずしもCPUと同じレイアウトで速くなるとは限りません。
データレイアウトによる性能影響
KVキャッシュのレイアウト差が性能に与える影響を、ベンチマークプログラムで確認します。
デコード時の典型パターンとして、固定された1 headのクエリベクトルq[head_dim]と、KVキャッシュに格納された全シーケンス長Sに対するキーk[s, h, d]との内積計算を対象とします。
score[s] = Σ_d q[d] * k[s, h, d]
ベンチマークテスト(layout_benchmark.cu)
ベンチマークでは、以下の2種類のメモリレイアウトを比較しました。
レイアウトSHD(最適寄り)
[seq_len, n_heads, head_dim]の順で配置し、head_dim(d)が最も内側となるレイアウト
レイアウトDSH(非最適)
[head_dim, seq_len, n_heads]の順で配置し、head(h)が最も内側となるレイアウト
それぞれのレイアウトについて、実行時間を測定した結果は以下のとおりです。
$ clang++ -O3 -ffast-math -march=native -Rpass=loop-vectorize benchmark_layout_shd_dsh.cpp -o bench
$ ./bench
Layout SHD : 62.3495 ms
Layout DSH : 523.455 ms
SHDのレイアウトでは、内積計算のループ(d方向)が連続メモリアクセスとなるため、コンパイラによるSIMDベクトル化が有効に働きます。一方、DSHのレイアウトでは、d方向のループにおけるメモリアクセスが、n_headsをストライドとする非連続アクセスとなります。そのため、ベクトル化はされてもgatherロードを伴う非効率なコード生成となります。
※ CPU: Apple M1, clang: 22.1.0
※ 実際のデータの値はレイアウトに揃えていないので、結果outの値は異なります
ベンチマークコード
#include <iostream>
#include <vector>
#include <chrono>
#include <random>
using namespace std;
// --- layout: k[s][h][d] ---
void dot_shd(
const float* q, // [D]
const float* k, // [S][H][D]
float* out, // [S]
int S, int H, int D, int head) {
for(int s = 0; s < S; ++s) {
const float* k_sh = k + (s * H + head) * D;
float acc = 0.0f;
#pragma clang loop vectorize(enable)
for(int d = 0; d < D; ++d) {
acc += q[d] * k_sh[d];
}
out[s] = acc;
}
}
// --- layout: k[d][s][h] ---
void dot_dsh(
const float* q, // [D]
const float* k, // [D][S][H]
float* out, // [S]
int S, int H, int D, int head) {
for(int s = 0; s < S; ++s) {
float acc = 0.0f;
#pragma clang loop vectorize(enable)
for(int d = 0; d < D; ++d) {
const float k_val = k[(d * S + s) * H + head];
acc += q[d] * k_val;
}
out[s] = acc;
}
}
// --- initialize ---
void init(vector<float>& v) {
std::mt19937 rng(0);
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
for(auto& x: v) x = dist(rng);
}
// --- benchmark ---
template<typename F>
double bench(F func, int iters = 10000) {
auto start = chrono::high_resolution_clock::now();
for(int i = 0; i < iters; ++i) {
func();
}
auto end = chrono::high_resolution_clock::now();
return chrono::duration<double, milli>(end - start).count();
}
int main() {
constexpr int S = 1024; // seq_len
constexpr int H = 16; // n_heads
constexpr int D = 64; // head_dim
const int head = 3;
vector<float> q(D);
vector<float> k_shd(S * H * D);
vector<float> k_dsh(D * S * H);
vector<float> out(S);
init(q);
init(k_shd);
init(k_dsh);
dot_shd(q.data(), k_shd.data(), out.data(), S, H, D, head);
dot_dsh(q.data(), k_dsh.data(), out.data(), S, H, D, head);
double t_shd = bench([&]() {
dot_shd(q.data(), k_shd.data(), out.data(), S, H, D, head);
});
cout << "out[0]: " << out[0] << endl;
double t_dsh = bench([&]() {
dot_dsh(q.data(), k_dsh.data(), out.data(), S, H, D, head);
});
cout << "out[0]: " << out[0] << endl;
cout << "Layout SHD : " << t_shd << " ms" << endl;
cout << "Layout DSH : " << t_dsh << " ms" << endl;
return 0;
}