はじめに
Python環境でのnetCDFファイルの処理に関し,かつてよりnetCDF4というパッケージが用いられてきた.その後,機能性や応用性を増強したarray処理用パッケージxarrayが開発され,シェアが広がりつつある(たぶん).
今回は,ひとつの有用な機能である内挿処理を用いるコードの例を紹介する.Pythonのバージョンは3.7系.
使用モジュール
import numpy as np
import xarray as xr
import multiprocessing as mp
準備:Look Up Tableの作成
準備段階として,まず,サンプルとなる配列をnetCDF形式で保存する.
本記事では,内挿補間により一つの値を取得するための土台として,この配列をLook up table (LUT)と名付ける.
# 各次元のサイズ
sz_1 = 10
sz_2 = 9
sz_3 = 8
# 各次元の配列(今回は0-1間で等間隔に値をとった配列を使用)
ax_1 = np.linspace(0,1,sz_1)
ax_2 = np.linspace(0,1,sz_2)
ax_3 = np.linspace(0,1,sz_3)
# 3種類の3次元データ(適当に作成)
data1 = np.arange(sz_1*sz_2*sz_3
).reshape(sz_1,sz_2,sz_3)
data2 = np.linspace(0,10,sz_1*sz_2*sz_3
).reshape(ln_0,ln_1,ln_2)
data3 = np.linspace(-1,1,sz_1*sz_2*sz_3
).reshape(sz_1,sz_2,sz_3)
# xarray配列の生成
LUT = xr.Dataset({'data1': (['dim_1', 'dim_2', 'dim_3'], data1),
'data2': (['dim_1', 'dim_2', 'dim_3'], data2),
'data3': (['dim_1', 'dim_2', 'dim_3'], data3),},
coords={'dim_1': ax_1,
'dim_2': ax_2,
'dim_3': ax_3,})
# netCDFで保存
LUT.to_netcdf('sample_interp_3d.nc')
内挿処理用クラスの定義
内挿のイメージは下図の通り.近傍のLUTの値を用いて加重平均をとることで,指定した次元値における配列値を取得する.
クラスを定義する.役割は2つで,
-
sample_lut()
が呼ばれた直後(__init__
)に自動でLUTが読み込まれる. - 各次元の値を指定することで内挿保管した値を返す(
Extract()
).
xarrayには,interp()
というメソッドが備わっており,次元の値を指定した配列値の取得がとても簡単かつ直感的な形でできる(scipyの関数を用いているらしい.詳細はこちら→公式).
class sample_lut(object):
def __init__(self):
self.LUT = self.Open()
def Open(self):
'''
Open LUT as xarray dataset.
'''
return xr.open_dataset('sample_interp_3d.nc').load()
def Extract(self,dict_items):
'''
Extract result.
'''
extracted = self.LUT.interp(dim_1=dict_items['v_1'],
dim_2=dict_items['v_2'],
dim_3=dict_items['v_3'],)
r_1 = extracted.data1.data
r_2 = extracted.data2.data
r_3 = extracted.data3.data
return [r_1, r_2, r_3]
上記コードですこし補足を加える.
-
open_dataset()
でnetCDFファイルの読込が可能.詳細はこちら→公式 -
load()
メソッドは,デフォルトでは部分的にのみにしか読み込まない配列を,一度すべてメモリに書き出すために含めている.公式ドキュメントには下記のような記述がある.実際,このあと行う並列処理の際にload()
なしだと予期せぬ挙動を示したため,今回は含めている.
xarray’s lazy loading of remote or on-disk datasets is often but not always desirable. Before performing computationally intense operations, it is often a good idea to load a Dataset (or DataArray) entirely into memory by invoking the load() method.
(Google翻訳)xarrayのリモートデータセットまたはディスク上のデータセットの遅延読み込みは、常に望ましいとは限りません。計算量の多い操作を実行する前に、load()メソッドを呼び出してDataset(またはDataArray)を完全にメモリにロードすることをお勧めします。
内挿処理の実行
# テーブル読込
LUT = sample_lut()
# 入力(各次元の値を指定)
items = dict(v_1 = 0.1, v_2 = 0.2, v_3 = 0.3)
# 内挿補間
outp = LUT.Extract(items)
print(outp) # [79.7 1.10848401 -0.7783032 ]
multiprocessingを用いた並列処理
内挿処理の速度はそこまで早くないため,入力の組み合わせが大量にある場合,コードを並列化するのがよい.
# 各次元の値を適当に100個作成
ar_v1, ar_v2, ar_v3 = np.random.rand(3,100)
# 先にLUTを読み込んでおく
LUT = sample_lut()
# 内挿処理を関数化
def process(v1, v2, v3):
items = dict(v_1 = v1, v_2 = v2, v_3 = v3)
# 関数外部で読み込んだLUTを用いて内挿
outp = LUT.Extract(items)
return outp
# 並列処理(3変数を入力に使用するためにstarmap関数を使用)
num_proc = 4
p = mp.Pool(processes=num_proc)
ar_out = p.starmap(process, zip(ar_v1, ar_v2, ar_v3))
p.close()
print(ar_out)
注意点:読み込み時にload()
を含めなかった際の挙動
クラス定義のところでload()
に関する補足説明を加えたが,これを使用しなかった(xr.open_dataset('sample_interp_3d.nc')
だけにした)場合,次のような挙動が発生する.
# 先にLUTを読み込んでおく
LUT = sample_lut()
def process(n):
# 入力を1組だけに固定
items = dict(v_1 = 0.1, v_2 = 0.2, v_3 = 0.3)
# 関数外部で読み込んだLUTを用いて内挿
outp = LUT.Extract(items)
return outp
# 並列処理(テストとして,同じ入力での計算を100回行う)
num_proc = 4
p = mp.Pool(processes=num_proc)
ar_out = p.map(process, range(100))
p.close()
print(ar_out)
''' 出力が一致しない
[[ 2.62198422e+017 6.41668327e-309 -7.78303199e-001]
[ 2.62198422e+017 6.41668327e-309 -7.78303199e-001]
: : :
[ 2.62198422e+017 6.41668327e-309 -7.78303199e-001]
[ 7.97000000e+001 6.41668327e-309 0.00000000e+000]
[ 7.97000000e+001 6.41668327e-309 0.00000000e+000]
: : :
[ 7.97000000e+001 6.41668327e-309 0.00000000e+000]
[ 7.97000000e+001 1.10848401e+000 -7.78303199e-001]
: : :
'''
挙動発生の仕組みは不明だが,load()
を含めることによってこのようなエラーは解消される.
おわりに
ものすごく便利そうなパッケージxarrayの記事がもう少し増えてくれることを祈る.
補足:説明に用いた図の描画コード
準備
import numpy as np
import itertools
# 内挿する各次元の値
v_1 = 0.71
v_2 = 0.52
v_3 = 0.24
# LUTのscatterに用いる点群の作成
xs = []
ys = []
zs = []
for x,y,z in itertools.product(ax_1, ax_2, ax_3):
xs.append(x)
ys.append(y)
zs.append(z)
# 各次元に垂直な平面の描画に用いる配列
edge_1, edge_2 = np.meshgrid([0,1],[0,1])
sl_x = np.ones((2,2)) * v_1
sl_y = np.ones((2,2)) * v_2
sl_z = np.ones((2,2)) * v_3
# 内挿する値の近傍右側の次元値を取得するindex
rind_1 = np.searchsorted(ax_1, v_1)
rind_2 = np.searchsorted(ax_2, v_2)
rind_3 = np.searchsorted(ax_3, v_3)
# 内挿する値の近傍の次元値
rect_1 = ax_1[rind_1-1:rind_1+1]
rect_2 = ax_2[rind_2-1:rind_2+1]
rect_3 = ax_3[rind_3-1:rind_3+1]
# 近傍でのscatterに用いる点群の作成
mag_xs = []
mag_ys = []
mag_zs = []
for x,y,z in itertools.product(rect_1, rect_2, rect_3):
mag_xs.append(x)
mag_ys.append(y)
mag_zs.append(z)
描画
(右の拡大図については,下記コードからいくつか描画項目をコメントアウトすれば作れる.)
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 3dを指定した図の作成
fig = plt.figure(figsize = (12,8))
ax = fig.add_subplot(111, projection='3d')
# LUTのscatter
ax.scatter(xs, ys, zs, color='k', s=4, linewidths=0)
# 各次元に垂直な平面(面と外枠を独立して描画)
ax.plot_surface(edge_1, edge_2, sl_z, alpha=0.1, color='r')
ax.plot_surface(edge_1, sl_y, edge_2, alpha=0.1, color='g')
ax.plot_surface(sl_x, edge_1, edge_2, alpha=0.1, color='b')
ax.plot_wireframe(edge_1, edge_2, sl_z, color='r', linewidths=1)
ax.plot_wireframe(edge_1, sl_y, edge_2, color='g', linewidths=1)
ax.plot_wireframe(sl_x, edge_1, edge_2, color='b', linewidths=1)
# 直方体の外枠
for s, e in itertools.combinations(np.array(list(itertools.product(rect_1,rect_2,rect_3))), 2):
# (x1,y1,z1), (x2,y2,z2)の組合せのうち,2要素が一致するもの(=辺)のみ選択
if len(set(s) & set(e)) == 2:
ax.plot3D(*zip(s,e), color="k", linewidth=1)
# 各次元に垂直な平面(近傍)
ax.plot_wireframe(edge_xy1, edge_xy2, sl_z, color='r', linewidths=1, linestyle='--')
ax.plot_wireframe(edge_zx2, sl_y, edge_zx1, color='g', linewidths=1, linestyle='--')
ax.plot_wireframe(sl_x, edge_yz1, edge_yz2, color='b', linewidths=1, linestyle='--')
# 平面同士の境界線
ax.plot(rect_1, [v_2, v_2], [v_3, v_3], color='y', linewidth=1, linestyle='--')
ax.plot([v_1, v_1], rect_2, [v_3, v_3], color='m', linewidth=1, linestyle='--')
ax.plot([v_1, v_1], [v_2, v_2], rect_3, color='c', linewidth=1, linestyle='--')
# 内挿地点
ax.scatter(v_1, v_2, v_3, color='orange', marker='s', s=100, linewidths=0)
ax.set_xlabel('Axis 1')
ax.set_ylabel('Axis 2')
ax.set_zlabel('Axis 3')
fig.savefig('xxxxxxxxx.png')
plt.show()