JAXとは
1. JAX が注目される理由
- NumPy ライクな記述性 ─ 既存の Python + NumPy コードをほぼそのまま移植でき、学習コストが低い。
-
自動微分 (Autograd) ─
grad
,value_and_grad
,jacfwd/jacrev
,hessian
で高次導関数まで一貫サポート。 -
JIT コンパイル (XLA) ─
jax.jit
で関数を GPU/TPU ネイティブコードへ変換し、10〜100× の高速化。 -
関数変換の合成 ─
grad
,jit
,vmap
,pmap
を後付けで合成してもコードが崩壊しない。 - CPU⇄GPU⇄TPU の透過性 ─ 配列をデバイスへ明示的にコピーせずとも動作。研究段階から大規模分散まで同じコード。
2. コア機能の概要
2.1 自動微分 (grad 系 API)
- 逆モード・順モードを両対応。ニューラルネットの誤差逆伝播から物理シミュレーションの高次微分まで幅広く適用。
- 「関数=第一級オブジェクト」の思想で、副作用を最小化しバックエンド最適化を最大化。
2.2 JIT コンパイル (jax.jit)
- XLA が計算グラフをステートフリーに最適化。実行時は AOT コンパイルされたバイナリがキャッシュされる。
- Python の for ループやif 文もトレースされるため、NumPy より高い抽象度でも速度を損なわない。
2.3 ベクトル化 (jax.vmap)
- for ループを暗黙的に並列ベクトル化し、関数そのものの意味論を保ったままバッチ演算化。
2.4 デバイス並列化 (jax.pmap / pjit)
- マルチ GPU・TPU へ自動 sharding。データ並列、モデル並列を単一 API で切り替え可能。
3. エコシステムと周辺ライブラリ
-
Flax ─ 公式 NN ライブラリ。
flax.linen.Module
でモジュラにモデル構築。 - Haiku ─ DeepMind 製。関数型パラダイムを強調しつつ TensorFlow ライクな可読性。
- Optax ─ 勾配降下アルゴリズム集。AdamW, LAMB, Lion, learning‑rate schedulers 等をチェーン合成。
- Orbax ─ TPU/マルチ GPU 向けチェックポイント標準。分散大規模モデルの耐障害性を確保。
- NumPyro / BlackJAX ─ ベイズ推論 (HMC, NUTS, SVI)。JIT + AD で MCMC が桁違いに高速。
- RLax / BraX ─ 強化学習演算+微分可能物理シミュレーション。
- Equinox / Diffrax ─ 完全関数型 NN と常微分方程式ソルバ。
4. クイックスタート:最小の学習ループ
import jax, jax.numpy as jnp
from jax import grad, jit, vmap
# ❶ モデル定義(線形回帰)
def model(params, x):
w, b = params
return jnp.dot(x, w) + b
# ❷ 損失関数
@jit # JIT + 自動微分
def loss_fn(params, x, y):
pred = model(params, x)
return jnp.mean((pred - y) ** 2)
# ❸ 勾配計算
grad_fn = jit(grad(loss_fn))
# ❹ データ (バッチ化を vmap で後付け可能)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 3))
w_true = jnp.array([1.5, -2.0, 0.7])
y = jnp.dot(x, w_true) + 0.1
# ❺ 初期化 & SGD ループ
params = [jnp.zeros(3), 0.0]
learning_rate = 0.1
for step in range(500):
grads = grad_fn(params, x, y)
params = [p - learning_rate * g for p, g in zip(params, grads)]
print("学習後のパラメータ", params)
ポイント
-
@jit
+grad
を合成し、純 NumPy 風コードがそのまま高速化。 - データ次元を増やしたい場合は
vmap(model, in_axes=(None, 0))
を追加するだけ。
5. 実践 Tips
- 乱数管理 ─ PRNGKey は関数境界で分割して明示的に受け渡す。
- PyTree ─ ネスト辞書・タプルを再帰展開して一括操作。モデルパラメータや学習状態をシリアル化しやすい。
-
プロファイラ ─
jax.profiler
+ TensorBoard → ボトルネック解析。 -
Python オーバーヘッド ─ 頻回ループは
vmap
/scan
/pmap
で JIT 内に閉じ込める。
6. まとめ
JAX は 「NumPy の書き味」 と 「XLA の性能」 を両立させ、研究コードをほぼ無改変でプロダクションスケールへ昇華できる稀有なプラットフォームです。関数変換というアイデアにより、自動微分・JIT・並列化を後付けできるため、機械学習エンジニアはアルゴリズム思考に集中できます。
次の一歩
- Flax/Optax を使った Transformer 学習ループを試作してみる
- NumPyro でベイズ線形回帰 ⇒ 不確実性推定を体験
- BraX で TPU 強化学習を 15 分以内に完走