LoginSignup
6
2

More than 1 year has passed since last update.

class で jit を使用する方法

Last updated at Posted at 2021-09-18

python高速化のためのjitをclassで使用したいとき
jitclassより@staticmethod を使ったほうがいいのではないかという個人メモ。

jitclass の使用例

jitclassを使えばクラスでもjitを使用することができる。

import numpy as np
import matplotlib.pyplot as plt
from numba import jit, f8
from numba.experimental import jitclass
import time
spec = [
    ('arr', f8[:,:,:]),               
    ('new_arr', f8[:,:,:]),        
]

@jitclass(spec)
class class_jit():
    def __init__(self,arr):
        self.arr = arr
        self.new_arr=np.zeros_like(arr, dtype=np.float64)
    def func(self):
        shape=self.arr.shape
        for x in range(shape[0]):
            for y in range(shape[1]):
                for z in range(shape[2]):
                    self.new_arr[x,y,z]=1e6*x+1e3*y+z
        return self.new_arr
    def plot(self):
        plt.imshow(self.new_arr[0])

if __name__=='__main__':
    n=256
    arr=np.zeros([n]*3)
    st=time.time()
    obj1=class_jit(arr)
    obj1.func()
    print(f'"class jit " elapsed_time {time.time()-st:2f}') 
"class jit " elapsed_time 0.512740

jitclass で困ったこと

jitclassではjit の対応していないメンバ関数を使用することができない。
例えばクラスのほかのメンバ関数でmatplotを使おうとすると

obj1.plot()

plt.imshowのところでエラーを返される。

TypingError: - Resolution failure for literal arguments:
Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'imshow' of type Module

@staticmethodを使って解決

jitを使用したいメンバ関数を@staticmethodで静的関数として、@jitを使用したら上手くいった。インスタンス変数を参照したい場合は、ラッパーにすればよい。

class class_jit2():
    def __init__(self):
        pass

    def func(self, arr): #wrapper
        self.arr=arr
        self.new_arr=self.__func(arr)

    @staticmethod
    @jit(f8[:,:,:](f8[:,:,:])) #型指定
    def __func(arr):
        new_arr=np.zeros_like(arr, dtype=np.float64)
        shape=arr.shape
        for x in range(shape[0]):
            for y in range(shape[1]):
                for z in range(shape[2]):
                    new_arr[x,y,z]=1e6*x+1e3*y+z
        return new_arr

    def plot(self):
        plt.imshow(self.new_arr[0])

if __name__=='__main__':
    n=256
    arr=np.zeros([n]*3)        
    st=time.time()
    obj2=class_jit2()
    obj2.func(arr)
    print(f'"class jit staticmethod " elapsed_time {time.time()-st:2f}') 
    obj2.plot()

これなら、jitmatplotも同一のクラスで使用できる。

"class jit staticmethod " elapsed_time 0.044842

image.png

申し訳程度の速度比較

import numpy as np
import matplotlib.pyplot as plt
from numba import jit, f8
from numba.experimental import jitclass
import time


def func(arr): #jitなし
    new_arr=np.zeros_like(arr)
    shape=arr.shape
    for x in range(shape[0]):
        for y in range(shape[1]):
            for z in range(shape[2]):
                new_arr[x,y,z]=1e6*x+1e3*y+z
    return new_arr

@jit(nopython=True) #jit 型指定なし
def func_jit(arr):
    new_arr=np.zeros_like(arr)
    shape=arr.shape
    for x in range(shape[0]):
        for y in range(shape[1]):
            for z in range(shape[2]):
                new_arr[x,y,z]=1e6*x+1e3*y+z   
    return new_arr

@jit(f8[:,:,:](f8[:,:,:])) #jit 型指定
def func_jit2(arr):
    new_arr=np.zeros_like(arr)
    shape=arr.shape
    for x in range(shape[0]):
        for y in range(shape[1]):
            for z in range(shape[2]):
                new_arr[x,y,z]=1e6*x+1e3*y+z   
    return new_arr


spec = [
    ('arr', f8[:,:,:]),               # a simple scalar field
    ('new_arr', f8[:,:,:]),          # an array field
]

@jitclass(spec)
class class_jit():
    def __init__(self,arr):
        self.arr = arr
        self.new_arr=np.zeros_like(arr, dtype=np.float64)
    def func(self):
        shape=self.arr.shape
        for x in range(shape[0]):
            for y in range(shape[1]):
                for z in range(shape[2]):
                    self.new_arr[x,y,z]=1e6*x+1e3*y+z
        return self.new_arr

class class_jit2():
    def __init__(self):
        pass

    def func(self, arr):
        self.arr=arr
        self.new_arr=self.__func(arr)

    @staticmethod
    @jit(f8[:,:,:](f8[:,:,:]))
    def __func(arr):
        new_arr=np.zeros_like(arr, dtype=np.float64)
        shape=arr.shape
        for x in range(shape[0]):
            for y in range(shape[1]):
                for z in range(shape[2]):
                    new_arr[x,y,z]=1e6*x+1e3*y+z
        return new_arr

if __name__=='__main__':

    n=256
    arr=np.zeros([n]*3)

    st=time.time()
    func(arr)
    print(f'"w/o_jit" elapsed_time {time.time()-st:2f}') 

    st=time.time()
    func_jit(arr)
    print(f'"jit + w/o type sepc elapsed_time {time.time()-st:2f}') 

    st=time.time()
    func_jit2(arr)
    print(f'"jit + type spec" elapsed_time {time.time()-st:2f}') 

    st=time.time()
    obj1=class_jit(arr)
    obj1.func()
    print(f'"class jit " elapsed_time {time.time()-st:2f}') 

    st=time.time()
    obj2=class_jit2()
    obj2.func(arr)
    print(f'"class jit staticmethod " elapsed_time {time.time()-st:2f}') 

"w/o_jit" elapsed_time 3.964147
"jit + w/o type sepc elapsed_time 0.163591
"jit + type spec" elapsed_time 0.047845
"class jit " elapsed_time 0.431932
"class jit staticmethod " elapsed_time 0.043882

@jitつけるだけで24倍
型指定すると83倍の高速化
(jitclass の速度がいまいちなのは謎。jitclassの型指定がうまくいってない...?)

6
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2