はじめに
jaxと、この記事の説明
- jaxは高速な並列計算が得意なpythonライブラリで、XLAでのコンパイル時に、いい感じに最適化できて高速になるらしい
- 高速化のためにはif文も関数に置き換えないといけない…
- この記事は、この記事の仲間
なぜjax?
私は他人のソースコードが全部jaxで書いてあったので、仕方なく手を出している
jax.lax.condの使い方
jax.lax.condは三項演算子とほぼ同じ使い方をする
(配列などには使えないので注意! 1 )
fori_loopの使い方
import jax
import jax.numpy as jnp
def true_fun(x, y):
return x + y
def false_fun(x, y):
return x - y
x = jnp.array(3.0) # 入力1
y = jnp.array(2.0) # 入力2
result = jax.lax.cond(
x > 0, # 判定式
true_fun, # 判定式がtrueのときに実行する関数
false_fun, # 判定式がfalseのときに実行するか関数
*(x, y), # 関数に渡す引数たち(わかりやすさのためにタプル形式を展開して与えているが、そのままx, yと書いてもOK)
)
print(result) # 5.0
公式サイトに書いてある動作の疑似コードは以下の通り:
def cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands) # 次の三項演算子の記述と同じ # return true_func(*operands) if pred else false_fun(*operands)
疑問と回答
Q. じゃあ、普通の次のようなif文はどう書けばいい?
if x > 0:
x = x*2
A. 次のように解釈しなおして書けばOK
x = x*2 if x > 0 else x
おわりに
- jax.lax.condは配列に使えないって、実際ほぼ出番無いのでは…
- 次は配列におけるif文操作の本題「jax.where」を書きます
-
じゃあ配列を一気に操作したい場合はどうするかというと、「jax.where」を使います。これは別記事で書きます。 ↩