はじめに
Pythonで数値計算や物理シミュレーションを行う際、処理速度が課題になります。
本記事では、Pythonコードを大幅に高速化できる4つの主要フレームワークの特徴と性能を比較し、用途に応じた最適な選択肢を紹介します。
マンデルブロ集合計算のベンチマークを通じて、各フレームワークのCPU/GPU性能を検証した結果、用途によって最適なフレームワークが異なることがわかりました。特にWarpはGPU処理で最高速を示し、NumbaはCPU処理と既存コード互換性に優れていることが確認できました。
各種フレームワークの特徴
Numba
シンプルで使いやすいJITコンパイラ
NumbaはPython関数をLLVMコンパイラを用いて機械語に変換し、処理を大幅に高速化するフレームワークです。特に既存のNumPyコードへの適用が容易で、単純なデコレータの追加だけで効果を得られる点が大きな特徴です。
- デコレータ(@jit)で関数を簡単に最適化
- NumPyとの高い互換性により既存コードの移行が容易
- CPUマルチコア並列化およびCUDA GPUアクセラレーションに対応
- 純粋なPythonコードと比較して劇的な速度向上が可能
Taichi
クロスプラットフォーム対応の高性能フレームワーク
Taichiは物理シミュレーション向けに設計された言語で、複数のハードウェアプラットフォームをサポートしています。CPUとGPUの両方で効率的に動作するコードを生成する能力を持っています。
- CPU/GPU向けに自動最適化されたコードを生成
- CUDA, Vulkan, Metal, OpenGLなど複数のGPUバックエンドをサポート
- 独自の階層的データ構造(SNode)により複雑なシミュレーションに対応1
- 物理ベースのシミュレーションに最適化された機能を提供
Warp
NVIDIA謹製の高性能計算エンジン
Warpは、NVIDIA社が開発したフレームワークで、特にNVIDIA GPU上での高速計算に最適化されています。物理シミュレーションと機械学習の統合に強みを持っています。
- NVIDIA GPUで特に高いパフォーマンスを発揮
- CPUおよびGPUで実行可能なコード
- PyTorchやJAXと連携可能な微分可能カーネルを提供
- シミュレーションのための特化した関数ライブラリを内蔵
JAX
自動微分と並列計算に特化したフレームワーク
JAXはGoogleが開発した、NumPyの関数をGPU/TPUで加速しつつ自動微分機能を提供するフレームワークです。機械学習の研究開発でよく使用されています。
- NumPyの関数をGPU/TPUで高速実行できるよう拡張
- 自動微分機能により機械学習モデルの最適化に適している
- XLA(Accelerated Linear Algebra)によるJITコンパイルで高速化
- SPMD2(単一プログラム、複数データ)モデルによるマルチデバイス並列処理機能
ベンチマークコード
マンデルブロ集合の計算は、反復計算と条件分岐を含む典型的な数値計算タスクであり、各フレームワークの性能を評価するのに適しています。3
以下のGitHubリポジトリで完全なコードを公開しています。
ベンチマークコードのサンプル(一部抜粋)
# Numbaのコード例
@numba.vectorize([numba.int32(numba.complex64)], target="parallel")
def run_numba_cpu_ufunc(c):
counter = 0
z = 0j
for _ in range(1000):
z = z**2 + c
if z.real**2+z.imag**2 >= 4:
break
counter += 1
return counter
x, y = np.meshgrid(np.linspace(-2, 1.2, 7680, dtype=np.float32), np.linspace(-1.2, 1.2, 4320, dtype=np.float32))
complex_array = x + y * 1j
z = run_numba_cpu_ufunc(complex_array)
# Taichiのコード例
@ti.func
def taichi_func(c: tm.vec2) -> ti.i32:
counter = 0
z = tm.vec2(0.0, 0.0)
for i in range(1000):
z = tm.vec2(z.x**2 - z.y**2, 2*z.x*z.y) + c
if z.x**2+z.y**2 >= 4:
break
counter += 1
return counter
@ti.kernel
def run_taichi_kernel(field: ti.template(), out: ti.template()):
for i, j in field:
out[i, j] = taichi_func(field[i, j])
def calculate_taichi_mandelbrot(complex_array):
field = ti.Vector.field(n=2, dtype=ti.f32, shape=(4320, 7680))
field.from_numpy(complex_array)
out = ti.field(dtype=ti.i32, shape=(4320, 7680))
run_taichi_kernel(field, out)
return out.to_numpy()
ti.init(arch=ti.cpu)
x, y = np.meshgrid(np.linspace(-2, 1.2, 7680), np.linspace(-1.2, 1.2, 4320))
complex_array = np.stack([x, y], axis=2)
z=calculate_taichi_mandelbrot(complex_array)
# Warpのコード例
@wp.func
def warp_func(c: wp.vec2) -> wp.int32:
counter = wp.int32(0)
z = wp.vec2(0.0, 0.0)
for i in range(1000):
z = wp.vec2(z[0]*z[0] - z[1]*z[1], 2.0*z[0]*z[1]) + c
if z[0]*z[0]+z[1]*z[1] >= 4.0:
break
counter += 1
return counter
@wp.kernel
def warp_kernel(field: wp.array2d(dtype=wp.vec2), out: wp.array2d(dtype=wp.int32)):
i, j = wp.tid()
out[i, j] = warp_func(field[i, j])
def calculate_warp_mandelbrot(complex_array, device):
field=wp.array(complex_array, dtype=wp.vec2, device=device)
out=wp.zeros(shape=(4320, 7680), dtype=wp.int32, device=device)
wp.launch(kernel=warp_kernel, dim=(4320, 7680), inputs=[field], outputs=[out], device=device)
return out.numpy()
device = "cpu"
z=calculate_warp_mandelbrot(complex_array, device)
# JAXのコード例
def jax_func(c):
def body(i, carry):
counter, z = carry
z = jnp.where(z.real**2 + z.imag**2 >= 4, z, z**2 + c)
counter = jnp.where(z.real**2 + z.imag**2 >= 4, counter, counter + 1)
return counter, z
final_counter, _ = jax.lax.fori_loop(0, 1000, body, (0, 0.+0j))
return final_counter
x, y = jnp.meshgrid(jnp.linspace(-2, 1.2, 7680, dtype=np.float32), jnp.linspace(-1.2, 1.2, 4320, dtype=np.float32))
complex_array = x + y * 1j
calculate_jax_mandelbrot = jax.jit(jnp.vectorize(jax_func), backend="cpu")
z = calculate_jax_mandelbrot(complex_array).block_until_ready()
速度比較結果
Kaggleノートブック上で、 NVIDIA Tesla P100のアクセラレータを選択して実行しました。
- CPU最速: Numba
- GPU最速: Warp
- 最遅: JAX(CPU/GPUともに)
Numba | Taichi | Warp | JAX | |
---|---|---|---|---|
CPU | 7.80s | 13.6s | 31.8s | 63.5s |
GPU | 0.837s | 0.708s | 0.310s | 2.34s |
フレームワーク選択の指針
- 既存コードがNumPyベースの場合→Numba
- Numbaが最適です。簡単な変更で大幅な高速化が可能です
- NVIDIA GPU専用環境の場合→Warp
- Warpが最も効率的です。特にPyTorchやJAXとの連携が必要な場合にも有効です
- ポータビリティ重視の場合→Taichi
- Taichiは複数GPUバックエンドをサポートしており、幅広い環境で利用可能です
- 巨大テンソル計算や機械学習モデルの場合→JAX
- JAXは今回のベンチマークでは遅かったものの、自動微分やSPMD並列処理が強みとなるため、特定用途では有効だと思います
- ただし、本ベンチマークのようなステートフルな計算(かつ遷移回数も多い場合)には不向きと考えられます
おまけ(実装上の注意点)
共通の最適化ポイント
- データ転送の最小化: CPU-GPU間のデータ転送は極力減らす
- メモリアクセスパターン: 連続したメモリアクセスを心がける
フレームワーク別のヒント
- Numba: parallel=Trueオプションでマルチコア活用、fastmath=Trueで数値精度と引き換えに速度向上(するかも)
- Taichi: 適切なデータ構造(sparse/denseなど)選択が重要
- Warp: NVIDIAの最新GPUドライバー利用でパフォーマンス向上
- JAX: 関数型スタイルを採用し、状態変化を最小限に
-
「Single Program, Multiple Data」の略称。複数のプロセッサが同じプログラムを実行しながら、異なるデータに対して処理を行う並列計算のモデル。 ↩
-
例えば、The Computer Language Benchmarks Gameで使われています。 ↩