3
1

More than 3 years have passed since last update.

jax.vmap, vectorize でサクっと SPMD な並列処理するメモ

Last updated at Posted at 2021-07-17

背景

数値計算を Python(numpy, scipy)でやっている. SPMD(e.g. 配列要素ごとに同じコードを実行)的なコードが実行できれば十分.

e.g.

ret = numpy.zeros((nx, ny))

for x in range(nx):
  for y in range(ny):
    r = numpy.sqrt(x*x + y*y)
    ret[x, y] = r

計算速くするためにマルチスレッド化とかしたいが, Python では並列処理スレッド関係とかでめんどい... multiprocessing の Pool や starmap などあるが, 処理するデータのサイズが大きいとメモリ不足になったりする(内部で一時的なバッファをたくさんつくるようでメモリ消費が多い)

CPU, GPU(CUDA) 両方でうごいてほしい

cupy の vectorize だと python コードでは使える cupy 関数が少なくてつらい(cupy.sum とか未対応) :cry:

jax.vmap, jax.numpy.vectorize

CPU, GPU の切り替え

jax は cpu backend と gpu backend 両方をインストールできます.

cpu backend を強制的に使いたい場合はたとえば python スクリプトで

import jax
# force use CPU
jax.config.update('jax_platform_name', 'cpu')

とします.

注意点

vmap, vectorize で呼び出される関数は jit 処理されるため, if 文などはつかえません.

where 関数で代用するか, jax.lax.cond https://jax.readthedocs.io/en/latest/jax.lax.html#lax-control-flow あたりを使いましょう.

complex128 は対応していないようです(complex64 に変換されてしまう)

kernel 内で sum など使うと, 2D を入力してもなぜか 1D に reduction されてしまいます(1D 方向の総和を取ると解釈されてしまうため?). 必要に応じて入力を 1D にしておく(e.g. flatten で)とよいでしょう.

もしくは, jax.numpy.vectorize を使うとよいでしょう

メモリ不足に陥りやすい

128x128 kernel で画像を convolve(フィルタをかける)みたいな処理だと CPU でメモリ 128 GB とかあってもメモリ不足になりやすいです(これは numpy.vectorize などもそうであるが)

jax では, メモリを考慮して入力のデータサイズに応じてよろしく並列数を設定するなどはできないようで, 計算ドメイン全体を並列で処理しようとしている感じでした.

自前で入力データを分割して少しづつ処理させる... みたいなのを記述しないとだめそうです.

たとえば 2D(e.g. 画像データ)だと, vsplit で分割して 1 行づつ処理するなど

def kernel(...):
  ...

fn = vectorize(kernel)

ret = []
for line in jnp.vsplit(img, img.shape[1]):
  ret.append(fn(line))

out_img = jnp.vstack(ret)

GPU memory

jax では GPU 実行の場合 90% メモリを pre-allocate します.
他のタスクと一緒に動かしたいとかの場合

を参考に環境変数などでメモリ設定します.

pmap

pmap というのもあります. ただこれは batch を GPU(CPU の場合はコア?)で並列で処理する用がメインっぽそうです. 複数 GPU で処理したいときは pmap の形式にできるとよいでしょう.

pytorch.vmap (experimental)

pytorch 1.9 時点で experimental ですが, vmap(vectorized map)があります!
(nightly build のみ)

2.0 あたりで標準対応してほしいですね. 少しいじれば pmap っぽく使えるでしょう.

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1