はじめに
jaxと、この記事の説明
- jaxは高速な並列計算が得意なpythonライブラリ
- jaxはpythonの通常のfor文やwhile文を使わず、
fori_loop, while_loop, scanを使うことが推奨されている- XLAでのコンパイル時に、いい感じに最適化できて高速になるらしい
- 中間状態の効率化が行われ、メモリ効率も良くなるらしい
- ここではこれらの使い方を紹介します(自分用メモ)
なぜjax?
私は他人のソースコードが全部jaxで書いてあったので、仕方なく手を出している
fori_loopの使い方
基本的な使い方は次の通りである。for文を単に関数に押し込めたような形だ。
fori_loopの使い方
import jax
def body_fun(i, val):
return val + i # i番目のループで構成した値を返却 → 次のループで使用されるvalとなる
# 最終ループで構成したvalがresultに格納される
result = jax.lax.fori_loop(
0, # i=0から始める
10, # i<10まで実行する
body_fun, # for文の中で実行する関数
0, # 実行する関数に渡すvalの初期値
)
print(result) # 0 + 1 + 2 + ... + 9 = 45
公式サイトに書いてある動作の疑似コードは以下の通り:
def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
while_loopの使い方
基本的な使い方は次の通りである。while文を単に関数に押し込めたような形だ。
while_loopの使い方
import jax
def condition_fun(val):
return val < 10 # whileの繰り返し条件の評価値を返却する
def body_fun(val):
return val + 1 # このループで構成した値を返却 → 次のループで使用されるvalとなる
# 最終ループで構成したvalがresultに格納される
result = jax.lax.while_loop(
condition_fun, # 繰り返し条件をvalから判定する関数
body_fun, # while文の中で実行する関数
0, # 実行する関数に渡すvalの初期値
)
print(result) # 1 + 1 + 1 ... + 1 = 10 (10未満のとき1ずつ加算した結果)
公式サイトに書いてある動作の疑似コードは以下の通り:
def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
scanの使い方
リストから1つずつ取り出して処理して、結果のリストにappendするような操作はよくすると思う。つまり次のような処理だ:
ys = []
for x_i in xs:
y_i = func(x_i)
ys.append(y_i)
# ys = [func(x_i) for x_i in xs] というリスト内包記法と同じ処理
この操作をまとめたのがscanだ。
基本的な使い方は次の通りで、body_funの構成がfori_loopやwhile_loopと比べて少し複雑になる。
scanの使い方
import jax
import jax.numpy as jnp
def body_fun(val, x):
return val + x, val + x # (次のvalの値, appendする出力)
# 最終ループで構成したvalと、構成したリストysが返却される
final_val, ys = jax.lax.scan(
body_fun, # ループの中で実行する関数
0, # 実行する関数に渡すvalの初期値
jnp.arange(5), # [0, 1, 2, 3, 4]; 1つずつ処理するデータx
)
print(final_val) # 10 (0 + 1 + 2 + 3 + 4)
print(ys) # [0, 1, 3, 6, 10]; 累積値
公式サイトに書いてある動作の疑似コードは以下の通り:
def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
おわりに
- 最初慣れるのは大変だが、そんなに難しくもない
- if文も使うべきではなく、jax.lax.condやjnp.whereを使えと言われている(そのうち記事書くかも)