はじめに
- jaxではscan, fori_loop, while_loopなど、独自のループ文の書き方がある
- このscanなどの中でprintをしても、中身が表示されずに困る
問題の詳細
次のプログラムのようにscanの中でprintしても、Traced<...>
といったデータしか表示されない。。
プログラム
import jax
import jax.numpy as jnp
def body_fun(val, x):
print(val) # デバッグのためにprintしたい
return val + x, val + x
final_val, ys = jax.lax.scan(body_fun, 0, jnp.arange(5))
実行結果
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/1)>
👆 なんとか表示させたい
解決方法
jax.debug.print
を使えば解決!
プログラム
import jax
import jax.numpy as jnp
def body_fun(val, x):
jax.debug.print('{}', val) # デバッグのためにprintしたい
return val + x, val + x
final_val, ys = jax.lax.scan(body_fun, 0, jnp.arange(5))
実行結果
0
0
1
3
6
👆 ちゃんと表示できている!
もっと複雑な状況でprintしたい場合
コールバックという、ホスト側にデータを移してから、データ操作する方法があるらしい。
jax.debug.printもコールバックを使った一つの関数とのこと。
https://docs.jax.dev/en/latest/external-callbacks.html
コールバックを利用するれば、もっと複雑な状況でもprintができる。
おわりに
結構この情報を探し当てるのに、結構な時間がかかった。。