2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

JAXをNVIDIAのGPU環境で使えるようにする方法(debian 12)

Last updated at Posted at 2024-11-15

JAXをNVIDIAのGPU環境で使えるようにする方法(debian 12)

 備忘録とJAXの啓蒙活動。

  • CUDA, cuDNN (NVIDIAドライバ)のインストール方法
  • JAXのソースコードからのビルド

の2本立て。

 Pytorchはほぼほぼ消えます。なぜならば、あらゆる種類の計算ハードウェア(CPU, GPU, TPU, FPGA, ASIC)に対して深層学習の計算グラフが記述される一方、その計算グラフモデルのパラメータは記録方法もキャッシュ方法もモデル・パラレリズムの観点でも、記録方法がハードウェアに束縛されてはならないからです。

御託を抜きにして使い方を説明します。

 JAXはCPUで動かすのは簡単です。JAXをNVIDIAのGPUで動かすときには、Pytorchユーザーの皆さんの思っている、

  • CUDAのバージョン

との格闘ではなく、

  • CUDAのバージョン(nvcc, つまりCUDAコンパイラのバージョン)
  • NVIDIAドライバのバージョン(上のnvccに対応したnvidia-smiが使えるようにする)
  • cuDNNのバージョン(上のnvccに対応したcuDNNのバージョンを使えるようにする)

との格闘になります。残念ながらpipで配布されているバイナリは使い物にならないと心得てください。
 すなわち、基本的にはあなたの環境にあった設定で、JAXをソースコードからコンパイルする、これが一番です。(私は配布されているバイナリをNVIDIAドライバ、CUDA、cuDNNのをさんざんバージョンを合わせたうえでpip installしたのですが、悉く失敗し、うまく行ったとしてもcuDNNのinitializationがfailしてしましました。あるいはnvidia-smiがそもそも飛んでGPUが認識されないとか、お恥ずかしいながら)

前提知識

 CUDAというのはNVIDIAのGPUの上でGeneral Purpose GPU(GPGPU)計算を行うためのフレームワークで、多くの場合はCUDA Cというプログラミング言語を使ってGPGPU計算を記述します。このコンパイルにはCUDA C用のクロスターゲット・コンパイラが必要なわけで、それがnvccです。PytorchやJAXのソースコードにはGPGPUの計算を利用するためのCUDA Cのコードが含まれており、自前でレポジトリをコンパイルするにはCUDAのコンパイラ(CUDA Toolkitでインストールできる)が必要です。
 今使っているコンピュータにGPUが入っている場合、そのGPUが計算資源として使えるためにはNVIDIAドライバが入っている必要があります。そのドライバっていうのが入っているとnvidia-smiというコマンドでどのGPUが入っているか確認できます。
 cuDNNは「まじでよくわからん」です。要するにたぶん深層学習用のCUDAライブラリなのだと理解しているのですが、私自身はCUDAコーディングでこのcuDNNを使ったことが一切ないので、ほとほと分かりません。JAXのようなフレームワークをコンパイルするのにcuDNNのライブラリにコードが依存しているらしく、これをダウンロードしておかないとコンパイルできないので入れます。

各種コマンドお困りのときに使う

 Debianのバージョンの確認

cat /etc/debian_version

 CUDAのパッケージ(cuda-toolkit)がどれが入っているか確認する。

dpkg --list | grep cuda

 cuDNNのバージョンの確認。

dpkg --list | grep cudnn

 NVIDIAドライバ関連が失敗した場合に、特にnvidia-smiが使えなくなります。そのような場合には

apt-get purge nvidia-*

で全て消します。
 同じ要領でCUDA関連(CUDA Toolkit)を削除する場合は

apt-get purge cuda*

で全て消します。
 同じ要領でcuDNNを削除する場合は

apt-get purge *cudnn*

で全て消します。

 それでも残っていて、cuDNNのバージョンが気に食わない場合

dpkg --purge [target package]

で消し飛ばしたいわけですが、

dpkg: warning: 'ldconfig' not found in PATH or not executable
dpkg: warning: 'start-stop-daemon' not found in PATH or not executable
dpkg: error: 2 expected programs not found in PATH or not executable
Note: root's PATH should usually contain /usr/local/sbin, /usr/sbin and /sbin

と出るので

export PATH=$PATH:/usr/local/sbin:/usr/sbin:/sbin

をしましょう。次にdpkg --purge [target package]を行います。消したいパッケージはdpkg --list | grep cudnnで調べて消します。

 上記のコマンドを使っても、こびりついて残っているパッケージの亡霊を削除するには

apt-get autoremove

 GPUドライバの設定をやっていると頻繁にドライバが壊れてバージョンが合わず、GUI画面が出せなくなります。古典的な計算機のログイン画面(login, userpasswordの組の奴)にするには

systemctl set-default multi-user.target

にするとドライバ処理をするときは安全です。グラフィカルに戻すときは

systemctl set-default graphical.target

でデスクトップ環境に戻ります。

開発環境

 私の2024/11月現在の自宅の開発環境は次の通り。CPUとメモリは多分関係ない。自宅のコンピュータなので、単純にATXマザーボードで筐体に入れたコンピュータです。SSHだけは有効になっている気がする。

Debian 12.7
CPU: Xeon 24 threads
Memory: 64 GB
GPU: NVIDIA RTX 3060 (12 GB)

 1)NVIDIAのドライバーをインストール、2)CUDA Toolkitをインストール、3)cuDNNをインストールとするのが普通なのでしょう。というのも、CUDA ToolkitはNVIDIAのドライバのバージョンに依存しているからです。だけど、これまでPytorchの環境を作ってきた経験上、CUDA Toolkitに自動でNVIDIAドライバが入っている傾向が強く、悉く2)CUDA Toolkitを入れる段階で、新規のNVIDIAドライバが入って計算機設定がぶっこわれます。だからもう私は、「先にCUDA Toolkit」をいれて、後からそのCUDAに対応できるよりも最新のNVIDIAドライバを入れます。
 つまり、インストールは次の順で行います。

  • CUDA Toolkit (nvccがこれで使えるようになる)
  • NVIDIA driver (nvidia-smiが使えるようになる)
  • cuDNN (dpkg --list | grepで見えるようになる)

GPU関連クリーンインストール前の状態

 何もドライバやCUDAが入っていない状態は次の通り。

root@domain:directory# dpkg --list | grep nvidia
root@domain:directory# dpkg --list | grep cuda
root@domain:directory# dpkg --list | grep cudnn
root@domain:directory# 

CUDA Toolkitのインストール

 使いたいCUDA Toolkitのバージョンをダウンロードする。CUDAのバージョンというのはGPUのCompute Capabilityというのに対応したものを利用する必要があります。

CUDAのCompute Capabilityの一覧を確認する
【参考文献】CUDA GPUs - Compute Capability

ここを読んで、私の場合はRTX 3060を調べてCompute Capability8.6と分かります。

 さらに、使用するcuDNNのバージョンについては次のSupport Matrixを見ます。この中からCompute Capability8.6と使いたいcuDNNのバージョン(今回は9.5.1)との対応関係を見ます。ここでcuDNNの使いたいバージョンがCUDA Compute Capabilityを満たしていないとダメです。cuDNN 9.5.1Compute Capability8.6のGPUをサポートしているので使えます。

cuDNNと
【参考文献】Support Matrix — NVIDIA cuDNN

画像左下のところで、使いたいcuDNNのバージョンを選ぶことができるのに注意してください。ここからCUDA Toolkitとしては12.0から12.6までを使うことができることが分かります。

 今回はCUDAは12.4を使います。(Pytorchに単純にCUDA 12.4のバイナリがありそうなので採用しただけ。JAXにとっては関係ない)
CUDA Toolkit 12.4のダウンロード仕方
【参考文献】CUDA Toolkit 12.4 Downloads
 2024/11月現在では

wget https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda-repo-debian12-12-4-local_12.4.0-550.54.14-1_amd64.deb
sudo dpkg -i cuda-repo-debian12-12-4-local_12.4.0-550.54.14-1_amd64.deb
sudo cp /var/cuda-repo-debian12-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/
sudo add-apt-repository contrib
sudo apt-get update
sudo apt-get -y install cuda-toolkit-12-4

でダウンロードできると記載されているので、このコマンドを粛々と実行します。(これはエラー出ないんだよ、うちの環境では……)ドライバについては、公式サイトの上記の続きに書いてある

apt-get install -y cuda-drivers

を実行すれば入ります。
 ここで初心者殺しフェーズありです。このインストールをしても、NVCC(CUDA Compiler、CCっていうのはCUDA Compilerのことだと思われます)が動かんのだ!というのもインストールされる先のbinがノーマルのパスに入っていないのです。そして、nvcc --versionでバージョンを確認しましょう。

you@domain:directory$ export PATH="$PATH:/usr/local/cuda-12.4/bin"
you@domain:directory$ nvcc # これは要するにコンパイラなので、コンパイルするファイル無しではエラー吐く
nvcc fatal   : No input files specified; use option --help for more information
you@domain:directory$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0
you@domain:directory$ 

 .bashrcの末尾にCUDAへのbinを追加するのがオススメです。

you@domain:directory$ tail .bashrc
if ! shopt -oq posix; then
  if [ -f /usr/share/bash-completion/bash_completion ]; then
    . /usr/share/bash-completion/bash_completion
  elif [ -f /etc/bash_completion ]; then
    . /etc/bash_completion
  fi
fi

# add cuda bin
export PATH="$PATH:/usr/local/cuda-12.4/bin"
you@domain:directory$

NVIDIAドライバのインストール

 たいがい、このCUDA Toolkit入れた時点でNVIDIAドライバは適切なバージョンが入ります。なお、私のDebianの場合は、クリーンインストールした直後に上記のCUDAを入れてもnvidia-smiが動かないので、一度

sudo systemctl reboot

でリブートする必要がありました。その後であれば、次のコマンドが通ります。

you@domain:directory$ nvidia-smi
Thu Nov 14 07:18:18 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3060        Off |   00000000:01:00.0 Off |                  N/A |
| 33%   34C    P0             N/A /  170W |       1MiB /  12288MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
you@domain:directory$

 この場面で、もしもお使いのGPU環境でnvidia-smiがエラーをはいている場合は、多くの場合経験上は「nvidia-smiのバージョン」と「ドライバのバージョン」が一致していないからです。その結果nvidia-smiコマンドが使えなくなります。
 そのようなトラブルが起こった場合は、次のページに飛んで、CUDA Toolkitとそのツールキットが最低限サポートしているNVIDIAドライバ『以上』のドライバを再インストールする必要があります。

CUDAのバージョンとそれに対応するNVIDIAドライバの一覧
【参考文献】NVIDIA CUDA Toolkit Release Notes

 上の表から対応するNVIDIAドライバのバージョンを見つけたら、次の「NVIDA 公式ドライバー」サイトから、お使いのGPUのドライバをインストールします。ここまでで、「GPU」に「Compute Capability」が依存していて、「Compute Capability」に使える「CUDA」が依存していて、「CUDA」に「NVIDIAドライバ」は依存しています。
【参考文献】NVIDIA 公式ドライバーのダウンロード | NVIDIA
最新ドライバの検索方法
 多くの場合は、お使いのGPUの「最新バージョンのドライバ」を入れればnvidia-smiが復活することが期待されますが、

  • CUDAの必要要件とする最低バージョンのドライバ
  • cuDNNの必要要件とする最低バージョンのドライバ

の二つを満たしたドライバを使うことが期待されます。
例えば2024/11現在はRTX 3060のドライバはこれです

 ここからダウンロードしたファイルはNVIDIA-Linux-x86_64-550.127.05.runのようなファイル名です。これをchmod +xで実行可能状態として、実行することでドライバを「欲しいバージョン」に変えられます。

root@domain:directory$ chmod +x NVIDIA-Linux-x86_64-550.127.05.run
root@domain:directory$ export PATH=$PATH:/usr/local/sbin:/usr/sbin:/sbin # これがないとinitramfsのbuild/rebuildが失敗します
root@domain:directory$ ./NVIDIA-Linux-x86_64-550.127.05.run

 再三ですが、基本的にDebian, Ubuntu系では「CUDA Toolkit」の時点で基本的には必要なドライバが入るので、CUDA→NVIDAドライバの順で入れないと、NVIDIAドライバのみを取り換えるのは難しい印象です。
 「最新バージョンのドライバ」が動かない場合は、NVIDIAドライバのArchiveというのを探してきて、nvidia-smiの言っているエラーに従って、ドライバのバージョンnvidia-smiのバージョン合わせてください。

【参考文献】Linux AMD64 Display Driver Archive

cuDNNのインストール

 cuDNNも同じ要領で最新版を探してきて(検索エンジンでもAIエンジンでもよくて)、それをダウンロード・インストールします。

【参考文献】cuDNN 9.5.1 Downloads
cuDNN 9.5.1 Downloadsのページ

 インストール・ページでは、次のようにLinuxのコマンドが並んでいるので、粛々とファイルをダウンロードして、インストールします。ただし、現状のcuDNNのコマンドはdebianで使うにはcpコマンドが壊れているので使えません。修正版を示しているのでそちらを使ってください。(ディレクトリ名が違うだけですが)

wget https://developer.download.nvidia.com/compute/cudnn/9.5.1/local_installers/cudnn-local-repo-debian12-9.5.1_1.0-1_amd64.deb
sudo dpkg -i cudnn-local-repo-debian12-9.5.1_1.0-1_amd64.deb
# sudo cp /var/cuda-repo-debian12-9-5-local/cudnn-*-keyring.gpg # 公式レポジトリのこのコマンドはdebianでは動かず、次のコマンドにします
sudo cp /var/cudnn-local-repo-debian12-9.5.1/cudnn-*-keyring.gpg /usr/share/keyrings/
sudo apt-get update
sudo apt-get -y install cudnn

 もしも、特定のメジャーバージョンのCUDAの違いに対応する必要がある場合(11系と12系など)の場合は、次のコマンドでインストールし分けます。

sudo apt-get -y install cudnn-cuda-11 # for cuda 11
sudo apt-get -y install cudnn-cuda-12 # for cuda 12

JAXのコンパイル

 このページにたどり着いている人は、大方、普通のpip installでGPU用のJAXが入らない人だと思われます。そういう皆さんのために、どのようなGPU環境でもとりあえずJAXをインストールしてみせるというのが本編になります。

 JAXのソースコードからのビルドは次がオフィシャルサイトです。

【参考文献】Building from source - JAX documentation
JAXをソースコードからコンパイルするには

 さて、JAXのビルドに必要なのは

  • CUDA (あるいはCUDA Toolkitと言ってもいい)
  • cuDNN (Deep Neural Network用のライブラリ?)
  • NCCL (GPUクラスターを使うために必要なライブラリ、らしい)
  • 他のコンパイラ群(build-essentialで十分な気がするのでapt-get install build-essentialsしておいてくださいと思ったけど、実はJAXはCコンパイラはclangがデフォルトなので、apt-get install clangをしておきます)

です。JAXのコンパイル・ビルドにはclang(Cコンパイラの一種、GCC:GNU C Compilerと双璧を成す最近のコンパイラ)がデフォルトで利用されるので。

apt-get install build-essential
apt-get install clang

をしておきましょう。(正直この辺りは試行錯誤しているうちにインストールしてしまったので忘れました)

 個人のPCで(単一)GPUをターゲットにJAXをコンパイルする目的ではCUDAcuDNNがあれば十分です。NCCLはGPUクラスタを使う場合に必要になりますが、その環境構築をするだけのセッティングが手元にないので割愛します。
(オフィシャルサイトにはNCCLを使ったGPUクラスタ向けのビルド方法も記載されています。例えば、

python build/build.py --enable_cuda \
--bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \
--bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \
--bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl"

これがそうです。)

 それではJAXのビルド環境を作っていきましょう。 ディレクトリを作成して、gitでソースコードをダウンロードしてきます。gitが入っていない場合は

apt-get install git

で入れてください。この上で、次のコマンドを動かします。

you@domain:directory$ mkdir jax
you@domain:directory$ cd jax
you@domain:directory/jax$ python3 -m venv .venv
you@domain:directory/jax$ . .venv/bin/activate # source .venv/bin/activateに同じ
(.venv) you@domain:directory/jax$ git clone https://github.com/jax-ml/jax
Cloning into 'jax'...
remote: Enumerating objects: 159803, done.
remote: Counting objects: 100% (290/290), done.
remote: Compressing objects: 100% (162/162), done.
remote: Total 159803 (delta 163), reused 218 (delta 128), pack-reused 159513 (from 1)
Receiving objects: 100% (159803/159803), 98.28 MiB | 10.31 MiB/s, done.
Resolving deltas: 100% (126702/126702), done.
(.venv) you@domain:directory/jax$ ls
jax
(.venv) you@domain:directory/jax$ cd jax
(.venv) you@domain:directory/jax/jax$ ls
AUTHORS     CHANGELOG.md  cloud_tpu_colabs  docs      jax          LICENSE            README.md  third_party
benchmarks  ci            conftest.py       examples  jaxlib       platform_mappings  setup.py   WORKSPACE
build       CITATION.bib  CONTRIBUTING.md   images    jax_plugins  pyproject.toml     tests

 ここまでくるといよいよJAXをビルドすることができるのですが、結論から言うと

(.venv) you@domain:directory/jax/jax$ python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=[CUDA VERSION] --cudnn_version=[CUDNN VERSION]

というコマンドを実行する必要があります。--cuda_version=[CUDA VERSION]--cudnn_version=[CUDNN VERSION]に何を書けばいいか分からないのですが、これはすごい適当な文字列を入れると、どのようなバージョンが許されているか分かります。試しに

(.venv) you@domain:directory/jax/jax$ python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=0 --cudnn_version=0

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Downloading bazel

### former logs ###

ERROR: Error computing the main repository mapping: no such package '@cuda_redist_json//': The supported CUDA versions are ["11.8", "12.1.1", "12.2.0", "12.3.1", "12.3.2", "12.4.0", "12.4.1", "12.5.0", "12.5.1", "12.6.0", "12.6.1", "12.6.2"]. Please provide a supported version in HERMETIC_CUDA_VERSION environment variable or add JSON URL for CUDA version=0.

### latter logs ###

としてみると、CUDAのバージョンとして"11.8"から"12.6.2"までが使えることが分かります。

 そこで--cuda_version=12.4.0として実行します。

(.venv) you@domain:directory/jax/jax$ python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=12.4.0 --cudnn_version=0

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Downloading bazel

### former logs ###

Error in fail: The supported CUDNN versions are ["8.6", "8.9.4.25", "8.9.6", "8.9.7.29", "9.1.1", "9.2.0", "9.2.1", "9.3.0", "9.4.0", "9.5.0"]. Please provide a supported version in HERMETIC_CUDNN_VERSION environment variable or add JSON URL for CUDNN version=0.

### latter logs ###

ここからcuDNNのバージョンは"8.6"から"9.5.0"までが使えます。

 以上から

  • CUDA 12.4
  • cuDNN 9.5.1
    においては、次のパラメータによるビルド・コマンドで、JAXをビルドすることができることが分かります。
python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=12.4.0 --cudnn_version=9.5.0

 このCUDAのバージョンやcuDNNのバージョンは正直なところよく互換性が分かっていないのですがPytorchみたいにぎっちぎちにあっていなくても動く印象があります。というのも、ビルドした環境はCUDAが

you@domain:directory$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0
you@domain:directory$

で、cuDNNが

you@domain:directory$ dpkg --list | grep cudnn
ii  cudnn                                         9.5.1-1                             amd64        NVIDIA CUDA Deep Neural Network library (cuDNN)
ii  cudnn-local-repo-debian12-9.5.1               1.0-1                               amd64        cudnn-local repository configuration files
ii  cudnn9                                        9.5.1-1                             amd64        NVIDIA CUDA Deep Neural Network library (cuDNN)
ii  cudnn9-cuda-12                                9.5.1.17-1                          amd64        NVIDIA cuDNN for CUDA 12
ii  cudnn9-cuda-12-6                              9.5.1.17-1                          amd64        NVIDIA cuDNN for CUDA 12.6
ii  libcudnn9-cuda-12                             9.5.1.17-1                          amd64        cuDNN runtime libraries for CUDA 12.6
ii  libcudnn9-dev-cuda-12                         9.5.1.17-1                          amd64        cuDNN development headers and symlinks for CUDA 12.6
ii  libcudnn9-samples                             9.5.1.17-1                          all          cuDNN samples
ii  libcudnn9-static-cuda-12                      9.5.1.17-1                          amd64        cuDNN static libraries for CUDA 12.6
you@domain:directory$

というバージョンですが、ビルド時のパラメータを勘案すると

  • CUDA:ビルド時12.4 < 環境12.4.r12.4
  • cuDNN:ビルド時9.5.0 < 環境9.5.1
    で動いてしまっているので、たぶん、インストールされているCUDA, cuDNNのバージョンのメジャー・マイナーのバージョンくらいまでが合っていれば大丈夫そうなのです。ただし、バージョン番号は[メジャー].[マイナー](.[マイナーマイナー]以降続く)だと考えていっています。

 さて、コンパイルには大変に時間がかかるので、閑話休題。このコンパイル成功すると思います? これが癪なことに時期に依るんですよ。私が最初にJAXをコンパイルしたときには

jaxlib/gpu/solver_kernels_ffi.cc:918:29: error: missing 'typename' prior to dependent type name 'solver::RealType<T>::value'

というのが長い長いエラーログの末に出ていて、もはや読めもしないのでログをChatGPTに食わせたのですよ。そしたらsolver_kernels_ffi.ccっていうファイルの

auto s_data = static_cast<solver::RealType<T>::value*>(s->untyped_data());

これを

auto s_data = static_cast<typename solver::RealType<T>::value*>(s->untyped_data());

これに書き換える必要がありました。(しかもこれはC++のコードなので、私は久々に書いたらコメントアウトさえ間違えて、元のコードを#で消して、新規のこのコードを入れようとしたわけです。そしたら//がコメントアウトであると文句を言われるっていうね……。しくしく)
 何が言いたいかというと、JAXのgithubにあるコードがいつでもコンパイルできると思うなかれ! そして動かない時には、AI(ChatGPTでもClaudeでもGeminiでもいいです)を使ってファイルを修正する必要があるということです。


 そうこう言っているうちに、2週間前はファイルの修正が必要だったのに、今度は修正せずにコンパイル・ビルドが通過。なんなんだよ、このクソは。私がこれを記事にしなくちゃいけないと思ったほとんどはCUDA, cuDNNのインストールじゃなくて、この「コード書き換え」さえ必要っていうところだったのにコンパイル通過してんじゃないかよ!! という愚痴はさておき。

(.venv) you@domain:directory/jax/jax$ python build/build.py --enable_cuda --build_gpu_plugin --cuda_version=12.4.0 --cudnn_version=9.5.0

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Downloading bazel

### bunch of logs ###

[11,431 / 11,449] Compiling llvm/lib/Target/AMDGPU/SIISelLowering.cpp; 12s local ... (9 actions, 8 running)
[11,434 / 11,449] Compiling llvm/lib/Target/AMDGPU/SIISelLowering.cpp; 13s local ... (7 actions running)
[11,435 / 11,449] Compiling llvm/lib/Target/AMDGPU/SIISelLowering.cpp; 15s local ... (7 actions, 6 running)
[11,439 / 11,449] Compiling llvm/lib/Target/AMDGPU/SIISelLowering.cpp; 17s local ... (5 actions, 4 running)
[11,444 / 11,449] Compiling xla/service/gpu/gpu_compiler.cc; 16s local ... (2 actions running)
[11,446 / 11,449] Compiling xla/service/gpu/gpu_compiler.cc; 18s local
[11,447 / 11,449] [Prepa] Linking external/xla/xla/service/gpu/libgpu_compiler.pic.a
[11,448 / 11,449] Linking jaxlib/tools/pjrt_c_api_gpu_plugin.so; 1s local
Target //jaxlib/tools:build_gpu_plugin_wheel up-to-date:
  bazel-bin/jaxlib/tools/build_gpu_plugin_wheel
INFO: Elapsed time: 563.475s, Critical Path: 416.37s
INFO: 2008 processes: 450 internal, 1558 local.
INFO: Build completed successfully, 2008 total actions
INFO: Running command line: bazel-bin/jaxlib/tools/build_gpu_plugin_wheel '--output_path=/home/you/jax/jax/dist' '--jaxlib_git_hash=1471702adc286bcf40e87c42877d538b4d589f90' '--cpu=x86_64' '--enable-cuda=True' '--platform_version=12'
Output wheel: /home/you/jax/jax/dist/jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl

To install the newly-built jax cuda plugin wheel on system Python, run:
  pip install /home/you/jax/jax/dist/jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl --force-reinstall

To install the newly-built jax cuda plugin wheel on hermetic Python, run:
  echo -e "\n/home/you/jax/jax/dist/jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl" >> build/requirements.in
  bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.11

となってビルド完了です。一応、このコンパイルした環境を抜けるためには

(.venv) you@domain:directory/jax/jax$ deactivate
you@domain:directory/jax/jax$

.venvの環境から抜けることができるのでご安心を。

ビルドしたJAXのインストール方法

 せっかくビルドしたJAXをあなたのpython venv環境で使えるようにするためには、再度公式のページに行きましょう。

【参考文献】Building from source - JAX documentation
JAXをインストールするには?

 次のが公式サイトでの「一般にGPU, TPUなどでのビルドとインストール方法」です。

python build/build.py
pip install dist/*.whl  # installs jaxlib (includes XLA)

 私はですね、正直Pythonのことなんて何一つわからないんですけど、これでは動かなかったということだけは分かったんですよ。どうやら、JAXをPython環境に入れるときには

  • jax
  • jaxlib

という二つが必要なようで、さきほどまでビルドしていたのはCUDA用のjaxjaxlibなんです。そのため、pip install dist/*.whlをしたあとにimport jax.numpyとしても通常は初期は「CPU」でJAXを動かそうとするため、CPU版のJAXが入っていなくてjaxnumpyなんてモジュールないぞ?って言われました。
 そこで、ここでは

  • CPU版
  • GPU版(ここまでビルドしてきたもの)

を両方使える環境を提供する方法を示します。


 まずは安直にCPU版のJAXをインストールします。でも古典的なpython setup.py installの方法ですからね!?

you@domain:directory$ mkdir test
you@domain:directory$ cd test
you@domain:directory/test$ python3 -m venv .venv
you@domain:directory/test$ . .venv/bin/activate
(.venv) you@domain:directory/test$ cp -r [jaxのビルドしたディレクトリ] .
(.venv) you@domain:directory/test$ ls
jax
(.venv) you@domain:directory/test$ cd jax
(.venv) you@domain:directory/test/jax$ python setup.py install

### bunch of logs ###

Best match: jaxlib 0.4.35
Processing jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl
Installing jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl to /home/you/test/.venv/lib/python3.11/site-packages
Adding jaxlib 0.4.35 to easy-install.pth file

Installed /home/you/test/.venv/lib/python3.11/site-packages/jaxlib-0.4.35-py3.11-linux-x86_64.egg
Finished processing dependencies for jax==0.4.36.dev20241115+1471702ad
(.venv) you@domain:directory/test/jax$

 こうして

  • jax (ビルドしたGPU版jaxlibと対応したバージョンのjax)
  • jaxlib (ビルドしたGPU版jaxlibと対応したCPU用バージョンのjaxlib)

がインストールされます。

(.venv) you@domain:directory/test$ pip list | grep jax
jax        0.4.36.dev20241115+1471702ad
jax        0.4.36.dev20241115+1471702ad
jaxlib     0.4.35
jaxlib     0.4.35
(.venv) you@domain:directory/test$

 ここで、気づいてください、jaxlibが自前ビルド版になっていません!。要するにこのプロセスでインストールされるのはCPU版なんです……。


 ここまできて、jax/dist/*.whlをインストールするとこれまで作ってきたGPU版のビルドにjaxlibが置き換わって、晴れて

  • CPU版
  • GPU版(ここまでビルドしてきたもの)

が使えるようになります。
 さあ、最後に公式ドキュメント通りにpip install dist/*.whlによってGPU版のJAXをインストールします。

(.venv) you@domain:directory/test$ ls
jax
(.venv) you@domain:directory/test$ ls jax/dist/
jax-0.4.36.dev20241115+1471702ad-py3.11.egg
jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl
jax_cuda12_plugin-0.4.36.dev20241115-cp311-cp311-manylinux2014_x86_64.whl
jaxlib-0.4.36.dev20241115-cp311-cp311-manylinux2014_x86_64.whl
(.venv) you@domain:directory/test$ pip install jax/dist/*.whl
Processing ./jax/dist/jax_cuda12_pjrt-0.4.36.dev20241115-py3-none-manylinux2014_x86_64.whl
Processing ./jax/dist/jax_cuda12_plugin-0.4.36.dev20241115-cp311-cp311-manylinux2014_x86_64.whl
Processing ./jax/dist/jaxlib-0.4.36.dev20241115-cp311-cp311-manylinux2014_x86_64.whl
Requirement already satisfied: scipy>=1.10 in ./.venv/lib/python3.11/site-packages/scipy-1.14.1-py3.11-linux-x86_64.egg (from jaxlib==0.4.36.dev20241115) (1.14.1)
Requirement already satisfied: numpy>=1.24 in ./.venv/lib/python3.11/site-packages/numpy-2.1.3-py3.11-linux-x86_64.egg (from jaxlib==0.4.36.dev20241115) (2.1.3)
Requirement already satisfied: ml-dtypes>=0.2.0 in ./.venv/lib/python3.11/site-packages/ml_dtypes-0.5.0-py3.11-linux-x86_64.egg (from jaxlib==0.4.36.dev20241115) (0.5.0)
Installing collected packages: jax-cuda12-pjrt, jax-cuda12-plugin, jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.35
    Uninstalling jaxlib-0.4.35:
      Successfully uninstalled jaxlib-0.4.35
Successfully installed jax-cuda12-pjrt-0.4.36.dev20241115 jax-cuda12-plugin-0.4.36.dev20241115 jaxlib-0.4.36.dev20241115
(.venv) you@domain:directory/test$

 それでは、本当にGPU版のJAXもインストールされているか確認してみましょう。

(.venv) you@domain:directory/test$ pip list | grep jax
jax               0.4.36.dev20241115+1471702ad
jax               0.4.36.dev20241115+1471702ad
jax-cuda12-pjrt   0.4.36.dev20241115
jax-cuda12-plugin 0.4.36.dev20241115
jaxlib            0.4.36.dev20241115
(.venv) you@domain:directory/test$

 パチパチ!完成です。


 ここまで長い道のりでしたが、後学のため失敗例を示します。

pip install jax
pip install jax/dist/*.whl

 これは最初のjaxでデフォルトのディストリビューションのjaxが入り、そのjaxのバージョンと、せっかくここまでビルドしてきたjaxlibのバージョンが合わないために失敗します。(もちろん運が良ければ、配布のjaxとコンパイルしてきたjaxlibのバージョンが一致するため、エラーなしで進みます)

実際にJAXコードを動かしてみる

 次のコードは

  • 現在の計算環境がCPUかGPUかを表示し
  • 行列計算を行った結果を出力する

JAXの簡単なコードです。

t.py
import jax
import jax.numpy as jnp

# 使用デバイスの確認
device = jax.devices()[0]
print(f"Using device: {device.device_kind}")

# 行列の定義
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])

# 行列の積
C = jnp.dot(A, B)

# 計算を強制実行して結果を取得
C_result = jax.device_get(C)

# 計算結果を表示
print("Matrix A:")
print(A)
print("Matrix B:")
print(B)
print("Result of A * B:")
print(C_result)

CPU実行

 JAXのコードをCPUで実行するときは、デフォルトのまま何もしていないか、

export JAX_PLATFORM_NAME=cpu

によって実行環境をCPUにする必要があります。

(.venv) you@domain:directory/test$ export JAX_PLATFORM_NAME=cpu
(.venv) you@domain:directory/test$ python t.py
Using device: cpu
Matrix A:
[[1 2]
 [3 4]]
Matrix B:
[[5 6]
 [7 8]]
Result of A * B:
[[19 22]
 [43 50]]
(.venv) you@domain:directory/test$

GPU実行

 JAXのコードをGPUで実行するときは

export JAX_PLATFORM_NAME=gpu

によって実行環境をGPUにする必要があります。一方で、コードには変更を施す必要はありません。(もちろんコード側でCPUやGPU実行を変更することもできます)

(.venv) you@domain:directory/test$ export JAX_PLATFORM_NAME=gpu
(.venv) you@domain:directory/test$ python t.py
Using device: NVIDIA GeForce RTX 3060
Matrix A:
[[1 2]
 [3 4]]
Matrix B:
[[5 6]
 [7 8]]
Result of A * B:
[[19 22]
 [43 50]]
(.venv) you@domain:directory/test$

とても長いチュートリアルでしたが、これで「JAXをNVIDIAのGPU環境で使えるようにする方法」は終了になります。(あとは御託)

啓蒙活動

 オブジェクト指向というのは「本質的にすべてのオブジェクト・インスタンスが万能計算機(あるいはそのカスタム品)」であって、それ同士のコミュニケーションをメッセージ・パッシングで実現しているわけです。一方で、関数型言語というのは計算を計算資源に分解しており、式と項を分離させて処理するわけです。ここで、「式=モデル」、「項=パラメータ」となります。万能計算性を一つの計算機リソースで実装できない、巨大深層学習モデルでは関数型の指向で実装する以外ないのです。オブジェクト・インスタンスでは、根本的に「ちゃんと動く(ネットワークなどの不安定性も排して!)万能計算性」を保証できない巨大深層学習モデル記述して、その学習・運用に耐えられないのです。理屈はこれだけ。
 だから明日からは深層学習のモデルはスケールさせる目的で全て関数型の計算フレームワークで記述してください

PyTorch is dead. Long live JAX.

 Pytorchは今後十年をかけて衰退していきます。それに対して機械学習・深層学習の界隈ではXLA形式で

  • モデル
  • パラメータ

の定義を分離できる「関数型言語」のスタイルを採用したモデルを利用する機会が増加していきます。その実装の一つがJAXです。

【参考文献】PyTorch is dead. Long live JAX.

JAXはGoogle Deepmindの開発したプラットフォームで、現在多くの機械学習プラットフォームが依拠しているPytorch(Meta, 旧Facebook)性のシステムに対して近年存在感を増しています。
 思い返せば、2012に「グーグルの猫」を皮切りに、私が覚えている限り、日本のコミュニティで深層学習のフレームワークとしては

  • Tensorflow: 後にKerasが出てver1の問題点が改善したver2が主流となる
  • Pytorch: Facebookが出してその後60%以上のシェアを持つアカデミアや開発研究の主流に
  • Chainer: 2020年に更新が終了した日本のPrefered系のフレームワーク
  • JAX: Google Deepmind制の数値計算一般に使える計算フレームワーク

が台頭しては消えてを繰り返しています。なお、昔話はこちらで。
【参考文献】2012年にAIの歴史が動いた!ついに猫認識に成功した「Googleの猫」

 要するに栄枯盛衰を繰り返しているのですが、私見ではTensorflow ver1とJAXが同じような計算グラフのコンパイルを基盤に持っていて、Tensorflow ver2 (with Keras)とPytorchとChainerが計算グラフの実行時構成というか実行時インタープリターという方法で動くわけです。
 計算資源が多様化するにしたがって、GPUでもCPUでもTPUでもクラスター化して、それぞれの計算資源で計算グラフを組んだうえで、それぞれのハードウェアに適した形でモデルのパラメータを配置しなくてはいけなくなりました。それでいうと、計算グラフはそれぞれのハードウェアに対してコンパイルしておいて、パラメータや入力/出力の教師データ・ペアは他の方法で与えるほうが都合がいいわけです。要するに、あるクラスターではGPUが主役を、あるクラスターではCPUが主役を、あるクラスターではTPUやFPGAやASICが主役の計算機構であるわけで、それにパラメータと教師データを流し込んで、パラメータ更新値を得る必要があるのです。すると、自ずと、再度Tensorflow ver1, JAXの世界に戻ってくる必要があるんですわ。と僕は啓蒙されました。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?