1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Python高速化!数値計算向けライブラリ比較【Numba, Taichi, Warp, JAX】

Last updated at Posted at 2025-04-06

はじめに

Pythonで数値計算や物理シミュレーションを行う際、処理速度が課題になります。

本記事では、Pythonコードを大幅に高速化できる4つの主要フレームワークの特徴と性能を比較し、用途に応じた最適な選択肢を紹介します。

マンデルブロ集合計算のベンチマークを通じて、各フレームワークのCPU/GPU性能を検証した結果、用途によって最適なフレームワークが異なることがわかりました。特にWarpはGPU処理で最高速を示し、NumbaはCPU処理と既存コード互換性に優れていることが確認できました。

download.png

各種フレームワークの特徴

Numba

シンプルで使いやすいJITコンパイラ

NumbaはPython関数をLLVMコンパイラを用いて機械語に変換し、処理を大幅に高速化するフレームワークです。特に既存のNumPyコードへの適用が容易で、単純なデコレータの追加だけで効果を得られる点が大きな特徴です。

  • デコレータ(@jit)で関数を簡単に最適化
  • NumPyとの高い互換性により既存コードの移行が容易
  • CPUマルチコア並列化およびCUDA GPUアクセラレーションに対応
  • 純粋なPythonコードと比較して劇的な速度向上が可能

Taichi

クロスプラットフォーム対応の高性能フレームワーク

Taichiは物理シミュレーション向けに設計された言語で、複数のハードウェアプラットフォームをサポートしています。CPUとGPUの両方で効率的に動作するコードを生成する能力を持っています。

  • CPU/GPU向けに自動最適化されたコードを生成
  • CUDA, Vulkan, Metal, OpenGLなど複数のGPUバックエンドをサポート
  • 独自の階層的データ構造(SNode)により複雑なシミュレーションに対応1
  • 物理ベースのシミュレーションに最適化された機能を提供

Warp

NVIDIA謹製の高性能計算エンジン

Warpは、NVIDIA社が開発したフレームワークで、特にNVIDIA GPU上での高速計算に最適化されています。物理シミュレーションと機械学習の統合に強みを持っています。

  • NVIDIA GPUで特に高いパフォーマンスを発揮
  • CPUおよびGPUで実行可能なコード
  • PyTorchやJAXと連携可能な微分可能カーネルを提供
  • シミュレーションのための特化した関数ライブラリを内蔵

JAX

自動微分と並列計算に特化したフレームワーク

JAXはGoogleが開発した、NumPyの関数をGPU/TPUで加速しつつ自動微分機能を提供するフレームワークです。機械学習の研究開発でよく使用されています。

  • NumPyの関数をGPU/TPUで高速実行できるよう拡張
  • 自動微分機能により機械学習モデルの最適化に適している
  • XLA(Accelerated Linear Algebra)によるJITコンパイルで高速化
  • SPMD2(単一プログラム、複数データ)モデルによるマルチデバイス並列処理機能

ベンチマークコード

マンデルブロ集合の計算は、反復計算と条件分岐を含む典型的な数値計算タスクであり、各フレームワークの性能を評価するのに適しています。3

以下のGitHubリポジトリで完全なコードを公開しています。

ベンチマークコードのサンプル(一部抜粋)
# Numbaのコード例
@numba.vectorize([numba.int32(numba.complex64)], target="parallel")
def run_numba_cpu_ufunc(c):
    counter = 0
    z = 0j
    for _ in range(1000):
        z = z**2 + c
        if z.real**2+z.imag**2 >= 4:
            break
        counter += 1
    return counter

x, y = np.meshgrid(np.linspace(-2, 1.2, 7680, dtype=np.float32), np.linspace(-1.2, 1.2, 4320, dtype=np.float32))
complex_array = x + y * 1j
z = run_numba_cpu_ufunc(complex_array)

# Taichiのコード例
@ti.func
def taichi_func(c: tm.vec2) -> ti.i32:
    counter = 0
    z = tm.vec2(0.0, 0.0)
    for i in range(1000):
        z = tm.vec2(z.x**2 - z.y**2, 2*z.x*z.y) + c
        if z.x**2+z.y**2 >= 4:
            break
        counter += 1
    return counter

@ti.kernel
def run_taichi_kernel(field: ti.template(), out: ti.template()):
    for i, j in field:
        out[i, j] = taichi_func(field[i, j])

def calculate_taichi_mandelbrot(complex_array):
    field = ti.Vector.field(n=2, dtype=ti.f32, shape=(4320, 7680))
    field.from_numpy(complex_array)
    out = ti.field(dtype=ti.i32, shape=(4320, 7680))
    run_taichi_kernel(field, out)
    return out.to_numpy()
ti.init(arch=ti.cpu)

x, y = np.meshgrid(np.linspace(-2, 1.2, 7680), np.linspace(-1.2, 1.2, 4320))
complex_array = np.stack([x, y], axis=2)

z=calculate_taichi_mandelbrot(complex_array)

# Warpのコード例
@wp.func
def warp_func(c: wp.vec2) -> wp.int32:
    counter = wp.int32(0)
    z = wp.vec2(0.0, 0.0)
    for i in range(1000):
        z = wp.vec2(z[0]*z[0] - z[1]*z[1], 2.0*z[0]*z[1]) + c
        if z[0]*z[0]+z[1]*z[1] >= 4.0:
            break
        counter += 1
    return counter

@wp.kernel
def warp_kernel(field: wp.array2d(dtype=wp.vec2), out: wp.array2d(dtype=wp.int32)):
    i, j = wp.tid()
    out[i, j] = warp_func(field[i, j])

def calculate_warp_mandelbrot(complex_array, device):
    field=wp.array(complex_array, dtype=wp.vec2, device=device)
    out=wp.zeros(shape=(4320, 7680), dtype=wp.int32, device=device)
    wp.launch(kernel=warp_kernel, dim=(4320, 7680), inputs=[field], outputs=[out], device=device)
    return out.numpy()

device = "cpu"

z=calculate_warp_mandelbrot(complex_array, device)

# JAXのコード例
def jax_func(c):
    def body(i, carry):
        counter, z = carry
        z = jnp.where(z.real**2 + z.imag**2 >= 4, z, z**2 + c)
        counter = jnp.where(z.real**2 + z.imag**2 >= 4, counter, counter + 1)
        return counter, z
    final_counter, _ = jax.lax.fori_loop(0, 1000, body, (0, 0.+0j))
    return final_counter

x, y = jnp.meshgrid(jnp.linspace(-2, 1.2, 7680, dtype=np.float32), jnp.linspace(-1.2, 1.2, 4320, dtype=np.float32))
complex_array = x + y * 1j

calculate_jax_mandelbrot = jax.jit(jnp.vectorize(jax_func), backend="cpu")
z = calculate_jax_mandelbrot(complex_array).block_until_ready()

速度比較結果

Kaggleノートブック上で、 NVIDIA Tesla P100のアクセラレータを選択して実行しました。

  • CPU最速: Numba
  • GPU最速: Warp
  • 最遅: JAX(CPU/GPUともに)
Numba Taichi Warp JAX
CPU 7.80s 13.6s 31.8s 63.5s
GPU 0.837s 0.708s 0.310s 2.34s

download.png

フレームワーク選択の指針

  1. 既存コードがNumPyベースの場合→Numba
    • Numbaが最適です。簡単な変更で大幅な高速化が可能です
  2. NVIDIA GPU専用環境の場合→Warp
    • Warpが最も効率的です。特にPyTorchやJAXとの連携が必要な場合にも有効です
  3. ポータビリティ重視の場合→Taichi
    • Taichiは複数GPUバックエンドをサポートしており、幅広い環境で利用可能です
  4. 巨大テンソル計算や機械学習モデルの場合→JAX
    • JAXは今回のベンチマークでは遅かったものの、自動微分やSPMD並列処理が強みとなるため、特定用途では有効だと思います
    • ただし、本ベンチマークのようなステートフルな計算(かつ遷移回数も多い場合)には不向きと考えられます

物理シミュレーションプラットフォームで選ぶのもいいかもしれません。
Omnivese⇔Warp
Genesis⇔Taichi
(資本的にはNVIDIA Omniverseがよさそうだけど、NVIDIAへの依存度が...)

おまけ(実装上の注意点)

共通の最適化ポイント

  • データ転送の最小化: CPU-GPU間のデータ転送は極力減らす
  • メモリアクセスパターン: 連続したメモリアクセスを心がける

フレームワーク別のヒント

  • Numba: parallel=Trueオプションでマルチコア活用、fastmath=Trueで数値精度と引き換えに速度向上(するかも)
  • Taichi: 適切なデータ構造(sparse/denseなど)選択が重要
  • Warp: NVIDIAの最新GPUドライバー利用でパフォーマンス向上
  • JAX: 関数型スタイルを採用し、状態変化を最小限に
  1. https://docs.taichi-lang.org/docs/sparse

  2. 「Single Program, Multiple Data」の略称。複数のプロセッサが同じプログラムを実行しながら、異なるデータに対して処理を行う並列計算のモデル。

  3. 例えば、The Computer Language Benchmarks Gameで使われています。

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?