2025/08/28 修正
コードの不具合を修正し、NumbaがCPU/GPU性能で最速になりました。
実験コードの中で複素数の2乗をz**2
としていたことで、NumbaのGPU処理が遅くなっている1ことが発覚しました(34jさんコメントありがとうございました)。該当箇所z*z
に変更し、実験しなおして修正しました。
はじめに
Pythonで数値計算や物理シミュレーションを行う際、処理速度が課題になります。
本記事では、Pythonコードを大幅に高速化できる4つの主要フレームワークの特徴と性能を比較し、用途に応じた最適な選択肢を紹介します。
マンデルブロ集合計算のベンチマークを通じて、各フレームワークのCPU/GPU性能を検証した結果、用途によって最適なフレームワークが異なることがわかりました。特にNunbaがCPUおよびGPUの両方で最高速を示しました。またTaichiは、今回比較した中で、CPUとGPUのコードを共通化しやすくポータブル性が最も高いと感じました(ti.init()
でデフォルトの数値精度を指定できるのが便利)。WarpはGPU処理が高速ですが、CPUでの性能はNumbaやTaichiには劣るという結果でした。
各種フレームワークの特徴
Numba
シンプルで使いやすいJITコンパイラ
NumbaはPython関数をLLVMコンパイラを用いて機械語に変換し、処理を大幅に高速化するフレームワークです。特に既存のNumPyコードへの適用が容易で、単純なデコレータの追加だけで効果を得られる点が大きな特徴です。
- デコレータ(@jit)で関数を簡単に最適化
- NumPyとの高い互換性により既存コードの移行が容易
- CPUマルチコア並列化およびCUDA GPUアクセラレーションに対応
- 純粋なPythonコードと比較して劇的な速度向上が可能
Taichi
クロスプラットフォーム対応の高性能フレームワーク
Taichiは物理シミュレーション向けに設計された言語で、複数のハードウェアプラットフォームをサポートしています。CPUとGPUの両方で効率的に動作するコードを生成する能力を持っています。
- CPU/GPU向けに自動最適化されたコードを生成
- CUDA, Vulkan, Metal, OpenGLなど複数のGPUバックエンドをサポート
- 独自の階層的データ構造(SNode)により複雑なシミュレーションに対応2
- 物理ベースのシミュレーションに最適化された機能を提供
Warp
NVIDIA謹製の高性能計算エンジン
Warpは、NVIDIA社が開発したフレームワークで、特にNVIDIA GPU上での高速計算に最適化されています。物理シミュレーションと機械学習の統合に強みを持っています。
- NVIDIA GPUで特に高いパフォーマンスを発揮
- CPUおよびGPUで実行可能なコード
- PyTorchやJAXと連携可能な微分可能カーネルを提供
- シミュレーションのための特化した関数ライブラリを内蔵
JAX
自動微分と並列計算に特化したフレームワーク
JAXはGoogleが開発した、NumPyの関数をGPU/TPUで加速しつつ自動微分機能を提供するフレームワークです。機械学習の研究開発でよく使用されています。
- NumPyの関数をGPU/TPUで高速実行できるよう拡張
- 自動微分機能により機械学習モデルの最適化に適している
- XLA(Accelerated Linear Algebra)によるJITコンパイルで高速化
- SPMD3(単一プログラム、複数データ)モデルによるマルチデバイス並列処理機能
ベンチマークコード
マンデルブロ集合の計算は、反復計算と条件分岐を含む典型的な数値計算タスクであり、各フレームワークの性能を評価するのに適しています。4
以下のGitHubリポジトリで完全なコードを公開しています。
ベンチマークコードのサンプル(一部抜粋)
# Numbaのコード例
@numba.vectorize([numba.int64(numba.complex128)], fastmath=True, target="parallel")
def run_numba_cpu_ufunc(c):
counter = 0
z = 0j
for _ in range(1000):
z = z * z + c
if z.real * z.real + z.imag * z.imag >= 4.0:
break
counter += 1
return counter
x, y = np.meshgrid(
np.linspace(-2, 1.2, 7680, dtype=np.float64),
np.linspace(-1.2, 1.2, 4320, dtype=np.float64),
)
complex_array = x + y * 1j
z = run_numba_cpu_ufunc(complex_array)
# Taichiのコード例
@ti.func
def taichi_func(c: tm.vec2) -> int:
counter = 0
z = tm.vec2(0.0, 0.0)
for i in range(ti.int32(1000)):
z = tm.cmul(z, z) + c
if z.x * z.x + z.y * z.y >= 4.0:
break
counter += 1
return counter
@ti.kernel
def run_taichi_kernel(field: ti.template(), out: ti.template()):
for i, j in field:
out[i, j] = taichi_func(field[i, j])
def calculate_taichi_mandelbrot(complex_array):
field = ti.Vector.field(n=2, dtype=float, shape=(4320, 7680))
field.from_numpy(complex_array)
out = ti.field(dtype=int, shape=(4320, 7680))
run_taichi_kernel(field, out)
return out.to_numpy()
ti.init(arch=ti.cpu, default_ip=ti.i64, default_fp=ti.f64)
x, y = np.meshgrid(
np.linspace(-2, 1.2, 7680, dtype=np.float64),
np.linspace(-1.2, 1.2, 4320, dtype=np.float64),
)
complex_array = np.stack([x, y], axis=2)
z = calculate_taichi_mandelbrot(complex_array)
# Warpのコード例
wp.init()
@wp.func
def warp_func(c: wp.vec2d) -> wp.int64:
counter = wp.int64(0)
z = wp.vec2d(wp.float64(0), wp.float64(0))
for i in range(1000):
z = wp.vec2d(z[0] * z[0] - z[1] * z[1], wp.float64(2.0) * z[0] * z[1]) + c
if z[0] * z[0] + z[1] * z[1] >= wp.float64(4.0):
break
counter += wp.int64(1)
return counter
@wp.kernel
def warp_kernel(field: wp.array2d(dtype=wp.vec2d), out: wp.array2d(dtype=wp.int64)):
i, j = wp.tid()
out[i, j] = warp_func(field[i, j])
def calculate_warp_cpu_mandelbrot(complex_array, device):
field = wp.array(complex_array, dtype=wp.vec2d, device=device)
out = wp.zeros(shape=(4320, 7680), dtype=wp.int64, device=device)
wp.launch(
kernel=warp_kernel,
dim=(4320, 7680),
inputs=[field],
outputs=[out],
device=device,
)
return out.numpy()
x, y = np.meshgrid(
np.linspace(-2, 1.2, 7680, dtype=np.float64),
np.linspace(-1.2, 1.2, 4320, dtype=np.float64),
)
complex_array = np.stack([x, y], axis=2)
z = calculate_warp_cpu_mandelbrot(complex_array, "cpu")
# JAXのコード例
def jax_func(c):
def body(i, carry):
counter, z = carry
z = jnp.where(z.real**2 + z.imag**2 >= 4, z, z * z + c)
counter = jnp.where(z.real**2 + z.imag**2 >= 4, counter, counter + 1)
return counter, z
final_counter, _ = jax.lax.fori_loop(0, 1000, body, (0, 0j))
return final_counter
x, y = jnp.meshgrid(
jnp.linspace(-2, 1.2, 7680, dtype=np.float32),
jnp.linspace(-1.2, 1.2, 4320, dtype=np.float32),
)
complex_array = x + y * 1j
calculate_jax_mandelbrot = jax.jit(jnp.vectorize(jax_func), backend="cpu")
z = calculate_jax_mandelbrot(complex_array).block_until_ready()
速度比較結果
Kaggleノートブック上で、 NVIDIA Tesla P100のアクセラレータを選択して実行しました。
- 最速: Numba(CPU/GPUともに)
- 最遅: JAX(CPU/GPUともに)
Numba | Taichi | Warp | JAX | Numpy | |
---|---|---|---|---|---|
CPU | 7.04s | 16.0s | 32.9s | 74.6s | 1486s |
GPU | 0.151s | 0.708s | 0.222s | 2.05s |
フレームワーク選択の指針
- 既存コードがNumPyベースの場合→Numba
- Numbaが最適です。簡単な変更で大幅な高速化が可能です
- NVIDIA GPU専用環境の微分可能シミュレータを開発する場合→Warp
- Warpが最も効率的です。特にPyTorchやJAXとの連携が必要な場合にも有効です5
- ポータビリティ重視の場合→Taichi
- 巨大テンソル計算や機械学習モデルの場合→JAX
- JAXは今回のベンチマークでは遅かったものの、自動微分やSPMD並列処理が強みとなるため、特定用途では有効だと思います
- ただし、本ベンチマークのようなステートフルな計算(かつ遷移回数も多い場合)には不向きと考えられます8
関連資料
Mandelbrot on all accelerators
Benchmark Numba, Taichi, Warp, JAX using Mandelbrot set
おまけ(実装上の注意点)
共通の最適化ポイント
- どのような演算が裏でおこなわれているのかを理解して実装する(自分への戒め)
- データ転送の最小化: CPU-GPU間のデータ転送は極力減らす
- メモリアクセスパターン: 連続したメモリアクセスを心がける
ただし、可読性とのトレードオフも考慮する必要があります。
フレームワーク別のヒント
- Numba: parallel=Trueオプションでマルチコア活用、fastmath=Trueで数値精度と引き換えに速度向上(するかも)
- Taichi: 適切なデータ構造(sparse/denseなど)選択が重要
- Warp: NVIDIAの最新GPUドライバー利用でパフォーマンス向上
- JAX: 関数型スタイルを採用し、状態変化を最小限に
-
CPU処理だと実質
z**2
→z*z
になるけど、GPU処理だとz**2
は複素数の複素数乗として処理されているかも。https://github.com/NVIDIA/numba-cuda/blob/75fb24cb091110050db3baddba74ed8978f0cf4c/numba_cuda/numba/cuda/mathimpl.py#L340 ↩ -
「Single Program, Multiple Data」の略称。複数のプロセッサが同じプログラムを実行しながら、異なるデータに対して処理を行う並列計算のモデル。 ↩
-
例えば、The Computer Language Benchmarks Gameで使われています。 ↩
-
https://gist.github.com/jpivarski/da343abd8024834ee8c5aaba691aafc7?permalink_comment_id=5112454#gistcomment-5112454 ↩