やりたいこと
jaxを使ったpythonプログラムのGPUにおける実行時間を、スクリプト・関数ごとに計測したい。できるだけ楽に。
環境
- local
- OS: macOS Monterey 12.3.1
- python: 3.11.5
- remote
- OS: Ubuntu 20.04.6 LTS
- python: 3.10.0
背景
私は物理シミュレーションの研究をしており、基本的にプログラムはpythonで用意し、jaxライブラリを用いたコードをリモートのGPUで回しています。研究の方向性が"新手法の開発"であるので、計算速度をできるだけ向上し既存の手法と比較したいです。そのためには、プログラムの各スクリプト、各関数にどの程度時間がかかっているとわかるとやりやすいのですが、計測したい部分にマニュアルでtime.time()やtimeit.timeit()を挿入するのは非常に手間がかかります。
そこで見つけたのが、
Profiling JAX programs
です。jaxを含むプログラムの実行時間の計測方法が紹介されています。
実行時間計測ツールの候補の手段2つ
上記のサイトには、PerfettoとTensorboardという2つのツールが紹介されています。PerfettoはAndroidのアプリ開発でよく用いられているようで、シンプルに時間が計測されわかりやすそうです。Tensorboardはその名の通りtensorflowとセットで使われるのが一般的で、DNNのアーキテクチャと、実行時間・メモリ使用量を対応させて見ることができる優れもののようです。今回は両方試しました。
まずはローカルで
新しいツールをいきなりリモートで使うと壁にぶつかることが個人的には多いので、まずはlocal(CPU)で試したのちにリモートで行いました。公式サイトそのままの内容です。
import jax
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# Run the operations to be profiled
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
ここでjaxだけinstallされていれば十分かと思いましたが、tensorflowがないと以下のwarningが表示されますが、実行自体は問題なくできます。
E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
以下のコマンドでtensorflowをinstallすると上記のwarningが消え、無事リンクが表示されます
$ pip instal tensorflow
その結果エラーは消え、無事時間計測を行うことができました。
$ python test_perfetto.py
Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
Tensorboardの場合は追加でpip install tensorboard-plugin-profile
することで同様に公式サイトのコードを実行することができました。
続いてリモートで
tensorflowが必須となると、gpuの場合はcudaのversionとの整合性の調整などで問題が発生しやすいので、ひとまずperffetoで行うことにしました。しっかりGPUとCPUに分かれて使用時間とメモリ消費が計測されています。
リモートで行う場合には、ssh接続の際にリモートとローカルのportを接続する必要があります。
$ ssh -L 9001:127.0.0.1:9001 <user>@<host>
先程のtest_perfetto.pyと同じスクリプトを実行することリンクが表示され、以下の画面を得ることができました。
わからなかったこと
perfettoはinteractiveにしか使えない、と公式サイトにあり、実際バッチジョブの形式では上手く実行できませんでした。今回は研究室所有の、interactiveジョブが可能なサーバーで行ったので問題ありませんでしたが、バッチジョブでも行えるとより便利だと思いました。今のところやり方はわかりません。
まとめ
pythonのGPU上の実行時間計測ツールとしてperfettoとtensorboardを検証しました。tensorboardについてはjax, tensorflowなどのversionの整合性をとるのが難しく断念し、perfettoを選択しました。まだ時間が計測される仕組みについては詳しく理解できていないので、より詳細な分析のために勉強したいです。
おまけ
localの仮想環境のyamlファイル
name: tensorflow2.14.0
channels:
- defaults
dependencies:
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2023.08.22=hecd8cb5_0
- libffi=3.4.4=hecd8cb5_0
- ncurses=6.4=hcec6c5f_0
- openssl=1.1.1w=hca72f7f_0
- python=3.11.0=h1fd4e5f_3
- readline=8.2=hca72f7f_0
- setuptools=68.0.0=py311hecd8cb5_0
- sqlite=3.41.2=h6c40b1e_0
- tk=8.6.12=h5d9f67b_0
- tzdata=2023c=h04d1e81_0
- wheel=0.41.2=py311hecd8cb5_0
- xz=5.4.2=h6c40b1e_0
- zlib=1.2.13=h4dc903c_0
- pip:
- absl-py==2.0.0
- anyio==4.0.0
- appnope==0.1.3
- argon2-cffi==23.1.0
- argon2-cffi-bindings==21.2.0
- arrow==1.3.0
- asttokens==2.4.1
- astunparse==1.6.3
- async-lru==2.0.4
- attrs==23.1.0
- babel==2.13.1
- beautifulsoup4==4.12.2
- bleach==6.1.0
- cachetools==5.3.2
- certifi==2023.7.22
- cffi==1.16.0
- charset-normalizer==3.3.2
- comm==0.1.4
- debugpy==1.8.0
- decorator==5.1.1
- defusedxml==0.7.1
- executing==2.0.1
- fastjsonschema==2.18.1
- flatbuffers==23.5.26
- fqdn==1.5.1
- gast==0.5.4
- google-auth==2.23.4
- google-auth-oauthlib==1.0.0
- google-pasta==0.2.0
- grpcio==1.59.2
- gviz-api==1.10.0
- h5py==3.10.0
- idna==3.4
- ipykernel==6.26.0
- ipython==8.17.2
- isoduration==20.11.0
- jax==0.4.19
- jaxlib==0.4.19
- jedi==0.19.1
- jinja2==3.1.2
- json5==0.9.14
- jsonpointer==2.4
- jsonschema==4.19.2
- jsonschema-specifications==2023.7.1
- jupyter-client==8.5.0
- jupyter-core==5.5.0
- jupyter-events==0.8.0
- jupyter-lsp==2.2.0
- jupyter-server==2.9.1
- jupyter-server-terminals==0.4.4
- jupyterlab==4.0.7
- jupyterlab-pygments==0.2.2
- jupyterlab-server==2.25.0
- keras==2.14.0
- libclang==16.0.6
- markdown==3.5.1
- markupsafe==2.1.3
- matplotlib-inline==0.1.6
- mistune==3.0.2
- ml-dtypes==0.2.0
- nbclient==0.8.0
- nbconvert==7.10.0
- nbformat==5.9.2
- nest-asyncio==1.5.8
- notebook-shim==0.2.3
- numpy==1.26.1
- oauthlib==3.2.2
- opt-einsum==3.3.0
- overrides==7.4.0
- packaging==23.2
- pandocfilters==1.5.0
- parso==0.8.3
- pexpect==4.8.0
- pip==23.3.1
- platformdirs==3.11.0
- prometheus-client==0.18.0
- prompt-toolkit==3.0.39
- protobuf==4.25.0
- psutil==5.9.6
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pyasn1==0.5.0
- pyasn1-modules==0.3.0
- pycparser==2.21
- pygments==2.16.1
- python-dateutil==2.8.2
- python-json-logger==2.0.7
- pyyaml==6.0.1
- pyzmq==25.1.1
- referencing==0.30.2
- requests==2.31.0
- requests-oauthlib==1.3.1
- rfc3339-validator==0.1.4
- rfc3986-validator==0.1.1
- rpds-py==0.10.6
- rsa==4.9
- scipy==1.11.3
- send2trash==1.8.2
- six==1.16.0
- sniffio==1.3.0
- soupsieve==2.5
- stack-data==0.6.3
- tensorboard==2.14.1
- tensorboard-data-server==0.7.2
- tensorboard-plugin-profile==2.14.0
- tensorflow==2.14.0
- tensorflow-estimator==2.14.0
- tensorflow-io-gcs-filesystem==0.34.0
- termcolor==2.3.0
- terminado==0.17.1
- tinycss2==1.2.1
- tornado==6.3.3
- traitlets==5.13.0
- types-python-dateutil==2.8.19.14
- typing-extensions==4.8.0
- uri-template==1.3.0
- urllib3==2.0.7
- wcwidth==0.2.9
- webcolors==1.13
- webencodings==0.5.1
- websocket-client==1.6.4
- werkzeug==3.0.1
- wrapt==1.14.1