LoginSignup
4
5

More than 1 year has passed since last update.

【高速なnumpy】WindowsでJAXを使おう【Python】

Last updated at Posted at 2022-04-19

この記事ではJust-In-Timeコンパイルにより高速計算を実現するライブラリJAXWindowsに導入する手順を説明します。
JITコンパイル、JAXの説明はしません。
【参考】https://github.com/cloudhan/jax-windows-builder

1. jaxlibのwhlファイルをダウンロードする

https://whls.blob.core.windows.net/unstable/index.htmlから自分の環境にあったjaxlibのwhlファイルをダウンロードします。
該当ファイルがどれかわからない!という方は次のコードを実行することでCPU/cuda、jaxlibのバージョン、cp等を確認してください。

pip3 debug --verbose 

2. whlファイルをインストールする

ダウンロードしたwhlファイルのパスをコピーしpip installします。
※Windowsではshift+右クリックでパスのコピーを選択できる。

pip install "C:\Users\~~~~~.whl"

3. jaxをpip installする

状況に応じてjax[gpu]やバージョンの指定が必要です。詳しくは

pip install jax

4. jaxをimportする

import jax 

実際にはimport jax.numpy as npjnpと使うことが多いと思いますが、np.randomのようにnumpy同様に使えない処理もあります。

今後に期待しましょう!

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5