JAXをNVIDIAのGPU環境で使えるようにする方法(debian 12)
備忘録とJAXの啓蒙活動。
- CUDA, cuDNN (NVIDIAドライバ)のインストール方法
- JAXのソースコードからのビルド
の2本立て。
Pytorchはほぼほぼ消えます。なぜならば、あらゆる種類の計算ハードウェア(CPU, GPU, TPU, FPGA, ASIC)に対して深層学習の計算グラフが記述される一方、その計算グラフモデルのパラメータは記録方法もキャッシュ方法もモデル・パラレリズムの観点でも、記録方法がハードウェアに束縛されてはならないからです。
御託を抜きにして使い方を説明します。
JAXはCPUで動かすのは簡単です。JAXをNVIDIAのGPUで動かすときには、Pytorchユーザーの皆さんの思っている、
- CUDAのバージョン
との格闘ではなく、
- CUDAのバージョン(nvcc, つまりCUDAコンパイラのバージョン)
- NVIDIAドライバのバージョン(上のnvccに対応したnvidia-smiが使えるようにする)
- cuDNNのバージョン(上のnvccに対応したcuDNNのバージョンを使えるようにする)
との格闘になります。残念ながらpipで配布されているバイナリは使い物にならないと心得てください。
すなわち、基本的にはあなたの環境にあった設定で、JAXをソースコードからコンパイルする、これが一番です。(私は配布されているバイナリをNVIDIAドライバ、CUDA、cuDNNのをさんざんバージョンを合わせたうえでpip installしたのですが、悉く失敗し、うまく行ったとしてもcuDNNのinitializationがfailしてしましました。あるいはnvidia-smi
がそもそも飛んでGPUが認識されないとか、お恥ずかしいながら)
前提知識
CUDAというのはNVIDIAのGPUの上でGeneral Purpose GPU(GPGPU)計算を行うためのフレームワークで、多くの場合はCUDA Cというプログラミング言語を使ってGPGPU計算を記述します。このコンパイルにはCUDA C用のクロスターゲット・コンパイラが必要なわけで、それがnvcc
です。PytorchやJAXのソースコードにはGPGPUの計算を利用するためのCUDA Cのコードが含まれており、自前でレポジトリをコンパイルするにはCUDAのコンパイラ(CUDA Toolkitでインストールできる)が必要です。
今使っているコンピュータにGPUが入っている場合、そのGPUが計算資源として使えるためにはNVIDIAドライバが入っている必要があります。そのドライバっていうのが入っているとnvidia-smi
というコマンドでどのGPUが入っているか確認できます。
cuDNNは「まじでよくわからん」です。要するにたぶん深層学習用のCUDAライブラリなのだと理解しているのですが、私自身はCUDAコーディングでこのcuDNNを使ったことが一切ないので、ほとほと分かりません。JAXのようなフレームワークをコンパイルするのにcuDNNのライブラリにコードが依存しているらしく、これをダウンロードしておかないとコンパイルできないので入れます。
各種コマンドお困りのときに使う
Debianのバージョンの確認
cat /etc/debian_version
CUDAのパッケージ(cuda-toolkit)がどれが入っているか確認する。
dpkg --list | grep cuda
cuDNNのバージョンの確認。
dpkg --list | grep cudnn
NVIDIAドライバ関連が失敗した場合に、特にnvidia-smi
が使えなくなります。そのような場合には
apt-get purge nvidia-*
で全て消します。
同じ要領でCUDA関連(CUDA Toolkit)を削除する場合は
apt-get purge cuda*
で全て消します。
同じ要領でcuDNNを削除する場合は
apt-get purge *cudnn*
で全て消します。
それでも残っていて、cuDNNのバージョンが気に食わない場合は
dpkg --purge [target package]
で消し飛ばしたいわけですが、
dpkg: warning: 'ldconfig' not found in PATH or not executable
dpkg: warning: 'start-stop-daemon' not found in PATH or not executable
dpkg: error: 2 expected programs not found in PATH or not executable
Note: root's PATH should usually contain /usr/local/sbin, /usr/sbin and /sbin
と出るので
export PATH=$PATH:/usr/local/sbin:/usr/sbin:/sbin
をしましょう。次にdpkg --purge [target package]
を行います。消したいパッケージはdpkg --list | grep cudnn
で調べて消します。
上記のコマンドを使っても、こびりついて残っているパッケージの亡霊を削除するには
apt-get autoremove
GPUドライバの設定をやっていると頻繁にドライバが壊れてバージョンが合わず、GUI画面が出せなくなります。古典的な計算機のログイン画面(login, userpasswordの組の奴)にするには
systemctl set-default multi-user.target
にするとドライバ処理をするときは安全です。グラフィカルに戻すときは
systemctl set-default graphical.target
でデスクトップ環境に戻ります。
開発環境
私の2024/11月現在の自宅の開発環境は次の通り。CPUとメモリは多分関係ない。自宅のコンピュータなので、単純にATXマザーボードで筐体に入れたコンピュータです。SSHだけは有効になっている気がする。
Debian 12.7
CPU: Xeon 24 threads
Memory: 64 GB
GPU: NVIDIA RTX 3060 (12 GB)
1)NVIDIAのドライバーをインストール、2)CUDA Toolkitをインストール、3)cuDNNをインストールとするのが普通なのでしょう。というのも、CUDA ToolkitはNVIDIAのドライバのバージョンに依存しているからです。だけど、これまでPytorchの環境を作ってきた経験上、CUDA Toolkitに自動でNVIDIAドライバが入っている傾向が強く、悉く2)CUDA Toolkitを入れる段階で、新規のNVIDIAドライバが入って計算機設定がぶっこわれます。だからもう私は、「先にCUDA Toolkit」をいれて、後からそのCUDAに対応できるよりも最新のNVIDIAドライバを入れます。
つまり、インストールは次の順で行います。
- CUDA Toolkit (
nvcc
がこれで使えるようになる) - NVIDIA driver (
nvidia-smi
が使えるようになる) - cuDNN (
dpkg --list | grep
で見えるようになる)
GPU関連クリーンインストール前の状態
何もドライバやCUDAが入っていない状態は次の通り。
root@domain:directory# dpkg --list | grep nvidia
root@domain:directory# dpkg --list | grep cuda
root@domain:directory# dpkg --list | grep cudnn
root@domain:directory#
CUDA Toolkitのインストール
使いたいCUDA Toolkitのバージョンをダウンロードする。CUDAのバージョンというのはGPUのCompute Capabilityというのに対応したものを利用する必要があります。
【参考文献】CUDA GPUs - Compute Capability
ここを読んで、私の場合はRTX 3060
を調べてCompute Capability
が8.6
と分かります。
さらに、使用するcuDNNのバージョンについては次のSupport Matrixを見ます。この中からCompute Capability
が8.6
と使いたいcuDNNのバージョン(今回は9.5.1
)との対応関係を見ます。ここでcuDNNの使いたいバージョンがCUDA Compute Capabilityを満たしていないとダメです。cuDNN 9.5.1
はCompute Capability
が8.6
のGPUをサポートしているので使えます。
【参考文献】Support Matrix — NVIDIA cuDNN
画像左下のところで、使いたいcuDNNのバージョンを選ぶことができるのに注意してください。ここからCUDA Toolkitとしては12.0
から12.6
までを使うことができることが分かります。
今回はCUDAは12.4
を使います。(Pytorchに単純にCUDA 12.4
のバイナリがありそうなので採用しただけ。JAXにとっては関係ない)
【参考文献】CUDA Toolkit 12.4 Downloads
2024/11月現在では
wget https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda-repo-debian12-12-4-local_12.4.0-550.54.14-1_amd64.deb
sudo dpkg -i cuda-repo-debian12-12-4-local_12.4.0-550.54.14-1_amd64.deb
sudo cp /var/cuda-repo-debian12-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/
sudo add-apt-repository contrib
sudo apt-get update
sudo apt-get -y install cuda-toolkit-12-4
でダウンロードできると記載されているので、このコマンドを粛々と実行します。(これはエラー出ないんだよ、うちの環境では……)ドライバについては、公式サイトの上記の続きに書いてある
apt-get install -y cuda-drivers
を実行すれば入ります。
ここで初心者殺しフェーズありです。このインストールをしても、NVCC(CUDA Compiler、CCっていうのはCUDA Compilerのことだと思われます)が動かんのだ!というのもインストールされる先のbin
がノーマルのパスに入っていないのです。そして、nvcc --version
でバージョンを確認しましょう。
you@domain:directory$ export PATH="$PATH:/usr/local/cuda-12.4/bin"
you@domain:directory$ nvcc # これは要するにコンパイラなので、コンパイルするファイル無しではエラー吐く
nvcc fatal : No input files specified; use option --help for more information
you@domain:directory$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0
you@domain:directory$
.bashrc
の末尾にCUDAへのbin
を追加するのがオススメです。
you@domain:directory$ tail .bashrc
if ! shopt -oq posix; then
if [ -f /usr/share/bash-completion/bash_completion ]; then
. /usr/share/bash-completion/bash_completion
elif [ -f /etc/bash_completion ]; then
. /etc/bash_completion
fi
fi
# add cuda bin
export PATH="$PATH:/usr/local/cuda-12.4/bin"
you@domain:directory$
NVIDIAドライバのインストール
たいがい、このCUDA Toolkit入れた時点でNVIDIAドライバは適切なバージョンが入ります。なお、私のDebianの場合は、クリーンインストールした直後に上記のCUDAを入れてもnvidia-smi
が動かないので、一度
sudo systemctl reboot
でリブートする必要がありました。その後であれば、次のコマンドが通ります。
you@domain:directory$ nvidia-smi
Thu Nov 14 07:18:18 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05 Driver Version: 550.127.05 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 Off | N/A |
| 33% 34C P0 N/A / 170W | 1MiB / 12288MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
you@domain:directory$
この場面で、もしもお使いのGPU環境でnvidia-smi
がエラーをはいている場合は、多くの場合経験上は「nvidia-smiのバージョン」と「ドライバのバージョン」が一致していないからです。その結果nvidia-smi
コマンドが使えなくなります。
そのようなトラブルが起こった場合は、次のページに飛んで、CUDA Toolkitとそのツールキットが最低限サポートしているNVIDIAドライバ『以上』のドライバを再インストールする必要があります。
【参考文献】NVIDIA CUDA Toolkit Release Notes
上の表から対応するNVIDIAドライバのバージョンを見つけたら、次の「NVIDA 公式ドライバー」サイトから、お使いのGPUのドライバをインストールします。ここまでで、「GPU」に「Compute Capability
」が依存していて、「Compute Capability
」に使える「CUDA」が依存していて、「CUDA」に「NVIDIAドライバ」は依存しています。
【参考文献】NVIDIA 公式ドライバーのダウンロード | NVIDIA
多くの場合は、お使いのGPUの「最新バージョンのドライバ」を入れればnvidia-smi
が復活することが期待されますが、
- CUDAの必要要件とする最低バージョンのドライバ
- cuDNNの必要要件とする最低バージョンのドライバ
ここからダウンロードしたファイルはNVIDIA-Linux-x86_64-550.127.05.run
のようなファイル名です。これをchmod +x
で実行可能状態として、実行することでドライバを「欲しいバージョン」に変えられます。
root@domain:directory$ chmod +x NVIDIA-Linux-x86_64-550.127.05.run
root@domain:directory$ export PATH=$PATH:/usr/local/sbin:/usr/sbin:/sbin # これがないとinitramfsのbuild/rebuildが失敗します
root@domain:directory$ ./NVIDIA-Linux-x86_64-550.127.05.run
再三ですが、基本的にDebian, Ubuntu系では「CUDA Toolkit」の時点で基本的には必要なドライバが入るので、CUDA→NVIDAドライバの順で入れないと、NVIDIAドライバのみを取り換えるのは難しい印象です。
「最新バージョンのドライバ」が動かない場合は、NVIDIAドライバのArchiveというのを探してきて、nvidia-smi
の言っているエラーに従って、ドライバのバージョンをnvidia-smi
のバージョンに合わせてください。
【参考文献】Linux AMD64 Display Driver Archive
cuDNNのインストール
cuDNNも同じ要領で最新版を探してきて(検索エンジンでもAIエンジンでもよくて)、それをダウンロード・インストールします。
【参考文献】cuDNN 9.5.1 Downloads
インストール・ページでは、次のようにLinuxのコマンドが並んでいるので、粛々とファイルをダウンロードして、インストールします。ただし、現状のcuDNNのコマンドはdebianで使うにはcp
コマンドが壊れているので使えません。修正版を示しているのでそちらを使ってください。(ディレクトリ名が違うだけですが)
wget https://developer.download.nvidia.com/compute/cudnn/9.5.1/local_installers/cudnn-local-repo-debian12-9.5.1_1.0-1_amd64.deb
sudo dpkg -i cudnn-local-repo-debian12-9.5.1_1.0-1_amd64.deb
# sudo cp /var/cuda-repo-debian12-9-5-local/cudnn-*-keyring.gpg # 公式レポジトリのこのコマンドはdebianでは動かず、次のコマンドにします
sudo cp /var/cudnn-local-repo-debian12-9.5.1/cudnn-*-keyring.gpg /usr/share/keyrings/
sudo apt-get update
sudo apt-get -y install cudnn
もしも、特定のメジャーバージョンのCUDAの違いに対応する必要がある場合(11
系と12
系など)の場合は、次のコマンドでインストールし分けます。
sudo apt-get -y install cudnn-cuda-11 # for cuda 11
sudo apt-get -y install cudnn-cuda-12 # for cuda 12
JAXのコンパイル
このページにたどり着いている人は、大方、普通のpip install
でGPU用のJAXが入らない人だと思われます。そういう皆さんのために、どのようなGPU環境でもとりあえずJAXをインストールしてみせるというのが本編になります。
JAXのソースコードからのビルドは次がオフィシャルサイトです。
【参考文献】Building from source - JAX documentation
さて、JAXのビルドに必要なのは
- CUDA (あるいはCUDA Toolkitと言ってもいい)
- cuDNN (Deep Neural Network用のライブラリ?)
- NCCL (GPUクラスターを使うために必要なライブラリ、らしい)
- 他のコンパイラ群(build-essentialで十分な気がするので
apt-get install build-essentials
しておいてくださいと思ったけど、実はJAXはCコンパイラはclangがデフォルトなので、apt-get install clang
をしておきます)
です。JAXのコンパイル・ビルドにはclang(Cコンパイラの一種、GCC:GNU C Compilerと双璧を成す最近のコンパイラ)がデフォルトで利用されるので。
apt-get install build-essential
apt-get install clang
をしておきましょう。(正直この辺りは試行錯誤しているうちにインストールしてしまったので忘れました)
個人のPCで(単一)GPUをターゲットにJAXをコンパイルする目的ではCUDAとcuDNNがあれば十分です。NCCLはGPUクラスタを使う場合に必要になりますが、その環境構築をするだけのセッティングが手元にないので割愛します。
(オフィシャルサイトにはNCCLを使ったGPUクラスタ向けのビルド方法も記載されています。例えば、
python build/build.py --enable_cuda \
--bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \
--bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \
--bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl"
これがそうです。)
それではJAXのビルド環境を作っていきましょう。 ディレクトリを作成して、git
でソースコードをダウンロードしてきます。git
が入っていない場合は
apt-get install git
で入れてください。この上で、次のコマンドを動かします。
you@domain:directory$ mkdir jax
you@domain:directory$ cd jax
you@domain:directory/jax$ python3 -m venv .venv
you@domain:directory/jax$ . .venv/bin/activate # source .venv/bin/activateに同じ
(.venv) you@domain:directory/jax$ git clone https://github.com/jax-ml/jax
Cloning into 'jax'...
remote: Enumerating objects: 159803, done.
remote: Counting objects: 100% (290/290), done.
remote: Compressing objects: 100% (162/162), done.
remote: Total 159803 (delta 163), reused 218 (delta 128), pack-reused 159513 (from 1)
Receiving objects: 100% (159803/159803), 98.28 MiB | 10.31 MiB/s, done.
Resolving deltas: 100% (126702/126702), done.
(.venv) you@domain:directory/jax$ ls
jax
(.venv) you@domain:directory/jax$ cd jax
(.venv) you@domain:directory/jax/jax$ ls
AUTHORS CHANGELOG.md cloud_tpu_colabs docs jax LICENSE README.md third_party
benchmarks ci conftest.py examples jaxlib platform_mappings setup.py WORKSPACE
build CITATION.bib CONTRIBUTING.md images jax_plugins pyproject.toml tests
ここまでくるといよいよJAXをビルドすることができるのですが、結論から言うと
(.venv) you@domain:directory/jax/jax$ python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=[CUDA VERSION] --cudnn_version=[CUDNN VERSION]
というコマンドを実行する必要があります。--cuda_version=[CUDA VERSION]
と--cudnn_version=[CUDNN VERSION]
に何を書けばいいか分からないのですが、これはすごい適当な文字列を入れると、どのようなバージョンが許されているか分かります。試しに
(.venv) you@domain:directory/jax/jax$ python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=0 --cudnn_version=0
_ _ __ __
| | / \ \ \/ /
_ | |/ _ \ \ /
| |_| / ___ \/ \
\___/_/ \/_/\_\
Downloading bazel
### former logs ###
ERROR: Error computing the main repository mapping: no such package '@cuda_redist_json//': The supported CUDA versions are ["11.8", "12.1.1", "12.2.0", "12.3.1", "12.3.2", "12.4.0", "12.4.1", "12.5.0", "12.5.1", "12.6.0", "12.6.1", "12.6.2"]. Please provide a supported version in HERMETIC_CUDA_VERSION environment variable or add JSON URL for CUDA version=0.
### latter logs ###
としてみると、CUDAのバージョンとして"11.8"
から"12.6.2"
までが使えることが分かります。
そこで--cuda_version=12.4.0
として実行します。
(.venv) you@domain:directory/jax/jax$ python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=12.4.0 --cudnn_version=0
_ _ __ __
| | / \ \ \/ /
_ | |/ _ \ \ /
| |_| / ___ \/ \
\___/_/ \/_/\_\
Downloading bazel
### former logs ###
Error in fail: The supported CUDNN versions are ["8.6", "8.9.4.25", "8.9.6", "8.9.7.29", "9.1.1", "9.2.0", "9.2.1", "9.3.0", "9.4.0", "9.5.0"]. Please provide a supported version in HERMETIC_CUDNN_VERSION environment variable or add JSON URL for CUDNN version=0.
### latter logs ###
ここからcuDNNのバージョンは"8.6"
から"9.5.0"
までが使えます。
以上から
- CUDA
12.4
- cuDNN
9.5.1
においては、次のパラメータによるビルド・コマンドで、JAXをビルドすることができることが分かります。
python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=12.4.0 --cudnn_version=9.5.0
このCUDAのバージョンやcuDNNのバージョンは正直なところよく互換性が分かっていないのですが、Pytorchみたいにぎっちぎちにあっていなくても動く印象があります。というのも、ビルドした環境はCUDAが
you@domain:directory$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0
you@domain:directory$
で、cuDNNが
you@domain:directory$ dpkg --list | grep cudnn
ii cudnn 9.5.1-1 amd64 NVIDIA CUDA Deep Neural Network library (cuDNN)
ii cudnn-local-repo-debian12-9.5.1 1.0-1 amd64 cudnn-local repository configuration files
ii cudnn9 9.5.1-1 amd64 NVIDIA CUDA Deep Neural Network library (cuDNN)
ii cudnn9-cuda-12 9.5.1.17-1 amd64 NVIDIA cuDNN for CUDA 12
ii cudnn9-cuda-12-6 9.5.1.17-1 amd64 NVIDIA cuDNN for CUDA 12.6
ii libcudnn9-cuda-12 9.5.1.17-1 amd64 cuDNN runtime libraries for CUDA 12.6
ii libcudnn9-dev-cuda-12 9.5.1.17-1 amd64 cuDNN development headers and symlinks for CUDA 12.6
ii libcudnn9-samples 9.5.1.17-1 all cuDNN samples
ii libcudnn9-static-cuda-12 9.5.1.17-1 amd64 cuDNN static libraries for CUDA 12.6
you@domain:directory$
というバージョンですが、ビルド時のパラメータを勘案すると
- CUDA:ビルド時
12.4
< 環境12.4.r12.4
- cuDNN:ビルド時
9.5.0
< 環境9.5.1
で動いてしまっているので、たぶん、インストールされているCUDA, cuDNNのバージョンのメジャー・マイナーのバージョンくらいまでが合っていれば大丈夫そうなのです。ただし、バージョン番号は[メジャー].[マイナー](.[マイナーマイナー]以降続く)
だと考えていっています。
さて、コンパイルには大変に時間がかかるので、閑話休題。このコンパイル成功すると思います? これが癪なことに時期に依るんですよ。私が最初にJAXをコンパイルしたときには
jaxlib/gpu/solver_kernels_ffi.cc:918:29: error: missing 'typename' prior to dependent type name 'solver::RealType<T>::value'
というのが長い長いエラーログの末に出ていて、もはや読めもしないのでログをChatGPTに食わせたのですよ。そしたらsolver_kernels_ffi.cc
っていうファイルの
auto s_data = static_cast<solver::RealType<T>::value*>(s->untyped_data());
これを
auto s_data = static_cast<typename solver::RealType<T>::value*>(s->untyped_data());
これに書き換える必要がありました。(しかもこれはC++
のコードなので、私は久々に書いたらコメントアウトさえ間違えて、元のコードを#
で消して、新規のこのコードを入れようとしたわけです。そしたら//
がコメントアウトであると文句を言われるっていうね……。しくしく)
何が言いたいかというと、JAXのgithubにあるコードがいつでもコンパイルできると思うなかれ! そして動かない時には、AI(ChatGPTでもClaudeでもGeminiでもいいです)を使ってファイルを修正する必要があるということです。
そうこう言っているうちに、2週間前はファイルの修正が必要だったのに、今度は修正せずにコンパイル・ビルドが通過。なんなんだよ、このクソは。私がこれを記事にしなくちゃいけないと思ったほとんどはCUDA, cuDNNのインストールじゃなくて、この「コード書き換え」さえ必要っていうところだったのにコンパイル通過してんじゃないかよ!! という愚痴はさておき。
(.venv) you@domain:directory/jax/jax$ python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=12.4.0 --cudnn_version=9.5.0
_ _ __ __
| | / \ \ \/ /
_ | |/ _ \ \ /
| |_| / ___ \/ \
\___/_/ \/_/\_\
Downloading bazel
### bunch of logs ###
[11,431 / 11,449] Compiling llvm/lib/Target/AMDGPU/SIISelLowering.cpp; 12s local ... (9 actions, 8 running)
[11,434 / 11,449] Compiling llvm/lib/Target/AMDGPU/SIISelLowering.cpp; 13s local ... (7 actions running)
[11,435 / 11,449] Compiling llvm/lib/Target/AMDGPU/SIISelLowering.cpp; 15s local ... (7 actions, 6 running)
[11,439 / 11,449] Compiling llvm/lib/Target/AMDGPU/SIISelLowering.cpp; 17s local ... (5 actions, 4 running)
[11,444 / 11,449] Compiling xla/service/gpu/gpu_compiler.cc; 16s local ... (2 actions running)
[11,446 / 11,449] Compiling xla/service/gpu/gpu_compiler.cc; 18s local
[11,447 / 11,449] [Prepa] Linking external/xla/xla/service/gpu/libgpu_compiler.pic.a
[11,448 / 11,449] Linking jaxlib/tools/pjrt_c_api_gpu_plugin.so; 1s local
Target //jaxlib/tools:build_gpu_plugin_wheel up-to-date:
bazel-bin/jaxlib/tools/build_gpu_plugin_wheel
INFO: Elapsed time: 563.475s, Critical Path: 416.37s
INFO: 2008 processes: 450 internal, 1558 local.
INFO: Build completed successfully, 2008 total actions
INFO: Running command line: bazel-bin/jaxlib/tools/build_gpu_plugin_wheel '--output_path=/home/you/jax/jax/dist' '--jaxlib_git_hash=1471702adc286bcf40e87c42877d538b4d589f90' '--cpu=x86_64' '--enable-cuda=True' '--platform_version=12'
Output wheel: /home/you/jax/jax/dist/jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl
To install the newly-built jax cuda plugin wheel on system Python, run:
pip install /home/you/jax/jax/dist/jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl --force-reinstall
To install the newly-built jax cuda plugin wheel on hermetic Python, run:
echo -e "\n/home/you/jax/jax/dist/jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl" >> build/requirements.in
bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.11
となってビルド完了です。一応、このコンパイルした環境を抜けるためには
(.venv) you@domain:directory/jax/jax$ deactivate
you@domain:directory/jax/jax$
で.venv
の環境から抜けることができるのでご安心を。
ビルドしたJAXのインストール方法
せっかくビルドしたJAXをあなたのpython venv
環境で使えるようにするためには、再度公式のページに行きましょう。
【参考文献】Building from source - JAX documentation
次のが公式サイトでの「一般にGPU, TPUなどでのビルドとインストール方法」です。
python build/build.py
pip install dist/*.whl # installs jaxlib (includes XLA)
私はですね、正直Pythonのことなんて何一つわからないんですけど、これでは動かなかったということだけは分かったんですよ。どうやら、JAXをPython環境に入れるときには
jax
jaxlib
という二つが必要なようで、さきほどまでビルドしていたのはCUDA用のjax
とjaxlib
なんです。そのため、pip install dist/*.whl
をしたあとにimport jax.numpy
としても通常は初期は「CPU」でJAXを動かそうとするため、CPU版のJAXが入っていなくてjax
にnumpy
なんてモジュールないぞ?って言われました。
そこで、ここでは
- CPU版
- GPU版(ここまでビルドしてきたもの)
を両方使える環境を提供する方法を示します。
まずは安直にCPU版のJAXをインストールします。でも古典的なpython setup.py install
の方法ですからね!?
you@domain:directory$ mkdir test
you@domain:directory$ cd test
you@domain:directory/test$ python3 -m venv .venv
you@domain:directory/test$ . .venv/bin/activate
(.venv) you@domain:directory/test$ cp -r [jaxのビルドしたディレクトリ] .
(.venv) you@domain:directory/test$ ls
jax
(.venv) you@domain:directory/test$ cd jax
(.venv) you@domain:directory/test/jax$ python setup.py install
### bunch of logs ###
Best match: jaxlib 0.4.35
Processing jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl
Installing jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl to /home/you/test/.venv/lib/python3.11/site-packages
Adding jaxlib 0.4.35 to easy-install.pth file
Installed /home/you/test/.venv/lib/python3.11/site-packages/jaxlib-0.4.35-py3.11-linux-x86_64.egg
Finished processing dependencies for jax==0.4.36.dev20241115+1471702ad
(.venv) you@domain:directory/test/jax$
こうして
-
jax
(ビルドしたGPU版jaxlib
と対応したバージョンのjax
) -
jaxlib
(ビルドしたGPU版jaxlib
と対応したCPU用バージョンのjaxlib
)
がインストールされます。
(.venv) you@domain:directory/test$ pip list | grep jax
jax 0.4.36.dev20241115+1471702ad
jax 0.4.36.dev20241115+1471702ad
jaxlib 0.4.35
jaxlib 0.4.35
(.venv) you@domain:directory/test$
ここで、気づいてください、jaxlib
が自前ビルド版になっていません!。要するにこのプロセスでインストールされるのはCPU版なんです……。
ここまできて、jax/dist/*.whl
をインストールするとこれまで作ってきたGPU版のビルドにjaxlib
が置き換わって、晴れて
- CPU版
- GPU版(ここまでビルドしてきたもの)
が使えるようになります。
さあ、最後に公式ドキュメント通りにpip install dist/*.whl
によってGPU版のJAXをインストールします。
(.venv) you@domain:directory/test$ ls
jax
(.venv) you@domain:directory/test$ ls jax/dist/
jax-0.4.36.dev20241115+1471702ad-py3.11.egg
jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl
jax_cuda12_plugin-0.4.36.dev20241115-cp311-cp311-manylinux2014_x86_64.whl
jaxlib-0.4.36.dev20241115-cp311-cp311-manylinux2014_x86_64.whl
(.venv) you@domain:directory/test$ pip install jax/dist/*.whl
Processing ./jax/dist/jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl
Processing ./jax/dist/jax_cuda12_plugin-0.4.36.dev20241115-cp311-cp311-manylinux2014_x86_64.whl
Processing ./jax/dist/jaxlib-0.4.36.dev20241115-cp311-cp311-manylinux2014_x86_64.whl
Requirement already satisfied: scipy>=1.10 in ./.venv/lib/python3.11/site-packages/scipy-1.14.1-py3.11-linux-x86_64.egg (from jaxlib==0.4.36.dev20241115) (1.14.1)
Requirement already satisfied: numpy>=1.24 in ./.venv/lib/python3.11/site-packages/numpy-2.1.3-py3.11-linux-x86_64.egg (from jaxlib==0.4.36.dev20241115) (2.1.3)
Requirement already satisfied: ml-dtypes>=0.2.0 in ./.venv/lib/python3.11/site-packages/ml_dtypes-0.5.0-py3.11-linux-x86_64.egg (from jaxlib==0.4.36.dev20241115) (0.5.0)
Installing collected packages: jax-cuda12-pjrt, jax-cuda12-plugin, jaxlib
Attempting uninstall: jaxlib
Found existing installation: jaxlib 0.4.35
Uninstalling jaxlib-0.4.35:
Successfully uninstalled jaxlib-0.4.35
Successfully installed jax-cuda12-pjrt-0.4.36.dev20241115 jax-cuda12-plugin-0.4.36.dev20241115 jaxlib-0.4.36.dev20241115
(.venv) you@domain:directory/test$
それでは、本当にGPU版のJAXもインストールされているか確認してみましょう。
(.venv) you@domain:directory/test$ pip list | grep jax
jax 0.4.36.dev20241115+1471702ad
jax 0.4.36.dev20241115+1471702ad
jax-cuda12-pjrt 0.4.36.dev20241115
jax-cuda12-plugin 0.4.36.dev20241115
jaxlib 0.4.36.dev20241115
(.venv) you@domain:directory/test$
パチパチ!完成です。
ここまで長い道のりでしたが、後学のため失敗例を示します。
pip install jax
pip install jax/dist/*.whl
これは最初のjaxでデフォルトのディストリビューションのjax
が入り、そのjax
のバージョンと、せっかくここまでビルドしてきたjaxlib
のバージョンが合わないために失敗します。(もちろん運が良ければ、配布のjax
とコンパイルしてきたjaxlib
のバージョンが一致するため、エラーなしで進みます)
実際にJAXコードを動かしてみる
次のコードは
- 現在の計算環境がCPUかGPUかを表示し
- 行列計算を行った結果を出力する
JAXの簡単なコードです。
import jax
import jax.numpy as jnp
# 使用デバイスの確認
device = jax.devices()[0]
print(f"Using device: {device.device_kind}")
# 行列の定義
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
# 行列の積
C = jnp.dot(A, B)
# 計算を強制実行して結果を取得
C_result = jax.device_get(C)
# 計算結果を表示
print("Matrix A:")
print(A)
print("Matrix B:")
print(B)
print("Result of A * B:")
print(C_result)
CPU実行
JAXのコードをCPUで実行するときは、デフォルトのまま何もしていないか、
export JAX_PLATFORM_NAME=cpu
によって実行環境をCPUにする必要があります。
(.venv) you@domain:directory/test$ export JAX_PLATFORM_NAME=cpu
(.venv) you@domain:directory/test$ python t.py
Using device: cpu
Matrix A:
[[1 2]
[3 4]]
Matrix B:
[[5 6]
[7 8]]
Result of A * B:
[[19 22]
[43 50]]
(.venv) you@domain:directory/test$
GPU実行
JAXのコードをGPUで実行するときは
export JAX_PLATFORM_NAME=gpu
によって実行環境をGPUにする必要があります。一方で、コードには変更を施す必要はありません。(もちろんコード側でCPUやGPU実行を変更することもできます)
(.venv) you@domain:directory/test$ export JAX_PLATFORM_NAME=gpu
(.venv) you@domain:directory/test$ python t.py
Using device: NVIDIA GeForce RTX 3060
Matrix A:
[[1 2]
[3 4]]
Matrix B:
[[5 6]
[7 8]]
Result of A * B:
[[19 22]
[43 50]]
(.venv) you@domain:directory/test$
とても長いチュートリアルでしたが、これで「JAXをNVIDIAのGPU環境で使えるようにする方法」は終了になります。(あとは御託)
啓蒙活動
オブジェクト指向というのは「本質的にすべてのオブジェクト・インスタンスが万能計算機(あるいはそのカスタム品)」であって、それ同士のコミュニケーションをメッセージ・パッシングで実現しているわけです。一方で、関数型言語というのは計算を計算資源に分解しており、式と項を分離させて処理するわけです。ここで、「式=モデル」、「項=パラメータ」となります。万能計算性を一つの計算機リソースで実装できない、巨大深層学習モデルでは関数型の指向で実装する以外ないのです。オブジェクト・インスタンスでは、根本的に「ちゃんと動く(ネットワークなどの不安定性も排して!)万能計算性」を保証できない巨大深層学習モデルを記述して、その学習・運用に耐えられないのです。理屈はこれだけ。
だから明日からは深層学習のモデルはスケールさせる目的で全て関数型の計算フレームワークで記述してください。
PyTorch is dead. Long live JAX.
Pytorchは今後十年をかけて衰退していきます。それに対して機械学習・深層学習の界隈ではXLA形式で
- モデル
- パラメータ
の定義を分離できる「関数型言語」のスタイルを採用したモデルを利用する機会が増加していきます。その実装の一つがJAXです。
【参考文献】PyTorch is dead. Long live JAX.
JAXはGoogle Deepmindの開発したプラットフォームで、現在多くの機械学習プラットフォームが依拠しているPytorch(Meta, 旧Facebook)性のシステムに対して近年存在感を増しています。
思い返せば、2012に「グーグルの猫」を皮切りに、私が覚えている限り、日本のコミュニティで深層学習のフレームワークとしては
- Tensorflow: 後にKerasが出てver1の問題点が改善したver2が主流となる
- Pytorch: Facebookが出してその後60%以上のシェアを持つアカデミアや開発研究の主流に
- Chainer: 2020年に更新が終了した日本のPrefered系のフレームワーク
- JAX: Google Deepmind制の数値計算一般に使える計算フレームワーク
が台頭しては消えてを繰り返しています。なお、昔話はこちらで。
【参考文献】2012年にAIの歴史が動いた!ついに猫認識に成功した「Googleの猫」
要するに栄枯盛衰を繰り返しているのですが、私見ではTensorflow ver1とJAXが同じような計算グラフのコンパイルを基盤に持っていて、Tensorflow ver2 (with Keras)とPytorchとChainerが計算グラフの実行時構成というか実行時インタープリターという方法で動くわけです。
計算資源が多様化するにしたがって、GPUでもCPUでもTPUでもクラスター化して、それぞれの計算資源で計算グラフを組んだうえで、それぞれのハードウェアに適した形でモデルのパラメータを配置しなくてはいけなくなりました。それでいうと、計算グラフはそれぞれのハードウェアに対してコンパイルしておいて、パラメータや入力/出力の教師データ・ペアは他の方法で与えるほうが都合がいいわけです。要するに、あるクラスターではGPUが主役を、あるクラスターではCPUが主役を、あるクラスターではTPUやFPGAやASICが主役の計算機構であるわけで、それにパラメータと教師データを流し込んで、パラメータ更新値を得る必要があるのです。すると、自ずと、再度Tensorflow ver1, JAXの世界に戻ってくる必要があるんですわ。と僕は啓蒙されました。