LoginSignup
6
2

More than 3 years have passed since last update.

JAXの自動微分で二階微分を求める

Last updated at Posted at 2020-06-17

はじめに

複雑な関数の微分を求めたい時に、自動微分が便利ですよね。そんな時、私は今までPytorchの自動微分を利用していました。しかし、自動微分だけ使いたいのにpytorchのパッケージがかなり重いので、もっと軽量なパッケージはないかなと探していたところ、JAXにたどりつきました。

JAXとは?

公式
image.png
Autograd(今はメンテナンスされていない)の更新バージョンです。GPUを使って高速に自動微分を計算できます(もちろんCPUでも動きます)。

インストール方法

CPU only version

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

GPUを使いたい場合は公式のpip installationガイダンスを参照ください。

JAXでの二階微分

それでは試しにlog関数の二階微分を求めてみます。

import jax.numpy as jnp
from jax import grad

# log関数の定義
fn = lambda x0: jnp.log(x0)

# x = 1の周りで微分
x = 1

# 代入
y0 = fn(x)
# 一回微分
y1 = grad(fn)(x)
# 二階微分
y2 = grad(grad(fn))(x)
実行結果
>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)

最後に

JAXを使うとお手軽簡単に自動微分が使えます。

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2