LoginSignup
0
0

jaxのループの中でprintする方法

Last updated at Posted at 2024-05-03

はじめに

  • jaxではscan, fori_loop, while_loopなど、独自のループ文の書き方がある
  • このscanなどの中でprintをしても、中身が表示されずに困る

問題の詳細

次のプログラムのようにscanの中でprintしても、Traced<...>といったデータしか表示されない。。

プログラム
import jax
import jax.numpy as jnp

def body_fun(val, x):
    print(val) # デバッグのためにprintしたい
    return val + x, val + x  

final_val, ys = jax.lax.scan(body_fun, 0, jnp.arange(5))
実行結果
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/1)>

👆 なんとか表示させたい

解決方法

jax.debug.printを使えば解決!

プログラム
import jax
import jax.numpy as jnp

def body_fun(val, x):
    jax.debug.print('{}', val) # デバッグのためにprintしたい
    return val + x, val + x  

final_val, ys = jax.lax.scan(body_fun, 0, jnp.arange(5))
実行結果
0
0
1
3
6

👆 ちゃんと表示できている!

おわりに

結構この情報を探し当てるのに、結構な時間がかかった。。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0