1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

JAXではじめないディープラーニング

Posted at

はじめに

こんにちは、@shoot16625 です👍
今回は、NumPyより高速とされるJAXが、CPUリソースが限られている状態でも効果が発揮できるのか検証しました。

そもそもJAXが、ゴリゴリのモデルをGPUでバリバリ動かすために開発されたようなものなので、JAXでディープラーニングをはじめて欲しいものですが、今日ははじめませんw。実システムでは、小規模モデル(ベクトル演算など)を低リソースなコンテナで処理するという状況は多多あるかと思います。そんな状況でも処理速度が求められることもあるでしょう。その際に、JAXが利用できるやもしれません。

実験コード

実験コードは、 JAX入門~高速なNumPyとして使いこなすためのチュートリアル~ を拝借させていただきました。ありがとうございます。

ファイルツリー
.
├── Dockerfile
├── docker-compose.yml
├── main
│   └── sample1.py
├── poetry.lock # なしでも可
└── pyproject.toml
sample1.py
from functools import partial
from timeit import timeit

import jax.numpy as jnp
import numpy as np
from jax import jit


# (size, size)の行列を作ってMod計算
@partial(jit, static_argnums=(0,))
def jax_jit_mod(size):
    x = jnp.arange(size, dtype=jnp.int32)
    mat = x[None, :] * x[:, None]  # (size, size)
    return mat % 256


def numpy_mod(size):
    x = np.arange(size, dtype=np.int32)
    mat = x[None, :] * x[:, None]
    return mat % 256


def main():
    for i in range(4):
        size = 10 ** (i + 1)
        repeat = 10 ** (4 - i)
        print("size =", size, "repeat =", repeat)
        print(" np:", np.round(timeit(lambda: numpy_mod(size), number=repeat), 3))
        print(
            "jnp:",
            jnp.round(
                timeit(lambda: jax_jit_mod(size).block_until_ready(), number=repeat), 3
            ),
        )
        print()


if __name__ == "__main__":
    main()
pyproject.toml
pyproject.toml
[tool.poetry]
name = "jax-sample"
version = "0.1.0"
description = ""
authors = ["Your Name <you@example.com>"]

[tool.poetry.dependencies]
python = "^3.9"
jax = "^0.3.0"
jaxlib = "^0.3.0"
numpy = "^1.22.2"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Dockerfile
Dockerfile
FROM python:3.9.10-slim

WORKDIR /app

RUN pip install poetry==1.1.13 && \
    poetry config virtualenvs.create false

COPY pyproject.toml ./
# COPY poetry.lock ./
RUN poetry install

COPY main ./main
docker-compose.yml
docker-compose.yml
version: "2.0"

services:
  app:
    build:
      context: ./
      dockerfile: Dockerfile
    environment:
      - TF_CPP_MIN_LOG_LEVEL=2 # tensorflowのログを減らす
    volumes:
      - ./:/srv/
実行コマンド
docker-compose build
docker-compose run app python main/sample1.py 

パフォーマンス比較

MacBook Pro の docker 上でCPU数を変更して検証しました。
処理としては、(size, size)の行列演算をrepeat回数行っています。
結果はNumPyとJAXの実行時間[s]です。

CPU:1

size = 10 repeat = 10000
 np: 0.048
jnp: 0.058000002

size = 100 repeat = 1000
 np: 0.042
jnp: 0.029000001

size = 1000 repeat = 100
 np: 0.551
jnp: 0.034

size = 10000 repeat = 10
 np: 4.699
jnp: 1.5230001

sizeが1000を超えてくると十分効果が期待できそうですね!!
10000*10000の行列演算でも3倍は高速化できています。

CPU:4

size = 10 repeat = 10000
 np: 0.049
jnp: 0.24200001

size = 100 repeat = 1000
 np: 0.045
jnp: 0.109000005

size = 1000 repeat = 100
 np: 0.583
jnp: 0.051000003

size = 10000 repeat = 10
 np: 4.105
jnp: 0.73800004

sizeが大きい場合はCPUが増えることで高速化していますが、意外なことに、sizeが小さい場合はCPU:1のときより、遅くなっています。約4倍遅くなっているので、CPUへの割り振りによるオーバーヘッドなんでしょうか...
NumPyはCPU数が変化しても大差ないですね。

forまわり

またこんど
(for分処理が苦手と小耳に挟んだので検証)

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?