はじめに
- jaxという並列実行をするためのpythonライブラリがあるようである
- 特に一番基本のvmapの使い方を簡単にまとめておく(ほぼ自分用)
vmapのメリット
- 基本的にはnp.vectorizeみたいに使える。
- CPU, GPU, TPUで並列実行できるようにコンパイルされる。
基本的な使い方
使い方はほぼnp.vectorizeと同様。
import jax
import jax.numpy as jnp
# 並列実行したい関数
def multiply(x, y):
return x * y
# jax.vmapを使用してmultiply関数をベクトル化
# in_axesは各入力に対するバッチ次元を定義する
vectorized_multiply = jax.vmap(multiply, in_axes=(0, 0))
x = jnp.array([1, 2, 3, 4])
y = jnp.array([10, 20, 30, 40])
result = vectorized_multiply(x, y) # 並列実行
print(result)
# 出力結果は[10,40,90,160]
@jax.vmapを使った書き方
並列化する軸はデフォルトで0が指定される。
軸を指定する必要がない場合は@jax.vmapでデコレートするだけでOK。
関数「vectorized_multiply」を別に作る必要がないのでスマート。
import jax
import jax.numpy as jnp
# 並列実行したい関数
@jax.vmap
def multiply(x, y):
return x * y
# 入力データ
x = jnp.array([1, 2, 3, 4])
y = jnp.array([10, 20, 30, 40])
# ベクトル化された関数を適用
result = multiply(x, y) # 並列実行
print(result)
# 出力結果は[10, 40, 90, 160]
複雑な構造に関する並列化
次のような複雑な構造に関する並列実行もできる。これすごい。
import jax
import jax.numpy as jnp
import chex
# 並列実行したい関数
def multiply(x, y):
return x * y.c.u
# jax.vmapを使用してmultiply関数をベクトル化
# in_axesは各入力に対するバッチ次元を定義する
vectorized_multiply = jax.vmap(multiply, in_axes=(0, 0))
@chex.dataclass(frozen=True)
class DataClass1:
u: chex.Array
v: chex.Array
@chex.dataclass(frozen=True)
class DataClass2:
a: chex.Array
b: chex.Array
c: DataClass1
x = jnp.array([1, 2, 3, 4])
y = DataClass2(
a=jnp.array([10,20,30,40]),
b=jnp.array([10,20,30,40]),
c=DataClass1(
u=jnp.array([1,2,3,4]),
v=jnp.array([10,20,30,40]),
),
)
result = vectorized_multiply(x, y) # 並列実行
print(result)
# 出力結果は[1,4,9,16]
おわりに
- jax.vmapの使い方は簡単。
- ただjaxを使いこなすのは結構難しい。。覚えないといけないことは次の通り。
- 型付け:コンパイルされるので、それなりに型付けしないといけない
- ループ構造:jax.lax.fori_loop, jax.lax.while_loop, jax.lax.scan
- プロセス並列化:jax.pmap
- 並列化の引数固定:functools.partial, @functools.partial
- その他:通常の関数・バッチ処理の関数が混ざるので、関数が普通の1個の値をとるのか、配列をとるのか、引数なのかをよく見わけないといけない(ソースコードにコメントを積極的に書き残そう)