挨拶
はじめまして、tmparticleです。
Qiita記事を書いてみたくなり、ひとまずネタとして題目のものを書きます。
計算速度に惹かれてjaxを2年間ほど使って来ましたが、色々と思うところがあり、pytorchに戻ろうかとふと思い、雑にではありますが計算速度の比較を行ってみました。
測定方法
環境: Ubuntu22.04, python3.10, NVIDIA GeForce RTX 4070, pytorch2.2.0+cu121, jax-0.4.24
測定方法:下記2コードの通り。かなり適当な行列演算を行う関数で時間測定。jaxはjitなし・あり、pytorchはtorch.compileなし・ありで比較。あと興味があったのでgradient計算もチェック。
import torch
from torch import func
from timeit import timeit
def run(dtype, use_compile, use_grad):
size=10
n = 10000
m = 1000
w1 = torch.ones((n,n), dtype=dtype, device='cuda:0')/n
b1 = torch.ones((1,n), dtype=dtype, device='cuda:0')/n
w2 = torch.ones((n,n), dtype=dtype, device='cuda:0')/n
b2 = torch.ones((1,n), dtype=dtype, device='cuda:0')/n
def _f(x1):
o = x1
for i in range(10):
o = torch.matmul(o,w1)-b1
o = o@w2-b2
o = o.mean()
return o
if use_grad: _f = func.grad(_f)
if use_compile: _f = torch.compile(_f)
def f(x1):
out = _f(x1)
torch.cuda.synchronize()
return out
x1 = torch.ones((m,n), dtype=dtype, device='cuda:0')
for i in range(3):
y1 = f(x1) #for compile and warming up gpu.
return timeit(lambda :f(x1), number=size)/size
if __name__=='__main__':
print('bf16: ', run(torch.bfloat16, False, False))
print('bf16 (compile): ', run(torch.bfloat16, True, False))
print('bf16 grad: ', run(torch.bfloat16, False, True))
print('bf16 grad (compile): ', run(torch.bfloat16, True, True))
import jax
from jax import numpy as jnp
from timeit import timeit
def run(dtype, use_compile, use_grad):
size=10
n = 10000
m = 1000
w1 = jnp.ones((n,n), dtype=dtype)/n
b1 = jnp.ones((1,n), dtype=dtype)/n
w2 = jnp.ones((n,n), dtype=dtype)/n
b2 = jnp.ones((1,n), dtype=dtype)/n
def _f(x1):
o = x1
o = jax.lax.fori_loop(0,10,lambda i,o:(o@w1-b1)@w2-b2, o)
o = o.mean()
return o
if use_grad: _f = jax.grad(_f)
if use_compile: _f = jax.jit(_f)
def f(x1):
out = _f(x1)
jax.block_until_ready(out)
return out
x1 = jnp.ones((m,n), dtype=dtype)
for i in range(3):
y1 = f(x1) #for compile and warming up gpu.
return timeit(lambda :f(x1), number=size)/size
if __name__=='__main__':
print('bf16: ', run(jnp.bfloat16, False, False))
print('bf16 (compile): ', run(jnp.bfloat16, True, False))
print('bf16 grad: ', run(jnp.bfloat16, False, True))
print('bf16 grad (compile): ', run(jnp.bfloat16, True, True))
(dtype = bfloat16はお好みで。一応float16, float32, 64もチェックしたところ、精度が一つ上がる度に計算時間が約2倍になるという結果が見られました。)
結果
jax:
bf16: 0.10227765199997521
bf16 (compile): 0.07033122960001492
bf16 grad: 0.17630739489995903
bf16 grad (compile): 0.06858054019994597
pytorch:
bf16: 0.06928130240012251
bf16 (compile): 0.06842002920002414
bf16 grad: 0.13468127169999206
bf16 grad (compile): 0.13464321860010386
まとめ
ひとまずcompileありで比較すると、勾配計算はjaxが2倍速い、非勾配計算は微妙にpytorchが速い、というところ。個人的には以前よりpytorchとjaxとの速度差がかなり縮まったと感じる。またちょっとpytorch使ってみようかな。
あと一つ思うところは、pytorchはcompileの恩恵がほぼ無い点である。どういうときにtorch.compileの恩恵が大きく出るのか気になるところ。