0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

perfettoでjaxを使ったpythonプログラムのGPUにおける実行時間を計測してみた

Last updated at Posted at 2023-11-04

やりたいこと

jaxを使ったpythonプログラムのGPUにおける実行時間を、スクリプト・関数ごとに計測したい。できるだけ楽に。

環境

  • local
    • OS: macOS Monterey 12.3.1
    • python: 3.11.5
  • remote
    • OS: Ubuntu 20.04.6 LTS
    • python: 3.10.0

背景

私は物理シミュレーションの研究をしており、基本的にプログラムはpythonで用意し、jaxライブラリを用いたコードをリモートのGPUで回しています。研究の方向性が"新手法の開発"であるので、計算速度をできるだけ向上し既存の手法と比較したいです。そのためには、プログラムの各スクリプト、各関数にどの程度時間がかかっているとわかるとやりやすいのですが、計測したい部分にマニュアルでtime.time()やtimeit.timeit()を挿入するのは非常に手間がかかります。

そこで見つけたのが、
Profiling JAX programs
です。jaxを含むプログラムの実行時間の計測方法が紹介されています。

実行時間計測ツールの候補の手段2つ

上記のサイトには、PerfettoTensorboardという2つのツールが紹介されています。PerfettoはAndroidのアプリ開発でよく用いられているようで、シンプルに時間が計測されわかりやすそうです。Tensorboardはその名の通りtensorflowとセットで使われるのが一般的で、DNNのアーキテクチャと、実行時間・メモリ使用量を対応させて見ることができる優れもののようです。今回は両方試しました。

まずはローカルで

新しいツールをいきなりリモートで使うと壁にぶつかることが個人的には多いので、まずはlocal(CPU)で試したのちにリモートで行いました。公式サイトそのままの内容です。

test_perfetto.py
import jax

with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  # Run the operations to be profiled
  key = jax.random.PRNGKey(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

ここでjaxだけinstallされていれば十分かと思いましたが、tensorflowがないと以下のwarningが表示されますが、実行自体は問題なくできます。

E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace

以下のコマンドでtensorflowをinstallすると上記のwarningが消え、無事リンクが表示されます

$ pip instal tensorflow

その結果エラーは消え、無事時間計測を行うことができました。

$ python test_perfetto.py 
Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz

image.png

Tensorboardの場合は追加でpip install tensorboard-plugin-profileすることで同様に公式サイトのコードを実行することができました。

続いてリモートで

tensorflowが必須となると、gpuの場合はcudaのversionとの整合性の調整などで問題が発生しやすいので、ひとまずperffetoで行うことにしました。しっかりGPUとCPUに分かれて使用時間とメモリ消費が計測されています。

リモートで行う場合には、ssh接続の際にリモートとローカルのportを接続する必要があります。

$ ssh -L 9001:127.0.0.1:9001 <user>@<host>

先程のtest_perfetto.pyと同じスクリプトを実行することリンクが表示され、以下の画面を得ることができました。
image.png

わからなかったこと

perfettoはinteractiveにしか使えない、と公式サイトにあり、実際バッチジョブの形式では上手く実行できませんでした。今回は研究室所有の、interactiveジョブが可能なサーバーで行ったので問題ありませんでしたが、バッチジョブでも行えるとより便利だと思いました。今のところやり方はわかりません。

まとめ

pythonのGPU上の実行時間計測ツールとしてperfettoとtensorboardを検証しました。tensorboardについてはjax, tensorflowなどのversionの整合性をとるのが難しく断念し、perfettoを選択しました。まだ時間が計測される仕組みについては詳しく理解できていないので、より詳細な分析のために勉強したいです。

おまけ

localの仮想環境のyamlファイル

local.yaml
name: tensorflow2.14.0
channels:
  - defaults
dependencies:
  - bzip2=1.0.8=h1de35cc_0
  - ca-certificates=2023.08.22=hecd8cb5_0
  - libffi=3.4.4=hecd8cb5_0
  - ncurses=6.4=hcec6c5f_0
  - openssl=1.1.1w=hca72f7f_0
  - python=3.11.0=h1fd4e5f_3
  - readline=8.2=hca72f7f_0
  - setuptools=68.0.0=py311hecd8cb5_0
  - sqlite=3.41.2=h6c40b1e_0
  - tk=8.6.12=h5d9f67b_0
  - tzdata=2023c=h04d1e81_0
  - wheel=0.41.2=py311hecd8cb5_0
  - xz=5.4.2=h6c40b1e_0
  - zlib=1.2.13=h4dc903c_0
  - pip:
      - absl-py==2.0.0
      - anyio==4.0.0
      - appnope==0.1.3
      - argon2-cffi==23.1.0
      - argon2-cffi-bindings==21.2.0
      - arrow==1.3.0
      - asttokens==2.4.1
      - astunparse==1.6.3
      - async-lru==2.0.4
      - attrs==23.1.0
      - babel==2.13.1
      - beautifulsoup4==4.12.2
      - bleach==6.1.0
      - cachetools==5.3.2
      - certifi==2023.7.22
      - cffi==1.16.0
      - charset-normalizer==3.3.2
      - comm==0.1.4
      - debugpy==1.8.0
      - decorator==5.1.1
      - defusedxml==0.7.1
      - executing==2.0.1
      - fastjsonschema==2.18.1
      - flatbuffers==23.5.26
      - fqdn==1.5.1
      - gast==0.5.4
      - google-auth==2.23.4
      - google-auth-oauthlib==1.0.0
      - google-pasta==0.2.0
      - grpcio==1.59.2
      - gviz-api==1.10.0
      - h5py==3.10.0
      - idna==3.4
      - ipykernel==6.26.0
      - ipython==8.17.2
      - isoduration==20.11.0
      - jax==0.4.19
      - jaxlib==0.4.19
      - jedi==0.19.1
      - jinja2==3.1.2
      - json5==0.9.14
      - jsonpointer==2.4
      - jsonschema==4.19.2
      - jsonschema-specifications==2023.7.1
      - jupyter-client==8.5.0
      - jupyter-core==5.5.0
      - jupyter-events==0.8.0
      - jupyter-lsp==2.2.0
      - jupyter-server==2.9.1
      - jupyter-server-terminals==0.4.4
      - jupyterlab==4.0.7
      - jupyterlab-pygments==0.2.2
      - jupyterlab-server==2.25.0
      - keras==2.14.0
      - libclang==16.0.6
      - markdown==3.5.1
      - markupsafe==2.1.3
      - matplotlib-inline==0.1.6
      - mistune==3.0.2
      - ml-dtypes==0.2.0
      - nbclient==0.8.0
      - nbconvert==7.10.0
      - nbformat==5.9.2
      - nest-asyncio==1.5.8
      - notebook-shim==0.2.3
      - numpy==1.26.1
      - oauthlib==3.2.2
      - opt-einsum==3.3.0
      - overrides==7.4.0
      - packaging==23.2
      - pandocfilters==1.5.0
      - parso==0.8.3
      - pexpect==4.8.0
      - pip==23.3.1
      - platformdirs==3.11.0
      - prometheus-client==0.18.0
      - prompt-toolkit==3.0.39
      - protobuf==4.25.0
      - psutil==5.9.6
      - ptyprocess==0.7.0
      - pure-eval==0.2.2
      - pyasn1==0.5.0
      - pyasn1-modules==0.3.0
      - pycparser==2.21
      - pygments==2.16.1
      - python-dateutil==2.8.2
      - python-json-logger==2.0.7
      - pyyaml==6.0.1
      - pyzmq==25.1.1
      - referencing==0.30.2
      - requests==2.31.0
      - requests-oauthlib==1.3.1
      - rfc3339-validator==0.1.4
      - rfc3986-validator==0.1.1
      - rpds-py==0.10.6
      - rsa==4.9
      - scipy==1.11.3
      - send2trash==1.8.2
      - six==1.16.0
      - sniffio==1.3.0
      - soupsieve==2.5
      - stack-data==0.6.3
      - tensorboard==2.14.1
      - tensorboard-data-server==0.7.2
      - tensorboard-plugin-profile==2.14.0
      - tensorflow==2.14.0
      - tensorflow-estimator==2.14.0
      - tensorflow-io-gcs-filesystem==0.34.0
      - termcolor==2.3.0
      - terminado==0.17.1
      - tinycss2==1.2.1
      - tornado==6.3.3
      - traitlets==5.13.0
      - types-python-dateutil==2.8.19.14
      - typing-extensions==4.8.0
      - uri-template==1.3.0
      - urllib3==2.0.7
      - wcwidth==0.2.9
      - webcolors==1.13
      - webencodings==0.5.1
      - websocket-client==1.6.4
      - werkzeug==3.0.1
      - wrapt==1.14.1
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?