2
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

updated at

Organization

JAXの自動微分で二階微分を求める

はじめに

複雑な関数の微分を求めたい時に、自動微分が便利ですよね。そんな時、私は今までPytorchの自動微分を利用していました。しかし、自動微分だけ使いたいのにpytorchのパッケージがかなり重いので、もっと軽量なパッケージはないかなと探していたところ、JAXにたどりつきました。

JAXとは?

公式
image.png
Autograd(今はメンテナンスされていない)の更新バージョンです。GPUを使って高速に自動微分を計算できます(もちろんCPUでも動きます)。

インストール方法

CPU only version

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

GPUを使いたい場合は公式のpip installationガイダンスを参照ください。

JAXでの二階微分

それでは試しにlog関数の二階微分を求めてみます。

import jax.numpy as jnp
from jax import grad

# log関数の定義
fn = lambda x0: jnp.log(x0)

# x = 1の周りで微分
x = 1

# 代入
y0 = fn(x)
# 一回微分
y1 = grad(fn)(x)
# 二階微分
y2 = grad(grad(fn))(x)
実行結果
>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)

最後に

JAXを使うとお手軽簡単に自動微分が使えます。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
2
Help us understand the problem. What are the problem?