この記事は自身の学習用としての備忘録です。
内容の正確性を保証するものではありません。
修正点がありましたら、優しくご指摘いただけると幸いです。
JAXでは非同期処理で計算されるため、計算の最後にPython側でblock_until_ready()を追加しないのはアンチパターンです。すべてのサンプルコードで書き忘れていました。すみません。。。。
はじめに
JAXとは?
JAXは、Googleが開発した高性能な数値計算ライブラリです。NumPyライクなAPIを提供しながら、以下の強力な機能を備えています。
- 自動微分
- JITコンパイル
- 自動ベクトル化
- GPU/TPU対応
今回はJITコンパイルの挙動が?????ってなったので深掘りしていきます。
実行環境
本記事のコードは、以下の環境で動作確認を行っています。
| パッケージ | バージョン |
|---|---|
| Python | 3.11 |
| jax | 0.8.1 |
基本
JAXのJIT(Just-In-Time)コンパイル機能を使うと以下のような流れで、関数を高速なXLA(Accelerated Linear Algebra)コードに変換して実行できます。
-
トレーシング(今回の目玉)
関数が呼ばれると、JAXは引数を「Tracer(プレースホルダー)」に置き換えて関数を実行します。このとき、JAXの演算(jnp.addなど)のみが記録され、Pythonの副作用(printやリスト操作)は即座に実行されて消え去ります -
JAXPRの生成
トレーシングの結果、計算手順が「JAXPR」という純粋な計算グラフとして表現されます。これは静的なグラフであり、Pythonの制御フローや副作用は含まれません -
XLAコンパイル
JAXPRがXLAコンパイラに渡され、ハードウェア(GPU/TPU)に最適化されたマシンコードが生成されます。これにより、高速な実行が可能になります
これにより関数を呼び出すとき、初回はJITコンパイルのオーバーヘッドにより時間がかかりますが2回目以降は高速になります。
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
return x * 2
# 初回呼び出し: コンパイル → 実行
result1 = f(jnp.array(5))
# 2回目以降: キャッシュされたコンパイル済みコードを使用(高速!)
result2 = f(jnp.array(10))
制約
JAXのJITコンパイルでは計算グラフ全体を静的に解析し、極限まで最適化できるように設計されています。それにより、『副作用を持たない(純粋関数)』『動的な制御フローが動かない(Pythonのif文やfor文など)』『配列サイズを変更できない(静的形状)』といった厳格な制約を課します。
global_state = []
@jax.jit
def bad_func(x):
# 1. 副作用: 外部の状態を変更(実行時に反映されない、または予期しない挙動)
global_state.append(x)
print("Tracing...") # printも副作用
# 2. 動的な制御フロー: 値に基づくPythonのif文(エラーが発生)
if x > 0:
return x
# 3. 動的な配列サイズ: 入力の値に基づいて配列サイズを決める(エラー)
return jnp.arange(x) # JIT内では配列サイズは静的でなければならない
そのため、通常のPythonとは異なり以下のようなことを行う必要があります。
- 静的引数を使って、コンパイル時に制御フローを確定させる(これを深掘りする)
- JAXが提供する制御フロー関数(
cond,where,scan,fori_loopなど)を使う - 最大サイズでパディングして
mask_indicesなどを用いて計算する
静的引数
JAXのJITコンパイルでは、Pythonの制御フロー(if文、forループ)を使う際に問題が発生します。そのため、以下のようなコードは動きません。
@jax.jit
def f(flag, loop_count, x):
if flag: # if文は@jax.jitでは使えない
return x * 2
else:
for i in range(loop_count): # for文は@jax.jitでは使えない
x = x + 1
return x
ただし、static_argnumsやstatic_argnamesにより静的引数を指定することで、この問題を回避することができます。静的引数を指定することでコンパイル時にその値を確定するため、トレーシング時にその制御フローを確定させることができるためです。
静的引数として渡せる値には以下のような制約があります
- ハッシュ可能な値のみ(整数、文字列、タプル、boolなど)
- JAX配列は不可
- リストや辞書はハッシュ化できないため不可
このソースコードは静的引数の挙動を理解するためのものです。@jax.jit内でPythonのif文やfor文を使うのはアンチパターンなのでお勧めしません。whereやvmapなどを用いることを推奨します。
import jax
import jax.numpy as jnp
from functools import partial
# これでもいける @partial(jax.jit, static_argnames=("flag", "loop_count"),)
@partial(jax.jit, static_argnums=(0, 1,),)
def f(flag, loop_count, x):
# flagはコンパイル時に確定しているためエラーにならない
if flag:
return x * loop_count
else:
# loop_countはコンパイル時に確定しているためエラーにならない
for i in range(loop_count):
x = x + 1
return x
print(f(True, 10, jnp.array(5)))
print((f).trace(True, 10, jnp.array(5)).jaxpr)
print(f(False, 10, jnp.array(5)))
print((f).trace(False, 10, jnp.array(5)).jaxpr)
50
{ lambda ; a:i32[]. let b:i32[] = mul a 10:i32[] in (b,) }
15
{ lambda ; a:i32[]. let
b:i32[] = add a 1:i32[]
c:i32[] = add b 1:i32[]
d:i32[] = add c 1:i32[]
e:i32[] = add d 1:i32[]
f:i32[] = add e 1:i32[]
g:i32[] = add f 1:i32[]
h:i32[] = add g 1:i32[]
i:i32[] = add h 1:i32[]
j:i32[] = add i 1:i32[]
k:i32[] = add j 1:i32[]
in (k,) }
実行結果より、JAXPRを見るとif文は省略され、for文は全て展開されていることがわかります。また、flag=Trueである時、loop_countを掛けるのではなく、10を掛けるように指示しています。そのため、コンパイル時(特にトレーシング時)に静的引数を用いてPythonの制御フローが実行されていることがわかります。
注意点
静的引数の値を変更すると、その値に基づいて再コンパイルが発生してしまいます。静的引数の値を変更する場合以外にも以下のような場合で再コンパイルが発生します。
- 代入するJAXの配列のサイズが変わる
- 代入するJAXの配列のdtypeが変わる
- デバイス(CPU/GPU/TPU)が変わる
そのため、@jax.jitをすれば高速になるというわけではなく、使い方によってはコンパイルによるオーバヘッドにより実行速度を低下させてしまいます。また、コンパイルによって生成されたバイナリはメモリに確保されるため、その分だけメモリ使用量も増大することになります。
import jax
import jax.numpy as jnp
from functools import partial
@partial(jax.jit, static_argnums=(0,),)
def f(s, x):
return x * s
# 良い例 静的引数が変化していないため一度コンパイルすればキャッシュ化されて高速化される。
for x in range(100):
print(f(1, jnp.array(x)))
# コンパイルのオーバーヘッドにより遅くなる例↓
# 悪い例 静的引数が変化しているため100回コンパイルする
for x in range(100):
print(f(x, jnp.array(5)))
# 悪い例 JAXのarrayのサイズが毎回変化しているため100回コンパイルする
for i in range(100):
print(f(1, jnp.arange(i)))
# ↑それぞれ100回分のコンパイル結果がメモリに保持されるためメモリがいっぱいになる
副作用の挙動
先ほどJAXでは副作用禁止と書きましたがもし副作用が絡むとどうなるかを以下のソースコードから見てみます。
import jax
import jax.numpy as jnp
from functools import partial
stack = [jnp.array(1), jnp.array(2), jnp.array(3)]
@partial(jax.jit, static_argnums=(0,),)
def p(b, a):
stack.pop(b) # コンパイル時に実行
stack.append(a) # コンパイル時に実行
print((p).trace(1, jnp.array(4)).jaxpr)
p(1, jnp.array(4)) # 初回はコンパイルが走る
p(1, jnp.array(4)) # 意味ない
p(1, jnp.array(4)) # 意味ない
print(stack)
{ lambda ; a:i32[]. let in () }
[Array(1, dtype=int32, weak_type=True), Array(3, dtype=int32, weak_type=True), JitTracer<~int32[]>]
通常のPythonベースで考えるなら[Array(1, 略), Array(4, 略), Array(4, 略)]となりそうですが、popやappendは1っ回しか実行されておらず、Array(4, 略)の代わりにJitTracerなるものが出てきます。なぜこうなるかを考えると、最初にp関数が実行されるとコンパイルが走り、そのトレーシング時にPythonのappendやpopが実行されて消去されるためです。このとき引数bは静的引数であるため、正しくstack.pop(1)が走りますが、引数aはトレーシング時で具体的な値が確定していないため、JitTracerと呼ばれるプレースホルダが用いられます。それを無理やりappendしているためこのような実行結果になります。JAXPRはどうなっているかというと、引数aを受け取っているだけになっております。これはPythonのappendやpopは副作用であるためトレーシング時に消去されるためです。それにより、p関数を3回実行しても何も起きません。
classが絡む場合
データをひとまとめにして扱い場合、classはtree_flattenやtree_unflattenなどの記述がめんどくさいため、特別な理由がない限りNamedTupleやdataclassを使うことを推奨します。NamedTupleは特別な設定なしで使うことができ、dataclassはregister_dataclassで簡単に設定できます。
JAXでclassを使いたくなった時を想定して以下のようなソースコードを考えます。
import jax
import jax.numpy as jnp
class Counter:
def __init__(self, n):
self.n = n
def count(self):
self.n = self.n + 1
return self.n
counter = Counter(jnp.array(0))
fast_count = jax.jit(counter.count)
for _ in range(3):
print((fast_count).trace().jaxpr)
print(fast_count())
{ lambda a:i32[]; . let b:i32[] = add a 1:i32[] in (b,) }
1
{ lambda a:i32[]; . let b:i32[] = add a 1:i32[] in (b,) }
1
{ lambda a:i32[]; . let b:i32[] = add a 1:i32[] in (b,) }
1
通常であれば1、2、3と出力してほしいですが常に1が返ってきます。JAXPRを見ると何度も1を足しているように見えますが、1を足した結果をself.nに書き戻すという記述がありません。そのため、何度実行してもself.nの結果はリセットされてしまい、1が出力されるという事態になっています。これはコンパイル済みのコード(XLA)は、Python側のメモリ(オブジェクトの状態)に干渉できないためです。
どうすれば更新できるのか
これを解決するためには、メソッド内でself.nを書き換えるのではなく、更新された値を持つ新しいインスタンスを返すように設計を変更します。しかし、単にクラスを返すだけではエラーを吐きます。なぜなら、JAXはクラスをどのように扱えばよいかを知らないためです。そこで、register_pytree_node_classを使用して、JAXにクラスの構造(どの値をトレースして、どの値をトレースしないのか)を教えます。
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class Counter:
def __init__(self, n):
self.n = n
@jax.jit
def count(self):
# self.n を書き換えるのではなく、新しい値を持つ新しいインスタンスを返す
new_n = self.n + 1
return Counter(new_n), new_n
def tree_flatten(self):
# JAXがトレースする要素(children)と、しない要素(aux_data)に分ける
children = (self.n,)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
# 分解された要素からクラスを復元する
return cls(*children)
counter = Counter(jnp.array(0))
for i in range(3):
# 戻り値で新しいインスタンスを受け取る
# 返ってきた新しいインスタンスで変数を更新する
counter, val = counter.count()
print(val)
1
2
3
register_pytree_node_classにより、自作クラスがJAXのPyTree(JITやgradが扱えるデータ構造)として認識されます。クラス内ではtree_flattenメソッドにより、クラスの属性をJAXが計算追跡すべきデータ(children)とその他の静的データ(aux_data)に分離します。これにより、@jax.jit内でもreturn Counter(new_n)のように新しいインスタンスを返すことができます。呼び出し側ではcounter, val = counter.count()のように、返ってきた新しいインスタンスで変数を更新することで、状態が変化しているように扱えます。
tree_flatten、tree_unflattenの深掘り
tree_flatten、tree_unflattenがよくわからないためprintやjax.debug.printなどでその挙動を見ていきます。
出力結果の見方
-
printはPythonインタプリタとコンパイル時に実行される -
jax.debug.printはPythonインタプリタとXLAで実行される
↓そのため -
printのみが出力されればコンパイル時に実行している -
jax.debug.printのみが出力されればXLAで実行している - 両方出力されればPythonインタプリタで実行している
↑と判別できる
tree_flattenやtree_unflattenにprintなどを書いておりますがこれはアンチパターンです。これらはコストゼロであることを前提として呼び出されるため必要最低限の処理しか書いてはいけない。
import jax
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
@register_pytree_node_class
class Counter:
def __init__(self, n, m):
self.n = n
self.m = m
@jax.jit
def count(self):
print("count")
jax.debug.print("jit-count")
return Counter(self.n + self.m, self.m), self.n + self.m
def tree_flatten(self):
print("flatten")
jax.debug.print("jit-flatten")
children = (self.n,)
aux_data = (self.m,)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
print("unflatten")
jax.debug.print("jit-unflatten")
return cls(*children, *aux_data)
counter = Counter(jnp.array(0), 1)
# 1回目
counter, col = counter.count()
print("Output:", col, counter.n)
# 2回目
counter, col = counter.count()
print("Output:", col, counter.n)
flatten
jit-flatten
flatten
jit-flatten
flatten
jit-flatten
flatten
jit-flatten
unflatten
count
flatten
flatten
jit-unflatten
jit-count
jit-flatten
jit-flatten
unflatten
jit-unflatten
Output: 1 1
flatten
jit-flatten
flatten
jit-flatten
jit-unflatten
jit-count
jit-flatten
jit-flatten
unflatten
jit-unflatten
Output: 2 2
1回目
① 引数の解体
flatten
jit-flatten
flatten
jit-flatten
flatten
jit-flatten
flatten
jit-flatten
まず、jax.jitが引数self(counter)を受け取り、属性の「形」や「中身」を確認しています。tree_flattenが4回も呼ばれているのは、キャッシュ検索や引数の整合性チェックのために複数回検証が行われる仕様だからであると考えられます。
② トレーシングの実行
unflatten
count
キャッシュがないため、トレーシングが行われます。tree_unflattenにより解体された引数を元に戻し、print("count")のみが実行されていることからもわかるようにトレーシングが行われている(jax.debug.printが実行されていないため)ことがわかります。
③ 戻り値の解体
flatten
flatten
count関数が返すCounterオブジェクトを分解して、計算結果のグラフを回収します。
これもトレーシング時に行われていることがわかります。
④ XLA実行
jit-unflatten
jit-count
jit-flatten
jit-flatten
ここで初めて、コンパイル結果が実行されます。トレース開始時に記録された入力復元(jit-unflatten)と、関数内部で記録された命令(jit-count)が実行されます。また、戻り値のCounterオブジェクトも解体する必要があるため、そのための命令(jit-flatten)も実行されます。
⑤ 結果の復元
unflatten
jit-unflatten
Output: 1 1
XLAの計算が完了し、Pythonの世界に戻ってきているため、解体したCounterオブジェクトを復元します。
2回目
① 引数の解体
flatten
jit-flatten
flatten
jit-flatten
2回目で静的引数などが変更されていないため、キャッシュが見つかります(おそらくキャッシュを用いているためtree_flattenは2回しか呼ばれない)。
② トレーシング省略
キャッシュを用いているため、トレーシングが省略されます。
③ XLA実行
jit-unflatten
jit-function
jit-flatten
jit-flatten
1回目と同じ
④ 結果の復元
unflatten
jit-unflatten
Output: 2 2
1回目と同じ
selfは静的引数にすべきかどうか
import jax
from jax.tree_util import register_pytree_node_class
from functools import partial
import jax.numpy as jnp
@register_pytree_node_class
class Counter:
def __init__(self, n, m):
self.n = n
self.m = m
@partial(jax.jit, static_argnums=(0,))
def count(self):
print("count")
jax.debug.print("jit-count")
return Counter(self.n + self.m, self.m), self.n + self.m
def tree_flatten(self):
print("flatten")
jax.debug.print("jit-flatten")
children = (self.n,)
aux_data = (self.m,)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
print("unflatten")
jax.debug.print("jit-unflatten")
return cls(*children, *aux_data)
counter = Counter(jnp.array(0), 1)
# 1回目
counter, col = counter.count()
print("Output:", col, counter.n)
# 2回目
counter, col = counter.count()
print("Output:", col, counter.n)
count
flatten
flatten
jit-count
jit-flatten
jit-flatten
unflatten
jit-unflatten
Output: 1 1
count
flatten
flatten
jit-count
jit-flatten
jit-flatten
unflatten
jit-unflatten
Output: 2 2
キャッシュがヒットしているように見えないので多分やらないほうがいいですね...
詳細は今度書きます...
inspectを使えばもっと詳細に分析できました。
それも今度書きます...
おわりに
@jax.jitすると具体的にどのような処理が行われるのかを書きました。私もJAXを使っていて大した処理じゃないのにメモリを15GBぐらい使っていてその時に静的引数を変えてみたら数百MBぐらいに収まった経験があるのでしっかり理解したほうがいいです。記事書いてる時に思ったのですがprintを使うことでいつコンパイルが走っているのかがわかるため、デバッグとかで使えそうです(printが常に呼ばれてたら良くないと判定できる)。