1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

KVキャッシュのデータレイアウト最適化(CPU版)

1
Posted at

はじめに

LLM推論の高速化において、FlashAttentionは広く用いられる重要な最適化手法の一つとなっています。一般に「IO-awareな設計によってメモリアクセスを削減し、softmax attentionを高速に計算する」と説明されることが多いですが、実際の推論性能を考えるうえでは、Key-Value Cache(KVキャッシュ)の存在は無視できません。
FlashAttentionの自己回帰型の推論では、過去トークンのKey/Valueを保持・参照するKVキャッシュの読み書きが支配的なコストとなる場合があります。特に、K/Vのメモリレイアウトやアクセスパターンは、メモリ帯域利用効率やキャッシュヒット率に直接影響し、最終的なスループットに大きく関係します。

本記事では、「KVキャッシュとは何か」と「なぜレイアウト最適化が重要なのか」をCPUでの例に基づいて考え、記載します。
個人的な学習と整理を兼ねて書いていますが、よろしくお願いします。

KVキャッシュとは何か

デコーダ型 Transformer では、自己回帰的に次のトークンを1つずつ生成します。このとき各ステップで、現在のトークンは過去全てのトークンに対して attention を計算します。
具体的には各層において、

  • $ Query(Q) $: 現在のトークンから計算される
  • $ Key(K) $と$ Value(V) $: 過去を含む各トークンに対して必要となる

という計算が行われます。
ここで、ステップ$ t $において長さ$ t $の系列全体を入力として各層の$ K, V $を再計算する場合、全体の計算量は生成長に対して$ O(n^2) $となります。このため、シーケンスが長くなるほど推論時間は大きく増加します。
そこで
そこで以下のような、KVキャッシュと呼ばれる手法が使われます。

  • 各層における$ Key(K) $と$ Value(V) $を、生成済みトークンごとにキャッシュとして保存する
  • 新しいトークンを生成する際には、そのトークンに対応する$ Q $(および正規の$ K, V $)のみを計算し、過去の$ K, V $はキャッシュから再利用する

これにより、各ステップで必要な計算は、「新規トークンに対する計算」と「キャッシュされた$ K, V $との内積計算」のみに限定されます。結果として、生成長$ n $に対する全体の計算量は$ O(n) $のままですが、各ステップでの再計算が不要になるため、実際の計算量を$ O(n) $まで削減でき、長いシーケンスを高速に推論できます。

メモリ帯域のボトルネック

KVキャッシュは通常、以下のようなレイアウトで保存されます。

[batch, seq_len, num_heads, head_dim]

推論では毎ステップ、過去すべての$ K, V $を読み出す必要があるため、メモリ帯域・キャッシュ効率・アクセスパターンが性能を大きく左右します。特に、連続アクセスできるかどうかが速度に直結しますので、レイアウト最適化が重要になります。

FlashAttentionとレイアウト最適化

推論では、次のように$ K, V $を読み出します。

for seq in 0..n-1:
  load K[seq, :, :]
  load V[seq, :, :]

これは、一見すると問題なさそうに見えますが、実際には「head方向がバラバラで連続していない」という問題があります。例えば、KVキャッシュのレイアウト

[batch, seq_len, num_heads, head_dim]

において、head_dim=128, num_heads=32の場合、1トークン分の$ K $は32個のヘッドに分かれており、一般的な実装では、headごとにメモリが飛び飛びの値になります。結果として、SIMDロードが非効率になります。FlashAttentionは、計算よりもメモリI/Oを最適化することに特化したアルゴリズムです。そのため、K/Vのレイアウトが最適化されているかどうかが重要になります。

FlashAttentionは、以下のようにQ/K/Vをチャンク(ブロック)単位で処理します。

Q: [seq_q_block, head_dim]
K: [seq_k_block, head_dim]
V: [seq_k_block, head_dim]

GPUの shared memory (SRAM) に

  • Qのブロック
  • K/Vのブロック

を載せて、softmax attentionを部分的に計算し、結果を累積します。そのため、K/Vをseq方向に連続したブロックとして読める、より良いレイアウトになります。

[batch, num_heads, seq_len, head_dim]

なぜこのレイアウトが最適なのかというと、FlashAttention内部では以下のようなループが走るからです。

for k_block in range(0, seq_len, BLOCK_SIZE):
    load K[k_block: k_block + BLOCK_SIZE]
    load V[k_block: k_block + BLOCK_SIZE]
    compute attention(Q_block, K_block)

このとき、

  • seq方向が連続していれば、1回のロードで済む
  • head_dimも連続していればvectorized loadが効く
  • ブロックサイズ(64など)に合わせて最適化できる

つまり、帯域効果が最大化されます。
また、FlashAttentionはheadごとに独立してattentionを計算するため、head方向が最も外側にあるとアクセスが規則的になります。

[batch, head, seq, dim]

のように、headを外側に置くことで

  • headごとに連続したK/Vブロックで読み込める
  • warp/threadの並列化がしやすい
  • shared memoryの配列がシンプルになる
    というメリットがあります。

ただし、この内側ループが連続となるレイアウトが有利なのはCPUの場合で、GPUでは「warp内のスレッドが連続アドレスにアクセスできるか(coalescing)」というスレッドとの組み合わせが重要となります。そのため、必ずしもCPUと同じレイアウトで速くなるとは限りません。

データレイアウトによる性能影響

KVキャッシュのレイアウト差が性能に与える影響を、ベンチマークプログラムで確認します。

デコード時の典型パターンとして、固定された1 headのクエリベクトルq[head_dim]と、KVキャッシュに格納された全シーケンス長Sに対するキーk[s, h, d]との内積計算を対象とします。

score[s] = Σ_d q[d] * k[s, h, d]

ベンチマークテスト(layout_benchmark.cu)

ベンチマークでは、以下の2種類のメモリレイアウトを比較しました。

レイアウトSHD(最適寄り)
[seq_len, n_heads, head_dim]の順で配置し、head_dim(d)が最も内側となるレイアウト

レイアウトDSH(非最適)
[head_dim, seq_len, n_heads]の順で配置し、head(h)が最も内側となるレイアウト

それぞれのレイアウトについて、実行時間を測定した結果は以下のとおりです。

$ clang++ -O3 -ffast-math -march=native -Rpass=loop-vectorize benchmark_layout_shd_dsh.cpp -o bench
$ ./bench
Layout SHD : 62.3495 ms
Layout DSH : 523.455 ms

SHDのレイアウトでは、内積計算のループ(d方向)が連続メモリアクセスとなるため、コンパイラによるSIMDベクトル化が有効に働きます。一方、DSHのレイアウトでは、d方向のループにおけるメモリアクセスが、n_headsをストライドとする非連続アクセスとなります。そのため、ベクトル化はされてもgatherロードを伴う非効率なコード生成となります。

※ CPU: Apple M1, clang: 22.1.0
※ 実際のデータの値はレイアウトに揃えていないので、結果outの値は異なります

ベンチマークコード

#include <iostream>
#include <vector>
#include <chrono>
#include <random>

using namespace std;

// --- layout: k[s][h][d] ---

void dot_shd(
  const float* q, // [D]
  const float* k, // [S][H][D]
  float* out,     // [S]
  int S, int H, int D, int head) {
  for(int s = 0; s < S; ++s) {
    const float* k_sh = k + (s * H + head) * D;

    float acc = 0.0f;

    #pragma clang loop vectorize(enable)
    for(int d = 0; d < D; ++d) {
      acc += q[d] * k_sh[d];
    }
    out[s] = acc;
  }
}

// --- layout: k[d][s][h] ---

void dot_dsh(
  const float* q, // [D]
  const float* k, // [D][S][H]
  float* out,     // [S]
  int S, int H, int D, int head) {
  for(int s = 0; s < S; ++s) {
    
    float acc = 0.0f;
    
    #pragma clang loop vectorize(enable)
    for(int d = 0; d < D; ++d) {
      const float k_val = k[(d * S + s) * H + head];
      acc += q[d] * k_val;
    }
    out[s] = acc;
  }
}

// --- initialize ---

void init(vector<float>& v) {
  std::mt19937 rng(0);
  std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
  for(auto& x: v) x = dist(rng);
}

// --- benchmark ---

template<typename F>
double bench(F func, int iters = 10000) {
  auto start = chrono::high_resolution_clock::now();
  for(int i = 0; i < iters; ++i) {
    func();
  }
  auto end = chrono::high_resolution_clock::now();
  return chrono::duration<double, milli>(end - start).count();
}

int main() {
  constexpr int S = 1024; // seq_len
  constexpr int H = 16;   // n_heads
  constexpr int D = 64;   // head_dim
  const int head = 3;

  vector<float> q(D);
  vector<float> k_shd(S * H * D);
  vector<float> k_dsh(D * S * H);
  vector<float> out(S);

  init(q);
  init(k_shd);
  init(k_dsh);

  dot_shd(q.data(), k_shd.data(), out.data(), S, H, D, head);
  dot_dsh(q.data(), k_dsh.data(), out.data(), S, H, D, head);
  
  double t_shd = bench([&]() {
    dot_shd(q.data(), k_shd.data(), out.data(), S, H, D, head);
  });
  cout << "out[0]: " << out[0] << endl;

  double t_dsh = bench([&]() {
    dot_dsh(q.data(), k_dsh.data(), out.data(), S, H, D, head);
  });
  cout << "out[0]: " << out[0] << endl;

  cout << "Layout SHD : " << t_shd << " ms" << endl;
  cout << "Layout DSH : " << t_dsh << " ms" << endl;
  return 0;
}
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?