はじめに
policy gradientを説明する多くの資料で方策勾配定理の証明が一部省略してあり気持ち悪かったので、行間を埋めました。
強化学習全般、およびpolicy gradient自体の説明はUS Berkeleyの強化学習の講義資料が分かりやすいです。本記事ではこの資料を補足する形で方策勾配定理の証明を行います。
方策勾配定理の証明はこの記事などにも詳しく書かれていますが、今回は少し違った形で証明を行います。
記号
$s_t$をstate、$a_t$をaction、$p(s_{t+1}|s_t, a_t)$を状態変化を表す確率分布(given)、$r(s_t,a_t)$を即時報酬、$\pi_\theta(a_t|s_t)$を方策(policy)とします。
また
s_{\leqq t} = (s_1, s_2,...,s_t) \\
a_{\leqq t} = (a_1, a_2,...,a_t) \\
s_{\geqq t} = (s_t, s_{t+1},...,s_{T+1}) \\
a_{\geqq t} = (a_t, a_{t+1},...,a_T) \\
とおきます。
方策$\pi_\theta$に従った行動した場合のstateとactionの列を$\tau$とします。
\tau = (s_1,a_1,...,s_T,a_T,s_{T+1})
ここで
p_\theta(\tau) = p(s_1)\prod_{t=1}^{T}p(s_{t+1}|s_t,a_t)\pi_\theta(a_t|s_t)
に注意します。
この$p_\theta(\tau)$による期待値を
E_{\pi_\theta}[f(\tau)] = \sum_{\tau}f(\tau)p_\theta(\tau)
と書くことにします。
#policy gradient
次のような累積報酬の期待値$J(\theta)$を最大化するような$\pi_\theta$のパラメータ$\theta$を見つけることを考えます。
\begin{align}
R(\tau)&\equiv \sum_{t=1}^{T}r(s_t,a_t) \\
J(\theta)&\equiv E_{\pi_\theta}[R(\tau)] \\
&=\sum_{\tau} R(\tau) p_\theta(\tau)
\end{align}
この目的を達成するため、policy gradientでは$\nabla_\theta J(\theta)$を計算し、勾配降下法で$J(\theta)$を最大化することを目指します。
方策勾配定理とはこの$\nabla_\theta J(\theta)$の計算方法を与える次のような定理を指します。
方策勾配定理
\begin{align}
\nabla_\theta J(\theta) &= E_{\pi_\theta}[\sum_{t=1}^T \nabla_\theta log(\pi_\theta(a_t|s_t))\sum_{t=1}^T r(s_t, a_t)] \tag{A}\\
&= E_{\pi_\theta}[\sum_{t=1}^T \nabla_\theta log(\pi_\theta(a_t|s_t))\sum_{t'=t}^T r(s_{t'}, a_{t'})] \tag{B} \\
&= E_{\pi_\theta}[\sum_{t=1}^T Q^{\pi_\theta}(s_t, a_t)\nabla_\theta log(\pi_\theta(a_t|s_t))] \tag{C}
\end{align}
ここで$Q^{\pi_\theta}(s,a)$は行動価値関数
\begin{align}
Q^{\pi_\theta}(s,a)&=E_{\pi_\theta}[\sum_{t'=t}^T r(s_{t'},a_{t'})|s_t=s,a_t=a] \\
&= \sum_{s_{\geqq t+1},a_{\geqq t+1}} \sum_{t'=t}^T r(s_{t'},a_{t'}) p(s_{t+1}|s, a)\prod_{l=t+1}^{T}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l)
\end{align}
を指します。
補題:期待値の性質
定理の証明に入る前に期待値の次の性質に注意します。
E_{\pi_\theta}[f(s_{\leqq t},a_{\leqq t})] = \sum_{s_{\leqq t},a_{\leqq t}} f(s_{\leqq t},a_{\leqq t}) \pi_\theta(a_t|s_t) p(s_1) \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \tag{a} \\
実際、
\begin{align}
E_{\pi_\theta}[f(s_{\leqq t},a_{\leqq t})] &= \sum_{s_{\leqq T+1}, a_{\leqq T}}f(s_{\leqq t},a_{\leqq t}) p(s_1)\prod_{l=1}^{T}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
&= \sum_{s_{\leqq T}, a_{\leqq T-1}} f(s_{\leqq t},a_{\leqq t}) p(s_1)\prod_{l=1}^{T-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \sum_{s_{T+1},a_{T}}p(s_{T+1}|s_{T}, a_{T})\pi_\theta(a_{T}|s_{T}) \\
&= \sum_{s_{\leqq T}, a_{\leqq T-1}} f(s_{\leqq t},a_{\leqq t}) p(s_1)\prod_{l=1}^{T-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
&... \\
&=\sum_{s_{\leqq t},a_{\leqq t}} f(s_{\leqq t},a_{\leqq t}) \pi_\theta(a_t|s_t) p(s_1) \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l)
\end{align}
となります。
定理の証明
まず
\begin{align}
\nabla_\theta J(\theta) &= \nabla_\theta\sum_{\tau}p_\theta(\tau)R(\tau) \\
&= \sum_{\tau}\nabla_\theta p_\theta(\tau)R(\tau) \\
&= \sum_{\tau}p_\theta(\tau)\nabla_\theta log(p_\theta(\tau))R(\tau) \\
&= E_{\pi_{\theta}}[\nabla_\theta log(p_\theta(\tau)) R(\tau)]
\end{align}
となります。ここで、
f \frac{d log(f)}{dx} = f \frac{df}{dx} \frac{1}{f} = \frac{df}{dx}
を使いました。
また、
\begin{align}
\nabla_\theta log(p_\theta(\tau)) &= \nabla_\theta log(\prod_t p(s_{t+1}|s_t, a_t)\pi_\theta (a_t|s_t)) \\
&= \nabla_\theta \{ \sum_t log(p(s_{t+1}|s_t, a_t)) + \sum_t log(\pi_\theta (a_t|s_t)) \} \\
&= \sum_t \nabla_\theta log(\pi_\theta (a_t|s_t))
\end{align}
なので
\nabla_\theta J(\theta) = E_{\pi_\theta}[\sum_{t=1}^T \nabla_\theta log(\pi_\theta(a_t|s_t))\sum_{t=1}^T r(s_t, a_t)]
となり、一つ目の式が示せました。
(A)から(B)の式変形は多くの資料で
過去の報酬はその時点の方策の良しあしに関係ないはずなので無視してよい
と説明されますが、次のようにして導けます。
$t'\leq t$ のとき
E_{\pi_\theta}[\nabla_\theta log(\pi_\theta(a_t|s_t)) r(s_{t'}, a_{t'})] = 0
を示せばよいですが、実際
\begin{align}
& E_{\pi_\theta}[\nabla_\theta log(\pi_\theta(a_t|s_t)) r(s_{t'}, a_{t'})] \\
&= \sum_{s_{\leqq t},a_{\leqq t}} \nabla_\theta log(\pi_\theta(a_t|s_t)) r(s_{t'}, a_{t'}) \pi_\theta(a_t|s_t) p(s_1) \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \hspace{10mm} (a)より\\
&= \sum_{s_{\leqq t},a_{\leqq t-1}} \sum_{a_t} \nabla_\theta \pi_\theta(a_t|s_t) r(s_{t'}, a_{t'}) p(s_1) \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
&= \sum_{s_{\leqq t},a_{\leqq t-1}} \nabla_\theta (\sum_{a_t} \pi_\theta(a_t|s_t)) r(s_{t'}, a_{t'}) p(s_1) \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
&= \sum_{s_{\leqq t},a_{\leqq t-1}} \nabla_\theta (1) r(s_{t'}, a_{t'}) p(s_1) \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
&= 0
\end{align}
となります。
最後に(B)から(C)への式変形を証明します。
(B)は(C)の行動価値関数をワンパスで近似している
とみなせますが、イコールが成り立ちます。
実際、
E_{\pi_\theta}[ \nabla_\theta log(\pi_\theta(a_t|s_t))\sum_{t'=t}^T r(s_{t'}, a_{t'})]
= E_{\pi_\theta}[Q^{\pi_\theta}(s_t, a_t) \nabla_\theta log(\pi_\theta(a_t|s_t))]
を示します。
\begin{align}
& E_{\pi_\theta}[Q^{\pi_\theta}(s_t, a_t) \nabla_\theta log(\pi_\theta(a_t|s_t))] \\
& = \sum_{s_{\leqq t},a_{\leqq t}} Q^{\pi_\theta}(s_t, a_t)\nabla_\theta log(\pi_\theta(a_t|s_t)) \pi_\theta(a_t|s_t) p(s_1) \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \hspace{10mm} (a)より\\
& = \sum_{s_{\leqq t},a_{\leqq t}} \nabla_\theta log(\pi_\theta(a_t|s_t)) \pi_\theta(a_t|s_t) p(s_1) \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
& \times \sum_{s_{\geqq t+1},a_{\geqq t+1}} (\sum_{t'=t}^T r(s_{t'},a_{t'})) p(s_{t+1}|s_t, a_t)\prod_{l=t+1}^{T}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
& = \sum_{s_{\leqq t},a_{\leqq t}}\sum_{s_{\geqq t+1},a_{\geqq t+1}} \nabla_\theta log(\pi_\theta(a_t|s_t))(\sum_{t'=t}^T r(s_{t'},a_{t'})) p(s_1) p(s_{t+1}|s_t, a_t)\pi_\theta(a_t|s_t) \\
& \times \prod_{l=1}^{t-1}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l)
\prod_{l=t+1}^{T}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
&=\sum_{s_{\leqq T}, a_{\leqq T-1}}\nabla_\theta log(\pi_\theta(a_t|s_t))(\sum_{t'=t}^T r(s_{t'},a_{t'}))p(s_1)\prod_{l=1}^{T}p(s_{l+1}|s_l, a_l)\pi_\theta(a_l|s_l) \\
&= E_{\pi_\theta}[ \nabla_\theta log(\pi_\theta(a_t|s_t))\sum_{t'=t}^T r(s_{t'}, a_{t'})]
\end{align}