はじめに
複雑な関数の微分を求めたい時に、自動微分が便利ですよね。そんな時、私は今までPytorchの自動微分を利用していました。しかし、自動微分だけ使いたいのにpytorchのパッケージがかなり重いので、もっと軽量なパッケージはないかなと探していたところ、JAXにたどりつきました。
JAXとは?
公式
Autograd(今はメンテナンスされていない)の更新バージョンです。GPUを使って高速に自動微分を計算できます(もちろんCPUでも動きます)。
インストール方法
CPU only version
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
GPUを使いたい場合は公式のpip installationガイダンスを参照ください。
JAXでの二階微分
それでは試しにlog関数の二階微分を求めてみます。
import jax.numpy as jnp
from jax import grad
# log関数の定義
fn = lambda x0: jnp.log(x0)
# x = 1の周りで微分
x = 1
# 代入
y0 = fn(x)
# 一回微分
y1 = grad(fn)(x)
# 二階微分
y2 = grad(grad(fn))(x)
実行結果
>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)
最後に
JAXを使うとお手軽簡単に自動微分が使えます。