LoginSignup
0
0

JAX と PyTorch の計算速度の比較

Last updated at Posted at 2024-02-18

挨拶

はじめまして、tmparticleです。
Qiita記事を書いてみたくなり、ひとまずネタとして題目のものを書きます。

計算速度に惹かれてjaxを2年間ほど使って来ましたが、色々と思うところがあり、pytorchに戻ろうかとふと思い、雑にではありますが計算速度の比較を行ってみました。

測定方法

環境: Ubuntu22.04, python3.10, NVIDIA GeForce RTX 4070, pytorch2.2.0+cu121, jax-0.4.24

測定方法:下記2コードの通り。かなり適当な行列演算を行う関数で時間測定。jaxはjitなし・あり、pytorchはtorch.compileなし・ありで比較。あと興味があったのでgradient計算もチェック。

test_pytorch.py
import torch
from torch import func
from timeit import timeit

def run(dtype, use_compile, use_grad):
    size=10
    n = 10000
    m = 1000
    w1 = torch.ones((n,n), dtype=dtype, device='cuda:0')/n
    b1 = torch.ones((1,n), dtype=dtype, device='cuda:0')/n
    w2 = torch.ones((n,n), dtype=dtype, device='cuda:0')/n
    b2 = torch.ones((1,n), dtype=dtype, device='cuda:0')/n
    def _f(x1): 
        o = x1
        for i in range(10):
            o = torch.matmul(o,w1)-b1
            o = o@w2-b2
        o = o.mean()
        return o
    if use_grad: _f = func.grad(_f)
    if use_compile: _f = torch.compile(_f)
    def f(x1):
        out = _f(x1)
        torch.cuda.synchronize()
        return out
    x1 = torch.ones((m,n), dtype=dtype, device='cuda:0')
    for i in range(3):
        y1 = f(x1) #for compile and warming up gpu.
    return timeit(lambda :f(x1), number=size)/size

if __name__=='__main__':
    print('bf16: ', run(torch.bfloat16, False, False))
    print('bf16 (compile): ', run(torch.bfloat16, True, False))
    print('bf16 grad: ', run(torch.bfloat16, False, True))
    print('bf16 grad (compile): ', run(torch.bfloat16, True, True))
test_jax.py
import jax
from jax import numpy as jnp
from timeit import timeit

def run(dtype, use_compile, use_grad):
    size=10
    n = 10000
    m = 1000
    w1 = jnp.ones((n,n), dtype=dtype)/n
    b1 = jnp.ones((1,n), dtype=dtype)/n
    w2 = jnp.ones((n,n), dtype=dtype)/n
    b2 = jnp.ones((1,n), dtype=dtype)/n
    def _f(x1): 
        o = x1
        o = jax.lax.fori_loop(0,10,lambda i,o:(o@w1-b1)@w2-b2, o)
        o = o.mean()
        return o
    if use_grad: _f = jax.grad(_f)
    if use_compile: _f = jax.jit(_f)
    def f(x1):
        out = _f(x1)
        jax.block_until_ready(out)
        return out
    x1 = jnp.ones((m,n), dtype=dtype)
    for i in range(3):
        y1 = f(x1) #for compile and warming up gpu.
    return timeit(lambda :f(x1), number=size)/size

if __name__=='__main__':
    print('bf16: ', run(jnp.bfloat16, False, False))
    print('bf16 (compile): ', run(jnp.bfloat16, True, False))
    print('bf16 grad: ', run(jnp.bfloat16, False, True))
    print('bf16 grad (compile): ', run(jnp.bfloat16, True, True))

(dtype = bfloat16はお好みで。一応float16, float32, 64もチェックしたところ、精度が一つ上がる度に計算時間が約2倍になるという結果が見られました。)

結果

jax:
bf16: 0.10227765199997521
bf16 (compile): 0.07033122960001492
bf16 grad: 0.17630739489995903
bf16 grad (compile): 0.06858054019994597

pytorch:
bf16: 0.06928130240012251
bf16 (compile): 0.06842002920002414
bf16 grad: 0.13468127169999206
bf16 grad (compile): 0.13464321860010386

まとめ

ひとまずcompileありで比較すると、勾配計算はjaxが2倍速い、非勾配計算は微妙にpytorchが速い、というところ。個人的には以前よりpytorchとjaxとの速度差がかなり縮まったと感じる。またちょっとpytorch使ってみようかな。
あと一つ思うところは、pytorchはcompileの恩恵がほぼ無い点である。どういうときにtorch.compileの恩恵が大きく出るのか気になるところ。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0