12
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Prioritized Experience Replay ~強化学習において効率よくサンプリングする方法~

Last updated at Posted at 2021-07-05

概要

 土木の分野では、機械の制御(ダムの制御)や都市開発(交通×AI)などで、強化学習が使われ始めています。初めてニューラルネットワークを用いて強化学習を構築する場合、最初に構築するアルゴリズムは、やはり深層Q学習(DQN)だろうと思われます。keras のブログで紹介されています。

 DQNは、モデルフリーと呼ばれるアルゴリズムです。モデルフリーは、どのような行動をすれば高い報酬が得られるかわからない場合に有効だと言われています。しかし、どのような行動すれば、高い報酬が得られるか探索する必要があり、サンプリングを多く行わなければならず、学習に時間がかかることも知られています。

 今回の記事では、DQNで使われるReplay Bufferを拡張したPrioritized Experience Replay を紹介したいと思います。Prioritized Experience Replayは、目標値と学習の推定値の差で定義されるTD誤差が大きいサンプルを重点的に学習に使う方法です。

 論文は、以下を引用しています。

 強化学習について、メンダコさんのサイトはとても参考になります。 

 @ymd_hさんからReplay Bufferに関するライブラリ cpprb を紹介していただきました。

Replay Buffer

 強化学習は、マルコフ決定過程を前提に構築されている。つまり、現状態$s(t)$から行動$a(t)$を選んだ場合、状態$s(t+1)$に遷移する確率は

P(s'|s,a) = P(s(t+1)=s'|s(t)=s,a(t)=a) 

である。この式が言いたいことは、次の状態$s(t+1)$は、現状態$s(t)$と行動$a(t)$のみから決まり、過去の状態$s(t-1),s(t-2),...$や過去の行動$a(t-1),a(t-2),...$からは決まらないことを前提としている。

 例えば、アタリなどのゲームの画像を状態として、時系列順にデータを入力し学習させると、強い相関(前の状態に対して強い依存)を持ってしまう。恐らくは、マルコフ決定過程でなくなり学習が上手くいかない。アタリなどのゲームを学習させる場合、時系列順に4枚の画像を1つの入力として学習させるが、それでもマルコフ決定過程でなくなるだろう。

 Replay Bufferは、データを貯めて、貯めたデータからランダムに選択(サンプリング)し、サンプリングされたデータをもとに学習させる。そうすれば、時系列に依らない瞬間的なデータを得ることができる。

 Prioritized Experience Replayは、ランダムにサンプリングするのではなく、TD誤差が多いデータを選ぶようにサンプリングする。

Sampling transition

 TD誤差が多いデータを選ぶようにsampling transitionについて説明する。データ$i$ における sampling transition の確率を以下のように定義する。

P(i) = \frac{p_i^{\alpha}}{\sum_k^N p_k^{\alpha}} 

$p_i$はデータ$i$ における優先度である。ハイパーパラメータを$\alpha \rightarrow 0$とすると

P(i) \underset{\alpha \rightarrow 0}{\longrightarrow} \frac{1}{N}

となり、一様な分布になる。これは Replay Bufferに対応する。

 次にデータ$i$ における優先度を、$\epsilon$を微小で正の値として、以下のように定義する。

p_i = |\delta_i| + \epsilon

$\delta_i$はTD誤差である。TD誤差は、目標値と学習の推定値の差で定義される。

\delta_j = R_j + \gamma \underset{a}{\max} Q(s_j,a_j) -  Q(s_{j-1},a_{j-1})

目標値は$ R_j + \gamma \underset{a}{\max} Q(s_j,a_j)$であり、推定値は$Q(s_{j-1},a_{j-1})$である。

Importance sampling weight

 sampling transition を使い、サンプリングするとTD誤差が大きくなる部分だけ学習してしまう。学習が不安定にならないように、Importance sampling weightを以下のように定義する。

w_i = \left(\frac{1}{N} \frac{1}{P(i)} \right)^{\beta}

ハイパーパラメータを$\beta$は、最終的に$\beta=1$となるようにアニーリングする。

 Importance sampling weightを使い、学習パラメータを更新することを考える。学習パラメータ$\theta$は、勾配法により以下のように更新する。

\theta \longleftarrow \theta + \eta w_j \delta_j \nabla_{\theta}Q(s_{j-1},a_{j-1})

tensorflow などでニューラルネットを構築する場合は、誤差関数を以下のようにすればよい。

E(\theta) = \sum_j w_j \delta_j^2

誤差関数を$\theta$で微分すると

\nabla_{\theta}E(\theta) = -\sum_j w_j \delta_j \nabla_{\theta}Q(s_{j-1},a_{j-1})

となるので、教科書に書いてある勾配法の$\theta \longleftarrow \theta - \eta \nabla_{\theta}E(\theta) $ となることが分かる。

 DQNでは Huber loss が使われるので絶対値誤差の場合も計算しておく。

E(\theta) = \sum_j w_j |\delta_j|

誤差関数を$\theta$で微分すると

\nabla_{\theta}E(\theta) = -\sum_j w_j \rm{sgn}(\delta_j) \nabla_{\theta}Q(s_{j-1},a_{j-1})

となる。$ \rm{sgn}(\delta_j) $は、$\delta_j$の符号である。

ちなみに、 Huber loss は

E(a) = \left\{
\begin{array}{ll}
\frac{1}{2}a^2 & ,|a |< \delta \\
\delta(|a|-\frac{1}{2}\delta) & ,else
\end{array}
\right.

である。

勾配法の部分のコードは(たぶん)こんな感じになる。

# tensorflow2 のコードが1だったので
delta = 1.0
with tf.GradientTape() as tape:
  # Q-values を求める
    q_values = model(state_sample)
    # masksがゼロの部分のq_valuesをゼロにする。
    q_action = tf.reduce_sum(tf.multiply(q_values, masks), axis=1)
    # Huber lossの計算
    TD_errors = updated_q_values - q_action
    TD_loss_aqua = 0.5* tf.square(TD_errors)
    TD_loss_abs = delta*tf.math.abs(TD_errors)- 0.5*tf.square(delta)
    TD_loss = tf.where(tf.math.abs(TD_errors)<delta, TD_loss_aqua, TD_loss_abs) 
    Q_loss = tf.reduce_mean( weights_sample* TD_loss)

# Backpropagation
grads = tape.gradient(Q_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

まとめ

学習はまだしていないので、次回あたりに。

12
11
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?