2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Flash Attentionってなんだ?〜GPUメモリを救う革命的アルゴリズムを完全理解〜

2
Posted at

この記事の対象読者

  • Pythonの基本文法(関数、クラス、import)を理解している方
  • PyTorchを使ったことがある、または興味がある方
  • LLM(大規模言語モデル)を動かそうとしてGPUメモリ不足に悩んだことがある方
  • Transformerアーキテクチャの概要を知っている方

この記事で得られること

  • Flash Attentionの概念と、なぜ「革命的」と呼ばれるのかの理解
  • GPUメモリ階層を意識したアルゴリズム設計の考え方
  • 実際にFlash Attentionを使うための環境構築とコード
  • バージョン別(1, 2, 3)の違いと使い分け

この記事で扱わないこと

  • Transformerの基礎理論(Self-Attentionの数式導出など)
  • CUDAプログラミングの詳細
  • 分散学習・マルチGPU環境の設定

1. Flash Attentionとの出会い

「CUDA out of memory」

このエラーメッセージを見て、何度絶望したことか。

私がLlama 2 7Bを自前のRTX 3090で動かそうとしたとき、シーケンス長を4096にしただけでこのエラーが出ました。24GBもVRAMがあるのに、たかだか70億パラメータのモデルが動かない。「GPUを買い足すしかないのか...」と諦めかけたそのとき、Flash Attentionの存在を知りました。

結論から言うと、Flash Attentionを有効にするだけで、同じモデル・同じシーケンス長が余裕で動くようになりました。しかも推論速度は2倍以上。「なんでこれがデフォルトじゃないんだ」と叫んだのを覚えています。

Flash Attentionは、GPUの「メモリ階層」を賢く使うことで、Attention計算を劇的に高速化・省メモリ化するアルゴリズムです。人間で例えるなら、「机の上の作業スペース(高速だが狭い)」と「本棚(遅いが広い)」を効率よく使い分ける整理術のようなものです。

ここまでで、Flash Attentionがどんなものか、なんとなくイメージできたでしょうか。次は、この技術が生まれた背景を見ていきましょう。


2. 前提知識の確認

本題に入る前に、この記事で使う用語を整理しておきます。

2.1 Attention(アテンション)とは

Transformerの中核をなす機構です。入力シーケンスの各要素が、他のすべての要素との「関連度」を計算します。「この単語は、どの単語に注目すべきか」を学習する仕組みと考えてください。

2.2 Self-Attention(自己注意機構)とは

入力シーケンス自身に対してAttentionを適用することです。Query(Q)、Key(K)、Value(V)という3つの行列を使い、以下の計算を行います。

Attention(Q, K, V) = softmax(QK^T / √d) × V

2.3 HBM(High Bandwidth Memory)とは

GPUのメインメモリです。容量は大きい(数十GB)ですが、アクセス速度は相対的に遅いです。RTX 3090なら24GB、A100なら40GB/80GBのHBMを持っています。

2.4 SRAM(Static RAM)とは

GPUのオンチップキャッシュです。容量は小さい(数MB)ですが、HBMより約100倍高速にアクセスできます。A100なら各SMに192KBのSRAMがあります。

2.5 Tiling(タイリング)とは

大きな計算を小さなブロックに分割して処理する手法です。一度にすべてのデータをメモリに載せるのではなく、ブロックごとに処理することでメモリ使用量を削減します。

これらの用語が押さえられたら、次に進みましょう。


3. Flash Attentionが生まれた背景

3.1 標準Attentionの問題点

Transformerのコンテキスト長を伸ばすことは、近年のLLM開発における最重要課題の一つでした。GPT-3の2Kトークンから、GPT-4の128K、さらにはLlama 3の1Mトークンへ。しかし、標準的なAttention実装には致命的な問題がありました。

計算量とメモリ使用量がシーケンス長の2乗に比例するという点です。

シーケンス長をNとすると、N×Nのアテンション行列を計算・保持する必要があります。シーケンス長が2倍になれば、メモリ使用量は4倍。8Kから32Kにすれば16倍です。これが「CUDA out of memory」の根本原因でした。

3.2 従来のアプローチの限界

この問題に対し、Sparse Attention、Linear Attention、Performer、Linformerなど、様々な「近似手法」が提案されてきました。これらは計算量を削減しますが、精度を犠牲にするというトレードオフがありました。

3.3 Tri Daoの着眼点

2022年、スタンフォード大学のTri Dao氏は、まったく異なるアプローチを提案しました。

「計算量を減らすのではなく、メモリアクセスを最適化する

Dao氏は、従来の実装が計算量ではなくメモリ帯域幅によってボトルネックになっていることに着目しました。GPUは行列計算は得意ですが、HBMとSRAM間のデータ転送が遅い。この「IO(入出力)」を意識したアルゴリズムを設計すれば、近似なしで高速化できると考えたのです。

背景がわかったところで、抽象的な概念から順に、具体的な仕組みを見ていきましょう。


4. Flash Attentionの基本概念

4.1 標準Attentionの何が遅いのか

標準的なAttention実装では、以下のステップで計算します。

  1. Q × K^T を計算し、N×N の行列 S をHBMに書き込む
  2. S をHBMから読み込み、softmaxを計算し、P をHBMに書き込む
  3. P をHBMから読み込み、P × V を計算し、出力 O をHBMに書き込む

問題は、N×Nという巨大な中間行列をHBMに何度も読み書きしていることです。シーケンス長が長くなるほど、この「メモリアクセスの無駄」が顕著になります。

4.2 Flash Attentionの核心:Tiling + Online Softmax

Flash Attentionの革新は、中間行列をHBMに書き込まないことです。

具体的には以下の戦略を取ります。

  1. Q, K, V をブロックに分割する
  2. 各ブロックをSRAM(高速キャッシュ)に読み込む
  3. ブロック単位でAttentionを計算する
  4. Online Softmaxで、全体を見なくても正しいsoftmaxを計算する
  5. 出力を逐次更新し、最終結果だけHBMに書き戻す

これにより、メモリ使用量がO(N²)からO(N)に削減されます。シーケンス長に対して線形!

4.3 Online Softmaxの魔法

「ブロックごとに計算して、正しいsoftmaxが出せるの?」と思うかもしれません。

softmaxの計算には「全要素の合計」が必要なので、通常は全データを見ないと計算できません。しかし、Online Softmax(2018年提案)という手法を使えば、逐次的に正しいsoftmaxを計算できます。

# 概念的なOnline Softmax
max_so_far = -inf
sum_so_far = 0
for block in blocks:
    new_max = max(max_so_far, block.max())
    # 前の結果を新しいmaxでrescale
    sum_so_far = sum_so_far * exp(max_so_far - new_max) + sum(exp(block - new_max))
    max_so_far = new_max

数値安定性を保ちながら、ブロックごとに結果を更新できるのがポイントです。

4.4 IO-Awarenessという設計思想

Flash Attentionの本質は「IO-Awareness(IO認識)」です。

従来のアルゴリズム設計はFLOPs(浮動小数点演算数)を最小化することに注力していました。しかしGPUでは、計算よりもメモリアクセスがボトルネックになることが多い。Flash Attentionは、HBMへの読み書き回数を明示的に最小化するよう設計されています。

基本概念が理解できたところで、これらの抽象的な概念を具体的なコードで実装していきましょう。


5. 実際に使ってみよう

5.1 環境構築

Flash Attentionを使うには、適切なハードウェアとソフトウェア環境が必要です。

ハードウェア要件:

  • NVIDIA GPU: Ampere(A100, RTX 30xx)、Ada(RTX 40xx)、Hopper(H100)以降
  • 注意: Turing世代(RTX 20xx)以前は非対応

ソフトウェア要件:

  • CUDA Toolkit 11.6以上(12.x推奨)
  • PyTorch 2.0以上
  • Linux推奨(Windowsは制限あり)
# 環境確認コマンド
nvidia-smi  # GPUとCUDAドライバのバージョン確認
nvcc -V     # CUDA Toolkitのバージョン確認
python -c "import torch; print(torch.__version__, torch.version.cuda)"

5.2 インストール方法

方法1: pipでインストール(推奨)

# 前提パッケージのインストール
pip install packaging ninja

# flash-attnのインストール(ビルドに時間がかかる場合あり)
pip install flash-attn --no-build-isolation

方法2: prebuilt wheelを使用(ビルド不要)

ビルドに失敗する場合や時間を節約したい場合は、prebuilt wheelを使いましょう。

# GitHubリリースページからダウンロード
# https://github.com/Dao-AILab/flash-attention/releases
# Python版、PyTorch版、CUDA版を確認して適切なwhlファイルを選択

# 例: Python 3.11, PyTorch 2.4, CUDA 12の場合
pip install flash_attn-2.7.0+cu12torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl

5.3 設定ファイルの準備

以下の3種類の設定ファイルを用意しています。用途に応じて選択してください。

開発環境用(config.yaml)

# config.yaml - 開発環境用(このままコピーして使える)
model:
  name: "meta-llama/Llama-3.2-1B"
  torch_dtype: "bfloat16"
  device_map: "auto"
  attn_implementation: "flash_attention_2"

generation:
  max_new_tokens: 256
  do_sample: true
  temperature: 0.7
  top_p: 0.9

environment:
  cuda_visible_devices: "0"
  log_level: "DEBUG"

本番環境用(config.production.yaml)

# config.production.yaml - 本番環境用(このままコピーして使える)
model:
  name: "meta-llama/Llama-3.2-1B"
  torch_dtype: "bfloat16"
  device_map: "auto"
  attn_implementation: "flash_attention_2"
  # 本番向け: キャッシュ有効化
  use_cache: true

generation:
  max_new_tokens: 512
  do_sample: false  # 再現性重視
  temperature: 1.0
  num_beams: 1

environment:
  cuda_visible_devices: "0,1"  # マルチGPU
  log_level: "WARNING"

optimization:
  # 本番環境向け最適化
  torch_compile: true
  compile_mode: "reduce-overhead"

テスト環境用(config.test.yaml)

# config.test.yaml - テスト/CI用(このままコピーして使える)
model:
  name: "meta-llama/Llama-3.2-1B"
  torch_dtype: "float16"  # テスト用に軽量
  device_map: "auto"
  attn_implementation: "flash_attention_2"

generation:
  max_new_tokens: 32  # テスト用に短く
  do_sample: false    # 再現性確保
  temperature: 1.0

environment:
  cuda_visible_devices: "0"
  log_level: "INFO"

test:
  # テスト用設定
  sample_size: 10
  seed: 42
  deterministic: true

設定ローダー(config_loader.py)

"""
設定ファイルローダー
使い方: config = load_config("development")
"""
import yaml
from pathlib import Path


def load_config(env: str = "development") -> dict:
    """環境に応じた設定ファイルを読み込む"""
    config_map = {
        "development": "config.yaml",
        "production": "config.production.yaml",
        "test": "config.test.yaml",
    }
    
    config_file = config_map.get(env, "config.yaml")
    config_path = Path(__file__).parent / config_file
    
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found: {config_path}")
    
    with open(config_path, "r") as f:
        return yaml.safe_load(f)


if __name__ == "__main__":
    import sys
    env = sys.argv[1] if len(sys.argv) > 1 else "development"
    config = load_config(env)
    print(f"Loaded {env} config:")
    print(yaml.dump(config, default_flow_style=False))

5.4 基本的な使い方

"""
Flash Attention 2を使ったLLM推論サンプル
使い方: python flash_attention_demo.py
必要なパッケージ: pip install transformers torch accelerate
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def main():
    """Flash Attention 2を使った推論のデモ"""
    model_id = "meta-llama/Llama-3.2-1B"
    
    # トークナイザーのロード
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # Flash Attention 2を有効にしてモデルをロード
    # ポイント: attn_implementation="flash_attention_2" を指定
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,  # bf16またはfp16が必須
        device_map="auto",
        attn_implementation="flash_attention_2",  # ここが重要!
    )
    
    # 推論実行
    prompt = "Flash Attentionとは、"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"入力: {prompt}")
    print(f"出力: {response}")
    
    # メモリ使用量の確認
    if torch.cuda.is_available():
        print(f"\nGPUメモリ使用量: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")


if __name__ == "__main__":
    main()

5.5 実行結果

上記のコードを実行すると、以下のような出力が得られます:

入力: Flash Attentionとは、
出力: Flash Attentionとは、Transformerモデルの注意機構を高速化するために
開発されたアルゴリズムです。従来の実装と比較して、メモリ使用量を大幅に
削減しながら、同等の計算精度を維持できます。特に長いシーケンスを扱う
場合に効果を発揮します。

GPUメモリ使用量: 2.34 GB

5.6 よくあるエラーと対処法

エラー 原因 対処法
CUDA_HOME environment variable is not set CUDA Toolkitのパス未設定 export CUDA_HOME=/usr/local/cuda を設定
nvcc was not found nvccコンパイラが見つからない CUDA Toolkitをインストール、またはprebuilt wheelを使用
RuntimeError: FlashAttention only supports fp16 and bf16 データ型がfp32 torch_dtype=torch.bfloat16 を指定
flash_attn_2_cuda.cpython... undefined symbol PyTorchとflash-attnのバージョン不一致 両方を再インストール、バージョンを揃える
No module named 'flash_attn' インストール失敗 pip install flash-attn --no-build-isolation を再実行
ビルドが2時間以上かかる ninjaが無効 pip install ninja の後に再インストール

基本的な使い方をマスターしたので、次は応用例を見ていきましょう。


6. ユースケース別ガイド

6.1 ユースケース1: Hugging Face Transformersでの推論高速化

  • 想定読者: Hugging Faceのモデルを使って推論を行いたい方
  • 推奨構成: flash-attn + transformers + accelerate
  • サンプルコード:
"""
Hugging Face Transformersでの推論高速化
比較: Flash Attention ON vs OFF
"""
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def benchmark_inference(model_id: str, use_flash: bool, prompt: str, n_runs: int = 5):
    """Flash Attentionの有無で推論速度を比較"""
    
    attn_impl = "flash_attention_2" if use_flash else "eager"
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation=attn_impl,
    )
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # ウォームアップ
    with torch.no_grad():
        _ = model.generate(**inputs, max_new_tokens=50)
    
    torch.cuda.synchronize()
    
    # ベンチマーク
    times = []
    for _ in range(n_runs):
        torch.cuda.reset_peak_memory_stats()
        start = time.perf_counter()
        
        with torch.no_grad():
            _ = model.generate(**inputs, max_new_tokens=100)
        
        torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    
    avg_time = sum(times) / len(times)
    peak_memory = torch.cuda.max_memory_allocated() / 1e9
    
    return {
        "attention": attn_impl,
        "avg_time_sec": round(avg_time, 3),
        "peak_memory_gb": round(peak_memory, 2),
    }


if __name__ == "__main__":
    model_id = "meta-llama/Llama-3.2-1B"
    prompt = "人工知能の未来について、" * 10  # 長めのプロンプト
    
    print("=== Flash Attention ベンチマーク ===\n")
    
    result_flash = benchmark_inference(model_id, use_flash=True, prompt=prompt)
    print(f"Flash Attention ON:  {result_flash}")
    
    # GPUメモリをクリア
    torch.cuda.empty_cache()
    
    result_eager = benchmark_inference(model_id, use_flash=False, prompt=prompt)
    print(f"Flash Attention OFF: {result_eager}")
    
    speedup = result_eager["avg_time_sec"] / result_flash["avg_time_sec"]
    print(f"\n速度向上: {speedup:.2f}x")

6.2 ユースケース2: 量子化との組み合わせ

  • 想定読者: VRAMが限られた環境でLLMを動かしたい方
  • 推奨構成: flash-attn + bitsandbytes(4bit/8bit量子化)
  • サンプルコード:
"""
Flash Attention + 4bit量子化の組み合わせ
RTX 3060 (12GB) でも大きなモデルを動かす
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


def load_quantized_model_with_flash(model_id: str):
    """4bit量子化 + Flash Attention でモデルをロード"""
    
    # 4bit量子化の設定
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # 量子化 + Flash Attention の組み合わせ
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        attn_implementation="flash_attention_2",  # 量子化と併用可能!
    )
    
    return model, tokenizer


def generate_with_quantized_model():
    """量子化モデルでテキスト生成"""
    model_id = "meta-llama/Llama-3.2-3B"  # 3Bパラメータ
    
    print("モデルをロード中(4bit量子化 + Flash Attention)...")
    model, tokenizer = load_quantized_model_with_flash(model_id)
    
    prompt = "機械学習エンジニアになるためのロードマップは、"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.7,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"\n入力: {prompt}")
    print(f"出力: {response}")
    print(f"\nGPUメモリ使用量: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")


if __name__ == "__main__":
    generate_with_quantized_model()

6.3 ユースケース3: PyTorch SDPAとの使い分け

  • 想定読者: PyTorch標準のSDPAとFlash Attentionの違いを理解したい方
  • 推奨構成: PyTorch 2.0以上のSDPA機能を活用
  • サンプルコード:
"""
PyTorch SDPA vs Flash Attention の比較
SDPA: Scaled Dot-Product Attention(PyTorch標準)
"""
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM, AutoTokenizer


def compare_attention_backends(model_id: str, prompt: str):
    """異なるAttentionバックエンドを比較"""
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # 方法1: 明示的にFlash Attention 2を指定
    print("方法1: attn_implementation='flash_attention_2'")
    model_flash = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2",
    )
    
    # 方法2: SDPAを指定(PyTorchが最適なバックエンドを自動選択)
    print("方法2: attn_implementation='sdpa'")
    model_sdpa = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="sdpa",
    )
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model_sdpa.device)
    
    # SDPAの場合、コンテキストマネージャでバックエンドを指定可能
    print("\nSDPA + Flash Attentionバックエンドを明示:")
    with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        with torch.no_grad():
            outputs = model_sdpa.generate(**inputs, max_new_tokens=50)
    
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))


def show_attention_options():
    """利用可能なAttention実装の一覧"""
    print("""
=== Attention実装の選択肢 ===

1. "flash_attention_2" - Flash Attention 2を直接使用
   - 最も高速、最もメモリ効率が良い
   - flash-attnパッケージが必要
   - fp16/bf16のみ対応

2. "sdpa" - PyTorch標準のScaled Dot-Product Attention
   - PyTorch 2.0以上で利用可能
   - 自動で最適なバックエンドを選択
   - Flash Attention, xFormers, C++実装から選択

3. "eager" - 従来の実装
   - 互換性重視
   - output_attentions=True が必要な場合に使用

4. "kernels-community/flash-attn2" - Kernelsライブラリ経由
   - flash-attnパッケージ不要
   - Hugging Face Hubからカーネルをダウンロード
""")


if __name__ == "__main__":
    show_attention_options()

ユースケースが把握できたところで、この記事を読んだ後の学習パスを確認しましょう。


7. バージョン別の違い

Flash Attentionは進化を続けています。バージョンごとの特徴を整理します。

7.1 バージョン比較表

項目 Flash Attention 1 Flash Attention 2 Flash Attention 3
発表年 2022年 2023年 2024年
対応GPU Ampere以降 Ampere以降 Hopperのみ
最大head dim 128 256 256
FP8対応 × ×
理論性能比 25-40% 50-73% 75-85%
A100性能 〜120 TFLOPs/s 〜225 TFLOPs/s -
H100性能 - 〜350 TFLOPs/s 〜740 TFLOPs/s

7.2 どのバージョンを使うべきか

  • RTX 30xx / RTX 40xx / A100: Flash Attention 2(最も安定)
  • H100 / H800: Flash Attention 3(ただしベータ版)
  • それ以外: PyTorch SDPAを使用

8. 学習ロードマップ

この記事を読んだ後、次のステップとして以下をおすすめします。

初級者向け(まずはここから)

  1. 公式GitHubリポジトリを確認: Dao-AILab/flash-attention
  2. Hugging Faceのドキュメント: GPU inference
  3. 実際にコードを動かす: 本記事のサンプルコードを実行

中級者向け(実践に進む)

  1. 論文を読む: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
  2. ベンチマークを取る: 自分のユースケースで効果を測定
  3. 量子化と組み合わせる: bitsandbytesとの併用を試す

上級者向け(さらに深く)

  1. CUDAカーネルを読む: flash-attentionのソースコード解析
  2. Flash Attention 3の論文: FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
  3. 自分のモデルに適用: カスタムTransformerへのFlash Attention統合

9. まとめ

この記事では、Flash Attentionについて以下を解説しました:

  1. Flash Attentionの本質: GPUメモリ階層を意識した「IO-Aware」なアルゴリズム設計
  2. なぜ速いのか: Tiling + Online SoftmaxでHBMへの読み書きを最小化
  3. 実践的な使い方: Hugging Face Transformersでのattn_implementation="flash_attention_2"
  4. バージョンの違い: 1, 2, 3の特徴と使い分け

私の所感

Flash Attentionは、「計算量を減らす」のではなく「メモリアクセスを最適化する」という発想の転換が素晴らしいと思います。近似なしで高速化できるため、精度を犠牲にすることなく、長いコンテキストを扱えるようになりました。

正直、初めてFlash Attentionを使ったときの「動いた!しかも速い!」という感動は忘れられません。GPUメモリ不足に悩んでいる方は、ぜひ試してみてください。数行のコード変更で、世界が変わります。


参考文献

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?