LoginSignup
5
2

More than 1 year has passed since last update.

JAXだけでつくるシンプルなニューラルネットワーク

Last updated at Posted at 2022-07-03

はじめに

JAXはGoogleが開発した自動微分、GPU(TPU)、Numpyのような機能を持った超便利なライブラリです。
詳しいところは、

がとても参考になります。

本記事では、このJAXを使ってシンプルなフィードフォワードNNの学習をしています。
同じくGoogleの開発チームが提供しているFlaxを組み合わせて使うことがポピュラーのようですが、まずは入門ということでJAXのみで実現してみます。

また(初心者なので)順を追った説明は控えて、コードをバンッ!と提示しながら説明します。
とりあえずコードだけで良いという方は、ぜひ

を参照してください。

おことわり)理論の説明、モデルの妥当性の考慮は完全に無視しています。実装の参考になれば幸いです。

コードを見ながら

使用ライブラリ

########################################
# ライブラリ
########################################

import jax
import jax.numpy as jnp
from jax import nn
from jax.nn.initializers import glorot_normal, normal

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

from tqdm import trange as tqdm_range

JAXのインストールはとりあえず済んでいるということで、最初の4行が今回使うJAX関連のものです。import jax.numpy as jnp がいわゆるNumpyのようなものです。
その他、今回は実験データとしてirisを使います。

ネットワークの定義

########################################
# ニューラルネットワーク(フィードフォワード)
########################################

@jax.jit
def Linear(params, x):
    return jnp.dot(x, params["W"]) + params["b"]

@jax.jit
def ANN(params, x):
    y1 = Linear(params["linear1"], x)
    z1 = nn.relu(y1)
    y2 = Linear(params["linear2"], z1)
    z2 = nn.softmax(y2)
    return z2

ここでは順伝搬の計算を定義しています。
x ---線形変換---> y1 ---ReLUで活性化---> z1 ---線形変換---> y2 ---ReLUで活性化---> z2 と変換されていきますので、今回は1層ネットワークです。またユニット数はこの後定義するparamsで明らかになります。
単純に行列計算をしているだけなので、直感的ですね。

また @jax.jit は、関数をJITコンパイルするためのデコレータです。付けると速くなるそうです。

訓練用関数の定義

########################################
# 訓練
########################################

# クロスエントロピー誤差
@jax.jit
def cross_entropy_loss(params, X, y):
    logits = ANN(params, X)
    return jnp.mean(-jnp.sum(y * jnp.log(logits), axis=1))

# パラメータの更新
@jax.jit
def update_params(params, grad, lr = 0.01):
    params["linear1"]["W"] -= lr * grad["linear1"]["W"]
    params["linear2"]["W"] -= lr * grad["linear2"]["W"]
    params["linear1"]["b"] -= lr * grad["linear1"]["b"]
    params["linear2"]["b"] -= lr * grad["linear2"]["b"]
    return params

# 与えられたXとyで勾配を計算&更新
@jax.jit
def train(params, X, y):
    grad = jax.grad(cross_entropy_loss)(params, X, y)
    return update_params(params, grad)

# バッチ毎に訓練
@jax.jit
def train_for_each_batch(batch_idx, params):
    target_train_indices = jax.lax.dynamic_slice(index, [batch_idx*batch_size], [batch_size])
    params = train(params, X_train[target_train_indices], y_train[target_train_indices])
    return params

まず、今回は3値分類になるので、クロスエントロピー誤差を使用します。cross_entropy_lossは、計算されたパラメータを使用して、損失を計算します。つまり、これを使ってSGDしていくわけです。

次に、update_paramsは勾配を基にパラメータを更新します。Pytorchや今回使わないFlaxでは、便利なクラスで簡単に更新してくれますが、今回はJAXのみでいきたいので書きます。

trainは、与えられた入力 $X$ とその真値 $y$ を使って、勾配を計算し、更新済みのパラメータを返します。
そして、train_for_each_batch は、入力の一部を使ってtrainに投げます。つまりはミニバッチ処理です。

データセットの読み込み

########################################
# データセットの読み込み
########################################

iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], test_size=0.25,  random_state=0)
X_train, X_test, y_train, y_test = jax.device_put(X_train), jax.device_put(X_test), jax.device_put(y_train), jax.device_put(y_test)
y_train = jnp.eye(3)[y_train] # ワン・ホットに変換
y_test = jnp.eye(3)[y_test] # ワン・ホットに変換

sklearnを用いて、irisデータセットを用意します。今回は、75%を訓練、残りの25%をテスト用としました。また、JAXでの計算にGPUを使う場合は、ここでdevice_putして、GPU上に乗せておきます。

実行!

########################################
# 学習してみる
########################################

# パラメータの初期化
rng = jax.random.PRNGKey(0)
rng1, rng2 = jax.random.split(rng)
rng1w, rng1b = jax.random.split(rng1)
rng2w, rng2b = jax.random.split(rng2)

params = {
    "linear1": {
        "W": glorot_normal()(rng1w, (4, 100)),
        "b": normal()(rng1b, (100,))
    },
    "linear2": {
        "W": glorot_normal()(rng2w, (100, 3)),
        "b": normal()(rng2b, (3,))
    }
}

まずはパラメータを初期化します。Wが重み、bがバイアスです。
入力次元 4 -> 100 -> 3 と変換される形です。

# バッチサイズ
batch_size = 50
# エポック数
epoch_nums = 100

# 訓練を回す
for epoch_id in range(epoch_nums):

    key = jax.random.PRNGKey(epoch_id+1)
    
    # 訓練データのインデックスをシャッフル
    index = jax.random.permutation(key, X_train.shape[0])
    # バッチ数
    batch_length = jnp.ceil(X_train.shape[0] / batch_size)

    # バッチ毎にパラメータを更新していく
    params = jax.lax.fori_loop(0, int(batch_length), train_for_each_batch, params)

    # 誤差の確認
    print(
        "訓練誤差:",
        '{:.3f}'.format(cross_entropy_loss(params, X_train, y_train)),
        "汎化誤差:",
        '{:.3f}'.format(cross_entropy_loss(params, X_test, y_test)),
        f"【 Epoch: {epoch_id} / {epoch_nums}"
    )

バッチサイズ、エポック数を指定して、学習を実行します。
1エポックの始めに、インデックス番号をシャッフルしておいて、関数train_for_each_batchの中で、使用する入力データを決めさせます。
fori_loopは、JAXにおいて高速にfor文を回すことができる魔法の関数(?)です。これは、

for batch_id in range(0, int(batch_length)):
    params = train_for_each_batch(batch_id, params)

と同じです。

おわりに

以上、JAXを使ってシンプルなNNを実装してみました。
初学者でいろいろ間違っているかもしれませんが、気づいた点があればご指摘いただければと思います。
TensorFlowやPyTorchのように日本語の解説記事が十分になく、この記事に辿り着いた方は(おそらく)苦労されているかと思います...
情報共有のためにも、わかったことがあればどんどん書いていきたいです。

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2