はじめに
JAXはGoogleが開発した自動微分、GPU(TPU)、Numpyのような機能を持った超便利なライブラリです。
詳しいところは、
- JAX入門~高速なNumPyとして使いこなすためのチュートリアル~ - Qiita
https://qiita.com/koshian2/items/44a871386576b4f80aff?utm_campaign=popular_items&utm_medium=feed&utm_source=popular_items - JAX/Flaxを使ってMNISTを学習させてみる | TC3株式会社
https://www.tc3.co.jp/jaxflax-introduction-with-mnist/
がとても参考になります。
本記事では、このJAXを使ってシンプルなフィードフォワードNNの学習をしています。
同じくGoogleの開発チームが提供しているFlaxを組み合わせて使うことがポピュラーのようですが、まずは入門ということでJAXのみで実現してみます。
また(初心者なので)順を追った説明は控えて、コードをバンッ!と提示しながら説明します。
とりあえずコードだけで良いという方は、ぜひ
を参照してください。
おことわり)理論の説明、モデルの妥当性の考慮は完全に無視しています。実装の参考になれば幸いです。
コードを見ながら
使用ライブラリ
########################################
# ライブラリ
########################################
import jax
import jax.numpy as jnp
from jax import nn
from jax.nn.initializers import glorot_normal, normal
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tqdm import trange as tqdm_range
JAXのインストールはとりあえず済んでいるということで、最初の4行が今回使うJAX関連のものです。import jax.numpy as jnp
がいわゆるNumpyのようなものです。
その他、今回は実験データとしてirisを使います。
ネットワークの定義
########################################
# ニューラルネットワーク(フィードフォワード)
########################################
@jax.jit
def Linear(params, x):
return jnp.dot(x, params["W"]) + params["b"]
@jax.jit
def ANN(params, x):
y1 = Linear(params["linear1"], x)
z1 = nn.relu(y1)
y2 = Linear(params["linear2"], z1)
z2 = nn.softmax(y2)
return z2
ここでは順伝搬の計算を定義しています。
x ---線形変換---> y1 ---ReLUで活性化---> z1 ---線形変換---> y2 ---ReLUで活性化---> z2
と変換されていきますので、今回は1層ネットワークです。またユニット数はこの後定義するparamsで明らかになります。
単純に行列計算をしているだけなので、直感的ですね。
また @jax.jit
は、関数をJITコンパイルするためのデコレータです。付けると速くなるそうです。
訓練用関数の定義
########################################
# 訓練
########################################
# クロスエントロピー誤差
@jax.jit
def cross_entropy_loss(params, X, y):
logits = ANN(params, X)
return jnp.mean(-jnp.sum(y * jnp.log(logits), axis=1))
# パラメータの更新
@jax.jit
def update_params(params, grad, lr = 0.01):
params["linear1"]["W"] -= lr * grad["linear1"]["W"]
params["linear2"]["W"] -= lr * grad["linear2"]["W"]
params["linear1"]["b"] -= lr * grad["linear1"]["b"]
params["linear2"]["b"] -= lr * grad["linear2"]["b"]
return params
# 与えられたXとyで勾配を計算&更新
@jax.jit
def train(params, X, y):
grad = jax.grad(cross_entropy_loss)(params, X, y)
return update_params(params, grad)
# バッチ毎に訓練
@jax.jit
def train_for_each_batch(batch_idx, params):
target_train_indices = jax.lax.dynamic_slice(index, [batch_idx*batch_size], [batch_size])
params = train(params, X_train[target_train_indices], y_train[target_train_indices])
return params
まず、今回は3値分類になるので、クロスエントロピー誤差を使用します。cross_entropy_loss
は、計算されたパラメータを使用して、損失を計算します。つまり、これを使ってSGDしていくわけです。
次に、update_params
は勾配を基にパラメータを更新します。Pytorchや今回使わないFlaxでは、便利なクラスで簡単に更新してくれますが、今回はJAXのみでいきたいので書きます。
train
は、与えられた入力 $X$ とその真値 $y$ を使って、勾配を計算し、更新済みのパラメータを返します。
そして、train_for_each_batch
は、入力の一部を使ってtrain
に投げます。つまりはミニバッチ処理です。
データセットの読み込み
########################################
# データセットの読み込み
########################################
iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], test_size=0.25, random_state=0)
X_train, X_test, y_train, y_test = jax.device_put(X_train), jax.device_put(X_test), jax.device_put(y_train), jax.device_put(y_test)
y_train = jnp.eye(3)[y_train] # ワン・ホットに変換
y_test = jnp.eye(3)[y_test] # ワン・ホットに変換
sklearnを用いて、irisデータセットを用意します。今回は、75%を訓練、残りの25%をテスト用としました。また、JAXでの計算にGPUを使う場合は、ここでdevice_put
して、GPU上に乗せておきます。
実行!
########################################
# 学習してみる
########################################
# パラメータの初期化
rng = jax.random.PRNGKey(0)
rng1, rng2 = jax.random.split(rng)
rng1w, rng1b = jax.random.split(rng1)
rng2w, rng2b = jax.random.split(rng2)
params = {
"linear1": {
"W": glorot_normal()(rng1w, (4, 100)),
"b": normal()(rng1b, (100,))
},
"linear2": {
"W": glorot_normal()(rng2w, (100, 3)),
"b": normal()(rng2b, (3,))
}
}
まずはパラメータを初期化します。Wが重み、bがバイアスです。
入力次元 4 -> 100 -> 3 と変換される形です。
# バッチサイズ
batch_size = 50
# エポック数
epoch_nums = 100
# 訓練を回す
for epoch_id in range(epoch_nums):
key = jax.random.PRNGKey(epoch_id+1)
# 訓練データのインデックスをシャッフル
index = jax.random.permutation(key, X_train.shape[0])
# バッチ数
batch_length = jnp.ceil(X_train.shape[0] / batch_size)
# バッチ毎にパラメータを更新していく
params = jax.lax.fori_loop(0, int(batch_length), train_for_each_batch, params)
# 誤差の確認
print(
"訓練誤差:",
'{:.3f}'.format(cross_entropy_loss(params, X_train, y_train)),
"汎化誤差:",
'{:.3f}'.format(cross_entropy_loss(params, X_test, y_test)),
f"【 Epoch: {epoch_id} / {epoch_nums} 】"
)
バッチサイズ、エポック数を指定して、学習を実行します。
1エポックの始めに、インデックス番号をシャッフルしておいて、関数train_for_each_batch
の中で、使用する入力データを決めさせます。
fori_loop
は、JAXにおいて高速にfor文を回すことができる魔法の関数(?)です。これは、
for batch_id in range(0, int(batch_length)):
params = train_for_each_batch(batch_id, params)
と同じです。
おわりに
以上、JAXを使ってシンプルなNNを実装してみました。
初学者でいろいろ間違っているかもしれませんが、気づいた点があればご指摘いただければと思います。
TensorFlowやPyTorchのように日本語の解説記事が十分になく、この記事に辿り着いた方は(おそらく)苦労されているかと思います...
情報共有のためにも、わかったことがあればどんどん書いていきたいです。