0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ただただアウトプットを癖付けるためのAdvent Calendar 2024

Day 13

NeurIPS著者が今更JAXを使ってみた話

Last updated at Posted at 2024-12-12

はじめに

この記事は「ただただアウトプットを癖付けるための Advent Calendar 2024」に投稿した記事です。

最初の記事にも書いた通り、私は生物物理の実験を専門にしている研究者です。
最近はデータ解析のため機械学習のコード開発も行っており、幸いにもその成果がNeurIPSに採択されました

上記のコード開発では、pytorchをこねくり回すことで、なんとか目的のモデルを作成することができました。
正直動作時間などは考えられておらず、効率化が必須の状況です。
そこで、JAXを使ってみることにしました。
今回は、適当に機械学習モデルを作成してみることにしました。

関連記事

前の記事「生物物理屋がSIGNATEの初心者向け課題に挑戦してみた話

次の記事「生物物理屋がNishikaの初心者向け課題に挑戦してみた話

JAX

こちらを参考にしています。

JAXは、数値計算を高速におこなうためのライブラリです。
Deepmindが開発したものであり、こちらからアクセスできます。

JAXは、numpyの代わりに使うことができます。
線形代数の計算を高速におこなうコンパイラが内蔵されているほか、自動微分もサポートされています。

PyTorchからの移行には、こちらも参考になりそうです。

Flax

Flaxは、JAXを使った機械学習ライブラリです。
これを用いることで、簡単に機械学習モデルを作成することができます。

インストール

pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install flax

windows11環境です。python3.8では動作しませんでした。今回はpython3.11.7を使っています。

JAXを使うコード

import jax
import jax.numpy as jnp

def f(x):
    return x**2

grad_f = jax.grad(f)
print(grad_f(3.0))

Flaxを使うコード

モデルの作成


import jax
import jax.numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
    features: int

    def setup(self):
        self.dense1 = nn.Dense(features=self.features)
        self.dense2 = nn.Dense(features=1)

    def __call__(self, x):
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        return x

model = MLP(features=10)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 10)))

モデルの学習

上の続きです。

import optax

optimizer_def = optax.adam(learning_rate=0.1, b1=0.9, b2=0.999)
optimizer = optimizer_def.init(params)

@jax.jit
def loss_fn(params, x, y):
    pred = model.apply(params, x)
    return jnp.mean((pred - y)**2)

loss_grad_fn = jax.value_and_grad(loss_fn)

x = jax.random.normal(jax.random.PRNGKey(0), (100, 10))
y = x @ jnp.ones((10, 1)) + 0.1 * jax.random.normal(jax.random.PRNGKey(0), (100, 1))

for i in range(101):
    loss_val, grads = loss_grad_fn(params, x, y)
    updates, optimizer = optimizer_def.update(grads, optimizer)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print(f'Loss step {i}: ', loss_val)

jax.jitは、関数をコンパイルして高速化するデコレータです。
GPUを使う場合は、jax.jitを使うことで高速化が期待できます。

optaxは、最適化アルゴリズムを提供するライブラリです。

このコードで、ちゃんとlossが減少していくことが確認できました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?