0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

RTX 5090 (Blackwellアーキテクチャ) におけるPyTorch cuBLASエラー (CUBLAS_STATUS_INVALID_VALUE) の検証と解決

Posted at

RTX 5090におけるPyTorch cuBLASエラー (CUBLAS_STATUS_INVALID_VALUE) の検証報告

1. 概要

NVIDIA GeForce RTX 5090 (Blackwellアーキテクチャ) 環境において、LLM (Qwen-3) の推論実行時にプロセスがクラッシュする事象を確認した。エラーログには CUBLAS_STATUS_INVALID_VALUE が記録されており、cuBLASライブラリ呼び出し時の不正なパラメータまたは互換性の問題が示唆された。

本検証の結果、当該事象は PyTorch 2.10.0 環境下における低精度浮動小数点演算(FP16/BF16)の行列積実行時に特異的に発生することが判明した。解決策として、PyTorch 2.8.0 へのバージョン変更により、正常に推論が実行できることを確認した。

2. 発生環境

  • GPU: NVIDIA GeForce RTX 5090
  • CUDA Driver: 12.8
  • Python: 3.12
  • 不具合発生時のPyTorchバージョン: 2.10.0+cu128
  • 検証モデル: Qwen/Qwen3-4B-Instruct-2507

3. 検証プロセス

不具合の発生箇所を特定するため、以下の3段階のスクリプトによる切り分けを実施した。

  1. Core Matrix Operations (cuBLAS Check)
    • Transformersライブラリを介さず、PyTorchの torch.matmul を直接使用してGPU上での行列演算を実行。
    • データ型による挙動の違いを確認(FP32, FP16, BFloat16)。
  2. Transformers Execution (Real-world Load)
    • 実際のLLM推論フローにおける動作検証。
    • Eager ModeおよびSDPA (Scaled Dot Product Attention) の両モードで検証。
  3. BitsAndBytes Workarounds
    • 量子化ライブラリ(BitsAndBytes)使用時の動作検証(4-bit NF4 / 8-bit Int8)。

4. 検証結果詳細

PyTorchのバージョンによる実行結果の比較を以下に示す。

4.1. バージョン別 動作比較サマリ

検証フェーズ 項目 PyTorch 2.10.0 (不具合あり) PyTorch 2.8.0 (正常)
1. 行列演算 FP32 (TF32) PASS ✅ PASS
FP16 (Half) FAIL ✅ PASS
BF16 (BFloat16) FAIL ✅ PASS
2. LLM推論 BF16 + Eager FAIL ✅ PASS
BF16 + SDPA FAIL ✅ PASS
3. 量子化 4-bit / 8-bit FAIL ✅ PASS

4.2. エラー内容の分析 (PyTorch 2.10.0)

PyTorch 2.10.0 環境においては、FP32を除くすべての演算テストで以下のエラーが発生した。

Error Message:

CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasGemmEx(...)

このエラーは、cublasGemmEx 関数への引数が不正であることを示している。具体的には、CUDA_R_16F (FP16) および CUDA_R_16BF (BF16) データ型を指定した行列演算においてのみ発生しており、FP32 (CUDA_R_32F) では発生していないことから、Blackwellアーキテクチャ向けCUDAカーネル(特に低精度演算用)とPyTorch 2.10.0間のバインディングまたは実装に不具合があると推測される。

TransformersライブラリやBitsAndBytesも内部的にFP16/BF16形式の行列演算、あるいはAccumulation(累積和)を使用するため、連鎖的にクラッシュが発生したと考えられる。

4.3. 正常動作の確認 (PyTorch 2.8.0)

PyTorchバージョンを 2.8.0+cu128 に変更した環境では、すべてのテストケース(FP16/BF16の行列演算、Transformersによる文章生成、量子化推論)が正常に完了した。

5. 結論と推奨対応

結論

RTX 5090環境において発生した CUBLAS_STATUS_INVALID_VALUE エラーは、PyTorch 2.10.0 における低精度浮動小数点演算(FP16/BF16)の不具合に起因するものである。

推奨対応

現状のRTX 5090環境においてQwen-3等のLLM推論を安定稼働させるためには、以下の対応を推奨する。

  • PyTorchバージョンのダウングレード:

本対応により、FP16/BF16精度を用いた推論およびBitsAndBytesによる量子化推論の正常動作が見込まれる。

補足

下記に検証に使用したスクリプトとその実行結果を示す

検証用スクリプト

import gc

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# --- Configuration ---
MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
DEVICE = "cuda"


def print_header(title):
    print("\n" + "=" * 70)
    print(f" {title}")
    print("=" * 70)


def log_result(test_name, status, error_msg=None):
    if status:
        print(f"[{test_name:<35}] -> ✅ PASS")
    else:
        print(f"[{test_name:<35}] -> ❌ FAIL")
        if error_msg:
            # エラーメッセージを短縮して表示(改行を除去)
            short_msg = str(error_msg).replace("\n", " ")[:120]
            print(f"    Error: {short_msg}...")


def clear_cache():
    """GPUメモリを強制解放"""
    if torch.cuda.is_available():
        gc.collect()
        torch.cuda.empty_cache()


# ==========================================
# Phase 1: Core Matrix Operations (Section 4.1 & 4.3)
# ==========================================
def run_matmul_test(dtype, label):
    try:
        # Section 4.3: 2の累乗サイズでの最小再現コード
        m, k, n = 128, 4096, 4096
        a = torch.randn(m, k, device=DEVICE, dtype=dtype)
        b = torch.randn(k, n, device=DEVICE, dtype=dtype)

        # 行列演算実行
        c = torch.matmul(a, b)
        torch.cuda.synchronize()

        log_result(label, True)
        return True
    except Exception as e:
        log_result(label, False, e)
        return False


# ==========================================
# Phase 2: Transformers Execution (Section 4.2 & 6)
# ==========================================
def run_transformers_test(dtype, attn_impl, label):
    print(f"   Running Transformers load & gen ({label})...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

        # device_map=None にして手動でto(DEVICE)することで制御を明確化
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            dtype=dtype,
            attn_implementation=attn_impl,
            trust_remote_code=True,
            device_map=None,
        )
        model.to(DEVICE)

        inputs = tokenizer("Hello", return_tensors="pt").to(DEVICE)
        with torch.inference_mode():
            _ = model.generate(**inputs, max_new_tokens=10)

        log_result(f"Transformers ({label})", True)
        del model, tokenizer
        return True
    except Exception as e:
        log_result(f"Transformers ({label})", False, e)
        return False
    finally:
        clear_cache()


# ==========================================
# Phase 3: BitsAndBytes Workarounds (Section 4.2)
# ==========================================
def run_bnb_test(load_in_4bit, load_in_8bit, label):
    print(f"   Running BitsAndBytes test ({label})...")
    try:
        kwargs = {}
        if load_in_4bit:
            kwargs["load_in_4bit"] = True
            kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16
        if load_in_8bit:
            kwargs["load_in_8bit"] = True
            kwargs["llm_int8_threshold"] = 6.0

        bnb_config = BitsAndBytesConfig(**kwargs)
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

        # BnB利用時は device_map="auto" が推奨される
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            quantization_config=bnb_config,
            trust_remote_code=True,
            device_map="auto",
        )

        inputs = tokenizer("Hello", return_tensors="pt").to(DEVICE)
        with torch.inference_mode():
            _ = model.generate(**inputs, max_new_tokens=10)

        log_result(f"BnB ({label})", True)
        del model, tokenizer
        return True
    except Exception as e:
        log_result(f"BnB ({label})", False, e)
        return False
    finally:
        clear_cache()


# ==========================================
# Main Execution Flow
# ==========================================
def main():
    print_header("RTX 5090 (Blackwell) Compatibility Verification Report Check")
    print(f"PyTorch Version : {torch.__version__}")
    print(f"CUDA Version    : {torch.version.cuda}")
    print(
        f"Device Name     : {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}"
    )
    print(f"Test Model      : {MODEL_ID}")

    # --- Phase 1: Core Matrix Operations ---
    print_header("1. Core Matrix Operations (cuBLAS Check)")
    torch.backends.cuda.matmul.allow_tf32 = True

    # Float32
    run_matmul_test(torch.float32, "FP32 (TF32=True)")

    # Float16
    run_matmul_test(torch.float16, "FP16 (Half)")

    # BFloat16
    run_matmul_test(torch.bfloat16, "BFloat16 (GemmEx)")

    # --- Phase 2: Transformers Execution ---
    print_header("2. Transformers Execution (Real-world Load)")

    # Eager Mode (SDPA無効)
    run_transformers_test(torch.bfloat16, "eager", "BF16 + Eager Mode")

    # Solution Verification (SDPA有効)
    run_transformers_test(torch.bfloat16, "sdpa", "BF16 + SDPA Mode")

    # --- Phase 3: BitsAndBytes Workarounds ---
    print_header("3. BitsAndBytes Workarounds")

    # 4-bit
    run_bnb_test(load_in_4bit=True, load_in_8bit=False, label="4-bit (NF4)")

    # 8-bit
    run_bnb_test(load_in_4bit=False, load_in_8bit=True, label="8-bit (Int8)")

    print("\n" + "=" * 70)
    print(" Verification Complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()

実行結果(pytorch 2.8.0)

test.py

======================================================================
 RTX 5090 (Blackwell) Compatibility Verification Report Check
======================================================================
PyTorch Version : 2.8.0+cu128
CUDA Version    : 12.8
Device Name     : NVIDIA GeForce RTX 5090
Test Model      : Qwen/Qwen3-4B-Instruct-2507

======================================================================
 1. Core Matrix Operations (cuBLAS Check)
======================================================================
[FP32 (TF32=True)                   ] -> ✅ PASS
[FP16 (Half)                        ] -> ✅ PASS
[BFloat16 (GemmEx)                  ] -> ✅ PASS

======================================================================
 2. Transformers Execution (Real-world Load)
======================================================================
   Running Transformers load & gen (BF16 + Eager Mode)...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 255.88it/s]
[Transformers (BF16 + Eager Mode)   ] -> ✅ PASS
   Running Transformers load & gen (BF16 + SDPA Mode)...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 260.07it/s]
[Transformers (BF16 + SDPA Mode)    ] -> ✅ PASS

======================================================================
 3. BitsAndBytes Workarounds
======================================================================
   Running BitsAndBytes test (4-bit (NF4))...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.15s/it]
[BnB (4-bit (NF4))                  ] -> ✅ PASS
   Running BitsAndBytes test (8-bit (Int8))...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.02it/s]
[BnB (8-bit (Int8))                 ] -> ✅ PASS

======================================================================
 Verification Complete.
======================================================================

実行結果(pytorch 10.0.0)

test.py

======================================================================
 RTX 5090 (Blackwell) Compatibility Verification Report Check
======================================================================
PyTorch Version : 2.10.0+cu128
CUDA Version    : 12.8
Device Name     : NVIDIA GeForce RTX 5090
Test Model      : Qwen/Qwen3-4B-Instruct-2507

======================================================================
 1. Core Matrix Operations (cuBLAS Check)
======================================================================
[FP32 (TF32=True)                   ] -> ✅ PASS
[FP16 (Half)                        ] -> ❌ FAIL
    Error: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasGemmEx( handle, opa, opb, m, n, k, alpha_ptr, a, CUDA_R_16F,...
[BFloat16 (GemmEx)                  ] -> ❌ FAIL
    Error: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, ...

======================================================================
 2. Transformers Execution (Real-world Load)
======================================================================
   Running Transformers load & gen (BF16 + Eager Mode)...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 227.64it/s]
[Transformers (BF16 + Eager Mode)   ] -> ❌ FAIL
    Error: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, ...
   Running Transformers load & gen (BF16 + SDPA Mode)...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 263.69it/s]
[Transformers (BF16 + SDPA Mode)    ] -> ❌ FAIL
    Error: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, ...

======================================================================
 3. BitsAndBytes Workarounds
======================================================================
   Running BitsAndBytes test (4-bit (NF4))...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.18s/it]
[BnB (4-bit (NF4))                  ] -> ❌ FAIL
    Error: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasGemmEx( handle, opa, opb, m, n, k, alpha_ptr, a, CUDA_R_16F,...
   Running BitsAndBytes test (8-bit (Int8))...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.06s/it]
[BnB (8-bit (Int8))                 ] -> ❌ FAIL
    Error: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasGemmEx( handle, opa, opb, m, n, k, alpha_ptr, a, CUDA_R_16F,...

======================================================================
 Verification Complete.
======================================================================
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?