1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

#0163(2025/06/04)JAXとはどんなライブラリか?

Posted at

JAXとは

1. JAX が注目される理由

  • NumPy ライクな記述性 ─ 既存の Python + NumPy コードをほぼそのまま移植でき、学習コストが低い。
  • 自動微分 (Autograd)grad, value_and_grad, jacfwd/jacrev, hessian で高次導関数まで一貫サポート。
  • JIT コンパイル (XLA)jax.jit で関数を GPU/TPU ネイティブコードへ変換し、10〜100× の高速化。
  • 関数変換の合成grad, jit, vmap, pmap後付けで合成してもコードが崩壊しない。
  • CPU⇄GPU⇄TPU の透過性 ─ 配列をデバイスへ明示的にコピーせずとも動作。研究段階から大規模分散まで同じコード。

2. コア機能の概要

2.1 自動微分 (grad 系 API)

  • 逆モード・順モードを両対応。ニューラルネットの誤差逆伝播から物理シミュレーションの高次微分まで幅広く適用。
  • 「関数=第一級オブジェクト」の思想で、副作用を最小化しバックエンド最適化を最大化。

2.2 JIT コンパイル (jax.jit)

  • XLA が計算グラフをステートフリーに最適化。実行時は AOT コンパイルされたバイナリがキャッシュされる。
  • Python の for ループif 文もトレースされるため、NumPy より高い抽象度でも速度を損なわない。

2.3 ベクトル化 (jax.vmap)

  • for ループを暗黙的に並列ベクトル化し、関数そのものの意味論を保ったままバッチ演算化。

2.4 デバイス並列化 (jax.pmap / pjit)

  • マルチ GPU・TPU へ自動 sharding。データ並列、モデル並列を単一 API で切り替え可能。

3. エコシステムと周辺ライブラリ

  • Flax ─ 公式 NN ライブラリ。flax.linen.Module でモジュラにモデル構築。
  • Haiku ─ DeepMind 製。関数型パラダイムを強調しつつ TensorFlow ライクな可読性。
  • Optax ─ 勾配降下アルゴリズム集。AdamW, LAMB, Lion, learning‑rate schedulers 等をチェーン合成。
  • Orbax ─ TPU/マルチ GPU 向けチェックポイント標準。分散大規模モデルの耐障害性を確保。
  • NumPyro / BlackJAX ─ ベイズ推論 (HMC, NUTS, SVI)。JIT + AD で MCMC が桁違いに高速。
  • RLax / BraX ─ 強化学習演算+微分可能物理シミュレーション。
  • Equinox / Diffrax ─ 完全関数型 NN と常微分方程式ソルバ。

4. クイックスタート:最小の学習ループ

import jax, jax.numpy as jnp
from jax import grad, jit, vmap

# ❶ モデル定義(線形回帰)
def model(params, x):
    w, b = params
    return jnp.dot(x, w) + b

# ❷ 損失関数
@jit  # JIT + 自動微分
def loss_fn(params, x, y):
    pred = model(params, x)
    return jnp.mean((pred - y) ** 2)

# ❸ 勾配計算
grad_fn = jit(grad(loss_fn))

# ❹ データ (バッチ化を vmap で後付け可能)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 3))
w_true = jnp.array([1.5, -2.0, 0.7])
y = jnp.dot(x, w_true) + 0.1

# ❺ 初期化 & SGD ループ
params = [jnp.zeros(3), 0.0]
learning_rate = 0.1
for step in range(500):
    grads = grad_fn(params, x, y)
    params = [p - learning_rate * g for p, g in zip(params, grads)]
print("学習後のパラメータ", params)

ポイント

  • @jit + grad を合成し、純 NumPy 風コードがそのまま高速化。
  • データ次元を増やしたい場合は vmap(model, in_axes=(None, 0)) を追加するだけ。

5. 実践 Tips

  • 乱数管理 ─ PRNGKey は関数境界で分割して明示的に受け渡す。
  • PyTree ─ ネスト辞書・タプルを再帰展開して一括操作。モデルパラメータや学習状態をシリアル化しやすい。
  • プロファイラjax.profiler + TensorBoard → ボトルネック解析。
  • Python オーバーヘッド ─ 頻回ループは vmap/scan/pmap で JIT 内に閉じ込める。

6. まとめ

JAX は 「NumPy の書き味」「XLA の性能」 を両立させ、研究コードをほぼ無改変でプロダクションスケールへ昇華できる稀有なプラットフォームです。関数変換というアイデアにより、自動微分・JIT・並列化を後付けできるため、機械学習エンジニアはアルゴリズム思考に集中できます。

次の一歩

  • Flax/Optax を使った Transformer 学習ループを試作してみる
  • NumPyro でベイズ線形回帰 ⇒ 不確実性推定を体験
  • BraX で TPU 強化学習を 15 分以内に完走
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?