LoginSignup
0
0

jaxの各ループ(fori_loop, while_loop, scan)の使い方

Last updated at Posted at 2024-05-03

はじめに

jaxと、この記事の説明

  • jaxは高速な並列計算が得意なpythonライブラリ
  • jaxはpythonの通常のfor文やwhile文を使わず
    fori_loop, while_loop, scanを使うことが推奨されている
    • XLAでのコンパイル時に、いい感じに最適化できて高速になるらしい
    • 中間状態の効率化が行われ、メモリ効率も良くなるらしい
  • ここではこれらの使い方を紹介します(自分用メモ)

なぜjax?

私は他人のソースコードが全部jaxで書いてあったので、仕方なく手を出している

fori_loopの使い方

基本的な使い方は次の通りである。for文を単に関数に押し込めたような形だ。

fori_loopの使い方
import jax

def body_fun(i, val):
    return val + i  # i番目のループで構成した値を返却 → 次のループで使用されるvalとなる

# 最終ループで構成したvalがresultに格納される
result = jax.lax.fori_loop(
    0,        # i=0から始める
    10,       # i<10まで実行する
    body_fun, # for文の中で実行する関数
    0,        # 実行する関数に渡すvalの初期値
) 

print(result)  # 0 + 1 + 2 + ... + 9 = 45

公式サイトに書いてある動作の疑似コードは以下の通り:

def fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val

while_loopの使い方

基本的な使い方は次の通りである。while文を単に関数に押し込めたような形だ。

while_loopの使い方
import jax

def condition_fun(val):
    return val < 10 # whileの繰り返し条件の評価値を返却する

def body_fun(val):
    return val + 1 # このループで構成した値を返却 → 次のループで使用されるvalとなる

# 最終ループで構成したvalがresultに格納される
result = jax.lax.while_loop(
    condition_fun, # 繰り返し条件をvalから判定する関数
    body_fun,      # while文の中で実行する関数
    0,             # 実行する関数に渡すvalの初期値
)

print(result)  # 1 + 1 + 1 ... + 1 = 10 (10未満のとき1ずつ加算した結果)

公式サイトに書いてある動作の疑似コードは以下の通り:

def while_loop(cond_fun, body_fun, init_val):
    val = init_val
    while cond_fun(val):
        val = body_fun(val)
    return val

scanの使い方

リストから1つずつ取り出して処理して、結果のリストにappendするような操作はよくすると思う。つまり次のような処理だ:

ys = []
for x_i in xs:
   y_i = func(x_i)
   ys.append(y_i)

# ys = [func(x_i) for x_i in xs] というリスト内包記法と同じ処理

この操作をまとめたのがscanだ。
基本的な使い方は次の通りで、body_funの構成がfori_loopやwhile_loopと比べて少し複雑になる。

scanの使い方
import jax
import jax.numpy as jnp

def body_fun(val, x):
    return val + x, val + x  # (次のvalの値, appendする出力)

# 最終ループで構成したvalと、構成したリストysが返却される
final_val, ys = jax.lax.scan(
    body_fun,      # ループの中で実行する関数
    0,             # 実行する関数に渡すvalの初期値
    jnp.arange(5), # [0, 1, 2, 3, 4]; 1つずつ処理するデータx
)

print(final_val)  # 10 (0 + 1 + 2 + 3 + 4)
print(ys)  # [0, 1, 3, 6, 10]; 累積値

公式サイトに書いてある動作の疑似コードは以下の通り:

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

おわりに

  • 最初慣れるのは大変だが、そんなに難しくもない
  • if文も使うべきではなく、jax.lax.condやjnp.whereを使えと言われている(そのうち記事書くかも)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0