0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

jaxのループの中でprintする方法

Last updated at Posted at 2024-05-03

はじめに

  • jaxではscan, fori_loop, while_loopなど、独自のループ文の書き方がある
  • このscanなどの中でprintをしても、中身が表示されずに困る

問題の詳細

次のプログラムのようにscanの中でprintしても、Traced<...>といったデータしか表示されない。。

プログラム
import jax
import jax.numpy as jnp

def body_fun(val, x):
    print(val) # デバッグのためにprintしたい
    return val + x, val + x  

final_val, ys = jax.lax.scan(body_fun, 0, jnp.arange(5))
実行結果
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/1)>

👆 なんとか表示させたい

解決方法

jax.debug.printを使えば解決!

プログラム
import jax
import jax.numpy as jnp

def body_fun(val, x):
    jax.debug.print('{}', val) # デバッグのためにprintしたい
    return val + x, val + x  

final_val, ys = jax.lax.scan(body_fun, 0, jnp.arange(5))
実行結果
0
0
1
3
6

👆 ちゃんと表示できている!

もっと複雑な状況でprintしたい場合

コールバックという、ホスト側にデータを移してから、データ操作する方法があるらしい。
jax.debug.printもコールバックを使った一つの関数とのこと。
https://docs.jax.dev/en/latest/external-callbacks.html

コールバックを利用するれば、もっと複雑な状況でもprintができる。

おわりに

結構この情報を探し当てるのに、結構な時間がかかった。。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?