2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

OSごとにCPU版、GPU版のPyTorchをそれぞれインストールするようにPoetryで管理する

Last updated at Posted at 2022-12-30

目的

M1 Mac(GPU未対応)でPyTorchのコードを書き、Linuxの計算機サーバのGPU(CUDA)上で動かしたい。
PyTorch公式で使われているpipではなく、PoetryでPyTorchを管理したい。
ただし、poetry lockでwhlファイルのダウンロードが発生しないようにしたい。

結論

  • torchを次のように指定
    • MacのCPU上、およびLinuxのCUDA11.6上でtorch1.13以上を動かしたいときの例
pyproject.toml
(前略)
[tool.poetry.dependencies]
(中略)
torch = [
    { version = ">=1.13+cpu", markers = "sys_platform == 'darwin'" },
    { version = ">=1.13+cu116", markers = "sys_platform == 'linux'" },
]
(後略)
  • 開発機(Mac)のターミナルで以下を入力
$ poetry lock
$ poetry export -o poetry-requirements.txt
  • poetry-requirements.txtをgitなどで同期し、Linuxの適切な仮想環境下で以下を入力(pipの場合)
$ pip install -r poetry-requirements.txt

ポイント

  • 環境ごとにOSを指定するには {..., markers = "sys_platform == 'OS_NAME'"}と記入(Poetry日本語ドキュメントPEP508)
    • {..., platform = "OS_NAME"}もサポートされている(StackOverflow, Issue)
      • 何故かドキュメントに情報がない(非推奨?)
    • OS_NAMEはここから選ぶ
      • Windows: win32
      • MacOS: darwin
      • Linux: linux
  • CPU/GPU指定について(PyTorch公式)
    • CPUを指定する場合は{ version = ">=VERSION+cpu", ... }
    • CUDA11.6を指定する場合は{ version = ">=VERSION+cu116", ... }

以下はこのメモ書き投稿に至った経緯です。

背景: pip管理からpoetry管理への移行のモチベーション

これまでpipでインストールを行なってきたが、インストールやバージョンアップ時に依存関係で面倒がよく生じていた:

  • tensorboardとprotobuf, grpcio, google-auth-oauthlibの関係
  • numbaがnumpy最新版に非対応など

pip list -oでバージョンアップ可能なパッケージが表示されるが実際には依存関係の問題でアップグレードできず、何となく気持ち悪かったので、自動的に依存関係を管理してくれるpoetryに移行することにした。

先行手法1: [[tool.poetry.source]]にレポジトリURLを指定(参考)

PyTorch公式のpipで推奨されている--extra-index-urlに近い方法。
しかし次のような問題点がある。

  • Authorization errorが大量に出る(その後、PyPIにフォールバックされて正常にインストールされる)
    • 何が起こっているかはこちらで詳しく書かれています。
  • poetry lockに恐ろしく時間がかかる
    • 同じくこちらに詳しく書かれています。
      • ただ最新版をインストールしてほしいだけなのに……

先行手法2: 直接whlファイルのURLを指定する(参考)

torchを次のように指定する:

pyproject.toml
torch = [
    { version = ">=1.13", markers = "sys_platform == 'darwin'" },
    { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp311-cp311-linux_x86_64.whl", markers = "sys_platform == 'linux'" },
]

しばらくはこちらで運用していたが、次の点が少し不満だった:

  • pytorchバージョンが固定されたwhlファイルのURLを指定する必要があるため、pytorchのバージョンアップに自動的に対応できない
  • whlファイル1つとはいえ約1.8GBあり、poetry lockに時間がかかる
    • 200Mbpsの環境でも2分程度

今回の方法: バージョン名でCPUやCUDAを指定する

公式PyTorchを見ると、CPUやCUDAの指定の別解として"torch==1.12.1+cu116"のように指定する方法が使われており、この記法を適用してみたところうまくいった。

pyproject.toml(再掲)
[tool.poetry.dependencies]
torch = [
    { version = ">=1.13+cpu", markers = "sys_platform == 'darwin'" },
    { version = ">=1.13+cu116", markers = "sys_platform == 'linux'" },
]

これによりpoetry lock時のwhlファイルのダウンロードを回避できただけでなく、torchの新バージョンのリリースで自動的にバージョンアップされるようになった。

終わりに

情報の誤り、改善点や補足情報、他の良い方法、同一手法の存在、その他何かありましたら遠慮なくコメントしていただけますと幸いです。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?