挨拶
はじめまして、tmparticleです。
この記事の初投稿時、計算速度に惹かれてjaxを2年間ほど使って来ましたが、コミュニティの小ささがネックに思うことがよくあり、pytorchに戻ろうかとふと思い、雑にではありますが計算速度の比較を行っていました。
後述するように、当時はpytorchのが勾配計算をする場合2倍遅い、非勾配計算は同等、という結果だったので、勾配計算が2倍くらいならpytorchでもいいかと思って、それから半年くらいはpytorchを使っていました。
なんですけど、新規プロジェクトを開始するフェーズになり、改めてjaxもいいところあったよなと思い返し、再度計測してpytorchからjaxに戻る意義があるのか確認することにしました。
測定方法
環境(2025年9月16時): Ubuntu22.04, python3.12, NVIDIA GeForce RTX 4070, torch==2.8.0, jax==0.7.1
環境(初投稿時): Ubuntu22.04, python3.10, NVIDIA GeForce RTX 4070, pytorch2.2.0+cu121, jax-0.4.24
測定方法:下記2コードの通り。行列演算を行う関数で時間測定。jaxはjitなし・あり、pytorchはtorch.compileなし・ありで比較。あと興味があったのでgradient計算もチェック。
是非自分の環境下でコードをコピペして実行してみてください、よかったら自分環境では〇〇の結果になった、などのコメントを頂けたら嬉しいです。
import torch
from torch import func
from timeit import timeit
def run(dtype, use_compile, use_grad):
size=10
n = 10000
m = 1000
w1 = torch.ones((n,n), dtype=dtype, device='cuda:0')/n
b1 = torch.ones((1,n), dtype=dtype, device='cuda:0')/n
w2 = torch.ones((n,n), dtype=dtype, device='cuda:0')/n
b2 = torch.ones((1,n), dtype=dtype, device='cuda:0')/n
def _f(x1):
o = x1
for i in range(10):
o = torch.matmul(o,w1)-b1
o = o@w2-b2
o = o.mean()
return o
if use_grad: _f = func.grad(_f)
if use_compile: _f = torch.compile(_f)
def f(x1):
out = _f(x1)
torch.cuda.synchronize()
return out
x1 = torch.ones((m,n), dtype=dtype, device='cuda:0')
for i in range(3):
y1 = f(x1) #for compile and warming up gpu.
return timeit(lambda :f(x1), number=size)/size
if __name__=='__main__':
print('bf16: ', run(torch.bfloat16, False, False))
print('bf16 (compile): ', run(torch.bfloat16, True, False))
print('bf16 grad: ', run(torch.bfloat16, False, True))
print('bf16 grad (compile): ', run(torch.bfloat16, True, True))
import jax
from jax import numpy as jnp
from timeit import timeit
def run(dtype, use_compile, use_grad):
size=10
n = 10000
m = 1000
w1 = jnp.ones((n,n), dtype=dtype)/n
b1 = jnp.ones((1,n), dtype=dtype)/n
w2 = jnp.ones((n,n), dtype=dtype)/n
b2 = jnp.ones((1,n), dtype=dtype)/n
def _f(x1):
o = x1
o = jax.lax.fori_loop(0,10,lambda i,o:(o@w1-b1)@w2-b2, o)
o = o.mean()
return o
if use_grad: _f = jax.grad(_f)
if use_compile: _f = jax.jit(_f)
def f(x1):
out = _f(x1)
jax.block_until_ready(out)
return out
x1 = jnp.ones((m,n), dtype=dtype)
for i in range(3):
y1 = f(x1) #for compile and warming up gpu.
return timeit(lambda :f(x1), number=size)/size
if __name__=='__main__':
print('bf16: ', run(jnp.bfloat16, False, False))
print('bf16 (compile): ', run(jnp.bfloat16, True, False))
print('bf16 grad: ', run(jnp.bfloat16, False, True))
print('bf16 grad (compile): ', run(jnp.bfloat16, True, True))
(dtype = bfloat16はお好みで。一応float16, float32, 64もチェックしたところ、精度が一つ上がる度に計算時間が約2倍になるという結果が見られました。)
結果 (2025年9月16日計測時)
jax:
bf16: 0.12038824840001325
bf16 (compile): 0.07432408749991737
bf16 grad: 0.18919248960000914
bf16 grad (compile): 0.06634759519993168
pytorch:
bf16: 0.07027425759997641
bf16 (compile): 0.06929735049998272
bf16 grad: 0.1354508929000076
bf16 grad (compile): 0.06666142720005155
補足:結果(初投稿時)
jax:
bf16: 0.10227765199997521
bf16 (compile): 0.07033122960001492
bf16 grad: 0.17630739489995903
bf16 grad (compile): 0.06858054019994597
pytorch:
bf16: 0.06928130240012251
bf16 (compile): 0.06842002920002414
bf16 grad: 0.13468127169999206
bf16 grad (compile): 0.13464321860010386
まとめ
まず2025年9月16日計測時の結果についてまとめると、compileなしだとpytorch優勢。compileありで比較すると、非勾配計算・勾配計算ともにほぼ差がないという結果になりました。なおこの記事の初投稿時(2024年)では、勾配計算はjaxが2倍速い、非勾配計算は微妙にpytorchが速い、という結果でした。
1年経ってる間にpytorchはこっそり勾配計算の方が改善されていたわけですね。
(jaxはちょっと遅くなってる、機能追加の影響?)
他のベンチではどうなるか分からないものの、速度面でもはやjaxの恩恵はないかもしれません。あと当時はtorch.compileしてもなんも変わらなくない?って状態だったのが、ちゃんと勾配計算の時に効果が発揮されるようになりました。
筆者はひとまずpytorchに留まろうかと思います、それではまた。