はじめに
LLM 推論において、Attention は計算コストとメモリアクセスの両面で大きな割合を占める処理です。特に、長いトークンを扱う場合、Attention の効率は推論性能全体に大きく影響します。そのため、SIMD 命令を活用した Vector 化やメモリアクセス最適化による高速化は、推論基盤を設計する上で重要なテーマとなります。
MLIR は、このような高性能化を行うために、多段階のローワリングを通じて Attention 関連の計算をベクトル演算へ変換することができます。しかし、その最適化過程は中間言語 (IR) として段階的に記述されるため、どのように低レベルなIRに変換されるのかは把握しづらくなっています。
本稿では、MLIR が Attention 演算をどのように Vector Dialect や低レベル IR へローワリングしていくのかを追跡し、その変換過程を見たいと思います。
自己学習と備忘録を兼ねて書いています。複数回に分けて調査します。
Attentionの基礎
Attention は、 Query(Q) と Key(K) の類似度を計算し、その結果を使って Value(V) を重み付き平均する計算です。
Q と K の積
まず、Query 行列と Key 行列の転置の積を計算します。
\begin{align}
S = QK^T
\end{align}
ここで、 $S$ は各 $Query$ と各 $Key$ の類似度(スコア)を表します。
この段階での主要な計算は、行列積(Matrix Multiplication)です。
softmax による正規化
各 Query に対するスコアを確率分布に変換します。
\begin{align}
A = softmax(S)
\end{align}
softmax は各行について
\begin{align}
A_{ij} = \frac{exp(S_{ij})}{\sum_{k} exp(S_{ik})}
\end{align}
を計算します。
この段階での主要な計算は、要素ごとの(Elementwise)のexpおよび除算、リダクション(Reduction)和算です。
V の重みつき和
得られた重み C を使って Value を集約します。
\begin{align}
O = CV
\end{align}
出力Oの各行は、
\begin{align}
O_i = \sum_j A_{ij}V_j
\end{align}
となり、Valueベクトルの重み付き和になっています。
この段階での主要な計算は、行列積(Matrix Multiplication)とリダクション(Reduction)重み付き和算です。
MLIR とは
MLIR (Multi-Level Intermediate Representation) は、コンパイラ基盤である LLVM プロジェクトで開発されている中間表現 (IR) です。
従来の LLVM IR は CPU 向けの低レベルな表現を対象としていましたが、MLIR はより高い抽象度から低い抽象度までを一つのフレームワークで扱えるよう設計されています。
MLIR の大きな特徴は、Dialect と呼ばれる中間表現を段階的に変換しながら最終的な機械語へ近づけていく点です。
例えば行列積であれば、
linalg.matmul
のような高レベル演算から始まり、
linalg
↓
scf / affine
↓
vector
↓
llvm
というように徐々に低レイヤ命令へ変換されます。それぞれの Dialect は異なる抽象度を表現しています。
| Dialect | 役割 |
|---|---|
| linalg | 行列演算など高レベル演算 |
| affine | ループ・メモリアクセス最適化 |
| scf | 一般的なループ構造 |
| vector | SIMD演算 |
| arith | 算術演算 |
| memref | メモリ操作 |
| llvm | LLVM IR 相当 |
このような段階的な Lowering によって、各段階ごとに最適化を行えることが MLIR の大きな特徴です。
対象
Attention の中でも行列積部分(QK^T および V の重み付き和)に着目します。
まず、単純な matmul が表現された MLIR の変換を確認し、その後 Attention について同様に考察したいと思います。