はじめに
こんにちは、@shoot16625 です👍
今回は、NumPyより高速とされるJAXが、CPUリソースが限られている状態でも効果が発揮できるのか検証しました。
そもそもJAXが、ゴリゴリのモデルをGPUでバリバリ動かすために開発されたようなものなので、JAXでディープラーニングをはじめて欲しいものですが、今日ははじめませんw。実システムでは、小規模モデル(ベクトル演算など)を低リソースなコンテナで処理するという状況は多多あるかと思います。そんな状況でも処理速度が求められることもあるでしょう。その際に、JAXが利用できるやもしれません。
実験コード
実験コードは、 JAX入門~高速なNumPyとして使いこなすためのチュートリアル~ を拝借させていただきました。ありがとうございます。
.
├── Dockerfile
├── docker-compose.yml
├── main
│ └── sample1.py
├── poetry.lock # なしでも可
└── pyproject.toml
from functools import partial
from timeit import timeit
import jax.numpy as jnp
import numpy as np
from jax import jit
# (size, size)の行列を作ってMod計算
@partial(jit, static_argnums=(0,))
def jax_jit_mod(size):
x = jnp.arange(size, dtype=jnp.int32)
mat = x[None, :] * x[:, None] # (size, size)
return mat % 256
def numpy_mod(size):
x = np.arange(size, dtype=np.int32)
mat = x[None, :] * x[:, None]
return mat % 256
def main():
for i in range(4):
size = 10 ** (i + 1)
repeat = 10 ** (4 - i)
print("size =", size, "repeat =", repeat)
print(" np:", np.round(timeit(lambda: numpy_mod(size), number=repeat), 3))
print(
"jnp:",
jnp.round(
timeit(lambda: jax_jit_mod(size).block_until_ready(), number=repeat), 3
),
)
print()
if __name__ == "__main__":
main()
pyproject.toml
[tool.poetry]
name = "jax-sample"
version = "0.1.0"
description = ""
authors = ["Your Name <you@example.com>"]
[tool.poetry.dependencies]
python = "^3.9"
jax = "^0.3.0"
jaxlib = "^0.3.0"
numpy = "^1.22.2"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Dockerfile
FROM python:3.9.10-slim
WORKDIR /app
RUN pip install poetry==1.1.13 && \
poetry config virtualenvs.create false
COPY pyproject.toml ./
# COPY poetry.lock ./
RUN poetry install
COPY main ./main
docker-compose.yml
version: "2.0"
services:
app:
build:
context: ./
dockerfile: Dockerfile
environment:
- TF_CPP_MIN_LOG_LEVEL=2 # tensorflowのログを減らす
volumes:
- ./:/srv/
実行コマンド
docker-compose build
docker-compose run app python main/sample1.py
パフォーマンス比較
MacBook Pro の docker 上でCPU数を変更して検証しました。
処理としては、(size, size)の行列演算をrepeat回数行っています。
結果はNumPyとJAXの実行時間[s]です。
CPU:1
size = 10 repeat = 10000
np: 0.048
jnp: 0.058000002
size = 100 repeat = 1000
np: 0.042
jnp: 0.029000001
size = 1000 repeat = 100
np: 0.551
jnp: 0.034
size = 10000 repeat = 10
np: 4.699
jnp: 1.5230001
sizeが1000を超えてくると十分効果が期待できそうですね!!
10000*10000の行列演算でも3倍は高速化できています。
CPU:4
size = 10 repeat = 10000
np: 0.049
jnp: 0.24200001
size = 100 repeat = 1000
np: 0.045
jnp: 0.109000005
size = 1000 repeat = 100
np: 0.583
jnp: 0.051000003
size = 10000 repeat = 10
np: 4.105
jnp: 0.73800004
sizeが大きい場合はCPUが増えることで高速化していますが、意外なことに、sizeが小さい場合はCPU:1のときより、遅くなっています。約4倍遅くなっているので、CPUへの割り振りによるオーバーヘッドなんでしょうか...
NumPyはCPU数が変化しても大差ないですね。
forまわり
またこんど
(for分処理が苦手と小耳に挟んだので検証)