JAXは、pipを使用すると簡単にインストールできますが、(現状)ARM系CPUが乗ったマシンではエラーを吐きます。
いろいろ試行錯誤した結果、ソースファイルからインストールすると上手くいったので、個人的メモです。
0. この記事の前提
- Ubuntuコンテナ(Docker)上にJAXとFlax環境を整備します。
- 検証マシン:OCI, Ampere A1 Compute
- GPU(CUDA)は使用しません
- Python3 / JupyterLabを使用
1. Dockerコンテナの作成
私が作成しているシェルスクリプトをもとに、Python3環境でJupyterLabが自動起動するUbuntuコンテナを作成します。
※自分の好みの環境が既にある方は、次の章にお進みください。
# リポジトリのクローン
$ git clone https://github.com/KaiSugahara/docker_template.git
$ cd docker_template/jupyter/ubuntu/
# コンテナを起動するシェルスクリプトの実行
$ bash run.sh
作成したいコンテナ名を入力してください(入力例「your_name_ubuntu」)
> jupyter
コンテナユーザのパスワードを入力してください(入力は非表示になっています)
> (パスワードを入力する)
Jupyterに接続するためのポート番号を入力してください
> 22000
HTTP(S)_PROXYを入力してください(指定しない場合はそのままEnter)
>
保存するイメージ名を入力してください(入力例「your_name/ubuntu」)
> jupyter
これだけで完了です。
(22000:localhost:22000
でポートフォワーディングすれば)http://localhost:22000 でJupyterLabに接続できます。
ここからは、このUbuntuコンテナ上にJAX/Flaxをインストールします。
$ docker exec -it container_name bash
により、コンテナ上でコマンドを実行していきます。
(もしくは、JupyterLabで開くターミナルでもOK)
2. jaxlibとjaxのインストール
JAXのドキュメント内 "Building from source" に従って進めます。
https://jax.readthedocs.io/en/latest/developer.html
まずは、JAXのソースコードを落とします。
$ sudo apt install git
$ git clone https://github.com/google/jax
$ cd jax
jaxlibをインストールします。(結構時間がかかるので気長に待ちます)
$ pip3 install numpy wheel
$ python3 build/build.py
$ pip3 install dist/*.whl
jaxlibのインストールができたので、次はjaxです。
$ pip3 install -e .
3. Flaxのインストール
まず、cmake
が入っていない場合は、このあとエラーになるのでインストール。
$ sudo apt install cmake
あとはpipでFlaxを入れます。
$ pip3 install flax
以上で、完了です。お疲れさまでした!