概要
土木の分野では、機械の制御(ダムの制御)や都市開発(交通×AI)などで、強化学習が使われ始めています。初めてニューラルネットワークを用いて強化学習を構築する場合、最初に構築するアルゴリズムは、やはり深層Q学習(DQN)だろうと思われます。
DQNのようにQ関数だけで行動と評価を行うのではなく、Actor Critic のように方策とQ関数を分けて行動と評価を行うほうが、サンプル効率が良いのではないかと(なんとなく)考えて、調べてみると、Soft Actor Critic の離散バージョンがありました。サンプル効率が良いかどうかはわかりませんが(連続バージョンのSoft Actor Criticよりはサンプル効率が良いらしい)、ゲーム環境によっては、Rainbowより精度が良いとのこと。
Soft Actor Criticと離散バージョンのSoft Actor Criticとの違いをかなり大雑把に述べると、ガウス分布かカテゴリカル分布かの違いです。今回は、Soft Actor Criticについて簡単に説明し、離散的なSoft Actor Criticについて紹介したいと思います。
論文は、以下を引用しています。
Soft Actor Criticについては、以下の論文を引用しています。
私より詳しい説明とPyTorchでの実装は@ku2482さんが説明されています。
Maximum entropy の方法
最初にMaximum entropy の方法について簡単に説明する。
通常の報酬の期待値は、ポリシー(方策)を$\pi$とし、経験分布$\rho_{\pi}$(と言っていいのだろうか)を使い
\sum_t \mathbb{E}_{(s_t,a_t) \sim \rho_{\pi} }\left[r(s_t,a_t) \right]
と書ける。$s_t,a_t,r(s_t,a_t) $は状態、行動、報酬である。この式にエントロピー$\mathcal{H}$を加え、目的関数を
J(\pi)=\sum_{t=0}^{T} \mathbb{E}_{(s_t,a_t) \sim \rho_{\pi} }\left[r(s_t,a_t) + \alpha \mathcal{H}(\pi(\cdot | s_t) )\right]
とする。ハイパーパラメータ$\alpha$は温度である。
最適な方策$\pi^*$は、目的関数が最大となるような$\pi$であり
\DeclareMathOperator*{\argmax}{arg\,max} %argmax
\pi^* =\argmax_{\pi} \sum_{t=0}^{T} \mathbb{E}_{(s_t,a_t) \sim \rho_{\pi} }\left[r(s_t,a_t) + \alpha \mathcal{H}(\pi(\cdot | s_t) )\right]
となるように求める。
Energy-Based Model
Energy-Basedの方策モデルについて説明する。
Energy-Basedの方策モデルは、物理でおなじみのカノニカル分布の形をしており、エネルギー関数 $\mathcal{E}$ を使い
\pi(a_t| s_t)\propto \exp\left\{-\mathcal{E}(a_t,s_t) \right\}
と書ける。Q学習の場合は、エネルギー関数 $\mathcal{E}$ を
\mathcal{E}(a_t,s_t) = - \frac{1}{\alpha} Q_{\rm{soft}}(a_t,s_t)
とする。
Soft Actor Critic のニューラルネットワークのパラメータの更新
Soft Actor Critic のニューラルネットワークのパラメータの更新について説明する。つまり、Actor Critic は2つのニューラルネットワーク、方策 $\pi_{\phi}$ とQ関数 $Q_{\theta}$ のパラメータを更新する必要がある。
最初に状態価値関数 $V(s_t)$ について説明する。Soft ベルマン演算子 $\mathcal{T}^{\pi}$を使い、$k+1$ステップにおけるQ関数を
Q^{k+1} = \mathcal{T}^{\pi} Q^k
とする。そして、状態価値関数 $V(s_t)$ は以下のように定義し、
V(s_t) = \mathbb{E}_{a_{t} \sim \pi} \left[Q(s_t,a_t) -\alpha \log\pi(a_t|s_t) \right]
Soft ベルマン演算子 $\mathcal{T}^{\pi}$は、状態価値関数を用いてQ関数を以下のように変換する。
\mathcal{T}^{\pi} Q(s_t,a_t) = r(s_t,a_t) + \gamma \mathbb{E}_{s_{t+1} \sim p}\left[V(s_{t+1}) \right]
Q関数の収束については、報酬をエントロピー付きで定義すれば、
r_{\pi}(s_t,a_t) \equiv r(s_t,a_t)+\gamma \mathbb{E}_{s_{t+1} \sim p}\left[\mathcal{H}(\pi(\cdot | s_t) ) \right]
Soft ベルマン演算子 $\mathcal{T}^{\pi}$によるQ関数の更新は、
Q(s_t,a_t) \longleftarrow r_{\pi}(s_t,a_t) + \gamma \mathbb{E}_{s_{t+1} \sim p ,a_{t+1} \sim \pi}\left[Q(s_{t+1},a_{t+1}) \right]
と書け、(Soft ではない)通常のベルマン演算子を用いたQ関数の更新となることから、Q関数の収束について証明される(らしい)。
次に Q 関数$Q_{\theta}(s_t,a_t)$のパラメータ$\theta$の更新について説明する。
Q 関数$Q_{\theta}(s_t,a_t)$のパラメータ$\theta$を決めるための目的関数は、推定値 $Q$ とベルマン演算子から更新された値 $\hat{Q}$ との誤差から計算される。
J_Q(\theta) = \mathbb{E}_{(s_t,a_t)\sim D}\left[\frac{1}{2}\left(Q_{\theta}(s_{t},a_{t})-\hat{Q}(s_{t},a_{t}) \right)^2 \right]
ベルマン演算子から更新された値 $\hat{Q}$ とは、
\hat{Q}(s_t,a_t) = r(s_t,a_t) + \gamma \mathbb{E}_{s_{t+1} \sim p}\left[V_{\bar{\theta}}(s_{t+1}) \right]
を指す。$(s_t,a_t)\sim D$は Replay Bufferからサンプルされた状態・行動の組である。
状態価値関数 $V(s_t)$ には $Q_{\theta}(s,a)$ を使うが、学習で勾配を更新する際は、パラメータ$\theta$は固定する。つまり、強化学習を安定させるために用いるtarget network の値を使う。target network のパラメータには、上付きbar(例えば$\bar{\theta}$)をつける。
$J_Q(\theta)$ の勾配を計算すると、期待値を計算する前の微分を$\hat{\nabla}$として
\hat{\nabla}_{\theta} J_Q(\theta) = \nabla_{\theta} Q_{\theta}(s_t,a_t) \left( Q_{\theta}(s_{t},a_{t}) - r(s_t,a_t) - \gamma V_{\bar{\theta}}(s_{t+1}) \right)
と計算され、勾配法で$\theta$を更新する。target network のパラメータ$\bar{\theta}$は、
\bar{\theta} \longleftarrow \tau \theta + (1-\tau) \bar{\theta}
で更新する。
最後に方策$\pi_{\phi}(a_t|s_t)$のパラメータ$\phi$の更新について説明する。
方策$\pi_{\phi}(a_t|s_t)$のパラメータ$\phi$を決めるための目的関数は、カルバック・ライブラー情報量(KL divergence)を用いて定義される。
J_{\pi}(\phi) = \mathbb{E}_{s_t\sim D} \left[D_{\rm{KL}}\left(\pi_{\phi}(\cdot |s_t)
\Biggr|\Biggr| \frac{\exp\{\frac{1}{\alpha}Q_{\theta}(s_t,\cdot) \} }{Z_{\theta}(s_t)} \right) \right]
KL divergence は、2つの確率分布の距離(厳密には違うが)を表し、方策 $\pi$ は、Energy-Basedの方策モデルの確率分布に近づくように更新される。
ただし、確率分布は誤差逆伝播法が使えないので、reparameterization trickを使う。つまり、行動 $a_t$ はガウス分布からのノイズ $\epsilon_t$ を使い、
a_t = f_{\phi}(\epsilon_t,s_t)
KL divergenceの定義から目的関数は、以下のように書く。
J_{\pi}(\phi) = \mathbb{E}_{s_t\sim D ,\epsilon_t \sim \mathcal{N}} \left[\alpha\log\pi_{\phi}(f_{\phi}(\epsilon_t,s_t) |s_t) -
Q_{\theta}(s_t,f_{\phi}(\epsilon_t,s_t)) +\alpha\log Z_{\theta}(s_t) \right]
$\log Z_{\theta}(s_t)$は、パラメータの更新に効いてこないのでプログラムでは省く。
$J_{\pi}(\phi)$ の勾配を計算すると、以下となる。
\hat{\nabla}_{\phi} J_{\pi}(\phi)= \alpha\nabla_{\phi}\log\pi_{\phi}(a_t|s_t) + \left(\alpha \nabla_{a_t}\log\pi_{\phi}(a_t|s_t)- \nabla_{a_t}Q_{\theta}(s_t,a_t) \right)\nabla_{\phi}f_{\phi}(\epsilon_t,s_t)
また、$\nabla_{\phi}$は、$\pi_{\phi}$とreparameterization trickで使った$f_{\phi}$に作用する。そして、勾配法で$\phi$を更新する。
正規化流
複雑なモデルにおいて、潜在変数の事後分布は、単純なガウス分布で近似することは難しい。したがって、複数回、微分可能で可逆な関数 $f$ を使い変換する事で、複雑な分布からのサンプルを取得する方法を正規化流という。正規化流について説明する。
関数 $f$ で変換前の値を $z$ 返還後を $\hat{z}$とする。 つまり、
\hat{z} = f(z)
さらに $f$ は可逆なので、
z = f^{-1}(\hat{z})
となる。$\hat{z}$で微分すると
\mathrm{d}z = \frac{\mathrm{d}f^{-1}}{\mathrm{d}\hat{z}} \mathrm{d}\hat{z}
となる。次に確率分布$q(z)$として座標変換を考えると
\int \mathrm{d}\hat{z} q(\hat{z})= \int \mathrm{d}z q(z)= \int \mathrm{d}\hat{z} q(z)\left| \mathrm{det}\left\{\frac{\mathrm{d}f^{-1}}{\mathrm{d}\hat{z}} \right\} \right|
つまり、(逆関数の定理を使うと?)
q(\hat{z})= q(z)\left| \mathrm{det}\left\{\frac{\mathrm{d}f^{-1}}{\mathrm{d}\hat{z}} \right\} \right| = q(z)\left| \mathrm{det}\left\{\frac{\mathrm{d}f}{\mathrm{d}z} \right\}^{-1} \right|
さらに、対数をとり
\log q(\hat{z})= \log q(z) - \log \mathrm{det}\left\{\frac{\mathrm{d}f}{\mathrm{d}z} \right\}
が得られる。
可逆で微分可能な関数として$f(u)=\tanh(u)$を採用し
\frac{\mathrm{d}}{\mathrm{d}u} \tanh(u) = 1- \tanh^2(u)
$\mu(u|s)$ をガウス分布として、方策 $\pi(a|s)$ を
\log \pi(a|s)= \log \mu(u|s) - \log \mathrm{det}\left\{\frac{\mathrm{d}f(u)}{\mathrm{d}u} \right\}
と表せば、$\log \mathrm{det} A = \mathrm{Tr}\log A $を用いて、最終的に方策$\pi(a|s)$は以下のように表せる。
\log \pi(a|s)= \log \mu(u|s) - \sum_{i=1}^D \log\left(1- \tanh^2(u_i) \right)
温度の調整
最後にハイパーパラメータであった温度$\alpha (\geq0)$の調整について説明する。強化学習は、ハイパーパラメータの調整がとても難しいので、学習で調整してくれるのはうれしい。
任意の$t$に対して条件
\mathbb{E}_{(s_t,a_t)\sim\rho_{\pi}} \left[-\log\pi_t(a_t|s_t) \right] \geq \mathcal{H}
のもと
\underset{\pi[0:T]}{\max} \mathbb{E}_{\rho_{\pi}}\left[\sum_{t=0}^T r(s_t,a_t )\right]
を求めることを考える(条件付最適化)。この式を展開してみると
\underset{\pi_0}{\max}\left\{\mathbb{E}[r(s_0,a_0)]+\underset{\pi_1}{\max}\left\{ \mathbb{E}[r(s_1,a_1)]+\ldots \left\{\underset{\pi_T}{\max}\mathbb{E}[r(s_T,a_T)] \right\}\ldots\right\} \right\}
の形となる。
まず、$t = T$ に注目すると、条件
\mathbb{E}_{(s_T,a_T)\sim\rho_{\pi}} \left[-\log\pi_T(a_T|s_T) \right] -\mathcal{H} \geq 0
から、$\alpha_T$ を双対変数として、最適な方策 $\pi_T^{\ast}$ は以下の式を満たすように求める(双対問題?)。
\underset{\pi_T}{\max}\mathbb{E}[r(s_T,a_T)] = \underset{\alpha_T\geq 0}{\min} \left\{\underset{\pi_T}{\max} \mathbb{E}_{(s_T,a_T)\sim\rho_{\pi}}\left[r(s_T,a_T)- \alpha_T \log\pi_T(a_T|s_T)\right] - \alpha_T \mathcal{H} \right\}
温度$\alpha_T$におけるMaximum entropy 方策を $\pi^*_T(a_T|s_T:\alpha_T)$ とすれば、最適な温度$\alpha_T^{\ast}$は以下のように求める。
\DeclareMathOperator*{\argmin}{arg\,min} %argmin
\alpha_T^* =\argmin_{\alpha_T} \mathbb{E}_{s_T,a_T\sim \pi_T^*}\left[- \alpha_T \log\pi_T^*(a_T|s_T:\alpha_T) - \alpha_T \mathcal{H} \right]
次に、$t=T-1$に注目すると、$t=T$における最適な Q 関数は以下であることに注意すれば
Q_{T}^{\ast}(s_{T},a_{T}) = \mathbb{E}[r(s_{T},a_{T})]
Q 関数は、$t=T$において最適化された温度$\alpha_T^{\ast}$と方策$\pi^*_T$の結果を使えば、以下のように書ける。
\begin{align}
Q_{T-1}^{\ast}(s_{T-1},a_{T-1}) &= \mathbb{E}[r(s_{T-1},a_{T-1})] + \mathbb{E}\left[ Q_{T-1}^{\ast}(s_{T},a_{T}) - \alpha_T^{\ast} \log\pi_T^{\ast}(a_T|s_T) \right] \\
&= \mathbb{E}[r(s_{T-1},a_{T-1})] +\underset{\pi_T}{\max}\mathbb{E}[r(s_T,a_T)]+\alpha_T \mathcal{H}
\end{align}
また、最適な方策 $\pi_{T-1}^{\ast}$ は以下の式を満たすように求める。
\underset{\pi_{T-1}}{\max}Q_{T-1}^{\ast}(s_{T-1},a_{T-1}) = \underset{\alpha_{T-1}\geq 0}{\min} \left\{\underset{\pi_{T-1}}{\max} \mathbb{E}\left[Q_{T-1}^{\ast}(s_{T-1},a_{T-1})- \alpha_{T-1} \log\pi_{T-1}(a_{T-1}|s_{T-1})\right] - \alpha_{T-1} \mathcal{H}\right\}
上2つの式を組み合わせれば
\begin{align}
&\underset{\pi_{T-1}}{\max}Q_{T-1}^{\ast}(s_{T-1},a_{T-1})-\alpha_T^{\ast} \mathcal{H} = \underset{\pi_{T-1}}{\max}\left\{\mathbb{E}[r(s_{T-1},a_{T-1})] +\underset{\pi_T}{\max}\mathbb{E}[r(s_T,a_T)] \right\} \\
&=\underset{\alpha_{T-1}\geq 0}{\min} \left\{\underset{\pi_{T-1}}{\max} \mathbb{E}\left[Q_{T-1}^{\ast}(s_{T-1},a_{T-1})- \alpha_{T-1} \log\pi_{T-1}(a_{T-1}|s_{T-1})\right] - \alpha_{T-1} \mathcal{H}\right\}-\alpha_T^{\ast} \mathcal{H}
\end{align}
が得られる。上式の1行目の右辺に注目すると、この操作を $t=0$まで行えば、
\underset{\pi[0:T]}{\max} \mathbb{E}_{\rho_{\pi}}\left[\sum_{t=0}^T r(s_t,a_t )\right]
が得られることが想像できる。したがって、最適な方策 $\pi_{t}^{\ast}$が求まり、条件を満たすように、最適な温度$\alpha_{t}^{\ast}$を決めるには、
\DeclareMathOperator*{\argmin}{arg\,min} %argmin
\alpha_{t}^* =\argmin_{\alpha_{t}} \mathbb{E}_{s_{t},a_{t}\sim \pi_{t}^*}\left[- \alpha_{t} \log\pi_{t}^*(a_{t}|s_{t}:\alpha_{t}) - \alpha_{t} \mathcal{H} \right]
となるように求めればよい。
したがって、温度$\alpha$に関する目的関数は、以下のように構築し
J(\alpha) = \mathbb{E}_{a_{t} \sim \pi_t} \left[-\alpha \left(\log\pi(a_t|s_t)+ \mathcal{H}\right) \right]
上式を最小となるように$\alpha$を勾配法で求めればよい。
離散的なSoft Actor Critic
離散的なSoft Actor Criticは、Soft Actor Criticと相違点が5点ありそれを説明する。
(1). Q 関数の入出力
Q関数は、$Q: S \times A \to \mathbb R$ から$Q: S \to\mathbb R^{|A|} $に変更。つまり、 $Q(s,a) \to Q(s)$ 。
Soft Actor Critic のQ 関数は状態と行動を入力していた。離散的なSoft Actor CriticのQ関数は状態のみを入力とする。
(2). 方策の出力
方策は、$\pi: S \to \mathbb R^{2|A|}$ から$ S \to \mathbb [0,1]^{|A|} $に変更。
つまり、離散的なSoft Actor Criticは、Softmax関数を使い直接確率を出力させる。
(3). 状態価値関数
Soft Actor Criticは、平均・分散を計算し、reparameterization trickを使い行動 $a$ をサンプリングしていた。
V(s_t) = \mathbb{E}_{a_{t} \sim \pi} \left[Q(s_t,a_t) -\alpha \log\pi(a_t|s_t) \right]
離散的なSoft Actor Criticは、方策の出力が直接確率になっているので、
V(s_t) = \pi^T(s_t) \left[Q(s_t) -\alpha \log\pi(s_t) \right]
と変更される。
(4). 温度調整のための目的関数
行動 $a$ をサンプリングする必要がないので、(3)と同様に
J(\alpha) = \mathbb{E}_{a_{t} \sim \pi_t} \left[-\alpha \left(\log\pi(a_t|s_t)+ \mathcal{H}\right) \right]
から
J(\alpha) = \pi_t^T(s_t) \left[-\alpha \left(\log\pi(s_t)+ \mathcal{H}\right) \right]
と変更される。
(5). 方策のパラメータを決定するための目的関数
行動 $a$ をサンプリングする必要がないので、(3)、(4)と同様に
J_{\pi}(\phi) = \mathbb{E}_{s_t\sim D ,\epsilon_t \sim \mathcal{N}} \left[\alpha\log\pi_{\phi}(f_{\phi}(\epsilon_t,s_t) |s_t) -
Q_{\theta}(s_t,f_{\phi}(\epsilon_t,s_t)) \right]
から
J_{\pi}(\phi) = \mathbb{E}_{s_t\sim D } \left[ \pi_t^T(s_t) \left(\alpha\log\pi_{\phi}(s_t) -
Q_{\theta}(s_t) \right) \right]
と変更される。
最後に
まだ学習をさせてないので、結果は次回?