0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

jax.vmapの使い方

Last updated at Posted at 2024-01-18

はじめに

  • jaxという並列実行をするためのpythonライブラリがあるようである
  • 特に一番基本のvmapの使い方を簡単にまとめておく(ほぼ自分用)

vmapのメリット

  • 基本的にはnp.vectorizeみたいに使える。
  • CPU, GPU, TPUで並列実行できるようにコンパイルされる。

基本的な使い方

使い方はほぼnp.vectorizeと同様。

import jax
import jax.numpy as jnp

# 並列実行したい関数
def multiply(x, y):
    return x * y

# jax.vmapを使用してmultiply関数をベクトル化
# in_axesは各入力に対するバッチ次元を定義する
vectorized_multiply = jax.vmap(multiply, in_axes=(0, 0))

x = jnp.array([1, 2, 3, 4])
y = jnp.array([10, 20, 30, 40])

result = vectorized_multiply(x, y) # 並列実行

print(result)
# 出力結果は[10,40,90,160]

@jax.vmapを使った書き方

並列化する軸はデフォルトで0が指定される。
軸を指定する必要がない場合は@jax.vmapでデコレートするだけでOK。
関数「vectorized_multiply」を別に作る必要がないのでスマート。

import jax
import jax.numpy as jnp

# 並列実行したい関数
@jax.vmap
def multiply(x, y):
    return x * y

# 入力データ
x = jnp.array([1, 2, 3, 4])
y = jnp.array([10, 20, 30, 40])

# ベクトル化された関数を適用
result = multiply(x, y)  # 並列実行

print(result)
# 出力結果は[10, 40, 90, 160]

複雑な構造に関する並列化

次のような複雑な構造に関する並列実行もできる。これすごい。

import jax
import jax.numpy as jnp
import chex

# 並列実行したい関数
def multiply(x, y):
    return x * y.c.u

# jax.vmapを使用してmultiply関数をベクトル化
# in_axesは各入力に対するバッチ次元を定義する
vectorized_multiply = jax.vmap(multiply, in_axes=(0, 0))


@chex.dataclass(frozen=True)
class DataClass1:
  u: chex.Array  
  v: chex.Array  

@chex.dataclass(frozen=True)
class DataClass2:
  a: chex.Array  
  b: chex.Array  
  c: DataClass1  

x = jnp.array([1, 2, 3, 4])
y = DataClass2(
    a=jnp.array([10,20,30,40]), 
    b=jnp.array([10,20,30,40]), 
    c=DataClass1(
        u=jnp.array([1,2,3,4]),
        v=jnp.array([10,20,30,40]),
    ),
)

result = vectorized_multiply(x, y) # 並列実行


print(result)
# 出力結果は[1,4,9,16]

おわりに

  • jax.vmapの使い方は簡単。
  • ただjaxを使いこなすのは結構難しい。。覚えないといけないことは次の通り。
    • 型付け:コンパイルされるので、それなりに型付けしないといけない
    • ループ構造:jax.lax.fori_loop, jax.lax.while_loop, jax.lax.scan
    • プロセス並列化:jax.pmap
    • 並列化の引数固定:functools.partial, @functools.partial
    • その他:通常の関数・バッチ処理の関数が混ざるので、関数が普通の1個の値をとるのか、配列をとるのか、引数なのかをよく見わけないといけない(ソースコードにコメントを積極的に書き残そう)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?