1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Pythonで高速に計算するためのあれやこれや

Posted at

0. 本稿について

私はこれまで,一般に処理速度が遅いとされているPythonでプログラムを書いて,データを処理してきました.本稿は,これまで積み上げてきた知見のメモ代わりとして投稿するものです.コードが洗練されていなかったり,説明が不明瞭だったりする部分もありますが,どうかお手柔らかに読んでいってください.

1. 問題設定と愚直な実装

本稿では,以下に示すライプニッツの公式を用いた円周率の計算を題材に,計算の高速化をしていきます.
$$
\sum_{n=0}^\infty \frac{(-1)^n}{2n+1} = \frac{\pi}{4}
$$
こちらの式を愚直に実装するとこんな感じになります.

no_speed_up.py
import time
import numpy as np

def sum_leibniz( ns, ne ) :

    # initial setting
    s = 0

    # main loop
    for n in range( ns, ne + 1 ) :
        s += (-1)**n/( 2*n + 1 )
    
    return s

if __name__ == '__main__' :

    # initial setting
    n_max = 10**7    # いくつまで総和を計算するか
    itr_num = 10     # 何回円周率の計算をするか

    # main loop
    etime_all = list()
    for i in range( itr_num ) :

        # main calculation
        stime = time.time()    
        qpi = sum_leibniz( 0, n_max )
        etime = time.time() - stime

        # add elapsed time
        etime_all.append( etime )

    # status output
    print( 'pi = {0}'.format( qpi*4 ) )
    print( 'Average elapsed time : {0:.2f} sec'.format( np.mean(etime_all) ) )

色々余計なことも書いていますが,計算自体は関数sum_leibnizで実行しています.また,計算にかかった時間は,円周率の計算を複数回実行して得られる,平均的な値を出力するようにしています.以降,プログラムを手元の環境で実行するときは,マシンスペックと相談して,変数n_maxの値を変更して実行してみてください.

2. JITコンパイラー「Numba」を使った高速化

Pythonプログラムを高速化する上で,最もお手軽なのはJIT(Just-In-Time)コンパイラNumbaを使った方法だと思います.Numbaの詳細な説明は他の記事(たとえばこちら)に譲るとして,本稿では並列化ありバージョンと並列化なしバージョンのプログラムを載せておきます.

numba_no_parallel.py
import time
import numpy as np
from numba import jit


@jit( 'f8(i8,i8)', nopython=True, parallel=False )
def sum_leibniz_numba_nop( ns, ne ) :

    # initial setting
    s = 0

    for n in range( ns, ne + 1 ) :
        s += (-1)**n/( 2*n + 1 )
    
    return s

if __name__ == '__main__' :

    # initial setting
    n_max = 10**8
    itr_num = 10

    # main loop
    etime_all = list()
    for i in range( itr_num ) :

        # main calculation
        stime = time.time()    
        qpi = sum_leibniz_numba_nop( 0, n_max )
        etime = time.time() - stime

        # add elapsed time
        etime_all.append( etime )

    # status output
    print( 'pi = {0}'.format( qpi*4 ) )
    print( 'Average elapsed time : {0:.2f} sec'.format( np.mean(etime_all) ) )
numba_parallel.py
import time
import numpy as np
from numba import jit, prange


@jit( 'f8(i8,i8)', nopython=True, parallel=True )
def sum_leibniz_numba_p( ns, ne ) :

    # initial setting
    s = 0

    for n in prange( ns, ne + 1 ) :
        s += (-1)**n/( 2*n + 1 )
    
    return s

if __name__ == '__main__' :

    # initial setting
    n_max = 10**8
    itr_num = 10

    # main loop
    etime_all = list()
    for i in range( itr_num ) :

        # main calculation
        stime = time.time()    
        qpi = sum_leibniz_numba_p( 0, n_max )
        etime = time.time() - stime

        # add elapsed time
        etime_all.append( etime )

    # status output
    print( 'pi = {0}'.format( qpi*4 ) )
    print( 'Average elapsed time : {0:.2f} sec'.format( np.mean(etime_all) ) )

3. Fortranの呼び出し:他言語との連携

ここからは結構マニアックな話です.私の働いている業界では,先人たちの遺産が重宝されていて,それらがFORTRAN/Fortranで書かれていることが結構あります.もちろん自分でPythonで実装しなおしてもいいのですが,例えば均質半無限弾性体中の食い違いによる変形を計算するコードを自分でPythonで書くのは,実行速度の面でも,労力の面でも現実的ではありません.そこで,この節ではPythonからFortranを呼び出す方法を紹介します.この節で使ったプログラムの作成には,こちらの記事を参考にしました(10年以上前の記事ですが,今でも通用します).

それでは,具体的な作業に入ります.まずは1節のプログラムで使用した関数sum_leibnizに対応するサブルーチンをFortran90で書きます.

subroutine_sum_leibniz.f90
subroutine sum_leibniz( ns, ne, s )
  implicit none
  integer(8), intent(in) :: ns, ne
  double precision, intent(inout) :: s
  integer(8) :: n

  ! main loop
  do n = ns, ne
    s = s + (-1)**n/dble( 2*n + 1 )
  end do

end subroutine

続いて,作成したサブルーチンを共有ライブラリとしてビルドします.私のノートPCがWindowsなのでそれ前提で書きます.Mac/Linuxユーザーの皆様は元記事を参考にビルドしてください.

gfortran -shared -O3 -o sum_leibniz.dll subroutine_sum_leibniz.f90

ここまででFortran側の準備は完了.では,Python側のプログラムを書いていきます.

call_fortran.py
import time
import numpy as np
from ctypes import *

def call_fsubroutine( ns, ne ) :

    # load 
    fmodule = cdll.LoadLibrary( './sum_leibniz.dll' )

    # set input variable type
    fmodule.sum_leibniz_.argtypes = [
        POINTER(c_int64), POINTER(c_int64), POINTER(c_double)
    ]

    # set output variable type
    fmodule.sum_leibniz_.restype = c_void_p

    # wrap variables
    ns = c_int64( ns )
    ne = c_int64( ne )
    s = 0.0
    s = c_double( s )

    # call subroutine
    fmodule.sum_leibniz_( byref(ns), byref(ne), byref(s) )

    return s.value 


if __name__ == '__main__' :

    # initial setting
    n_max = 10**10
    itr_num = 10

    # main loop
    etime_all = list()
    for i in range( itr_num ) :

        # main calculation
        stime = time.time()    
        qpi = call_fsubroutine( 0, n_max )
        etime = time.time() - stime

        # add elapsed time
        etime_all.append( etime )

    # status output
    print( 'pi = {0}'.format( qpi*4 ) )
    print( 'Average elapsed time : {0:.2f} sec'.format( np.mean(etime_all) ) )

このプログラムを実行すればOK.やってることの詳細が気になる人は元記事を読んでください.

4. Numbaに頼らない並列化

Numbaを使った並列化はお手軽で嬉しいのですが,状況によってはこのやり方が通用しない局面に出くわすこともあるでしょう(例えばSTLを使った季節調整を同時に複数の時系列に適用したい,とか...).そういった時のために,Numbaに頼らない並列化を覚えておくと良いでしょう.この節で使うプログラムの作成には,この記事(並列化初心者の私にはとても分かりやすかった!)とPython公式を参考にしました.詳細が気になる人はこれらの記事を読んでください.

まずは1節のプログラムを並列化してみます.

para_multiprocessing.py
import time
import numpy as np
from multiprocessing import Process, Pipe

def sum_leibniz_conn( ns, ne, conn ) :

    # initial setting
    s = 0

    # main loop
    for n in range( ns, ne + 1 ) :
        s += (-1)**n/( 2*n + 1 )
    
    conn.send( s )
    conn.close()


def run_multiprocessing( ns, ne ) :

    # set start & end points
    ns1, ne1 = ns, int( ne/4 )
    ns2, ne2 = int( ne/4 ) + 1, int( ne/2 )
    ns3, ne3 = int( ne/2 ) + 1, int( ne*3/4 )
    ns4, ne4 = int( 3*ne/4 ) + 1, ne

    # set variables for connection
    par_conn1, chi_conn1 = Pipe()
    par_conn2, chi_conn2 = Pipe()
    par_conn3, chi_conn3 = Pipe()
    par_conn4, chi_conn4 = Pipe()

    # set objects
    process1 = Process( target=sum_leibniz_conn, args=(ns1, ne1, chi_conn1) )
    process2 = Process( target=sum_leibniz_conn, args=(ns2, ne2, chi_conn2) )
    process3 = Process( target=sum_leibniz_conn, args=(ns3, ne3, chi_conn3) )
    process4 = Process( target=sum_leibniz_conn, args=(ns4, ne4, chi_conn4) )

    # start process
    process1.start()
    process2.start()
    process3.start()
    process4.start()

    # wait for ending all processes
    process1.join()
    process2.join()
    process3.join()
    process4.join()

    qpi1 = par_conn1.recv()
    qpi2 = par_conn2.recv()
    qpi3 = par_conn3.recv()
    qpi4 = par_conn4.recv()

    return qpi1 + qpi2 + qpi3 + qpi4
    

if __name__ == '__main__':

    # initial setting
    n_max = 10**7
    itr_num = 10

    # main loop
    etime_all = list()
    for i in range( itr_num ) :

        # main calculation
        stime = time.time()    
        qpi = run_multiprocessing( 0, n_max )
        etime = time.time() - stime

        # add elapsed time
        etime_all.append( etime )

    # status output
    print( 'pi = {0}'.format( qpi*4 ) )
    print( 'Average elapsed time : {0:.2f} sec'.format( np.mean(etime_all) ) )

並列化は初心者なのですが,これで上手くいきました(嬉しい!).続いて,3節のプログラムを並列化します.

para_multiprocessing_fcall.py
import time
import numpy as np
from ctypes import *
from multiprocessing import Process, Pipe

def call_fsubroutine_conn( ns, ne, conn ) :

    # load
    fmodule = cdll.LoadLibrary( './sum_leibniz.dll' )

    # set input variable type
    fmodule.sum_leibniz_.argtypes = [
        POINTER(c_int64), POINTER(c_int64), POINTER(c_double)
    ]

    # set output variable type
    fmodule.sum_leibniz_.restype = c_void_p

    # wrap variables
    ns = c_int64( ns )
    ne = c_int64( ne )
    s = 0.0
    s = c_double( s )

    # call subroutine
    fmodule.sum_leibniz_( byref(ns), byref(ne), byref(s) )

    conn.send( s.value )
    conn.close()


def run_multiprocessing( ns, ne ) :

    # set start & end points
    ns1, ne1 = ns, int( ne/4 )
    ns2, ne2 = int( ne/4 ) + 1, int( ne/2 )
    ns3, ne3 = int( ne/2 ) + 1, int( ne*3/4 )
    ns4, ne4 = int( 3*ne/4 ) + 1, ne

    # set variables for connection
    par_conn1, chi_conn1 = Pipe()
    par_conn2, chi_conn2 = Pipe()
    par_conn3, chi_conn3 = Pipe()
    par_conn4, chi_conn4 = Pipe()

    # set objects
    process1 = Process( target=call_fsubroutine_conn, args=(ns1, ne1, chi_conn1) )
    process2 = Process( target=call_fsubroutine_conn, args=(ns2, ne2, chi_conn2) )
    process3 = Process( target=call_fsubroutine_conn, args=(ns3, ne3, chi_conn3) )
    process4 = Process( target=call_fsubroutine_conn, args=(ns4, ne4, chi_conn4) )

    # start process
    process1.start()
    process2.start()
    process3.start()
    process4.start()

    # wait for ending all processes
    process1.join()
    process2.join()
    process3.join()
    process4.join()

    qpi1 = par_conn1.recv()
    qpi2 = par_conn2.recv()
    qpi3 = par_conn3.recv()
    qpi4 = par_conn4.recv()

    return qpi1 + qpi2 + qpi3 + qpi4
    

if __name__ == '__main__':

    # initial setting
    n_max = 10**10
    itr_num = 10

    # main loop
    etime_all = list()
    for i in range( itr_num ) :

        # main calculation
        stime = time.time()    
        qpi = run_multiprocessing( 0, n_max )
        etime = time.time() - stime

        # add elapsed time
        etime_all.append( etime )

    # status output
    print( 'pi = {0}'.format( qpi*4 ) )
    print( 'Average elapsed time : {0:.2f} sec'.format( np.mean(etime_all) ) )

こちらも上手くいきました.Fortranの呼び出しと並列化が共存できると何が嬉しいかというと,均質半無限弾性体中の食い違いによる変形を計算するコードのようなサブルーチンを並列化の枠組みの中で呼び出せちゃうということです.これができると,例えばこういうプログラムを自分で実装するときに,高速化ができちゃうわけですね.

5. おわりに

本稿は自分用のメモ代わりとして,解説少なめでお送りしました.各節で使っている手法の解説は,私よりもプログラミングが上手い人たちの書いた記事を参考にしてもらえれば良いと思います.

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?