0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

jax.lax.condの使い方

Posted at

はじめに

jaxと、この記事の説明

  • jaxは高速な並列計算が得意なpythonライブラリで、XLAでのコンパイル時に、いい感じに最適化できて高速になるらしい
  • 高速化のためにはif文も関数に置き換えないといけない…
  • この記事は、この記事の仲間

なぜjax?

私は他人のソースコードが全部jaxで書いてあったので、仕方なく手を出している

jax.lax.condの使い方

jax.lax.condは三項演算子とほぼ同じ使い方をする
配列などには使えないので注意! 1

fori_loopの使い方
import jax
import jax.numpy as jnp

def true_fun(x, y):
    return x + y

def false_fun(x, y):
    return x - y

x = jnp.array(3.0)  # 入力1
y = jnp.array(2.0)  # 入力2
result = jax.lax.cond(
    x > 0,     # 判定式
    true_fun,  # 判定式がtrueのときに実行する関数
    false_fun, # 判定式がfalseのときに実行するか関数
    *(x, y),   # 関数に渡す引数たち(わかりやすさのためにタプル形式を展開して与えているが、そのままx, yと書いてもOK)
)
print(result)  # 5.0

公式サイトに書いてある動作の疑似コードは以下の通り:

def cond(pred, true_fun, false_fun, *operands):
 if pred:
   return true_fun(*operands)
 else:
   return false_fun(*operands)

# 次の三項演算子の記述と同じ
# return true_func(*operands) if pred else false_fun(*operands) 

疑問と回答

Q. じゃあ、普通の次のようなif文はどう書けばいい?

if x > 0:
    x = x*2

A. 次のように解釈しなおして書けばOK

x = x*2 if x > 0 else x

おわりに

  • jax.lax.condは配列に使えないって、実際ほぼ出番無いのでは…
  • 次は配列におけるif文操作の本題「jax.where」を書きます
  1. じゃあ配列を一気に操作したい場合はどうするかというと、「jax.where」を使います。これは別記事で書きます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?