0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

TritonによるPyTorch最適化とLLM処理速度比較実験 (Part 1)

Last updated at Posted at 2025-10-07

こんにちは、アミフィアブル株式会社のAI研究部エンジニアのブイ・クアンです。

この投稿は英語で書いたものをAIで和訳したものをベースにしていますのでご了承ください。

LLMの処理速度という課題(PyTorchとTriton)

大規模言語モデル(LLM)の発展において、推論速度/処理速度が現在ボトルネックの一つとなっています。
MetaのLlama-3.1-8B-InstructのようなLLMの処理をGPU上でいかに効率化するかが、実社会での利用において不可欠になります。

LLM処理には、ディープラーニングの標準フレームワークPyTorchが用いられますが、これをカスタム最適化することにより大幅にパフォーマンスが向上します。それを実現するのがTritonです。

TritonはOpenAIによって開発されたオープンソースのプログラミング言語とコンパイラで、効率的なGPUカーネルを記述するために設計されています。CUDA (C++)のPython風代替品として考えばわかりやすいでしょう。Tritonは、開発者がNVIDIA GPU(およびますます他のハードウェア)上で直接実行されるカスタム操作を定義することを可能にし、並列化、メモリアクセス、計算融合を最適化し、それにより推論を大幅に強化します。
そして、この投稿では、Llama-3.1-8B-Instructモデルにおける標準PyTorch実装とTritonで書かれた高速カーネル群(Liger-Kernel経由)を使った最適化を比較します。

ネタバレ: Tritonによる最適化はバニラPyTorchを上回り、顕著な速度向上を実現します。

データセット: 長文要約のためのBookSum

推論パフォーマンスを評価するためには、特にメモリと計算の限界を押し上げる長いシーケンスを反映した実世界の課題を模倣したデータセットが必要です。ここではデ長文要約のために設計されたBookSumというデータセットを利用します。

BookSumは、複数の粒度(段落レベル、章レベル、全書籍レベル)で人間による高品質な要約を提供します。
章レベルは約12,630サンプルが公開され、訓練、検証、テストセットに分割されています

BookSumは、LLMを拡張テキストでテストするのに理想的です。なぜならば、モデルは数千トークンにわたる非自明な依存関係を扱う必要があり、また、その長いシーケンスは、アテンションとフィードフォワード計算中にGPUリソースをストレスを与えるため、Tritonの最適化(融合操作やメモリ効率の高いカーネル)が輝くからです。短いプロンプトではこれらの利点が明らかにならないかもしれませんが、BookSumの長い入力は、ドキュメント要約や分析のような生産環境に似たシナリオでの効率性を強調します。

私たちの実験では、章レベルでの検証分割から最初の10サンプルを使用しました。各サンプルには、しばしば3,000から8,000トークンの長い完全な書籍の章と、簡潔な要約が含まれています。例えば、「The Last of the Mohicans」のような古典の章は、森を通じた追跡やキャラクター分析などの詳細なナラティブ、対話、描写、プロット展開を特徴とします。
私たちは「Summarize the following chapter: [chapter text]」のようなプロンプトを作成し、入力が4,000トークンを超える可能性があり、さらに最大512生成トークンを追加します。

ベースライン: LlamaでのPyTorch推論

私たちのテストでは、Llama-3.1-8B-Instructモデルを使ったバニラPyTorchセットアップを使用しました。説明したように、BookSumデータセットから10の要約プロンプトを作成し、それぞれで最大512の新しいトークンを生成しました。

torch.float16精度でCUDA対応GPU上で実行した標準PyTorchの処理速度の結果は以下の通りです:

  • プロンプトあたりの平均時間: 25.7004秒
  • 10プロンプトの合計時間: 257.00秒
  • 生成トークン合計: 5120トークン
  • トークン/秒: 19.9トークン

これは堅実なパフォーマンスですが、PyTorchの汎用オペレータは、特にアテンションやフィードフォワードネットワークのような計算集約的なレイヤーで最適化の余地を残します。

Tritonの紹介: 定義とその重要性

冒頭で紹介したTritonの本質は、GPUプログラミングの複雑さを抽象化してシンプルに書けるようにすることです。

  • カーネルを書く: GPUで「どうやってデータを読むか、計算するか、保存するか」を、Pythonライクな書き方で関数として書けます。

  • 自動で最適化してくれる: GPUで速く動かすためには、通常なら「並列化の方法」「メモリの使い方」「処理の分割方法」など細かい調整が必要です。Tritonはそれを自動で調整してくれるので、CUDAの専門知識がなくてもGPUをほぼ限界性能で使えます。

  • PyTorchとのつながり: Tritonで作ったカーネルは、PyTorchのtorch.compileを通して呼び出したり、直接オペレーションとして組み込んだりできます。

つまり Triton は、ディープラーニングでよく使う基本的な処理を、GPUに合わせて超効率的に実行できるコードを自動で作ってくれるツールです。

大規模言語モデル(LLaMA など)では、推論の大半の時間が「行列の掛け算」「正規化処理」「活性化関数」に費やされます。Triton を使うと、これらの処理を GPU に最適化して速く実行できる、というのが大きな強みです。

使用方法: Ligerを使ったLlamaへのTritonカーネル統合

Tritonを実践的に活用するため、私たちはLiger-Kernelというオープンソースライブラリを利用しました。これは、Llamaモデル向けに事前記述されたTritonカーネルを提供します。Ligerは、RoPE(Rotary Position Embeddings)、SwiGLU活性化、RMSNorm、クロスエントロピー損失などの重要な操作を最適化し、すべてTritonで実装されて最大効率を実現します。

コードでの使用方法はシンプルです:

  1. 依存関係のインストール: Tritonをインストール(pip install triton)し、transformersとliger-kernelもいれます。
  1. カーネルの適用: Ligerのapply_liger_kernel_to_llama関数を使ってモデルをパッチ:
    from liger_kernel.transformers import apply_liger_kernel_to_llama
    
    apply_liger_kernel_to_llama(
        rope=True,      # Optimize rotary embeddings
        swiglu=True,    # Optimize SwiGLU activations
        cross_entropy=True,  # Faster loss computation
        rms_norm=True   # Optimized normalization
    )
    
  2. モデルのロードと実行: Hugging FaceのAutoModelForCausalLMgenerateメソッドを通常通り進めます。カーネルがPyTorchのデフォルトを自動的に置き換えます。

私たちのスクリプトでは、Ligerをオンに切り替え、同じ要約タスクを再実行しました。大きなコード変更はなく—最適化を適用する数行だけです。

ベンチマークスクリプト: 比較のための完全セットアップ

比較を再現可能にするため、私たちが使用したコアベンチマークスクリプトを紹介します。これにはインポート、データセットロード、プロンプト準備、推論関数、メインストリクトが含まれます。フラグusing_ligerを設定して、標準PyTorchベースラインとLiger経由のTriton最適化バージョンを切り替えます。これにより、最小限の変更で両方のセットアップを実行できます。

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from liger_kernel.transformers import apply_liger_kernel_to_llama

# Function to run inference and measure time + tokens
def run_inference(model, tokenizer, prompts, device):
    model.eval()
    total_time = 0
    total_generated_tokens = 0
    with torch.no_grad():
        for prompt in prompts:
            inputs = tokenizer(prompt, return_tensors="pt").to(device)  
            print(f"Input ids shape: {inputs.input_ids.shape}")
            start_time = time.time()
            outputs = model.generate(
                **inputs,
                max_new_tokens=512,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id
            )
            end_time = time.time()
            total_time += (end_time - start_time)

            # Calculate generated tokens
            input_len = inputs.input_ids.shape[1]   
            generated_seq = outputs[0, input_len:]  
            generated_number = generated_seq.shape[0]

            total_generated_tokens += generated_number

    avg_time_per_prompt = total_time / len(prompts)
    return avg_time_per_prompt, total_time, total_generated_tokens

# Load a sample dataset for inference
dataset = load_dataset("kmfoda/booksum", split="validation[:10]")
prompts = [f"Summarize the following chapter:\n{chapter}" for chapter in dataset["chapter"]]  # Long inputs for summarization

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
using_liger = False
if using_liger:
    print("Using Liger kernel for optimized inference...")
    # Triton version with Liger kernel
    apply_liger_kernel_to_llama(
        rope=True,
        swiglu=True,
        cross_entropy=True,
        fused_linear_cross_entropy=False,
        rms_norm=True
        )
# Standard Meta-Llama-3.1-8B-Instruct (no Liger)    
model_id = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer_standard = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer_standard.pad_token = tokenizer_standard.eos_token
model_standard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
avg_time_standard, total_time_standard, total_generated_tokens = run_inference(model_standard, tokenizer_standard, prompts, device)
if using_liger:
    print(f"Liger Model - Avg time per prompt: {avg_time_standard:.4f}s | Total time: {total_time_standard:.2f}s | Total generated tokens: {total_generated_tokens}")
else:
    print(f"Standard Model - Avg time per prompt: {avg_time_standard:.4f}s | Total time: {total_time_standard:.2f}s | Total generated tokens: {total_generated_tokens}")

主要部分の説明

  • インポートとセットアップ: 必要なライブラリをインポートします。これにはtransformers(モデルロード用)、datasets(BookSum用)、liger_kernel(最適化用)が含まれます。デバイスはCUDAが利用可能ならCUDAに設定されます。
  • 推論関数 (run_inference): ベンチマークの中心です。プロンプトをバッチで処理(デフォルトbatch_size=1で長いシーケンスのOOMエラーを回避)、入力をトークナイズ、model.generateを使って生成時間を測定、生成トークンを計算します。関数はプロンプトあたりの平均時間、合計時間、合計生成トークンを返します。
  • データセットとプロンプト: BookSumの検証分割から10の章サンプルをロード。プロンプトは章テキストの前に「Summarize the following chapter:\n」を付けて構築し、モデルを要約に向けます。これらの長い入力(しばしば>4kトークン)は実世界の効率をテストします。
  • Ligerトグル: using_liger = Trueを設定すると、モデルロード前にTritonカーネルを適用します。これによりRoPE、SwiGLU、RMSNormなどの操作を最適化バージョンにパッチします。

このスクリプトは自己完結型で、直接実行可能です(依存関係がインストールされている場合)。結果を出力し、私たちはこれを使ってトークン/秒を計算しました(total_generated_tokens / total_time)。

TritonカーネルがPyTorchより速い理由

Tritonの速度優位性は、GPU効率のための設計に由来し、PyTorchの標準実装をいくつかの点で上回ります:

  • 操作融合: PyTorchは各操作(例: 行列乗算の後に活性化)ごとに別々のカーネルを起動し、複数のGPU呼び出しによるオーバーヘッドが発生します。Tritonはこれらを単一のカーネルに融合し、起動遅延とメモリ帯域使用を削減します。Llamaでは、融合RMSNorm + SwiGLUレイヤーで輝き、重複データ移動を削減します。

  • メモリ最適化: Tritonはメモリアクセスパターンを精密に制御可能で、データを高速共有メモリにフィットさせるタイル分割や、コアレスロードでキャッシュミスを最小化します。PyTorchのオペレータはより汎用的で、大規模テンソル中のアテンションやフィードフォワードパスでサブオプティマルなメモリ使用を引き起こします。

  • ハードウェアに合わせた並列化: Tritonのブロックベースプログラミングで、スレッドとワープの使用を明示的に定義し、GPUアーキテクチャ(例: AmpereやHopper)と一致します。これにより、PyTorchのワンサイズフィットオールアプローチに比べて高い占有率とスループットを実現します。

  • カスタムオペのオーバーヘッド削減: LLMではRoPEのような操作が複雑な三角関数と回転を伴います。TritonのコンパイルカーネルはPyTorchの解釈版やJITコンパイル版より速く実行され、特にスケールで顕著です。

ベンチマークでは、LigerのようなTritonベース最適化はモデルサイズ、シーケンス長、ハードウェアに応じて5-20%の速度向上をもたらします。私たちの結果はこれに一致します:

  • プロンプトあたりの平均時間: 24.4397秒(vs. 25.7004s)
  • 10プロンプトの合計時間: 244.40秒(vs. 257.00s)
  • 生成トークン合計: 5120(生成によるわずかな変動)
  • トークン/秒: 21(vs 19.9トークン)

これは、より多くのプロンプトを速く処理することを意味し—生産環境のバッチ推論で重要です。長いシーケンスや大きなモデルでは、利益が複合します。Tritonの効率が計算強度とともにスケールするからです。

指標 PyTorchベースライン Triton (via Liger) 改善率
プロンプトあたりの平均時間 (s) 25.7004 24.4397 +5%
合計時間 (s) 257.00 244.40 +5.1%
トークン/秒 19.9 21 +5.24%

注: テストはfloat16精度の単一GPUで実行。結果はハードウェアによって異なりますが、相対的な速度向上は保持されます。

Introducing Triton: Open-source GPU programming for ...
Welcome to Triton's documentation! — Triton documentation
Introduction to torch.compile
torch.compiler
BookSum: A Collection of Datasets for Long-form Narrative Summarization
kmfoda/booksum · Datasets at Hugging Face
meta-llama/Llama-3.1-8B-Instruct - Hugging Face
linkedin/Liger-Kernel: Efficient Triton Kernels for LLM ...
liger-kernel

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?