SDG Hubで合成データが作れそうだったので、Training Hub の OSFT も試そうと思っています。が、Runnable なサンプルを動かすのに手間取りました。
いろいろと知識不足なところがあったのでメモしておきます。
環境
-
GCP(GCE)
-
GPU: A100を2枚刺し(VRAM: 80GB)
-
CUDA 12.8 インストール済みの deeplearning-platform-release のイメージを拝借
.tf ファイル(セキュリティの設定などは除外しています)
locals { vm_name = "training-hub-vm" } variable "project_id" { description = "The Google Cloud project ID" type = string } data "google_compute_image" "deeplearning" { name = "common-cu128-ubuntu-2404-nvidia-570-v20251217" project = "deeplearning-platform-release" } # 1. サービスアカウント resource "google_service_account" "training_hub" { project = var.project_id account_id = local.vm_name display_name = "VM (${local.vm_name})用Service Account" } # VM resource "google_compute_instance" "default" { project = var.project_id name = local.vm_name machine_type = "a2-highgpu-2g" # A100*2(vCPUs: 24, RAM: 170 GiB, 80GB) zone = "asia-northeast1-c" boot_disk { initialize_params { image = data.google_compute_image.deeplearning.id size = 256 type = "pd-balanced" } } network_interface { network = "default" access_config { # Ephemeral IPを自動割当 } } metadata = { # ドライバの自動インストール install-nvidia-driver = "True" } scheduling { # GPUを使う場合は TERMINATE が必須 # https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/compute_instance#guest_accelerator-1 on_host_maintenance = "TERMINATE" # Preemptible VM 用の設定(安価だがGCP側の都合により突然停止することあり) # また、provisioning_model を SPOT にしないと値引きされない罠 # https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/compute_instance#provisioning_model-1 preemptible = "true" provisioning_model = "SPOT" automatic_restart = "false" } service_account { email = google_service_account.training_hub.email scopes = ["cloud-platform"] } }
1. インストールにめっちゃ時間がかかる
flash-attn, causal-conv1d, mamba-ssm のインストールにとても時間がかかります。なので、python, cuda, pytorch等のバージョンがマッチしたビルド済みの *.whl をありがたく使わせていただくのが吉(私もビルドして貢献しろという話ですけど…)
昨日言ってたflash-attentionのpre-buildやっておいた
— もりりん (@mjun0812) 2024年10月28日
これで地獄のようなbuild時間から逃れられる......
(なお、Github Actionsで各ビルド2h、計11hかかる)https://t.co/c8QBcjQMwl
都合よくマッチした *.whl が無いことも多いので、先にビルド済みの *.whl に合わせて環境構築したほうが早いと思います。今回はこんな感じでインストールしました
uv init -p 3.11
uv sync
# cuda: 12系, torch: 2.7, python: 3.11,
uv add ninja torch==2.7 psutil
uv add "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp311-cp311-linux_x86_64.whl"
uv add "causal-conv1d @ https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.4/causal_conv1d-1.5.4+cu12torch2.7cxx11abiTRUE-cp311-cp311-linux_x86_64.whl"
uv add "mamba-ssm @ https://github.com/state-spaces/mamba/releases/download/v2.2.6.post3/mamba_ssm-2.2.6.post3+cu12torch2.7cxx11abiTRUE-cp311-cp311-linux_x86_64.whl"
uv add training-hub[cuda] --no-build-isolation
2. GPUのアーキテクチャとの不整合
実は最初は A100 ではなく安価な T4 を使って動作確認をしていました
ですが flash-attn はAmpere以降のA100, L4などを対象としており、T4などの古いアーキテクチャはサポートしていません。ちゃんと README.mdにも書いてありますし、エラーログも分かりやすいのでハマるところでは無いですが、長時間のインストール後に発覚すると悲しい気持ちになれます
3. network gIB not found で落ちた
これはGCEを使っているからかもしれませんが、そもそも NCCL の理解もなかったので結構ハマりました
NCCL のログが出ない
ログがほぼ無いまま落ちるのでサンプルを動かすときには NCCL のログレベルを変更しておいたほうが良いです
os.environ["NCCL_DEBUG"] = "INFO"
NCCL_IB_DISABLE=1 でハマった
今回はVMを1つしか使わないので NCCL/gIB を使いません。Geminiに相談しつつそれらしい設定を追加してみました
# 注意: これはGCPでは設定しちゃだめなやつでした
os.environ["NCCL_IB_DISABLE"] = "1"
結果として以下のような warning は出るものの gIB not found の問題は解消されました
/nccl-shim/src/guest_config_checker.cc:101 NCCL WARN NCCL/NET (shim) mismatch recommended: NCCL_IB_DISABLE=1 (expected unset)
一方でその後の処理で Segmentation Fault (exitcode: -11) が発生するようになり、ハマってしまいました
そうなんです。NCCL_IB_DISABLE=1 は罠だったんですね。
実は NCCL/gIB を無効化するには NCCL_NET を Socket にするのが正しいようです。
os.environ["NCCL_NET"] = "Socket"
GCE では良しなに NCCL の最適化をしています。細かいことは分かっていませんが warning で教えてくれていたように NCCL_IB_DISABLE は設定しちゃだめな変数だったみたいです。nvidia のドキュメントを見ると NCCL_IB_DISABLE でも良さそうな気がしちゃったのですけどね。というか未だによく分かっていないところがあるので勉強が必要ですね。
4. torch 2.9以降じゃないと動かない
pyproject.toml を見ると torch>=2.6 となっています
でも、処理の途中で呼ばれる mini-trainer の osft_utils.py では torch.distributed.send_object_list() や torch.distributed.recv_object_list() に use_batch=True を渡しています。これは torch>=2.9 以降に入ったものでした。
で、torch>=2.9 だと flash-attn などのビルド済みファイルがなく、また今さらビルドする気にもなれず…
今回は use_batch を削除しても問題なさそうだったので、パッチをあてて対応することにしました
これでようやく学習処理が進みました。GPU2台をしっかり使ってくれています
以上、本当は学習結果を紹介して Advent Calendar 2025 を終わろうとしたら想定外にハマってしまったお話でした
