1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

離散的な行動でも使える Soft Actor Critic

Last updated at Posted at 2021-07-07

概要

 土木の分野では、機械の制御(ダムの制御)や都市開発(交通×AI)などで、強化学習が使われ始めています。初めてニューラルネットワークを用いて強化学習を構築する場合、最初に構築するアルゴリズムは、やはり深層Q学習(DQN)だろうと思われます。

 DQNのようにQ関数だけで行動と評価を行うのではなく、Actor Critic のように方策とQ関数を分けて行動と評価を行うほうが、サンプル効率が良いのではないかと(なんとなく)考えて、調べてみると、Soft Actor Critic の離散バージョンがありました。サンプル効率が良いかどうかはわかりませんが(連続バージョンのSoft Actor Criticよりはサンプル効率が良いらしい)、ゲーム環境によっては、Rainbowより精度が良いとのこと。

 Soft Actor Criticと離散バージョンのSoft Actor Criticとの違いをかなり大雑把に述べると、ガウス分布かカテゴリカル分布かの違いです。今回は、Soft Actor Criticについて簡単に説明し、離散的なSoft Actor Criticについて紹介したいと思います。
 
 
 論文は、以下を引用しています。

 Soft Actor Criticについては、以下の論文を引用しています。

 私より詳しい説明とPyTorchでの実装は@ku2482さんが説明されています。

Maximum entropy の方法

 最初にMaximum entropy の方法について簡単に説明する。
 通常の報酬の期待値は、ポリシー(方策)を$\pi$とし、経験分布$\rho_{\pi}$(と言っていいのだろうか)を使い

\sum_t \mathbb{E}_{(s_t,a_t) \sim \rho_{\pi}  }\left[r(s_t,a_t) \right]

と書ける。$s_t,a_t,r(s_t,a_t) $は状態、行動、報酬である。この式にエントロピー$\mathcal{H}$を加え、目的関数を

J(\pi)=\sum_{t=0}^{T} \mathbb{E}_{(s_t,a_t) \sim \rho_{\pi}  }\left[r(s_t,a_t) + \alpha \mathcal{H}(\pi(\cdot | s_t) )\right]

とする。ハイパーパラメータ$\alpha$は温度である。
 最適な方策$\pi^*$は、目的関数が最大となるような$\pi$であり

\DeclareMathOperator*{\argmax}{arg\,max} %argmax
\pi^* =\argmax_{\pi} \sum_{t=0}^{T} \mathbb{E}_{(s_t,a_t) \sim \rho_{\pi}  }\left[r(s_t,a_t) + \alpha \mathcal{H}(\pi(\cdot | s_t) )\right]

となるように求める。

Energy-Based Model

 Energy-Basedの方策モデルについて説明する。
 Energy-Basedの方策モデルは、物理でおなじみのカノニカル分布の形をしており、エネルギー関数 $\mathcal{E}$ を使い

\pi(a_t| s_t)\propto \exp\left\{-\mathcal{E}(a_t,s_t) \right\}

と書ける。Q学習の場合は、エネルギー関数 $\mathcal{E}$ を

\mathcal{E}(a_t,s_t) = - \frac{1}{\alpha} Q_{\rm{soft}}(a_t,s_t)

とする。

Soft Actor Critic のニューラルネットワークのパラメータの更新

 Soft Actor Critic のニューラルネットワークのパラメータの更新について説明する。つまり、Actor Critic は2つのニューラルネットワーク、方策 $\pi_{\phi}$ とQ関数 $Q_{\theta}$ のパラメータを更新する必要がある。

 最初に状態価値関数 $V(s_t)$ について説明する。Soft ベルマン演算子 $\mathcal{T}^{\pi}$を使い、$k+1$ステップにおけるQ関数を

Q^{k+1} = \mathcal{T}^{\pi} Q^k

とする。そして、状態価値関数 $V(s_t)$ は以下のように定義し、

V(s_t) = \mathbb{E}_{a_{t} \sim \pi} \left[Q(s_t,a_t) -\alpha \log\pi(a_t|s_t) \right]

Soft ベルマン演算子 $\mathcal{T}^{\pi}$は、状態価値関数を用いてQ関数を以下のように変換する。

\mathcal{T}^{\pi} Q(s_t,a_t) = r(s_t,a_t) + \gamma \mathbb{E}_{s_{t+1} \sim p}\left[V(s_{t+1}) \right]

 Q関数の収束については、報酬をエントロピー付きで定義すれば、

r_{\pi}(s_t,a_t) \equiv r(s_t,a_t)+\gamma \mathbb{E}_{s_{t+1} \sim p}\left[\mathcal{H}(\pi(\cdot | s_t) ) \right]

Soft ベルマン演算子 $\mathcal{T}^{\pi}$によるQ関数の更新は、

Q(s_t,a_t) \longleftarrow  r_{\pi}(s_t,a_t)  + \gamma \mathbb{E}_{s_{t+1} \sim p ,a_{t+1} \sim \pi}\left[Q(s_{t+1},a_{t+1}) \right]

と書け、(Soft ではない)通常のベルマン演算子を用いたQ関数の更新となることから、Q関数の収束について証明される(らしい)。
 
 次に Q 関数$Q_{\theta}(s_t,a_t)$のパラメータ$\theta$の更新について説明する。
  Q 関数$Q_{\theta}(s_t,a_t)$のパラメータ$\theta$を決めるための目的関数は、推定値 $Q$ とベルマン演算子から更新された値 $\hat{Q}$ との誤差から計算される。

J_Q(\theta) = \mathbb{E}_{(s_t,a_t)\sim D}\left[\frac{1}{2}\left(Q_{\theta}(s_{t},a_{t})-\hat{Q}(s_{t},a_{t}) \right)^2 \right]

ベルマン演算子から更新された値 $\hat{Q}$ とは、

\hat{Q}(s_t,a_t) = r(s_t,a_t) + \gamma \mathbb{E}_{s_{t+1} \sim p}\left[V_{\bar{\theta}}(s_{t+1}) \right]

を指す。$(s_t,a_t)\sim D$は Replay Bufferからサンプルされた状態・行動の組である。

 状態価値関数 $V(s_t)$ には $Q_{\theta}(s,a)$ を使うが、学習で勾配を更新する際は、パラメータ$\theta$は固定する。つまり、強化学習を安定させるために用いるtarget network の値を使う。target network のパラメータには、上付きbar(例えば$\bar{\theta}$)をつける。

 $J_Q(\theta)$ の勾配を計算すると、期待値を計算する前の微分を$\hat{\nabla}$として

\hat{\nabla}_{\theta} J_Q(\theta) = \nabla_{\theta} Q_{\theta}(s_t,a_t) \left( Q_{\theta}(s_{t},a_{t}) - r(s_t,a_t) - \gamma V_{\bar{\theta}}(s_{t+1}) \right)

と計算され、勾配法で$\theta$を更新する。target network のパラメータ$\bar{\theta}$は、

\bar{\theta} \longleftarrow \tau \theta + (1-\tau) \bar{\theta}

で更新する。

 最後に方策$\pi_{\phi}(a_t|s_t)$のパラメータ$\phi$の更新について説明する。
 方策$\pi_{\phi}(a_t|s_t)$のパラメータ$\phi$を決めるための目的関数は、カルバック・ライブラー情報量(KL divergence)を用いて定義される。

J_{\pi}(\phi) = \mathbb{E}_{s_t\sim D} \left[D_{\rm{KL}}\left(\pi_{\phi}(\cdot |s_t) 
 \Biggr|\Biggr| \frac{\exp\{\frac{1}{\alpha}Q_{\theta}(s_t,\cdot) \} }{Z_{\theta}(s_t)} \right) \right]

KL divergence は、2つの確率分布の距離(厳密には違うが)を表し、方策 $\pi$ は、Energy-Basedの方策モデルの確率分布に近づくように更新される。
 ただし、確率分布は誤差逆伝播法が使えないので、reparameterization trickを使う。つまり、行動 $a_t$ はガウス分布からのノイズ $\epsilon_t$ を使い、

a_t = f_{\phi}(\epsilon_t,s_t)

 KL divergenceの定義から目的関数は、以下のように書く。

J_{\pi}(\phi) = \mathbb{E}_{s_t\sim D ,\epsilon_t \sim \mathcal{N}} \left[\alpha\log\pi_{\phi}(f_{\phi}(\epsilon_t,s_t) |s_t) -
  Q_{\theta}(s_t,f_{\phi}(\epsilon_t,s_t)) +\alpha\log Z_{\theta}(s_t)  \right]

$\log Z_{\theta}(s_t)$は、パラメータの更新に効いてこないのでプログラムでは省く。
 $J_{\pi}(\phi)$ の勾配を計算すると、以下となる。

\hat{\nabla}_{\phi} J_{\pi}(\phi)= \alpha\nabla_{\phi}\log\pi_{\phi}(a_t|s_t) + \left(\alpha \nabla_{a_t}\log\pi_{\phi}(a_t|s_t)- \nabla_{a_t}Q_{\theta}(s_t,a_t)  \right)\nabla_{\phi}f_{\phi}(\epsilon_t,s_t)

また、$\nabla_{\phi}$は、$\pi_{\phi}$とreparameterization trickで使った$f_{\phi}$に作用する。そして、勾配法で$\phi$を更新する。

正規化流

 複雑なモデルにおいて、潜在変数の事後分布は、単純なガウス分布で近似することは難しい。したがって、複数回、微分可能で可逆な関数 $f$ を使い変換する事で、複雑な分布からのサンプルを取得する方法を正規化流という。正規化流について説明する。
 関数 $f$ で変換前の値を $z$ 返還後を $\hat{z}$とする。 つまり、

\hat{z} = f(z)

さらに $f$ は可逆なので、

z = f^{-1}(\hat{z})

となる。$\hat{z}$で微分すると

\mathrm{d}z = \frac{\mathrm{d}f^{-1}}{\mathrm{d}\hat{z}} \mathrm{d}\hat{z}

となる。次に確率分布$q(z)$として座標変換を考えると

\int \mathrm{d}\hat{z} q(\hat{z})= \int \mathrm{d}z q(z)= \int  \mathrm{d}\hat{z}  q(z)\left| \mathrm{det}\left\{\frac{\mathrm{d}f^{-1}}{\mathrm{d}\hat{z}} \right\} \right|
  

つまり、(逆関数の定理を使うと?)

q(\hat{z})= q(z)\left| \mathrm{det}\left\{\frac{\mathrm{d}f^{-1}}{\mathrm{d}\hat{z}} \right\} \right| =  q(z)\left| \mathrm{det}\left\{\frac{\mathrm{d}f}{\mathrm{d}z} \right\}^{-1} \right| 
  

さらに、対数をとり

\log q(\hat{z})= \log q(z) - \log  \mathrm{det}\left\{\frac{\mathrm{d}f}{\mathrm{d}z} \right\}
  

が得られる。
 可逆で微分可能な関数として$f(u)=\tanh(u)$を採用し

\frac{\mathrm{d}}{\mathrm{d}u} \tanh(u) = 1- \tanh^2(u)
  

$\mu(u|s)$ をガウス分布として、方策 $\pi(a|s)$ を

\log \pi(a|s)= \log \mu(u|s) - \log  \mathrm{det}\left\{\frac{\mathrm{d}f(u)}{\mathrm{d}u} \right\}
  

と表せば、$\log \mathrm{det} A = \mathrm{Tr}\log A $を用いて、最終的に方策$\pi(a|s)$は以下のように表せる。

\log \pi(a|s)= \log \mu(u|s) - \sum_{i=1}^D \log\left(1- \tanh^2(u_i) \right)  

温度の調整

 最後にハイパーパラメータであった温度$\alpha (\geq0)$の調整について説明する。強化学習は、ハイパーパラメータの調整がとても難しいので、学習で調整してくれるのはうれしい。

 任意の$t$に対して条件

 \mathbb{E}_{(s_t,a_t)\sim\rho_{\pi}} \left[-\log\pi_t(a_t|s_t)   \right] \geq \mathcal{H}

のもと

\underset{\pi[0:T]}{\max} \mathbb{E}_{\rho_{\pi}}\left[\sum_{t=0}^T r(s_t,a_t )\right]

を求めることを考える(条件付最適化)。この式を展開してみると

\underset{\pi_0}{\max}\left\{\mathbb{E}[r(s_0,a_0)]+\underset{\pi_1}{\max}\left\{ \mathbb{E}[r(s_1,a_1)]+\ldots \left\{\underset{\pi_T}{\max}\mathbb{E}[r(s_T,a_T)] \right\}\ldots\right\} \right\}

の形となる。

まず、$t = T$ に注目すると、条件

 \mathbb{E}_{(s_T,a_T)\sim\rho_{\pi}} \left[-\log\pi_T(a_T|s_T)   \right] -\mathcal{H}  \geq 0 

から、$\alpha_T$ を双対変数として、最適な方策 $\pi_T^{\ast}$ は以下の式を満たすように求める(双対問題?)。

\underset{\pi_T}{\max}\mathbb{E}[r(s_T,a_T)] = \underset{\alpha_T\geq 0}{\min} \left\{\underset{\pi_T}{\max} \mathbb{E}_{(s_T,a_T)\sim\rho_{\pi}}\left[r(s_T,a_T)- \alpha_T \log\pi_T(a_T|s_T)\right] - \alpha_T \mathcal{H} \right\}

 温度$\alpha_T$におけるMaximum entropy 方策を $\pi^*_T(a_T|s_T:\alpha_T)$ とすれば、最適な温度$\alpha_T^{\ast}$は以下のように求める。

\DeclareMathOperator*{\argmin}{arg\,min} %argmin
\alpha_T^* =\argmin_{\alpha_T} \mathbb{E}_{s_T,a_T\sim \pi_T^*}\left[- \alpha_T \log\pi_T^*(a_T|s_T:\alpha_T) - \alpha_T \mathcal{H} \right]

次に、$t=T-1$に注目すると、$t=T$における最適な Q 関数は以下であることに注意すれば

Q_{T}^{\ast}(s_{T},a_{T}) = \mathbb{E}[r(s_{T},a_{T})] 

Q 関数は、$t=T$において最適化された温度$\alpha_T^{\ast}$と方策$\pi^*_T$の結果を使えば、以下のように書ける。

\begin{align}
Q_{T-1}^{\ast}(s_{T-1},a_{T-1}) &= \mathbb{E}[r(s_{T-1},a_{T-1})] + \mathbb{E}\left[ Q_{T-1}^{\ast}(s_{T},a_{T}) - \alpha_T^{\ast} \log\pi_T^{\ast}(a_T|s_T) \right] \\
&= \mathbb{E}[r(s_{T-1},a_{T-1})] +\underset{\pi_T}{\max}\mathbb{E}[r(s_T,a_T)]+\alpha_T \mathcal{H}
\end{align}

また、最適な方策 $\pi_{T-1}^{\ast}$ は以下の式を満たすように求める。

\underset{\pi_{T-1}}{\max}Q_{T-1}^{\ast}(s_{T-1},a_{T-1}) = \underset{\alpha_{T-1}\geq 0}{\min} \left\{\underset{\pi_{T-1}}{\max} \mathbb{E}\left[Q_{T-1}^{\ast}(s_{T-1},a_{T-1})- \alpha_{T-1} \log\pi_{T-1}(a_{T-1}|s_{T-1})\right] - \alpha_{T-1} \mathcal{H}\right\}

上2つの式を組み合わせれば

\begin{align}
&\underset{\pi_{T-1}}{\max}Q_{T-1}^{\ast}(s_{T-1},a_{T-1})-\alpha_T^{\ast} \mathcal{H} = \underset{\pi_{T-1}}{\max}\left\{\mathbb{E}[r(s_{T-1},a_{T-1})] +\underset{\pi_T}{\max}\mathbb{E}[r(s_T,a_T)] \right\} \\
&=\underset{\alpha_{T-1}\geq 0}{\min} \left\{\underset{\pi_{T-1}}{\max} \mathbb{E}\left[Q_{T-1}^{\ast}(s_{T-1},a_{T-1})- \alpha_{T-1} \log\pi_{T-1}(a_{T-1}|s_{T-1})\right] - \alpha_{T-1} \mathcal{H}\right\}-\alpha_T^{\ast} \mathcal{H} 
\end{align}

が得られる。上式の1行目の右辺に注目すると、この操作を $t=0$まで行えば、

\underset{\pi[0:T]}{\max} \mathbb{E}_{\rho_{\pi}}\left[\sum_{t=0}^T r(s_t,a_t )\right]

が得られることが想像できる。したがって、最適な方策 $\pi_{t}^{\ast}$が求まり、条件を満たすように、最適な温度$\alpha_{t}^{\ast}$を決めるには、

\DeclareMathOperator*{\argmin}{arg\,min} %argmin
\alpha_{t}^* =\argmin_{\alpha_{t}} \mathbb{E}_{s_{t},a_{t}\sim \pi_{t}^*}\left[- \alpha_{t} \log\pi_{t}^*(a_{t}|s_{t}:\alpha_{t}) - \alpha_{t} \mathcal{H} \right]

となるように求めればよい。

 したがって、温度$\alpha$に関する目的関数は、以下のように構築し

J(\alpha) = \mathbb{E}_{a_{t} \sim \pi_t} \left[-\alpha \left(\log\pi(a_t|s_t)+  \mathcal{H}\right) \right] 

上式を最小となるように$\alpha$を勾配法で求めればよい。

離散的なSoft Actor Critic

 離散的なSoft Actor Criticは、Soft Actor Criticと相違点が5点ありそれを説明する。

(1). Q 関数の入出力
 Q関数は、$Q: S \times A \to \mathbb R$ から$Q: S \to\mathbb R^{|A|} $に変更。つまり、 $Q(s,a) \to Q(s)$ 。
Soft Actor Critic のQ 関数は状態と行動を入力していた。離散的なSoft Actor CriticのQ関数は状態のみを入力とする。

(2). 方策の出力
 方策は、$\pi: S \to \mathbb R^{2|A|}$ から$ S \to \mathbb [0,1]^{|A|} $に変更。
つまり、離散的なSoft Actor Criticは、Softmax関数を使い直接確率を出力させる。

(3). 状態価値関数
 Soft Actor Criticは、平均・分散を計算し、reparameterization trickを使い行動 $a$ をサンプリングしていた。

V(s_t) = \mathbb{E}_{a_{t} \sim \pi} \left[Q(s_t,a_t) -\alpha \log\pi(a_t|s_t) \right]

離散的なSoft Actor Criticは、方策の出力が直接確率になっているので、

V(s_t) = \pi^T(s_t) \left[Q(s_t) -\alpha \log\pi(s_t) \right]

と変更される。

(4). 温度調整のための目的関数
 行動 $a$ をサンプリングする必要がないので、(3)と同様に

J(\alpha) = \mathbb{E}_{a_{t} \sim \pi_t} \left[-\alpha \left(\log\pi(a_t|s_t)+  \mathcal{H}\right) \right] 

から

J(\alpha) =  \pi_t^T(s_t) \left[-\alpha \left(\log\pi(s_t)+  \mathcal{H}\right) \right] 

と変更される。

(5). 方策のパラメータを決定するための目的関数
 行動 $a$ をサンプリングする必要がないので、(3)、(4)と同様に

J_{\pi}(\phi) = \mathbb{E}_{s_t\sim D ,\epsilon_t \sim \mathcal{N}} \left[\alpha\log\pi_{\phi}(f_{\phi}(\epsilon_t,s_t) |s_t) -
  Q_{\theta}(s_t,f_{\phi}(\epsilon_t,s_t))  \right]

から

J_{\pi}(\phi) = \mathbb{E}_{s_t\sim D } \left[ \pi_t^T(s_t) \left(\alpha\log\pi_{\phi}(s_t) -
  Q_{\theta}(s_t) \right) \right]

と変更される。

最後に

まだ学習をさせてないので、結果は次回?

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?