LoginSignup
0
0

jax/flaxのインストール時のトラブルメモ

Last updated at Posted at 2023-01-17

インストール

GPUを使う時はpipインストール時に指定しなければいけない

pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

実際にはcudnnのバージョンに合わせてバージョン指定したほうがいい(後述)

pip3 install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip3 install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

2023/9/11追記
公式のgithub見るとcudnnのバージョン指定不要になったっぽい?

pip install --upgrade pip

# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

バージョンについて

cudnnバージョンは環境のcudnnのメジャーバージョンは合わせなければいけない
マイナーバージョンは環境のcudnnよりも古いjaxをインストールする

環境にインストールされているcudnnのバージョンはヘッダファイルを見る

$ cat /usr/include/cudnn_version.h

---省略---
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 5
#define CUDNN_PATCHLEVEL 0
---省略---

上の例の場合、jax[cuda11_cudnn82]をインストールしなければいけない

GPUのメモリ

デフォルトではGPUメモリの9割をプリアロケートする。
チュートリアルやサンプルコードが動かないとかの報告がある。

特に、Getting startedのコードはTFDSでMNISTをロードする際にTensorFlowがGPUメモリをアロケートしてしまうので、TensorFlow側も対処が必要

JAXの設定

下記のどちらかを.bashrcでexportするか、実行時に指定する

export XLA_PYTHON_CLIENT_PREALLOCATE=false # preallocateしない場合
export XLA_PYTHON_CLIENT_MEM_FRACTION=.XX # 80-85%でうまくいくらしい。デフォルトは90%

TensorFlowの設定

下記をコード内で実行しておく

tf.config.experimental.set_visible_devices([], "GPU")

参考

https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
https://github.com/google/jax/issues/8746
https://tech.yellowback.net/posts/jax-oom

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0