TL;DR
Chainer v3でHessian-vector productが実装できるようになったのでTRPOを実装しました。
Trust Region Policy Optimization
Trust Region Policy Optimization (TRPO)はICML2015で発表された深層強化学習アルゴリズムです。2015年はDeepMindのDQNがNatureに掲載された年であり、Nature版DQNと同じくらい古いアルゴリズムといえますが、行動空間が連続値の場合の深層強化学習では未だにSOTAの1つです。
TRPOは次の制約付き最適化問題を解いて確率的方策$\pi_\theta$のパラメータを更新します。
\text{maximize}_\theta ~ L_{\theta_\text{old}}(\theta)\\
\text{subject to} ~ \bar{D}_\text{KL}(\theta_\text{old},\theta) \le \delta,\\
\text{where} ~ L_{\theta_\text{old}}(\theta) = \mathbb{E}_{s \sim \rho_{\theta_\text{old}}, a \sim \pi_{\theta_\text{old}}} \Big[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A_{\theta_\text{old}}(s,a) \Big],\\
\bar{D}_\text{KL}(\theta_\text{old},\theta) = \mathbb{E}_{s \sim \rho_{\theta_\text{old}}} [ D_\text{KL}(\pi_{\theta_{old}}(\cdot|s)||\pi_\theta(\cdot|s))].
ここで最大化している$L_{\theta_\text{old}}(\theta)$は、真に最大化したい値($\pi_\theta$に従った場合の期待収益)の$\theta \simeq \theta_\text{old}$における近似になっており、KLダイバージェンスの制約はその近似の信頼領域を表しています。詳しくは元論文や、自分のICML2015読み会発表資料などを参照してください。
上の制約付き最適化問題を具体的にどうやって解くかというと、まず$\theta \simeq \theta_\text{old}$において$L$、$\bar{D}_\text{KL}$をそれぞれ1次近似、2次近似します:
L_{\theta_\text{old}}(\theta) \approx g \cdot (\theta - \theta_\text{old}) + L_{\theta_\text{old}}(\theta_\text{old}),\\
\bar{D}_\text{KL}(\theta_\text{old},\theta) \approx \frac{1}{2} (\theta - \theta_\text{old})^T
A(\theta - \theta_\text{old}),\\
\text{where} ~ g = \nabla_\theta L_{\theta_\text{old}}(\theta)\Bigr|_{\theta=\theta_\text{old}},\\
A_{ij} = \frac{\partial}{\partial \theta_i} \frac{\partial}{\partial \theta_j} \bar{D}_\text{KL}(\theta_\text{old},\theta)\Bigr|_{\theta=\theta_\text{old}}.
近似後の制約付き最適化問題を解き、解となる新しいパラメータ$\theta_\text{new}$を求めます:
\theta_\text{new} = \theta_\text{old} + \beta s,\\
\text{where} ~ s = A^{-1} g,\\
\beta = \sqrt{2 \delta / s^T A s}.
ここで求めた更新量$\beta s$はあくまで元の問題の近似解なので、以下を満たすかどうか実際に計算して確認し、どちらも満たすまで$\beta$を減らしていきます:
L_{\theta_\text{old}}(\theta_\text{new}) > L_{\theta_\text{old}}(\theta_\text{old}),\\
\bar{D}_\text{KL}(\theta_\text{old},\theta_\text{new}) < \delta.
こうして最終的に得られた更新量がTRPOの1イテレーションの更新量になります。
Hessian-vector productの実装
TRPOの更新量の計算のうち、$s = A^{-1} g$を求めるには、Conjugate Gradient法を使用します。そのためには、Hessian-vector product(ヘッセ行列とベクトルの積)を計算する必要があります。ヘッセ行列を求めてから直接その逆行列を求めてもいいのですが、パラメータ数が多くなると現実的ではなくなります。
Chainerではv2まで勾配の勾配を計算することができなかったので、Hessian-vector productの計算には有限差分法を用いるしかなかったのですが、v3からは勾配の勾配の計算がサポートされたのでそれを使って効率的に求めることができます。$y$の$\theta$についてのヘッセ行列とベクトル$u$の積は、$\nabla_\theta (u^T \nabla_\theta y)$として求めることができるので、これをそのままChainerで実装すると以下のようなコードになります。
import chainer
import chainer.functions as F
import numpy as np
def flatten_and_concat_variables(vs):
# Variableをつなげて1つのベクトルにする関数
return F.concat([F.flatten(v) for v in vs], axis=0)
# パラメータa, bからなるモデルを考える
a = chainer.Parameter(initializer=0, shape=2)
b = chainer.Parameter(initializer=0, shape=2)
params = [a, b]
size = a.size + b.size
# この式の、パラメータについてのヘッセ行列を考える
y = F.sum((2 * a - b) ** 2)
# yの勾配をつなげて1つのベクトルにする(後でさらに勾配を計算するのでenable_double_backprop=Trueが必要)
gy = chainer.grad([y], params, enable_double_backprop=True)
flat_gy = _flatten_and_concat_variables(gy)
print('gy:', flat_gy.data)
# ランダムなベクトルに対して、ヘッセ行列とベクトルの積を求めたい
vec = np.random.rand(size).astype(np.float32)
print('random vector', vec)
# ヘッセ行列を計算せずに、ヘッセ行列とベクトルの積だけ求める!
hvp = chainer.grad([F.sum(flat_gy * vec)], params)
flat_hvp = _flatten_and_concat_variables(hvp)
print('Hessian-vector product (smart)', flat_hvp.data)
def compute_hessian(flat_gy):
# 愚直にヘッセ行列を計算する関数(計算の確認のみに使用)
ggy = []
for i in range(flat_gy.size):
ggyi = chainer.grad([flat_gy[i]], params)
flat_ggyi = _flatten_and_concat_variables(ggyi)
ggy.append(flat_ggyi.data)
return np.stack(ggy)
h = compute_hessian(flat_gy)
print('Hessian:', h)
print('Hessian-vector product (dumb)', h.dot(vec))
# 結果は一致する
np.testing.assert_allclose(h.dot(vec), flat_hvp.data)
出力は以下のようになり、Hessian-vector productが計算できているのが確認できます。
gy: [ 0. 0. -0. -0.]
random vector [ 0.94696099 0.29959747 0.1657497 0.05266208]
Hessian-vector product (smart) [ 6.91268921 2.18613148 -3.4563446 -1.09306574]
Hessian: [[ 8. 0. -4. -0.]
[ 0. 8. -0. -4.]
[-4. 0. 2. -0.]
[ 0. -4. -0. 2.]]
Hessian-vector product (dumb) [ 6.91268921 2.18613148 -3.4563446 -1.09306574]
TRPOの実装
TRPOはChainerRLのエージェントとして実装しました(まだmasterにはマージされてません)。
https://github.com/muupan/chainerrl/blob/trpo/chainerrl/agents/trpo.py
https://github.com/muupan/chainerrl/blob/trpo/examples/gym/train_trpo_gym.py
試しにOpenAI GymのHumanoid-v1を学習させてみました。アドバンテージの推定のために方策だけでなく状態価値関数も別にニューラルネットで学習しており、設定は隠れ層のユニット数を400にしている以外は上のtrain_trpo_gym.py
と同一です。