4
2

More than 3 years have passed since last update.

モデルベースのアルゴリズム DreamerV2についての説明

Last updated at Posted at 2021-09-06

概要

 土木の分野では、機械の制御(ダムの制御)や都市開発(交通×AI)などで、強化学習が使われ始めています。

 強化学習のアルゴリズムには、モデルフリーとモデルベースがあります。モデルフリーの有名なアルゴリズムは、深層Q学習(DQN)、アクター・クリティック(Actor Critic) などがあります。モデルベースの有名なアルゴリズムは、AlphaGo、AlphaZero、MuZeroなどがあります。

 モデルフリーなアルゴリズムは、どのような行動すれば、高い報酬が得られるか探索する必要があり、サンプリングを多く行わなければならず、学習に時間がかかることも知られています。反対に、モデルベースなアルゴリズムは、サンプル効率が良いことが知られていますが、MuZeroのようなネットワークは、実用的ではないことが知られています。

 今回の記事では、モデルベースのアルゴリズムであるDreamerV2について解説したいと思います。Dreamerの利点は、モデルフリーに比べ、大量のサンプルは必要でなく、MuZeroに比べ実用的なところです。
 
 DreamerV2を一言で説明すると、world modelでモデルを構築し、Actor Critic をベースに行動を行うアルゴリズムです。

 以下の論文を引用しています。

 論文の著者である Danijar さんがコードを公開しています。

world model については、以下の論文を引用しています。

Information bottleneck の計算は、以下の2つの論文を引用しています。

モデルベースについて

 最初に、モデルフリーのアルゴリズムについて説明する。以下の図は、モデルフリーの概要を表した図である。

image.png

 モデルフリーのアルゴリズムは(方策・価値関数を決めるAIは)、直接環境(経験)から状態・報酬を受け取り、方策・価値関数を学習する。そして、行動を決定する。
 モデルフリーのアルゴリズムは、どのような行動をすれば高い報酬が得られるかわからない場合に有効である。しかし、どのような行動すれば、高い報酬が得られるか探索する必要があり、サンプリングを多く行わなければならず、学習に時間がかかることが知られている。

 次に、モデルベースのアルゴリズムについて説明する。以下の図は、モデルベースの概要を表した図である。
image.png

 モデルベースのアルゴリズムは、環境(経験)からモデルを構築する。方策・価値関数を決めるAIは、環境(経験)から状態・報酬を受け取らず、モデルから得られた状態・報酬を受け取り、方策・価値関数を学習する。このように、モデルを使い方策を改善していくことをプランニングと言う。

 モデルベースのアルゴリズムは、新しいデータを取り出す際に、モデルから取り出すことができ、モデルフリーのアルゴリズムと比べサンプル効率が良いと言われる。しかし、環境から得られるデータは、多いデータと少ないデータがあり、少ないデータについては学習されない可能性がある。さらに、方策・価値関数を決めるAIは、直接環境から学習するのではなく、モデルから推定された状態・報酬から学習されるので、実際の環境と大きくかけ離れ、誤差が大きくなる可能性がある。

Credit assignment problem

 Credit assignment problem とは、強化学習において、報酬が得られるステップと報酬に貢献した行動が、時間的に離れているので学習が難しくなる問題である。

 例えば、ブロック崩し(break out)の場合、ブロックを崩せば報酬が貰える。しかし、そのステップは特に重要でなく、ボールがバーにあてた時の行動が重要であるが、その時の報酬はゼロである。

 そのため、方策(ポリシー)はすぐに変更できることが望ましく、多くの強化学習のモデルにおいて、小さいネットワークである理由が、トレーニング中に、素早く、適切なポリシーに反復(改善)する必要があるためである。

 理想的には、大規模なネットワークであるVAEやRNNを使い、効率的に学習したい。 

World model

 World modelについて、簡単に説明する。
 使用されている図は論文から引用した。図のように、人間は、限られた感覚や知覚できるものに基づき、世界のメンタルモデルを開発する。その内部モデル(メンタルモデル?)に基づき、人間は行動している。内部モデルは、空間的側面と時間的側面を持ち抽象的な表現を学習している。

 スクリーンショット 2021-09-02 174604.png
 
 以上のことから、空間的な側面はVAEを、時間的な側面はRNNを使いWorld model を構築する。World model では、以下の図のように、VAEやRNNのような大きなネットワークと、行動を決める小さなネットワーク(コントローラー)からなる。 
 
 大規模なネットワークであるVAEやRNNを使うことで、高い表現力を学習する。そして、行動は小さなネットワークで行うので、Credit assignment problem に対応させることができ、効率的に学習が行える。

スクリーンショット 2021-09-02 174541.png

Information bottleneck objective

 
 World model の誤差関数として使われるInformation bottleneckについて説明する。

 最初に相互情報量について説明すると、相互情報量$I(X;Y)$は、確率変数$X$と$Y$の相互依存性の尺度であり

\begin{align}
I(X;Y)&\overset{def}{\equiv}\sum_{x\in X,y\in Y} P(x,y)\log\left(\frac{P(x,y)}{P(x)P(y)} \right) \\
&=\sum_{x\in X,y\in Y} P(x,y)\log\left(\frac{P(y|x)P(x)}{P(x)P(y)} \right) \\
&=\sum_{x\in X,y\in Y} P(x,y)\log\left(\frac{P(y|x)}{P(y)} \right) \\
\end{align}

である。2行目は公式$P(X,Y)=P(Y|X)P(X)$を使った。確率変数$X$と$Y$が独立であることは(依存性がないことは)、$P(X,Y)=P(Y)P(X)$であることなので、相互情報量はゼロになる。

\begin{align}
\log\left(\frac{P(x,y)}{P(x)P(y)} \right) = \log\left(\frac{P(x)P(y)}{P(x)P(y)} \right) = \log1 =0
\end{align}

 次に、Information bottleneck について説明する。
 確率変数$X,Y$について、VAEのように、$Y$を復元できるように$X$を$Z$に圧縮したいとする。$Z$は圧縮表現と呼ばれる。相互情報量をもとに、Information bottleneckを以下のように定義する。

\begin{align}
\max I(Y;Z) - \beta I(X;Z)
\end{align}

 Information bottleneckを最大化することを考える。学習で用いる際は、これにマイナスをつけて最小化する。
 $\max I(Y;Z)$は、圧縮表現$Z$と$Y$を依存させるようにすることで、圧縮表現$Z$から$Y$を復元しやすくする。$\beta I(X;Z)$は、小さくすることで圧縮表現$Z$と$X$を独立にし、情報を落としながら圧縮させる。圧縮表現$Z$が$X$に依存してしまうと、圧縮されていないこと意味し、つまり、$Z=X$を意味する。

 また、$\beta$を上手く調節する必要があり、$\beta$が小さすぎると、$I(X;Z)$は最大化に寄与しないので、$Z=X$となりやすく圧縮されない。$\beta$が大きすぎると、$I(X;Z)$は最大化に寄与しすぎてしまい、圧縮表現$Z$と$Y$が依存しにくくなるので($\max I(Y;Z)$が最大化に寄与しにくくなるので)、復元されない。

DreamerV2におけるWorld modelのネットワーク

 DreamerV2で使われるWorld model について説明する。以下使用されている図は、論文から引用している。World modelは、過去の経験から得られる$T$ステップまでの列、画像の列$x_{1:T}$ 、行動の列$a_{1:T}$、報酬の列$r_{1:T}$ 、割引率(discount forctor)の列$\gamma_{1:T}$ から学習される。論文では、バッチサイズ$B=50$、列の長さ$T=50$としている。割引率は、ステップに関係なく$\gamma=0.999$とする。(プログラムでは$\gamma=1$?)

スクリーンショット 2021-09-03 093852.png

 図のように、モデルはImage encoder、RSSM(Recurrent State-Space Model)、3つのpredictor(Image, Reward, Discount)からなる。 RSSMは、3つのモデルからなり

\begin{align}
\mbox{Recurrent model} \ \ :& \ \ h_t = f_{\phi}(h_{t-1},z_{t-1},a_{t-1}) \\
\mbox{Representation model} \ \ :&  \ \ z_t \sim q_{\phi}(z_t|h_t,x_t )  \\
\mbox{Transition predictor} \ \ :&  \ \ \hat{z}_t \sim p_{\phi}(\hat{z}_t|h_t)  \\
\end{align}

3つのpredictorは、

\begin{align}
\mbox{Image predictor} \ \ :& \ \ \hat{x}_t \sim p_{\phi}(\hat{x}_t|h_t,z_t ) \\
\mbox{Reward predictor} \ \ :&  \ \ \hat{r}_t \sim p_{\phi}(\hat{r}_t|h_t,z_t ) \\
\mbox{Discount predictor} \ \ :&  \ \ \hat{\gamma}_t \sim p_{\phi}(\hat{\gamma}_t |h_t,z_t ) \\
\end{align}

と表せる。Representation modelの$x_t$は、Image encoderから出力された潜在変数である。$\hat{x}_t$は、decoderによって復元された予測画像であり、$\hat{r}_t$は予測された報酬、 $\hat{\gamma}_t$は予測された割引率である。$z_t$は事後状態を表し、$\hat{z}_t$は事前状態を表す。

 Image predictor および Reward predictor ではガウス分布を使い、Discount predictorでは、ベルヌーイ分布を使い、Representation model および Transition predictor ではカテゴリカル分布を使う。上の図において、Representation model および Transition predictorから出力される$z_t$および$\hat{z}_t$は、32個の隠れ変数に対し32個のカテゴリーがある。

 ガウス分布の場合(DreamerV1の場合)、誤差逆伝播法を使うため reparameterization trick を使用していた。DreamerV2の論文では、カテゴリカル分布を用いるため、Straight-Through gradients with Automatic Differentiation を使用して学習を行う。アルゴリズムは以下の図である。
 スクリーンショット 2021-09-03 115806.png

World model の誤差関数

 World model の誤差関数は、以下の Information bottleneck を使う。

\begin{align}
\max I\left(z_{1:T};(\hat{x}_{1:T},\hat{r}_{1:T},\hat{\gamma}_{1:T}) |a_{1:T},h_{1:T} \right) - \beta I(z_{1:T};x_{1:T}|a_{1:T},h_{1:T})
\end{align}

上式からELBO(evidence lower bound)を求める。1項目を計算すると、

\begin{align}
& I\left(z_{1:T};(\hat{x}_{1:T},\hat{r}_{1:T},\hat{\gamma}_{1:T}) |a_{1:T},h_{1:T} \right) \\
&= \mathbb{E}\left[\log P(\hat{x}_{1:T},\hat{r}_{1:T},\hat{\gamma}_{1:T}|z_{1:T},a_{1:T},h_{1:T} )- \log P(\hat{x}_{1:T},\hat{r}_{1:T},\hat{\gamma}_{1:T}|a_{1:T},h_{1:T} )  \right] \\
& \overset{+}{=} \mathbb{E}\left[\log P(\hat{x}_{1:T},\hat{r}_{1:T},\hat{\gamma}_{1:T}|z_{1:T},a_{1:T},h_{1:T} )\right] \\
& \geq \mathbb{E}\left[\log P(\hat{x}_{1:T},\hat{r}_{1:T},\hat{\gamma}_{1:T}|z_{1:T},a_{1:T},h_{1:T} )\right] \\
&-\mbox{KL}\left[P(\hat{x}_{1:T},\hat{r}_{1:T},\hat{\gamma}_{1:T}|z_{1:T},a_{1:T},h_{1:T} )\big|\big| \underset{t}{\Pi} p_{\phi}(\hat{x}_t|h_t,z_t )  p_{\phi}(\hat{r}_t|h_t,z_t )  p_{\phi}(\hat{\gamma}_t|h_t,z_t ) \right] \\
& =\mathbb{E}\left[\sum_t \left\{ \log p_{\phi}(\hat{x}_t|h_t,z_t ) + \log p_{\phi}(\hat{r}_t|h_t,z_t ) + \log p_{\phi}(\hat{\gamma}_t|h_t,z_t ) \right\}\right]
\end{align}

と計算される。相互情報量の非負性から、最初の項を下限とする。$\overset{+}{=}$ は、Representation modelに関係ない項(2行目の2項目)を省いたことを意味する。KL情報量(Kullback–Leibler divergence)は非負であることを用いた。2項目を計算すると、

\begin{align}
&I(z_{1:T};x_{1:T}|a_{1:T},h_{1:T})\\
&= \mathbb{E}\left[\sum_t \left\{ \log P(z_t|z_{t-1},h_{t-1},a_{t-1},x_{t}) -  \log P(z_t|z_{t-1},h_{t-1},a_{t-1}) \right\} \right] \\
& \leq \mathbb{E}\left[\sum_t \left\{ \log q_{\phi}(z_t|z_{t-1},h_{t-1},a_{t-1},x_{t}) -  \log p_{\phi}({z}_t|z_{t-1},h_{t-1},a_{t-1}) \right\} \right] \\
& =\mathbb{E}\left[\sum_t \mbox{KL}\left[q_{\phi}(z_t|h_t,x_{t})\big|\big|p_{\phi}({z}_t|h_t) \right]      \right]
\end{align}

と計算される。途中で$h_t = f_{\phi}(h_{t-1},z_{t-1},a_{t-1})$を使った。また、数式では$x_t$は確率変数であるが、決定的に決まるとする。($x_t \sim \delta(x_t)$のように確率分布がデルタ関数的になっているとする。)3行目の不等式については、KL情報量が非負であることを用いた。

\begin{align}
\mbox{KL}\left[P(z_t|h_t)\big|\big|p_{\phi}({z}_t|h_t) \right] \geq 0  \ \ \Rightarrow \ \ \mathbb{E}\left[\log P(z_t|h_t)\right] \geq \mathbb{E}\left[\log p_{\phi}({z}_t|h_t)\right]
\end{align}

 以上の2つの結果を合わせればELBOが求まる。最終的に、誤差関数は、

\begin{align}
\mathcal{L}&= -\mathbb{E}\left[\sum_t\left\{\log p_{\phi}(\hat{x}_t|h_t,z_t ) + \log p_{\phi}(\hat{r}_t|h_t,z_t ) + \log p_{\phi}(\hat{\gamma}_t|h_t,z_t )  \right\}  \right] +\beta \mathbb{E}\left[ \mbox{KL}\left[q_{\phi}(z_t|h_t,x_{t})\big|\big|p_{\phi}(\hat{z}_t|h_t) \right]  \right] 
\end{align}

を最小化する。論文において、atariの場合は$\beta=0.1$、連続値を扱う制御の場合は$\beta=1.0$としている。

 誤差関数の2項目のKL情報量は、2つの目的があり、Representation modelに対して事前状態を学習させること、もう一つは、事前状態に対してRepresentation modelを正則化することである。(KL情報量が最小となるのは、確率 $q_{\phi}(z_t|h_t,x_{t})$ と $p_{\phi}(\hat z_t|h_t)$ が同じときである。おそらく、正則化とは、事後分布$q_{\phi}(z_t|h_t,x_{t})$ を事前分布$p_{\phi}(\hat z_t|h_t)$に近づけること。)

 ただし、事前状態の学習が不十分であるのに、Representation modelを正則化してほしくない。よって、$\alpha=0.8$とし、図のようなアルゴリズムで学習させる。
スクリーンショット 2021-09-03 152358.png
KL情報量を計算するときに、事前状態の学習に比重を置くことで、事前状態の学習を促進し、事前状態を早く学習させる。

Imagination MDP

 最後に、行動の部分にあたるActor Critic の学習について説明する。以下使用されている図は、論文から引用している。
 
スクリーンショット 2021-09-03 154833.png
 Actor Critic の学習は、Imagination MDP を使い学習される。Imagination MDPについて説明すると、最初に、World modelの学習中に使われたデータを使い、初期状態分布$\hat z_0$を生成させる。
(恐らく、論文で使われる$\hat{z}_t$は、$(h_t,z_t)$を意味している。World modelを学習させるとき、バッチサイズ$B=50$、列の長さ$T=50$としているので、初期状態分布から得られるサンプルは$50*50=2500$個ある。)

 imagination horizon $H=15$として、初期状態分布から図のように、transition predictor $p_{\phi}(\hat z_t|h_t)$ から事前状態の列$\hat z_{1:H}$、reward predictor $p_{\phi}(\hat r_t|h_t,z_t )$から報酬の列$\hat r_{1:H}$、discount predictor $p_{\phi}(\hat \gamma_t |h_t,z_t )$ から割引率の列$\hat \gamma_{1:H}$を生成させる。

 Actor は、Critic の出力を最大化する行動をとるように学習し、Critic は Actor が獲得する将来の報酬の合計を予測するように学習を行う。Actor Criticのモデルは、Imagination MDPによって生成された列を使い、

\begin{align}
\mbox{Actor} \ \ :& \ \ \hat a_t \sim p_{\psi}(\hat a_t|h_t,z_t)  \\
\mbox{Critic} \ \ :&  \ \ v_{\xi}(h_t,z_t) \approx \mathbb{E}_{p_{\phi},p_{\psi}}\left[\sum_{\tau \geq t} \hat \gamma^{\tau-t} \hat r_{\tau} \right] \\
&\mbox{or} \ \  v_t \sim p_{\xi}(v_t|h_t,z_t)
\end{align}

とする。Actor は、カテゴリカル分布を用いて行動$a_t$を生成し、Critic は、ガウス分布を用いて価値関数$v_t$を生成させる。

 Critic loss については、general $\lambda$ target を使い価値関数を

\begin{align}
V^{\lambda}_t = \hat r_t + \hat\gamma_t
\begin{cases}
(1-\lambda)v_{\xi}(h_{t+1},z_{t+1}) +\lambda V^{\lambda}_{t+1}  \ \ \ \ &\mbox{if}\ \ t< H \\
v_{\xi}(h_{H},z_{H})   &\mbox{if}\ \ t=H
\end{cases}
\end{align}

とする。$\lambda=0.95$とする。loss は自乗誤差を使い

\begin{align}
L(\xi) = \mathbb{E}_{p_{\phi},p_{\psi}}\left[\sum_{t=1}^{H-1} \frac{1}{2}\left(v_{\xi}(h_{t},z_{t})-\mbox{sg}(V^{\lambda}_t)  \right)^2  \right]
\end{align}

とする。$\mbox{sg}$は、stop gradientを意味し、勾配の更新は行わない。Critic は学習を安定させるため、target network を使い100step ごとにパラメータを更新する。Critic loss の実際の計算は、

\begin{align}
L(\xi) = -\mathbb{E}_{p_{\phi},p_{\psi}}\left[\sum_{t=1}^{H-1} \log p_{\xi}(\mbox{sg}(V^{\lambda}_t) |h_t,z_t)  \right]
\end{align}

を計算する。(公開されているコードでは)$p_{\xi}$は分散1のガウス分布に従っているので、

\begin{align}
\log p_{\xi}(v_t|h_t,z_t) = -\frac{1}{2}\log 2\pi - \frac{1}{2}(v_t-v_{\xi}(h_t,z_t)) ^2
\end{align}

となり、定数部分は勾配の更新に寄与しないので、同じ意味である。

 Actor loss については、

\begin{align}
L(\psi) = -\mathbb{E}_{p_{\phi},p_{\psi}}\left[\sum_{t=1}^{H-1} \left\{\rho\log  p_{\psi}(\hat a_t|h_t,z_t) \mbox{sg}(V^{\lambda}_t-v_{\xi}(h_{t},z_{t}) )+(1-\rho) V^{\lambda}_t 
 -\eta H[a_t|h_{t},z_{t}]\right\}\right]
\end{align}

を使う。$v_{\xi}(h_{t},z_{t})$は、$p_{\xi}(v_t|h_t,z_t)$の最頻値から計算する。論文において、atariの場合は$\rho=1,\eta=10^{-3}$、連続値を扱う制御の場合は$\rho=0,\eta=10^{-4}$としている。

(公開されているコードでは、Actor loss および Critic loss に対し、Imagination MDPにおけるステップごとの割引率の累積積を計算し、その累積積をloss に重みづけしている。)

break out を学習させた結果

 break out を約1週間ほど学習させた。論文のように、モデルフリーのアルゴリズムと比べてはないが、経験的には学習が早く行われている。(論文のbreak out の結果は、モデルフリーのアルゴリズムより精度がよい。)以下の図は、各ステップにおけるbreak out の報酬である。

history_plot_eval_return.png

 以下の図は、world model の再構成画像である。1段目は環境から得られた画像であり、2段目が再構成画像、3段目は1段目と2段目の差(本来の画像と再構成画像の誤差)である。 Replay Buffer からランダムにデータを6個取り出し、さらに、アニメーションの長さが(連続で)50になるようにランダムにクリップする。3段目を確認すると、特に、バーの部分とボールの部分が再構成されていないことが分かり、もう少し学習が必要かもしれない。

world_model_output.gif

論文の著者のコードでは、混合精度(Mixed precision)を使用している。
私の環境でtensorflowなどのアップデートをしなければならないので、混合精度は使っていない。

最後に

 
 土木の分野では、画像を入力とする強化学習の需要があまりないような気がする。(ロボット系は除いて)

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2