GPUの性能を最大限引き出す!FlashRNNによる50倍高速化
はじめに
近年、自然言語処理や時系列予測などの分野では、Transformerに代表されるシーケンス並列化可能なニューラルネットワークが大きな成功を収めています。しかし、これらのモデルは状態追跡能力に欠けるという課題があります。一方、RNN(Recurrent Neural Network)は状態追跡に優れていますが、逐次的な処理が必要なため、GPUの性能を十分に引き出せないという問題がありました。
RNNの逐次処理は、GPUの大規模並列処理能力を活用できない。
そこで本研究では、FlashRNNというRNNの高速化手法を提案します。FlashRNNは、GPUのキャッシュ階層を考慮した最適化により、RNNの計算を高速化します。
技術の基礎知識
RNNは、各ステップの隠れ状態が前のステップに依存する逐次的なモデルです。LSTMやGRUなどのRNNでは、ゲートと呼ばれる制御機構により、長期的な依存関係を学習できます。しかし、RNNの実装では、行列積と要素ごとの非線形演算が交互に現れるため、メモリアクセスがボトルネックになります。
GPUは、レジスタ、SRAM、HBMの階層的なメモリ構成を持つ。
単純なRNNの実装では、これらのメモリ階層を効果的に活用できず、GPUの性能を十分に引き出せません。
新しく提案する手法
FlashRNNは、RNNの行列積と非線形演算を1つのカーネルに融合し、レジスタとSRAMを活用してメモリアクセスを最小化します。時間方向のループをカーネル内に組み込むことで、重み行列や状態をレジスタにキャッシュし、HBMへのアクセスを大幅に削減します。
さらに、FlashRNNでは、RNNを複数のヘッドに分割し、ヘッドごとに並列処理を行います。ヘッドサイズとバッチサイズを適切に調整することで、メモリ使用量とパフォーマンスのトレードオフを制御します。
最適なタイリングサイズは、ConstrINTという整数制約充足ソルバーにより自動的に求められます。
$\text{Tiling Size} = \arg\max_{\text{size}} \text{Performance}(\text{size}) \text{ s.t. } \text{Memory Constraints}$
実験結果
FlashRNNの性能を評価するため、言語モデルの学習と状態追跡タスクにおいて、FlashRNN LSTMとTransformerを比較しました。
8つのH100 GPUを用いた学習では、LSTMは単一ヘッドでTransformerより40%程度遅くなりましたが、12ヘッド構成ではその差は25%程度に縮まりました。状態追跡タスクでは、LSTMは完璧な精度を達成し、Transformerを大きく上回りました。
FlashRNNの融合カーネルは、PyTorchの50倍、cuDNNの3倍高速。
バッチサイズが小さい場合は、融合カーネルの方が交互カーネルより2-4倍高速でしたが、大きいバッチでは交互カーネルがスケールしました。
今後の可能性
FlashRNNにより、RNNの推論と学習の高速化が実現しました。これにより、状態追跡が必要な時系列予測やロジカル推論など、Transformerの苦手な問題への応用が期待できます。
今後は、非同期メモリ操作やSRAM間通信など、より新しいGPUの機能を活用した最適化が課題です。FlashRNNのソースコードは公開されており、RNN関連研究の新たな基盤となることが期待されます。