LoginSignup
4
4

More than 5 years have passed since last update.

pythonで多次元配列の並列計算の紹介と比較

Last updated at Posted at 2019-05-08

初めに

pythonにおいて全く同じ処理を多次元配列の各グリッドに対して実行したいと思い,pythonの並列計算に手を出し始めたので大枠をまとめました.
検証モジュールはpython3.7のmultiprocessingとjoblibです.

実験設定・環境

時間軸(datetimes)×緯度(lats)×経度(lons)の3次元配列の処理を行います.
これら3つの情報を必要とするものとして,太陽高度の計算を行うこととします.
使用モジュールはpysolarという外部モジュールで,pipやcondaでインストールできます.

検証に必要な計算時間を考慮した結果,座標格子は2度×2度,時間は2006年1月1日の0時,6時,12時,18時の4つを選択しました.下図はその出力結果(単位は度)です.出力配列(sa_ar)は4×90×180となります.

qiita_sa_test.png

(描画に用いたコードは本記事末尾に載せました).

作業はすべてJupyter notebook上で行いました.
計算に要する時間計測は%%timeコマンドで行っています.

注意:今回用いるpysolarやdatetimeはpython2と3ではいろいろと変更点があるため,本実験をpython2で実行する際はコードの改変が必要です.

準備

緯度経度,時間の配列を作成します.

import numpy as np
import datetime as dt
import pysolar
import itertools

deg = 2.

lon_min = -180
lon_max =  180
lat_min = - 90
lat_max =   90

lon_tot = lon_max - lon_min
lat_tot = lat_max - lat_min
nhpx = int(lon_tot/deg)
nvpx = int(lat_tot/deg)

lons = np.linspace(lon_min+deg*0.5,lon_max-deg*0.5,nhpx)
lats = np.linspace(lat_min+deg*0.5,lat_max-deg*0.5,nvpx)

datetimes = [dt.datetime(2006,1,1,h,tzinfo=dt.timezone.utc) for h in (0,6,12,18)]
ndt = len(datetimes)

検証

1. 並列化なし

forループを重ねて,グリッドごとに計算→代入を繰り返させます.

%%time
sa_ar = np.zeros((ndt,nvpx,nhpx,))
for idt,dtime in enumerate(datetimes):
    for ilat,lat in enumerate(lats):
        for ilon,lon in enumerate(lons):
            sa_ar[idt,ilat,ilon] = pysolar.solar.get_altitude(lat, lon, dtime)

要した時間は01分42秒でした.

2. joblib

import joblib as jl

並列計算に用いるCPU数を変えて実験をしてみます.

2.1. n_jobs=-1

n_jobs=-1でCPUを最大使用可能数だけ使用する設定にしています.
関数に入力するdtime,lat,lonは,itertools.product()を用いて全組合せを生成し,joblibの機能に渡しています.出力は1次元リストなので,numpy.arrayに代入して成型して完了です.

%%time
def process(dtime,lat, lon):
    return pysolar.solar.get_altitude(lat, lon, dtime)

result = jl.Parallel(n_jobs=-1)([jl.delayed(process)(dtime,lat,lon) for dtime,lat,lon in itertools.product(datetimes,lats,lons,)])
sa_ar = np.array(result).reshape((ndt,nvpx,nhpx,))

要した時間は02分18秒でした.なぜか単純にやるよりも時間がかかってしまいました.
大型サーバで計算を行っているため,最大利用可能CPU数を設定すると結構な量を使うことになるのですが,何かに時間を割かれてしまっているようです.

2.2. n_jobs=2,4

CPU数に固定の値を指定しました.
結果は,n_jobs=2で01分07秒,n_jobs=4で00分31秒と,期待通りに計算時間の短縮が図れました.もう少し実験してもよさそうですが,まずまず計算時間の短縮はできているので,ここでやめておきます.

3. multiprocessing

import multiprocessing as mp

こちらも,並列計算に用いるCPU数を変えて検証してみます.

3.1. processes=cpu_count()

processes=mutiprocessing.cpu_count()でCPUを最大使用可能数だけ使用する設定にしています.
計算に必要な変数が1つのみであれば,map()メソッドを用いますが,今回のように複数である場合はstarmap()メソッドを使用します.それ以外はほとんどjoblibとやることは同様です.

%%time
def process(dtime, lat, lon):
    return pysolar.solar.get_altitude(lat, lon, dtime)

p = mp.Pool(processes=mp.cpu_count())
# result = p.map(process, single list)
result = p.starmap(process, itertools.product(datetimes,lats,lons,))
p.close()
sa_ar = np.array(result).reshape((ndt,nvpx,nhpx,))

結果は00分03秒でした.joblibとは対照的に,最大可能数指定でもCPUをうまく利用できているようです.

3.2. processes=2,4

processesを変えて実験してみました.結果は,processes=2で00分57秒,processes=4で00分28秒と,いずれもややjoblibよりも速い結果となりました.

まとめ

といった形で今回はmultiprocessingのほうがよさげな結果となりました.
が,2つでそこまで大差はないこと,検証していないオプションがあること,コードが複雑化した際の挙動を予測できていないことなどから,あくまで参考値として理解しています.
なにより計算がうまく回ってくれたことにほっとしているので,どんどん生かしていきたいと思います.

補足:マップ描画コード

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

# draw lines
def draw(ax,array,title):
    ax.set_extent((lon_min, lon_max, lat_min, lat_max),crs=ccrs.PlateCarree())
    ax.set_title(title)
    ax.coastlines(linewidth=0.4)
    gl = ax.gridlines(linewidth=0.5, draw_labels=True, linestyle='--')
    gl.xlabels_top = False
    gl.ylabels_right = False
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER

    # draw array (set extent)
    im = ax.imshow(array,extent=(lon_min, lon_max, lat_min, lat_max),vmin=-90,vmax=90)
    return im

fig = plt.figure(figsize=(16,8))
ax1 = fig.add_subplot(221,projection=ccrs.PlateCarree())
ax2 = fig.add_subplot(222,projection=ccrs.PlateCarree())
ax3 = fig.add_subplot(223,projection=ccrs.PlateCarree())
ax4 = fig.add_subplot(224,projection=ccrs.PlateCarree())
im1 = draw(ax1,sa_ar[0],datetimes[0].strftime('%Y-%m-%d %H:%M'))
im2 = draw(ax2,sa_ar[1],datetimes[1].strftime('%Y-%m-%d %H:%M'))
im3 = draw(ax3,sa_ar[2],datetimes[2].strftime('%Y-%m-%d %H:%M'))
im4 = draw(ax4,sa_ar[3],datetimes[3].strftime('%Y-%m-%d %H:%M'))
plt.colorbar(im1,fig.add_axes((0.92,0.13,0.02,0.75)))
4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4