9
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Trust Region Policy Optimization (TRPO) 実装のためのTips

Last updated at Posted at 2020-07-24

はじめに

TRPOを実装する際に詰まったところがあったのでそのメモです。あくまでメモですが、実装したいなーと思っている方の手助けとなればと思います。

Step1: TRPOの実装はいろいろ種類があることを知ろう

TRPOのオリジナルは2015年に提案されたこちら

Trust Region Policy Optimization

GAE (Generalized Advantage Estimation)の論文内で言及されているもの

High-Dimensional Continuous Control Using Generalized Advantage Estimation

PPO内で言及されているもの

Proximal Policy Optimization Algorithms

RL that Mattersで言及されているもの

Deep Reinforcement Learning that Matters

です。ただ、実装例を見ていくと、Chainerrl、著者の実装、Open AI gym baselineの実装はどれもGAEをアドバンテージの推定に使ってそうです。ネットワーク構造やパラメータは各実装でばらつきがあり、パラメータに敏感すぎる深層強化学習手法においてこれだけいろいろとあると、ちょっと弱ってしまいますが、GAEを使う実装が普及していることを知っておいた方が良さそうです。何もしらずにTRPOのオリジナルだけ読んで、実装しようとすると、参考になる実装例があまりないかと思います。chainerrlはRL that mattersのパラメータで再現されていますのでそれを信頼するのが一番良さそうです。

Step2: Appendix Cをちゃんと読もう、そのほかの説明記事を参考にしよう

深層強化学習論文のあるあるですが、TRPOの論文はAppendixC(特にAppendixC.1)を読まないと実装ができないようになっています。重要な理論パートを全部飛ばすとTRPOがやっているのは

  1. 下記式のold方策と更新後の方策のkl divergenceのパラメータに関するヘッシアンと近似収益のパラメータに関する勾配を算出
  2. 算出したヘッシアンと勾配を使い、共役勾配法でパラメータの更新値を算出
  3. そのパラメータ更新値が制約を満たすか、性能向上するかを確認しながら、直線探索

の3手順です。オリジナルの論文内でいうと重要な式は以下です。(muupanさんの記事より引用)

\text{maximize}\theta ~ L{\theta_\text{old}}(\theta)\
\text{subject to} ~ \bar{D}\text{KL}(\theta\text{old},\theta) \le \delta,\
\text{where} ~ L_{\theta_\text{old}}(\theta) = \mathbb{E}{s \sim \rho{\theta_\text{old}}, a \sim \pi_{\theta_\text{old}}} \Big[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A_{\theta_\text{old}}(s,a) \Big],\
\bar{D}\text{KL}(\theta\text{old},\theta) = \mathbb{E}{s \sim \rho{\theta_\text{old}}} [ D_\text{KL}(\pi_{\theta_{old}}(\cdot|s)||\pi_\theta(\cdot|s))].


もっと詳しい内容は、muupanさんの記事やスライドYu Ishiharaさんの記事が分かりやすいです。

https://qiita.com/muupan/items/34f3c89c100d414e9afc

https://www.slideshare.net/mooopan/trust-region-policy-optimization

https://speakerdeck.com/yuishihara/model-based-reinforcement-learning-for-atari

Open AIの記事も英語が読める方にとっては良いかと思います。

https://spinningup.openai.com/en/latest/algorithms/trpo.html

# Step3: Hessianを自動微分ライブラリで計算できるようにしよう、パラメータの持つ勾配の抜き出し方などを把握しよう

TRPOは語弊を恐れずに言うならパラメータの勾配に関する扱い方が普通の強化学習手法と異なります。通常のPytorchやtensorflow、nnablaなどに実装されているSolver(Optimizer)を用いる場合、パラメータやネットワークをSolverに渡してあげてこちらは、lossを計算し、backwardして、solverをupdateすればSolverが勝手にパラメータを更新してくれるので良いですが、TRPOの場合、各パラメータの勾配やヘッシアンを実際に抜き出し、こちらで更新量を計算する必要があります。

そのため各パラメータが保持している、勾配を抜き出すなどの処理を行う必要があります。ヘッシアンを計算するためには、Pytorchであれば、grad.autograd.grad、nnablaであればgradなどを使ってパラメータの勾配を計算し、その勾配を使ってさらに、backwardしてまたパラメータの勾配を算出することが必要です。そのためにも、grad関数(通常、引数はloss、パラメータで構成されているはず)の使い方を理解しましょう。

# Step4: Hessianとxの掛け算の結果を高速に計算できることを知って、実装しよう

パラメータに関するヘッシアンをナイーブに計算するとなると、時間もメモリも消費します(ネットワークのパラメータ数×ネットワークのパラメータ数)。そこで、それを工夫するための手法というかテクニックがあります。

https://www.telesens.co/2018/06/09/efficiently-computing-the-fisher-vector-product-in-trpo/

の記事のみていただくと分かりますが、テクニック次第で効率よく計算することができます。なお、この記事の内容は一応Appendix C.1に書かれていますが、分かりにくいので上記の記事がとっても参考になりました。

# Step5: 共役勾配法を理解しよう

TRPOはパラメータの更新量をもとめる際に共役勾配法を使います。なので、まず共役勾配法の実装が必要になります。なお、共役勾配法は、以下式を解くために使われます。$\boldsymbol{s}$は更新量を計算するために必要なパラメータのベクトルです。(なので、shape=パラメータ数です。よって、$\boldsymbol{g}$もそうです。)

```math
\boldsymbol{s} = \boldsymbol{H}^{-1}\boldsymbol{g}

ナイーブに計算するなら、ヘッシアンの算出+ヘッシアンの逆行列が必要になりますが、共役勾配法を使うことで、$\boldsymbol{H}\boldsymbol{s}$の計算できる関数と、$\boldsymbol{g}$(近似の期待収益のパラメータに関する勾配)があれば良く、上記の式を計算できます。よってナイーブなヘッシアンの計算がいらないというメリットがあります。Appendix中に記述がありますが、共役勾配法で10回反復したものを近似解とすることで、計算量や時間、メモリなどを節約しています。(RL that mattersでは20回になっています。)共役勾配法については、Wikipediaの解説が分かりやすいです。共役勾配法は最適化のための手法ですが、$x = A^{-1}b$のような連立方程式を解法することに使えます。

実装例をみていくと、著者実装参考

def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
    p = b.copy()
    r = b.copy()
    x = np.zeros_like(b)
    rdotr = r.dot(r)

    fmtstr =  "%10i %10.3g %10.3g"
    titlestr =  "%10s %10s %10s"
    if verbose: print titlestr % ("iter", "residual norm", "soln norm")

    for i in xrange(cg_iters):
        if callback is not None:
            callback(x)
        if verbose: print fmtstr % (i, rdotr, np.linalg.norm(x))
        z = f_Ax(p)
        v = rdotr / p.dot(z)
        x += v*p
        r -= v*z
        newrdotr = r.dot(r)
        mu = newrdotr/rdotr
        p = r + mu*p

        rdotr = newrdotr
        if rdotr < residual_tol:
            break

    if callback is not None:
        callback(x)
    if verbose: print fmtstr % (i+1, rdotr, np.linalg.norm(x))  # pylint: disable=W0631
    return x

この共役勾配法の関数の引数になっているAxがヘッシアンとベクトルの積を計算する関数になります。(多くの実装でHessian(fisher)-vector-productと名前が付けられています。)

ヘッシアンとベクトルの積は実際の実装例を見ていただければわかると思いますが、多くが、vectorを受け取り、そのvectorを使ってパラメータの勾配との積を計算し、その勾配を求めることでパラメータのヘッシアンを算出しています。chainerrlの実装を参考

def _hessian_vector_product(flat_grads, params, vec):
    """Compute hessian vector product efficiently by backprop."""
    grads = chainer.grad([F.sum(flat_grads * vec)], params)
    assert all(grad is not None for grad in grads),\
        "The Hessian-vector product contains None."
    grads_data = [grad.array for grad in grads]
    return _flatten_and_concat_ndarrays(grads_data)

なお、実際にこの関数を使うときに、

def fisher_vector_product_func(vec):
    fvp = _hessian_vector_product(flat_kl_grads, policy_params, vec)
    return fvp + self.conjugate_gradient_damping * vec

というようにdampingが存在していますが、

の3pの6式に書かれているように、これは正則化項で計算を安定させるためのものです。

Step6: ここまでで必ずテストを書こう

ここまで理解できれば、様々な方の実装例を参考にしながらTRPOのコアな部分は実装ができるかと思います。しかし、要素が多いのでテストを書くことをお勧めします。各要素が比較的独立しているのでtrpoはテストしやすいかと思います。

  1. 共役勾配法が正しく実装できているのか確認、$x = A^{-1}b$を正しく計算できるか(numpyのinvやpinvを使えばすぐ計算できるのでそれと比較する。)
  2. ヘッシアンの計算が正しく理解できているのか、適当な関数($x^2 + y^3 + ...$)を用いて手で計算したヘッシアンと比較し、自分が正しくライブラリを扱えているのかテスト。不安な場合は、自分の使っているライブラリと異なるライブラリで計算して比較する。
  3. ヘッシアンとベクトルの積を計算する関数があっているのかをmuupanさんの記事のようにテストしてみる、または、実際の使われ方のように、共役勾配法を含めて$x = A^{-1}b$を本当に正しく計算できるのかテスト。不安な場合は、自分の使っているライブラリと異なるライブラリで計算して比較する。
  4. 計算グラフを一応表示できるなにかのツールを使って目視で確認する、shapeのテストを行う。

STEP7: GAEやモンテカルロ推定を実装しよう

本記事では触れませんが、TRPOはQ値またはアドバンテージの推定が必要なので、その値を推定するアルゴリズムを実装しましょう。こちらもテストすることをおすすめします。

最後に

学習が上手くいくことを祈りましょう。。。

9
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?