python高速化のためのjit
をclassで使用したいとき
jitclass
より@staticmethod
を使ったほうがいいのではないかという個人メモ。
jitclass の使用例
jitclass
を使えばクラスでもjit
を使用することができる。
import numpy as np
import matplotlib.pyplot as plt
from numba import jit, f8
from numba.experimental import jitclass
import time
spec = [
('arr', f8[:,:,:]),
('new_arr', f8[:,:,:]),
]
@jitclass(spec)
class class_jit():
def __init__(self,arr):
self.arr = arr
self.new_arr=np.zeros_like(arr, dtype=np.float64)
def func(self):
shape=self.arr.shape
for x in range(shape[0]):
for y in range(shape[1]):
for z in range(shape[2]):
self.new_arr[x,y,z]=1e6*x+1e3*y+z
return self.new_arr
def plot(self):
plt.imshow(self.new_arr[0])
if __name__=='__main__':
n=256
arr=np.zeros([n]*3)
st=time.time()
obj1=class_jit(arr)
obj1.func()
print(f'"class jit " elapsed_time {time.time()-st:2f}')
"class jit " elapsed_time 0.512740
##jitclass で困ったこと
jitclass
ではjit
の対応していないメンバ関数を使用することができない。
例えばクラスのほかのメンバ関数でmatplot
を使おうとすると
obj1.plot()
plt.imshowのところでエラーを返される。
TypingError: - Resolution failure for literal arguments:
Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'imshow' of type Module
@staticmethodを使って解決
jit
を使用したいメンバ関数を@staticmethod
で静的関数として、@jit
を使用したら上手くいった。インスタンス変数を参照したい場合は、ラッパーにすればよい。
class class_jit2():
def __init__(self):
pass
def func(self, arr): #wrapper
self.arr=arr
self.new_arr=self.__func(arr)
@staticmethod
@jit(f8[:,:,:](f8[:,:,:])) #型指定
def __func(arr):
new_arr=np.zeros_like(arr, dtype=np.float64)
shape=arr.shape
for x in range(shape[0]):
for y in range(shape[1]):
for z in range(shape[2]):
new_arr[x,y,z]=1e6*x+1e3*y+z
return new_arr
def plot(self):
plt.imshow(self.new_arr[0])
if __name__=='__main__':
n=256
arr=np.zeros([n]*3)
st=time.time()
obj2=class_jit2()
obj2.func(arr)
print(f'"class jit staticmethod " elapsed_time {time.time()-st:2f}')
obj2.plot()
これなら、jit
もmatplot
も同一のクラスで使用できる。
"class jit staticmethod " elapsed_time 0.044842
申し訳程度の速度比較
import numpy as np
import matplotlib.pyplot as plt
from numba import jit, f8
from numba.experimental import jitclass
import time
def func(arr): #jitなし
new_arr=np.zeros_like(arr)
shape=arr.shape
for x in range(shape[0]):
for y in range(shape[1]):
for z in range(shape[2]):
new_arr[x,y,z]=1e6*x+1e3*y+z
return new_arr
@jit(nopython=True) #jit 型指定なし
def func_jit(arr):
new_arr=np.zeros_like(arr)
shape=arr.shape
for x in range(shape[0]):
for y in range(shape[1]):
for z in range(shape[2]):
new_arr[x,y,z]=1e6*x+1e3*y+z
return new_arr
@jit(f8[:,:,:](f8[:,:,:])) #jit 型指定
def func_jit2(arr):
new_arr=np.zeros_like(arr)
shape=arr.shape
for x in range(shape[0]):
for y in range(shape[1]):
for z in range(shape[2]):
new_arr[x,y,z]=1e6*x+1e3*y+z
return new_arr
spec = [
('arr', f8[:,:,:]), # a simple scalar field
('new_arr', f8[:,:,:]), # an array field
]
@jitclass(spec)
class class_jit():
def __init__(self,arr):
self.arr = arr
self.new_arr=np.zeros_like(arr, dtype=np.float64)
def func(self):
shape=self.arr.shape
for x in range(shape[0]):
for y in range(shape[1]):
for z in range(shape[2]):
self.new_arr[x,y,z]=1e6*x+1e3*y+z
return self.new_arr
class class_jit2():
def __init__(self):
pass
def func(self, arr):
self.arr=arr
self.new_arr=self.__func(arr)
@staticmethod
@jit(f8[:,:,:](f8[:,:,:]))
def __func(arr):
new_arr=np.zeros_like(arr, dtype=np.float64)
shape=arr.shape
for x in range(shape[0]):
for y in range(shape[1]):
for z in range(shape[2]):
new_arr[x,y,z]=1e6*x+1e3*y+z
return new_arr
if __name__=='__main__':
n=256
arr=np.zeros([n]*3)
st=time.time()
func(arr)
print(f'"w/o_jit" elapsed_time {time.time()-st:2f}')
st=time.time()
func_jit(arr)
print(f'"jit + w/o type sepc elapsed_time {time.time()-st:2f}')
st=time.time()
func_jit2(arr)
print(f'"jit + type spec" elapsed_time {time.time()-st:2f}')
st=time.time()
obj1=class_jit(arr)
obj1.func()
print(f'"class jit " elapsed_time {time.time()-st:2f}')
st=time.time()
obj2=class_jit2()
obj2.func(arr)
print(f'"class jit staticmethod " elapsed_time {time.time()-st:2f}')
"w/o_jit" elapsed_time 3.964147
"jit + w/o type sepc elapsed_time 0.163591
"jit + type spec" elapsed_time 0.047845
"class jit " elapsed_time 0.431932
"class jit staticmethod " elapsed_time 0.043882
@jit
つけるだけで24倍
型指定すると83倍の高速化
(jitclass
の速度がいまいちなのは謎。jitclass
の型指定がうまくいってない...?)