Hierarchical Reasoning Model (HRM)という新しい推論AIが俄かに注目を集めている。自然言語での推論能力の向上を目指してきたLLMとは異なり,自然言語によらず時間軸の異なる2種類の再帰モジュールで推論するAIである。
Hierarchical Reasoning Modelのコードが論文著者たちの公式レポジトリhttps://github.com/sapientinc/HRM にあるため,手元のWSL2環境で動かせるRTX 3060(VRAM 12GB)1個で再現した。オリジナルの論文・公式レポジトリでは,ARC-AGI-2,Sudoku 9x9 Extreme,Maze 30x30 Hardという3種類のデータそれぞれで,モデル構築を行い評価している。今回私は,このうちQuick Demoとして紹介されており,GPUを8個使い分散学習すればたった10分で事前学習できるというSudoku Extreme(平たく言えば難しい数独)の学習に実際に取り組んだ。
学習の結果はおおむね論文通り再現された。しかし,必ずしもVRAMの大きくないGPU1個では,元の論文で主張されているような高速・高効率な学習と推論はできなかった。
また,そもそも再現までの実装に結構手間取った。HRMの説明自体は日経XTECHはじめ日本語で書かれた各種媒体で目にするが,実際に実装した結果をまとめた日本語記事は見つけられなかった。そのため,この再現の取り組みをまとめて備忘とするとともに,皆さんにご批正賜りたく思う。
以下,本稿は「再現してみての結果と所感」と「再現実装」の大きく2節構成である。お手元で再現されたい方は,「再現実装」まで読み進められたい。
再現してみての結果と所感: 論文通りの学習結果だが,VRAMの小さいGPUではやはり実行に時間がかかる
- 手元で構築した事前学習済みモデルによる推論(評価)の結果は,正解率48.2%。元の報告を下回ったが,下回った理由も含めて想定内と思える結果
- ただし,VRAM 12GBのGPUでは,学習にも推論にもかなり時間がかかる。やはりクラウドコンピューティングサービスで,高性能GPUを使った方がいい
- 学習: 1日と7時間38分
- 推論: 1時間25分
- (そもそもHRMは,数独やARC-AGI[色板の並べ方のパターンを複数例見た後で,その規則に従い適切に色板を並べるタスク]のような非言語的記号の操作・座標を扱うタスクの推論を行うモデルのようである。ロボットや自動運転など座標が関わる用途には向いているかもしれないと思った。一方,通常のLLMに期待するような自然言語が関わるタスク,たとえば内容が充実したレポート生成や有効なネクストアクション・戦略作成などをHRMでも行えるようにするには,学習データをかなり工夫しなければならないと思う。ただ,本来はそのような自然言語が関わるタスクはHRMの範囲外かもしれない。)
再現実装
環境構築
WSL2環境について
- OS: Ubuntu 24.04
- ローカルレポジトリの場所:
/home/<user_name>/<HRMの再現レポジトリ>
=~/<HRMの再現レポジトリ>
- Pythonの環境: 「Python環境の構築」節にて後述
※WSL2でHRMを再現する方法として,https://github.com/sapientinc/HRM/pull/37 で紹介されている通りmnt/c/...
(WSL2環境から参照するWindowsネイティブ環境位置)に再現レポジトリを置くこともできる。ただ,この方法では,WSL2環境とWindowsネイティブ環境の間でファイルの読み取り・書き込みが生じる。この環境間をまたぐデータ入出力はタイムロスにつながる(参考: WindowsとLinux間でのファイルの読み取り・書き込み性能は,WSL2の方がWSL1より遅い)ため,WSL2環境にhttps://github.com/sapientinc/HRM をgit cloneすることで,WSL2環境単独で完結する再現環境を作った。
CUDA 12.6のインストール(HRMの手順通り)
# Install CUDA 12.6
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run
wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
sudo sh cuda_installer.run --silent --toolkit --override
export CUDA_HOME=/usr/local/cuda-12.6
※上記の操作はHRMの手順通りだが,実際にはこの後,HRMの手順にない内容として,
-
~/.bashrc
に下記の3行を追加し保存し,export CUDA_HOME=/usr/local/cuda-12.6 export LD_LIBRARY_PATH=/usr/local/cuda-12.6/lib64:$LD_LIBRARY_PATH export PATH=/usr/local/cuda-12.6/bin:$PATH
-
source ~/.bashrc
を実行し, -
nvcc --version
でCuda compilation tools, release 12.6, V12.6.85
と表示されることを確認した。
これらの操作により,確実にCUDA 12.6が使えることを確かめた。
Python環境の構築(HRMの手順通りに行わない)
パッケージのインストールには,HRMのREADME.mdで使われているpipではなく,uvを使った。uvの方がパッケージインストールが高速で,パッケージのバージョンコントロールも容易だからだ。uvによるパッケージインストールの際に,以下のようなpyproject.tomlを作った。こちらをuvとPython仮想環境(Python 3.13以上)がある環境にダウンロードすれば,uv sync --extra build; uv sync --extra build --extra compile
をターミナルで実行することでパッケージインストールができるはずだ。なお,単にuv sync
ではflash-attnなどが正常にインストールできないことに注意。
なお,HRM公式レポジトリのissueのひとつに「pyproject.tomlにより完全なパッケージバージョンコントロールを行わないか」旨提案があるが,却下されている。私個人はpyproject.tomlがあった方がパッケージインストール時の事故は防ぎやすいと考える。
[project]
name = "hrm-replicate"
version = "0.1.0"
description = "A uv project to replicate the implementation of Hierarchical Reasoning Model easily"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"adam-atan2-pytorch>=0.1.18",
"argdantic>=1.3.3",
"coolname>=2.2.0",
"einops>=0.8.1",
"flash-attn",
"huggingface-hub>=0.34.4",
"hydra-core>=1.3.2",
"ninja>=1.11.1.4",
"omegaconf>=2.3.0",
"packaging>=24.1",
"pydantic>=2.11.7",
"setuptools>=70.2.0",
"setuptools-scm>=8.3.1",
"torch<2.8.0",
"torchaudio<2.8.0",
"torchvision<0.23.0",
"tqdm>=4.66.5",
"wandb>=0.21.1",
"wheel>=0.45.1",
]
[tool.uv]
extra-index-url = ["https://download.pytorch.org/whl/cu126"]
no-build-isolation-package = ["flash-attn", "adam-atan2"]
[[tool.uv.index]]
url = "https://download.pytorch.org/whl/cu126"
[project.optional-dependencies]
build = ["torch", "setuptools", "packaging"]
compile = ["flash-attn"]
[[tool.uv.dependency-metadata]]
name = "flash-attn"
version = "2.8.2"
requires-dist = ["torch", "einops"]
[tool.uv.sources]
flash-attn = { url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.2/flash_attn-2.8.2+cu12torch2.7cxx11abiTRUE-cp313-cp313-linux_x86_64.whl" }
このpyproject.tomlでパッケージをインストールした仮想環境には,以下のようなパッケージが入る。
`uv pip list`の実行結果
Package Version
------------------------ ------------
adam-atan2-pytorch 0.1.18
annotated-types 0.7.0
antlr4-python3-runtime 4.9.3
argdantic 1.3.3
certifi 2022.12.7
charset-normalizer 2.1.1
click 8.2.1
coolname 2.2.0
einops 0.8.1
filelock 3.13.1
flash-attn 2.8.2
fsspec 2024.6.1
gitdb 4.0.12
gitpython 3.1.45
hf-xet 1.1.7
huggingface-hub 0.34.4
hydra-core 1.3.2
idna 3.4
jinja2 3.1.4
markupsafe 3.0.2
mpmath 1.3.0
networkx 3.3
ninja 1.11.1.4
numpy 2.1.2
nvidia-cublas-cu12 12.6.4.1
nvidia-cuda-cupti-cu12 12.6.80
nvidia-cuda-nvrtc-cu12 12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12 9.5.1.17
nvidia-cufft-cu12 11.3.0.4
nvidia-cufile-cu12 1.11.1.6
nvidia-curand-cu12 10.3.7.77
nvidia-cusolver-cu12 11.7.1.2
nvidia-cusparse-cu12 12.5.4.2
nvidia-cusparselt-cu12 0.6.3
nvidia-nccl-cu12 2.26.2
nvidia-nvjitlink-cu12 12.6.85
nvidia-nvtx-cu12 12.6.77
omegaconf 2.3.0
packaging 24.1
pillow 11.0.0
platformdirs 4.3.8
protobuf 6.31.1
pydantic 2.11.7
pydantic-core 2.33.2
pydantic-settings 2.10.1
python-dotenv 1.1.1
pyyaml 6.0.2
requests 2.28.1
sentry-sdk 2.34.1
setuptools 70.2.0
setuptools-scm 8.3.1
smmap 5.0.2
sympy 1.13.3
torch 2.7.1+cu126
torchaudio 2.7.1+cu126
torchvision 0.22.1+cu126
tqdm 4.66.5
triton 3.3.1
typing-extensions 4.12.2
typing-inspection 0.4.1
urllib3 1.26.13
wandb 0.21.1
wheel 0.45.1
uvにした際のポイント
- flash-attnがtorch 2.7までしかサポートしていないため,torch<2.8.0にした
- flash-attnのインストールを高速化するため,flash-attnのレポジトリで配布されているwhlファイルのうち,環境に合うものを選んでインストールする設定にした(
flash-attn = { url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.2/flash_attn-2.8.2+cu12torch2.7cxx11abiTRUE-cp313-cp313-linux_x86_64.whl" }
の箇所)
参考: https://github.com/Dao-AILab/flash-attention/issues/945#issuecomment-2948520692 - オリジナルのHRMではadam-atan2というパッケージを使うが,事前学習の実行時に
No module named 'adam_atan2_backend'
というエラーが出る。そのため,下記の2操作を行った- adam-atan2ではなく,adam-atan2-pytorchをインストール
- ターミナル(bash)で
sed -i 's/adam_atan2/adam_atan2_pytorch/g' pretrain.py; sed -i 's/AdamATan2/AdamAtan2/g' pretrain.py; sed -i 's/lr=0,/lr=0.0001,/g' pretrain.py
の実行により,adam-atan2-pytorchが使えるようにpretrain.pyを書き換え
参考: https://github.com/sapientinc/HRM/issues/25#issuecomment-3162492484
sudoke-extreamの事前学習の実行
W&Bへのログイン(HRMの手順通り)
wandb login
Quick Demo用の学習データを準備(HRMの手順通り)
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000
Quick Demo用の学習を回す(HRMの手順通りに行わない)
OMP_NUM_THREADS=32 uv run python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=768 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0
- オリジナルの
OMP_NUM_THREADS=8
から,手元環境のCPU数に合わせOMP_NUM_THREADS=32
に増加 - オリジナルの
global_batch_size=384
から,手元環境のGPU VRAMに合わせglobal_batch_size=768
に増加- 結果,GPU VRAM専用GPUメモリいっぱいの11.6GBまで学習に使えた
- だが,公式READMEには,single GPU, smaller batch sizeと案内されているため,おそらく正しくないかもしれない
- ただ,
global_batch_size
を384以下にすると,GPU VRAM専用GPUメモリにだいぶ空きができる(global_batch_size=384
で7.5GBしか使われない)ため,GPU VRAM専用GPUメモリがフル活用されていないように感じられてしまう…
HRM公式READMEには,8個のGPUで分散学習すればOMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
は10分で完了すると書かれていた。だが,実際には自分の環境で1日と7時間38分もかかった。VRAM 12GBのGPU1個では,ここまで時間がかかるのである。やはり27Mモデルであっても「学習の沙汰も金次第」なんだなあと悲しくなった。
なお,今回使用したパソコンは液冷式だが,1日と7時間38分もの間,GPUの温度は80℃前後だった。GPUが焼け付きそうな環境では実施するのは困難である。やはり,適切なクラウドコンピューティングサービスで大容量VRAMのGPUを使う方がいい。
学習完了後のターミナル画面
wandb: Run history:
wandb: num_params ▁
wandb: train/accuracy ▁▆▆▆▆▆▆▇▇▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇█▇▇▇████▇
wandb: train/count ▁▁▁█████████████████████████████████████
wandb: train/exact_accuracy ▁▂▁▁▁▁▁▁▃▂▃▁▃▂▅▃▄▄▄▄▅▇▇▅▆▆▇▇█▇▆▅▄▆▇██▆▇▇
wandb: train/lm_loss █▅▅▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▁▁▁
wandb: train/lr ▁▂▃▃▅███████████████████████████████████
wandb: train/q_continue_loss ▁▁▁▁▁▅▃▃▂▅▅▄▃▃▃▄▄▇▄▃▄▅▅▅▅▆▆▆▆▆█▇▆█▇█▇▆██
wandb: train/q_halt_accuracy ▁███████████████████████████████████████
wandb: train/q_halt_loss ▂▁▁▁▁▃▂▂▃▂▅▂▃▃▇▁▃▆▄▃▃▄▃▇▄▅▇▃▄▄▃▄▅█▄▅▆▄▃▃
wandb: train/steps ▁▁▁▁████████▇▇▇██▇▇▇▇▇▇▆▇▆▆▆▆▅▅█▇█▅██▅▆▅
wandb:
wandb: Run summary:
wandb: num_params 27275266
wandb: train/accuracy 0.91679
wandb: train/count 1
wandb: train/exact_accuracy 0.73973
wandb: train/lm_loss 0.59844
wandb: train/lr 7e-05
wandb: train/q_continue_loss 0.49648
wandb: train/q_halt_accuracy 1
wandb: train/q_halt_loss 0.01292
wandb: train/steps 8.47945
wandb:
wandb: 🚀 View run HierarchicalReasoningModel_ACTV1 <run_name> at: https://wandb.ai/<user_name>/Sudoku-extreme-1k-aug-1000%20ACT-torch/runs/<run_id>
wandb: ⭐️ View project at: https://wandb.ai/<user_name>/Sudoku-extreme-1k-aug-1000%20ACT-torch
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/<run-id>/logs
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 26040/26041 [31:38:55<00:04, 4.38s/it]
しつこいが,1日と7時間38分,学習に費やしている。
学習後のW&Bのダッシュボード
事前学習済みモデルの評価試行(推論の実行)
> OMP_NUM_THREADS=1 torchrun --nproc-per-node 1 evaluate.py checkpoint=checkpoints/Sudoku-extreme-1k-aug-1000\ ACT-torch/HierarchicalReasoningModel_ACTV1\ <自動で付与される試行の名称>/step_26040
上記コードにより,data/sudoku-extreme-1k-aug-1000/test/dataset.jsonのtotal_groups
を見る限り422,786問もの数独の問題を解く推論が始まる。
実行結果は,下記の通り。学習と同様,推論にも時間がかかっている。今回の推論は1時間25分だった。なお,H200(VRAM 141GB)1個でSudoku Extremeの評価を実行すると数分で完了するというコメントがあった。推論実行でも,「持つべきものは高性能なGPU」である。
Starting evaluation
[rank0]:W0811 10:05:34.779000 2170142 .venv/lib/python3.13/site-packages/torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode
{'all': {'accuracy': np.float32(0.8218188), 'exact_accuracy': np.float32(0.4816905), 'lm_loss': np.float32(0.41667807), 'q_halt_accuracy': np.float32(0.9927883), 'q_halt_loss': np.float32(0.044404175), 'steps': np.float32(16.0)}}
[rank0]:[W811 11:30:14.788350826 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
著者のコメントによるとexact_accuracy
が正解率であるため,私の再現では正解率が48.2%だった。元の論文では下記の図の通りSudoku Extremeの正解率が55.0%と報告されているため,今回の結果はその成績を下回った。この点についてはHRM公式レポジトリでも疑義照会がされている。これに対し,論文著者たちは「1000件のトレーニングセットは小さいため、多少のばらつきが生じる可能性がある。トレーニング時間をやや長くし、過学習直前で早期終了するようにすればよい。(引用者注: 正解率の)標準偏差は約2%程度だった」旨説明がある。また,RTX5090を使って2回Sudoku Extremeの学習と評価を行った際に,2回とも55%程度の正解率だったというコメントもある。そのため,今回の正解率も,ある程度想定の範囲内と思われる。
参考: