3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

JAX と PyTorch の計算速度の比較(2025年9月16日更新)

Last updated at Posted at 2024-02-18

挨拶

はじめまして、tmparticleです。

この記事の初投稿時、計算速度に惹かれてjaxを2年間ほど使って来ましたが、コミュニティの小ささがネックに思うことがよくあり、pytorchに戻ろうかとふと思い、雑にではありますが計算速度の比較を行っていました。
後述するように、当時はpytorchのが勾配計算をする場合2倍遅い、非勾配計算は同等、という結果だったので、勾配計算が2倍くらいならpytorchでもいいかと思って、それから半年くらいはpytorchを使っていました。

なんですけど、新規プロジェクトを開始するフェーズになり、改めてjaxもいいところあったよなと思い返し、再度計測してpytorchからjaxに戻る意義があるのか確認することにしました。

測定方法

環境(2025年9月16時): Ubuntu22.04, python3.12, NVIDIA GeForce RTX 4070, torch==2.8.0, jax==0.7.1

環境(初投稿時): Ubuntu22.04, python3.10, NVIDIA GeForce RTX 4070, pytorch2.2.0+cu121, jax-0.4.24

測定方法:下記2コードの通り。行列演算を行う関数で時間測定。jaxはjitなし・あり、pytorchはtorch.compileなし・ありで比較。あと興味があったのでgradient計算もチェック。
是非自分の環境下でコードをコピペして実行してみてください、よかったら自分環境では〇〇の結果になった、などのコメントを頂けたら嬉しいです。

test_pytorch.py
import torch
from torch import func
from timeit import timeit

def run(dtype, use_compile, use_grad):
    size=10
    n = 10000
    m = 1000
    w1 = torch.ones((n,n), dtype=dtype, device='cuda:0')/n
    b1 = torch.ones((1,n), dtype=dtype, device='cuda:0')/n
    w2 = torch.ones((n,n), dtype=dtype, device='cuda:0')/n
    b2 = torch.ones((1,n), dtype=dtype, device='cuda:0')/n
    def _f(x1): 
        o = x1
        for i in range(10):
            o = torch.matmul(o,w1)-b1
            o = o@w2-b2
        o = o.mean()
        return o
    if use_grad: _f = func.grad(_f)
    if use_compile: _f = torch.compile(_f)
    def f(x1):
        out = _f(x1)
        torch.cuda.synchronize()
        return out
    x1 = torch.ones((m,n), dtype=dtype, device='cuda:0')
    for i in range(3):
        y1 = f(x1) #for compile and warming up gpu.
    return timeit(lambda :f(x1), number=size)/size

if __name__=='__main__':
    print('bf16: ', run(torch.bfloat16, False, False))
    print('bf16 (compile): ', run(torch.bfloat16, True, False))
    print('bf16 grad: ', run(torch.bfloat16, False, True))
    print('bf16 grad (compile): ', run(torch.bfloat16, True, True))
test_jax.py
import jax
from jax import numpy as jnp
from timeit import timeit

def run(dtype, use_compile, use_grad):
    size=10
    n = 10000
    m = 1000
    w1 = jnp.ones((n,n), dtype=dtype)/n
    b1 = jnp.ones((1,n), dtype=dtype)/n
    w2 = jnp.ones((n,n), dtype=dtype)/n
    b2 = jnp.ones((1,n), dtype=dtype)/n
    def _f(x1): 
        o = x1
        o = jax.lax.fori_loop(0,10,lambda i,o:(o@w1-b1)@w2-b2, o)
        o = o.mean()
        return o
    if use_grad: _f = jax.grad(_f)
    if use_compile: _f = jax.jit(_f)
    def f(x1):
        out = _f(x1)
        jax.block_until_ready(out)
        return out
    x1 = jnp.ones((m,n), dtype=dtype)
    for i in range(3):
        y1 = f(x1) #for compile and warming up gpu.
    return timeit(lambda :f(x1), number=size)/size

if __name__=='__main__':
    print('bf16: ', run(jnp.bfloat16, False, False))
    print('bf16 (compile): ', run(jnp.bfloat16, True, False))
    print('bf16 grad: ', run(jnp.bfloat16, False, True))
    print('bf16 grad (compile): ', run(jnp.bfloat16, True, True))

(dtype = bfloat16はお好みで。一応float16, float32, 64もチェックしたところ、精度が一つ上がる度に計算時間が約2倍になるという結果が見られました。)

結果 (2025年9月16日計測時)

jax:
bf16: 0.12038824840001325
bf16 (compile): 0.07432408749991737
bf16 grad: 0.18919248960000914
bf16 grad (compile): 0.06634759519993168

pytorch:
bf16: 0.07027425759997641
bf16 (compile): 0.06929735049998272
bf16 grad: 0.1354508929000076
bf16 grad (compile): 0.06666142720005155

補足:結果(初投稿時)

jax:
bf16: 0.10227765199997521
bf16 (compile): 0.07033122960001492
bf16 grad: 0.17630739489995903
bf16 grad (compile): 0.06858054019994597

pytorch:
bf16: 0.06928130240012251
bf16 (compile): 0.06842002920002414
bf16 grad: 0.13468127169999206
bf16 grad (compile): 0.13464321860010386

まとめ

まず2025年9月16日計測時の結果についてまとめると、compileなしだとpytorch優勢。compileありで比較すると、非勾配計算・勾配計算ともにほぼ差がないという結果になりました。なおこの記事の初投稿時(2024年)では、勾配計算はjaxが2倍速い、非勾配計算は微妙にpytorchが速い、という結果でした。

1年経ってる間にpytorchはこっそり勾配計算の方が改善されていたわけですね。
(jaxはちょっと遅くなってる、機能追加の影響?)

他のベンチではどうなるか分からないものの、速度面でもはやjaxの恩恵はないかもしれません。あと当時はtorch.compileしてもなんも変わらなくない?って状態だったのが、ちゃんと勾配計算の時に効果が発揮されるようになりました。

筆者はひとまずpytorchに留まろうかと思います、それではまた。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?