今回はTD-MPC2という最新のモデルベース強化学習モデルをローカルで動かす方法について説明します。
TD-MPC2がどのようなものであるかについては,論文を参照してください。
また,コードの使い方などについてはGitHubのReadMeを参照するといいでしょう。
- 想定している環境
OS | Windows11 Pro |
---|---|
GPU | RTX 3090 |
NvidiaドライバはWindows上に既にインストールされているとします。
# nvidia-smi
Mon Sep 30 14:11:02 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.02 Driver Version: 560.94 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 On | 00000000:01:00.0 On | N/A |
| 0% 53C P8 37W / 370W | 1563MiB / 24576MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
また,wsl上にUbuntu22.04が既にインストールされている事を前提とします。
セットアップ
パッケージのインストールとリポジトリのダウンロード
wslを開いて,ターミナルで以下のコマンドを順に実行していきます。
apt-get -y update && apt-get install -y --no-install-recommends build-essential git nano rsync vim tree curl wget unzip htop tmux xvfb patchelf ca-certificates bash-completion libjpeg-dev libpng-dev ffmpeg cmake swig libssl-dev libcurl4-openssl-dev libopenmpi-dev python3-dev zlib1g-dev qtbase5-dev qtdeclarative5-dev libglib2.0-0 libglu1-mesa-dev libgl1-mesa-dev libvulkan1 libgl1-mesa-glx libosmesa6 libosmesa6-dev libglew-dev mesa-utils && apt-get clean && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* && mkdir /root/.ssh
cd && mkdir SourceCode && cd SourceCode
git clone https://github.com/nicklashansen/tdmpc2.git
Anacondaのインストール
次にAnacondaをインストールします。
こちらのサイトの上の方にあるFree Download
をクリックするとダウンロードページが開きます。
自分の環境に合ったLinux用のインストーラーのダウンロードリンクをコピーしてください。
cd
mkdir Downloads && cd Downloads
wget コピーしたリンク
chmod +x Anaconda3-2024.06-1-Linux-x86_64.sh
./Anaconda3-2024.06-1-Linux-x86_64.sh
Anaconda3-2024.06-1-Linux-x86_64.sh
は名前がタイミングによって異なる場合があるので,確認してから実行してください。
インストール出来たら,wslを一度開きなおします。
Pythonライブラリのインストール
Gitリポジトリのdocker
ディレクトリ内にenvironment.yaml
が用意されていますが,そのまま実行するとうまくいかないので,編集しておきます。
13行目 pytorch=2.4.1
14行目 torchvision=0.19.1
編集出来たらconda環境を作りましょう。
cd ~/SourceCode/tdmpc2
conda env create -f docker/environment.yaml
実行するとtdmpc2
という環境が作られるはずです。
pipでgymとmeta-worldライブラリを導入します。
conda activate tdmpc2
pip install gym==0.21.0
pip install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
cudaのインストール
次にwsl上にcudaをインストールします。特に難しくないので説明は割愛します。
DMCを動かすだけなら必要ありませんが,Meta-Worldを動かしたかったので,Mujocoをインストールします。ついでにWandbのAPIキーも登録しておくといいでしょう。
cd
sudo nano .bashrc
# 末尾に追加
export MUJOCO_GL=egl
export LD_LIBRARY_PATH=$HOME/.mujoco/mujoco210/bin:/usr/lib/nvidia:$LD_LIBRARY_PATH
export WANDB_API_KEY=WandBのAPIキー
# 保存して閉じる
source .bashrc
mkdir -p $HOME/.mujoco
wget https://www.tdmpc2.com/files/mjkey.txt
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz
tar -xzf mujoco210-linux-x86_64.tar.gz
rm mujoco210-linux-x86_64.tar.gz
mv mujoco210 $HOME/.mujoco/mujoco210
mv mjkey.txt $HOME/.mujoco/mjkey.txt
以下のコマンドをエラーなく実行できればmujocoをインストールできています。
python -c "import mujoco_py"
テスト
evaluation
折角なので動作確認も兼ねて,公式が配布している重みを使ってMeta-Worldタスクを実行してみます。
こちらのページから好きな重みのダウンロードリンクをコピーします。今回はmw-door-open-1.pt
を使うことにします。
cd ~/SourceCode/tdmpc2
mkdir weights && cd weights
wget https://huggingface.co/nicklashansen/tdmpc2/resolve/main/metaworld/mw-door-open-1.pt
リンクの部分は適宜差し替えてください。リンクの末尾に?download=true
が付いている場合は削除してください。
cd ../tdmpc2
python evaluate.py task=mw-door-open checkpoint=/root/SourceCode/tdmpc2/weights/mw-door-open-1.pt
動かせました。
training
ついでに学習もできることを確認しておきます。
tdmpc2/tdmpc2/config.yaml
の'disable_wandb'をfalseにすればwandbでログを見る事ができます。
なお,config.yaml
のパラメータは実行時に引数として渡すことができます。
WandBを使う場合は,wandb_entity
にユーザー名,wandb_project
に適当なプロジェクト名を指定してください。
# python train.py task=mw-pick-place-wall
libEGL warning: MESA-LOADER: failed to open vgem: /usr/lib/dri/vgem_dri.so: cannot open shared object file: No such file or directory (search paths /usr/lib/x86_64-linux-gnu/dri:\$${ORIGIN}/dri:/usr/lib/dri, suffix _dri)
libEGL warning: NEEDS EXTENSION: falling back to kms_swrast
Work dir: /root/SourceCode/tdmpc2/tdmpc2/logs/mw-pick-place-wall/1/default
-------------------------------------------
Task: Mw Pick Place Wall
Steps: 10,000,000
Observations: [39]
Actions: 4
Experiment: default
-------------------------------------------
Wandb disabled.
Architecture: WorldModel(
(_encoder): ModuleDict(
(state): Sequential(
(0): NormedLinear(in_features=39, out_features=256, bias=True, act=Mish)
(1): NormedLinear(in_features=256, out_features=512, bias=True, act=SimNorm)
)
)
(_dynamics): Sequential(
(0): NormedLinear(in_features=516, out_features=512, bias=True, act=Mish)
(1): NormedLinear(in_features=512, out_features=512, bias=True, act=Mish)
(2): NormedLinear(in_features=512, out_features=512, bias=True, act=SimNorm)
)