はじめに
普段大学の研究ではサーバーのGPU上でpythonのライブラリであるjaxを使っています。jaxは線形代数関連の演算を非常に速く処理することができ便利なので、できればローカルでも使用したかったのですが、Apple SliconであるM1 Macだと今までインストールが上手く行きませんでした。日本語で解説しているサイトがなかったためここでごく簡単にまとめておきます。
環境
- OS: macOS Monterey
- チップ: Apple M1
- python: 3.11
インストール方法
インストール方法で大事なことはただ一つ、miniforgeでパッケージをインストールするということでした。非常に簡単なのですが、明記されているWebサイトが少なく気づくまで苦労しました。こちらのstack overflowに記載があります。
インストール方法自体は他の方がすでにまとめてくれているので、参考文献に記載しておきます。
miniforgeをinstallし、condaのpathがminiforgeのcondaに通っていることを確認すれば、あとはpipでinstallするだけです。
pip install jaxlib jax
これでjaxがM1 mac上で動くようになりました。
試してみる
以下の非常に簡単なコードで、numpyとjaxの計算速度の比較、自動微分について計算してみました。
import time
import timeit
import jax.numpy as jnp
import numpy as np
# numpy array
x_np = np.random.rand(10000, 10000)
y_np = np.random.rand(10000, 10000)
# jax array
x_jax = jnp.array(x_np)
y_jax = jnp.array(y_np)
# numpy arrayの計算 (numpyの計算速度)
np_time = timeit.timeit(
"x_np * y_np", setup="from __main__ import x_np, y_np", number=100
)
print(f"Average numpy calculation time: {np_time / 100} seconds")
# jax arrayの計算 (jaxの計算速度)
jax_time = timeit.timeit(
"x_jax * y_jax", setup="from __main__ import x_jax, y_jax", number=100
)
print(f"Average jax calculation time: {jax_time / 100} seconds")
## jaxの自動微分
from jax import grad
def f(x):
return x**2 + 2 * x + 1
f_grad = grad(f)
print(f"gradient of function f at x=1.0: {f_grad(1.0)}")
出力
Average numpy calculation time: 0.19221712457947432 seconds
Average jax calculation time: 0.03303184084012173 seconds
gradient of function f at x=1.0: 4.0
計算時間がnumpyに比べて~x6になっていて、自動微分の計算もできています!
ローカルでの計算速度も早くなったので、これでより研究を加速させていきたいと思います。