18
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ChainerAdvent Calendar 2017

Day 24

ChainerでTrust Region Policy Optimization (TRPO)

Last updated at Posted at 2017-12-24

TL;DR

Chainer v3でHessian-vector productが実装できるようになったのでTRPOを実装しました。

out.gif

Trust Region Policy Optimization

Trust Region Policy Optimization (TRPO)はICML2015で発表された深層強化学習アルゴリズムです。2015年はDeepMindのDQNがNatureに掲載された年であり、Nature版DQNと同じくらい古いアルゴリズムといえますが、行動空間が連続値の場合の深層強化学習では未だにSOTAの1つです。

TRPOは次の制約付き最適化問題を解いて確率的方策$\pi_\theta$のパラメータを更新します。

\text{maximize}_\theta ~ L_{\theta_\text{old}}(\theta)\\
\text{subject to} ~ \bar{D}_\text{KL}(\theta_\text{old},\theta) \le \delta,\\
\text{where} ~ L_{\theta_\text{old}}(\theta) = \mathbb{E}_{s \sim \rho_{\theta_\text{old}}, a \sim \pi_{\theta_\text{old}}} \Big[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A_{\theta_\text{old}}(s,a) \Big],\\
\bar{D}_\text{KL}(\theta_\text{old},\theta) = \mathbb{E}_{s \sim \rho_{\theta_\text{old}}} [ D_\text{KL}(\pi_{\theta_{old}}(\cdot|s)||\pi_\theta(\cdot|s))].

ここで最大化している$L_{\theta_\text{old}}(\theta)$は、真に最大化したい値($\pi_\theta$に従った場合の期待収益)の$\theta \simeq \theta_\text{old}$における近似になっており、KLダイバージェンスの制約はその近似の信頼領域を表しています。詳しくは元論文や、自分のICML2015読み会発表資料などを参照してください。

上の制約付き最適化問題を具体的にどうやって解くかというと、まず$\theta \simeq \theta_\text{old}$において$L$、$\bar{D}_\text{KL}$をそれぞれ1次近似、2次近似します:

L_{\theta_\text{old}}(\theta) \approx g \cdot (\theta - \theta_\text{old}) + L_{\theta_\text{old}}(\theta_\text{old}),\\
\bar{D}_\text{KL}(\theta_\text{old},\theta) \approx \frac{1}{2} (\theta - \theta_\text{old})^T 
 A(\theta - \theta_\text{old}),\\
\text{where} ~ g = \nabla_\theta L_{\theta_\text{old}}(\theta)\Bigr|_{\theta=\theta_\text{old}},\\
 A_{ij} = \frac{\partial}{\partial \theta_i} \frac{\partial}{\partial \theta_j} \bar{D}_\text{KL}(\theta_\text{old},\theta)\Bigr|_{\theta=\theta_\text{old}}.

近似後の制約付き最適化問題を解き、解となる新しいパラメータ$\theta_\text{new}$を求めます:

\theta_\text{new} = \theta_\text{old} + \beta s,\\
\text{where} ~ s = A^{-1} g,\\
\beta = \sqrt{2 \delta / s^T A s}.

ここで求めた更新量$\beta s$はあくまで元の問題の近似解なので、以下を満たすかどうか実際に計算して確認し、どちらも満たすまで$\beta$を減らしていきます:

L_{\theta_\text{old}}(\theta_\text{new}) > L_{\theta_\text{old}}(\theta_\text{old}),\\
\bar{D}_\text{KL}(\theta_\text{old},\theta_\text{new}) < \delta.

こうして最終的に得られた更新量がTRPOの1イテレーションの更新量になります。

Hessian-vector productの実装

TRPOの更新量の計算のうち、$s = A^{-1} g$を求めるには、Conjugate Gradient法を使用します。そのためには、Hessian-vector product(ヘッセ行列とベクトルの積)を計算する必要があります。ヘッセ行列を求めてから直接その逆行列を求めてもいいのですが、パラメータ数が多くなると現実的ではなくなります。

Chainerではv2まで勾配の勾配を計算することができなかったので、Hessian-vector productの計算には有限差分法を用いるしかなかったのですが、v3からは勾配の勾配の計算がサポートされたのでそれを使って効率的に求めることができます。$y$の$\theta$についてのヘッセ行列とベクトル$u$の積は、$\nabla_\theta (u^T \nabla_\theta y)$として求めることができるので、これをそのままChainerで実装すると以下のようなコードになります。

import chainer
import chainer.functions as F
import numpy as np


def flatten_and_concat_variables(vs):
    # Variableをつなげて1つのベクトルにする関数
    return F.concat([F.flatten(v) for v in vs], axis=0)

# パラメータa, bからなるモデルを考える
a = chainer.Parameter(initializer=0, shape=2)
b = chainer.Parameter(initializer=0, shape=2)
params = [a, b]
size = a.size + b.size

# この式の、パラメータについてのヘッセ行列を考える
y = F.sum((2 * a - b) ** 2)

# yの勾配をつなげて1つのベクトルにする(後でさらに勾配を計算するのでenable_double_backprop=Trueが必要)
gy = chainer.grad([y], params, enable_double_backprop=True)
flat_gy = _flatten_and_concat_variables(gy)
print('gy:', flat_gy.data)

# ランダムなベクトルに対して、ヘッセ行列とベクトルの積を求めたい
vec = np.random.rand(size).astype(np.float32)
print('random vector', vec)

# ヘッセ行列を計算せずに、ヘッセ行列とベクトルの積だけ求める!
hvp = chainer.grad([F.sum(flat_gy * vec)], params)
flat_hvp = _flatten_and_concat_variables(hvp)
print('Hessian-vector product (smart)', flat_hvp.data)

def compute_hessian(flat_gy):
    # 愚直にヘッセ行列を計算する関数(計算の確認のみに使用)
    ggy = []
    for i in range(flat_gy.size):
        ggyi = chainer.grad([flat_gy[i]], params)
        flat_ggyi = _flatten_and_concat_variables(ggyi)
        ggy.append(flat_ggyi.data)
    return np.stack(ggy)

h = compute_hessian(flat_gy)
print('Hessian:', h)
print('Hessian-vector product (dumb)', h.dot(vec))


# 結果は一致する
np.testing.assert_allclose(h.dot(vec), flat_hvp.data)

出力は以下のようになり、Hessian-vector productが計算できているのが確認できます。

gy: [ 0.  0. -0. -0.]
random vector [ 0.94696099  0.29959747  0.1657497   0.05266208]
Hessian-vector product (smart) [ 6.91268921  2.18613148 -3.4563446  -1.09306574]
Hessian: [[ 8.  0. -4. -0.]
 [ 0.  8. -0. -4.]
 [-4.  0.  2. -0.]
 [ 0. -4. -0.  2.]]
Hessian-vector product (dumb) [ 6.91268921  2.18613148 -3.4563446  -1.09306574]

TRPOの実装

TRPOはChainerRLのエージェントとして実装しました(まだmasterにはマージされてません)。

https://github.com/muupan/chainerrl/blob/trpo/chainerrl/agents/trpo.py
https://github.com/muupan/chainerrl/blob/trpo/examples/gym/train_trpo_gym.py

試しにOpenAI GymのHumanoid-v1を学習させてみました。アドバンテージの推定のために方策だけでなく状態価値関数も別にニューラルネットで学習しており、設定は隠れ層のユニット数を400にしている以外は上のtrain_trpo_gym.pyと同一です。

400units_100M.png
学習曲線です。縦軸は100回の評価エピソードの平均収益です。

out.gif
元気よく走り回っています。

18
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?