はじめに
この記事は「ただただアウトプットを癖付けるための Advent Calendar 2024」に投稿した記事です。
最初の記事にも書いた通り、私は生物物理の実験を専門にしている研究者です。
最近はデータ解析のため機械学習のコード開発も行っており、幸いにもその成果がNeurIPSに採択されました。
上記のコード開発では、pytorchをこねくり回すことで、なんとか目的のモデルを作成することができました。
正直動作時間などは考えられておらず、効率化が必須の状況です。
そこで、JAXを使ってみることにしました。
今回は、適当に機械学習モデルを作成してみることにしました。
関連記事
前の記事「生物物理屋がSIGNATEの初心者向け課題に挑戦してみた話」
次の記事「生物物理屋がNishikaの初心者向け課題に挑戦してみた話」
JAX
こちらを参考にしています。
JAXは、数値計算を高速におこなうためのライブラリです。
Deepmindが開発したものであり、こちらからアクセスできます。
JAXは、numpyの代わりに使うことができます。
線形代数の計算を高速におこなうコンパイラが内蔵されているほか、自動微分もサポートされています。
PyTorchからの移行には、こちらも参考になりそうです。
Flax
Flaxは、JAXを使った機械学習ライブラリです。
これを用いることで、簡単に機械学習モデルを作成することができます。
インストール
pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install flax
windows11環境です。python3.8では動作しませんでした。今回はpython3.11.7を使っています。
JAXを使うコード
import jax
import jax.numpy as jnp
def f(x):
return x**2
grad_f = jax.grad(f)
print(grad_f(3.0))
Flaxを使うコード
モデルの作成
import jax
import jax.numpy as jnp
from flax import linen as nn
class MLP(nn.Module):
features: int
def setup(self):
self.dense1 = nn.Dense(features=self.features)
self.dense2 = nn.Dense(features=1)
def __call__(self, x):
x = self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
return x
model = MLP(features=10)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 10)))
モデルの学習
上の続きです。
import optax
optimizer_def = optax.adam(learning_rate=0.1, b1=0.9, b2=0.999)
optimizer = optimizer_def.init(params)
@jax.jit
def loss_fn(params, x, y):
pred = model.apply(params, x)
return jnp.mean((pred - y)**2)
loss_grad_fn = jax.value_and_grad(loss_fn)
x = jax.random.normal(jax.random.PRNGKey(0), (100, 10))
y = x @ jnp.ones((10, 1)) + 0.1 * jax.random.normal(jax.random.PRNGKey(0), (100, 1))
for i in range(101):
loss_val, grads = loss_grad_fn(params, x, y)
updates, optimizer = optimizer_def.update(grads, optimizer)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)
jax.jitは、関数をコンパイルして高速化するデコレータです。
GPUを使う場合は、jax.jitを使うことで高速化が期待できます。
optaxは、最適化アルゴリズムを提供するライブラリです。
このコードで、ちゃんとlossが減少していくことが確認できました。