どうも、しょっちゅうJAXとPyTorchどっちがいいのか迷っているtmparticleです。
コードをfrom scratchで書くときに、いっつもJAXで書こうかPyTorchで書こうか良く迷うのですが、いざ両ライブラリを使って書いてみると、「そういやこの不満以前も感じてたやつだな。。。」みたいなことが良くあるので、自分用メモも兼ねて、毎回してる「ぼやき」をまとめてみました。気休めに見ていただければと思います(これ間違ってる、こうすればよい、ここに書かれていないことのぼやき、等のコメント・ご意見歓迎です!)。
PyTorch と比較した時の JAX の「良い」ところ
-
numpy の API を模倣している
=> torchも似せる努力は多少してるけど、微妙にアレないじゃん、名称同じなのに処理微妙に違うじゃん、ってことが割とよくある。あとtorchのAPIドキュメントはnumpyに比べて分かりにくい。 -
jax.jit は optimizer の step 処理も jit 可能
=> torch.compile は現状だとoptimizer.stepに対するコンパイルはbeta扱いで、現状はデフォルトで用意されてるOptimizer(かつコンパイル対応の物)に対してしか動作しない。なので、自作Optimizerを使いたい場合はコンパイルは基本無理。あとOptimizerのstep処理におけるcpu <-> gpu間のやり取りオーバーヘッドが大きく、訓練回す際のボトルネックになりがち(GPU使用率を下げる最大の要因になりがちで、かつちょっとした工夫でどうこうできる問題でもない)。 -
Optimizerをコンパイル可能な形式で自作しやすい
=> これは上記でもふれたとおりなのと、optaxもあって作りやすい。pytorchも公式ドキュメントで「コンパイル可能な形式での」Optimizerの自作方法について書いてくれればいいのだが。
PyTorch と比較した時の JAX の「悪い」ところ
-
分散学習へと発展させにくい
=> PyTorchはLightning等を始め分散学習しやすいライブラリが豊富だが、jaxにはそういうのはない。jaxでも分散学習は可能だとは思うが、ドキュメントが親切じゃないので結構辛いと思う。 -
推論に限れば速度の差は殆どない
=> 詳細は私の前記事での実験内容参照。 -
コミュニティがやっぱりpytorchに比べて小さい
=> 先行研究は大体pytorchだったり、エコシステムの豊富さもpytorch。既存の研究資産を流用したい場合にやはり不便。 -
コードが大規模化してくると「状態」の管理が面倒になってくる
=> 筆者が関数型プログラミングのコツや書き方の答えを見つけられてないだけかもしれないが、どう頑張ってもコードが大規模化してくると、状態クラスAを内包する状態クラスBを内包する状態クラスCを内包する状態クラスD、みたいな巨大な状態ができあがりがち。Stateful(オブジェクト指向)の方がコードを書く量はすっきりするし大規模化にも向いていると思う。 -
乱数キーの扱いが面倒
=> これは一つ前のぼやきと被るかもだが、乱数キーも「状態」として切り分けて扱うことになるので、コードを書く量が増える。あとpytorch使う時に乱数の再現性で困ったことは現状ないので、不必要な面倒くささに思える。あと特にflax.nnxとかでNN書くときに思うことだが、nnx.scanなどを使う際の内部での乱数キーの扱いがどうなっているのかが非常に分かりずらい(ドキュメントが不親切)。 -
チェックポイント周りの管理が複雑・生焼け感
=> orbaxが複雑に感じる。
まとめ
以上のぼやきから、著者の場合、JAXを使って幸せになれるケースというのはざっくり言うと、(1)勾配ステップの時間が重要で、(2)できればOptimizerは自作したくて、(3)分散学習に移行する見込みは薄くて、(4)既存の研究資産の流用は不要な方で、(5)コードも大規模化する見込みが薄い、場合に限ると言えそうです。あんまりそんなケースないか(笑)。
でもな~numpy APIで書けるの気持ちいいんだよなぁ、、、と優柔不断な私。。。