この記事は自作している強化学習フレームワーク SimpleDistributedRL の解説記事です。
また、強化学習の基礎についてはこちらを参考にしてください。
追記:22/7/2 全体的に見直して修正しました。また REINFORCE を追加しました。
方策勾配法の基礎
ベルマン方程式は以下でした。
$$
\begin{align}
V_{\pi}(s) &= \sum_{a} \pi(a|s) Q_{\pi}(s, a)
\end{align}
$$
マルコフ決定過程ではアクションが決まると次の状態の確率が決まりました。
という事はこの式において方策 $\pi$ が決まると全ての不確定要素が決まるので、期待値(価値)を求めることができます。
これを利用し、価値が最大となる方策を学習しようという考えが方策ベースのアルゴリズムとなります。
ではどうするかというと、方策を確率モデルとし、確率を決めるパラメータ $\theta$ を用いてこの価値が最大となるパラメータを探すという問題に置き換えます。
$$
J(\theta) \propto V_{\pi_{\theta}}(s) = \sum_{a} \pi_{\theta}(a|s) Q_{\pi_{\theta}}(s, a)
$$
$J(\theta)$ はエピソード長に比例するので比例関係を表す $\propto$ (proportional to) を使っています。
勾配法(最急降下法)
勾配法は最適化問題を解くアルゴリズムの1つで、ある関数を最小にするパラメータを求めるアルゴリズムです。
以下手順で最適解を探します。
- 初期地点 $x$ を決める
- $x$ での傾き(微分値)を求める
- 傾きと学習率をもとに次の探索地点 $x'$ を求める
- 2~3の更新を傾きが0となる $x$ が見つかるまで繰り返す
更新式は以下です。($\eta$ は学習率です)
$$x \leftarrow x - \eta f'(x)$$
具体例で見てみます。
$f(x)=x^2+1$ の関数に対して勾配法で最小値を求めてみます。
微分した関数は $f'(x)=2x$、学習率は0.25とします。
最適解は $f'(x)=2x=0$ の時なので $x=0$ です。
$x=2$ から見てみます。
\begin{align}
x &\leftarrow 2 - 0.25 \times 4 = 1 \\
x &\leftarrow 1 - 0.25 \times 2 = 0.5 \\
x &\leftarrow 0.5 - 0.25 \times 1 = 0.25 \\
x &\leftarrow 0.25 - 0.25 \times 0.5 = 0.125 \\
...
\end{align}
これを繰り返すと $x=0$ になりそうです。
こうやって最小値を求める手法が勾配法です。
方策勾配法の更新
方策を、勾配法を用いて目的関数 $J(\theta)$ を最大化するパラメータ $\theta$ を求める問題として解きます。
$$\theta \leftarrow \theta + \alpha \nabla J(\theta)$$
$\alpha$ が学習率で、$\nabla J(\theta)$ が $J(\theta)$ の微分です。
符号が $+$ になっているのは最大値を求めるためです。
次に $\nabla J(\theta)$ ですが、これは方策勾配定理を用いると以下になります。
$$
\nabla J(\theta) \propto E_{\pi_{\theta}} \Bigl[ \nabla \log \pi_{\theta}(a|s) Q_{\pi_{\theta}}(s,a) \Bigr]
$$
この定理に関しては導出が複雑らしく詳細は以下を参考にしてください。
- Reinforcement Learning: An Introduction(269ページ)
- 方策勾配定理のすっきりした証明
- 【強化学習入門】方策勾配定理の証明メモ 【Policy Gradient Theorem】
これを計算すればいいわけですが解析的には解けないので、実際にはサンプリングして近似解を得ることで更新します。(モンテカルロ法)
\begin{align}
\nabla J(\theta) &\propto E_{\pi_{\theta}} \Bigl[ \nabla \log \pi_{\theta}(a|s) Q_{\pi_{\theta}}(s,a) \Bigr] \\
& \approx \frac{1}{M} \sum_{m=0}^{M-1} \frac{1}{T} \sum_{t=0}^{T-1} \nabla \log \pi_{\theta}(a_{t,m}|s_{t,m}) Q_{\pi_{\theta}}(s_{t,m},a_{t,m})
\end{align}
$M$ がエピソード回数、$T$ 1エピソードの総ステップ数です。
ベースライン
更新パラメータの $\nabla J(\theta)$ ですが、分散が大きいと更新後の方策が大きく変わるので、分散が小さいほうが学習が安定します。
新たにベースライン $B(s)$ というものを用いて分散を小さくするテクニックを紹介します。
以下のように行動価値関数にベースラインを引くというものです。
\begin{align}
E_{\pi_{\theta}} \Big[ \nabla \log \pi_{\theta}(a|s) & Q_{\pi_{\theta}}(s,a) \Big] \\
↓ \\
E_{\pi_{\theta}} \Big[ \nabla \log \pi_{\theta}(a|s) & \Big( Q_{\pi_{\theta}} (s,a) - B(s) \Big) \Big]
\end{align}
ベースラインを引いても期待値に変化はありません。(分散のみ変わります)
変化がない事を、以下の式で右辺が0である事で確認してみました。
(数式に強くないので間違っていたらすいません)
E_{\pi_{\theta}} \Big[ \nabla \log \pi_{\theta}(a|s) \Big( Q_{\pi_{\theta}} (s,a) - B(s) \Big) \Big] = E_{\pi_{\theta}} \Big[ \nabla \log \pi_{\theta}(a|s) Q_{\pi_{\theta}}(s,a) \Big] -
E_{\pi_{\theta}} \Big[ \nabla \log \pi_{\theta}(a|s) B(s) \Big]
以下0の計算。
\begin{align}
E_{\pi_{\theta}} \Big[ \nabla \log \pi_{\theta}(a|s) B(s) \Big]
&= E_{s_{0:t},a_{0:t-1}} \Big[ E_{s_{t+1:T},a_{t:T-1}} \Big[ \nabla \log \pi_{\theta}(a_t|s_t) B(s_t) \Big] \Big] \\
&= E_{s_{0:t},a_{0:t-1}} \Big[ B(s_t) E_{s_{t+1:T},a_{t:T-1}} \Big[ \nabla \log \pi_{\theta}(a_t|s_t) \Big] \Big] \\
&= E_{s_{0:t},a_{0:t-1}} \Big[ B(s_t) E_{,a_t} \Big[ \nabla \log \pi_{\theta}(a_t|s_t) \Big] \Big] \\
&= E_{s_{0:t},a_{0:t-1}} \Big[ B(s_t) \times 0 \Big] \\
&= 0 \\
\end{align}
$B(s)$ が外にでるのは $a$ に依存しないからです。
また、$E_{a_t}$ は以下です。
\begin{align}
E_{a_t} \Big[ \nabla \log \pi_{\theta}(a_t|s_t) \Big]
&= \int \frac{\nabla_\theta \pi_\theta(a_t|s_t)}{\pi_\theta(a_t|s_t)} \pi_\theta(a_t|s_t) da_t \\
&= \nabla_\theta \int \pi_\theta(a_t|s_t) da_t \\
&= \nabla_\theta 1 \\
&= 0 \\
\end{align}
ベースラインは状態 $s$ で決まる値ならどんな値でも問題ない事が分かります。
参考
・強化学習理論の基礎3
・Policy Gradient Algorithms
・Going Deeper Into Reinforcement Learning: Fundamentals of Policy Gradients
・https://yagami12.hatenablog.com/entry/2019/02/22/210608
REINFORCE
更新に必要なQ値ですが、REINFORCEではこれを報酬のみで近似します。
\begin{align}
Q_{\pi_{\theta}}(s_t,a_t) &= r_{t+1} + Q_{\pi_{\theta}}(s_{t+1},a_{t+1}) \\
& \approx r_{t+1}\\
\end{align}
またベースラインをモンテカルロ法で得た報酬の平均値にします。
\frac{1}{M} \sum_{m=0}^{M-1} \frac{1}{T} \sum_{t=0}^{T-1} \nabla \log \pi_{\theta}(a_{t,m}|s_{t,m}) Q_{\pi_{\theta}}(s_{t,m},a_{t,m}) \\
↓ \\
\frac{1}{M} \sum_{m=0}^{M-1} \frac{1}{T} \sum_{t=0}^{T-1} \nabla \log \pi_{\theta}(a_{t,m}|s_{t,m})(r_{t,m} - \bar{b}) \\
$$
\bar{b} = \frac{1}{M} \sum_{m=0}^{M-1} \frac{1}{T} \sum_{t=0}^{T-1} r_{t,m}
$$
Advantage
REINFORCEではベースラインに報酬の平均値を使いましたが、ベースラインに状態価値を用いた価値をAdvantageといいます。
B(s) = V(s) \\
A(s,a) = Q(s,a) - V(s)
E_{\pi_{\theta}} \Big[ \nabla \log \pi_{\theta}(a|s) \Big( Q_{\pi_{\theta}} (s,a) - B(s) \Big) \Big] \\
= E_{\pi_{\theta}} \Big[ \nabla \log \pi_{\theta}(a|s) A(s,a) \Big] \\
別の視点ですが、考え方はRainbowのDueling networksと同じですね。(Advantageの意味についてはリンク先の記事を参照)
方策には状態価値は直接必要ないので、Advantage関数を使用するのは理にかなっている気がします。
方策モデル
最後に実装に向けて $\nabla \log \pi_{\theta}(a|s)$ に関する具体的なモデルを見てみます。
Softmax(離散値)
Softmax関数は以下で、パラメータは各アクション毎に用意されます。
$$
\pi_{\theta}(a_i|s) = f(\theta_i) = \frac{ e^{\theta_i}}{\sum_{k=1}^n e^{\theta_k}} \quad(i=1,2,...,n)
$$
方策勾配定理で必要になる対数微分を計算します。
計算過程
\begin{align}
\nabla \log( f(x_i)) &= \nabla \log(\frac{e^{x_i}}{\sum_{k=1}^n e^{x_k}} ) \\
&= \nabla ( \log(e^{x_i}) - \log(\sum_{k=1}^n e^{x_k}))
\\
&= \nabla x_i - \nabla \log(\sum_{k=1}^n e^{x_k})
\\
&= 1 - \frac{\nabla \sum_{k=1}^n e^{x_k}}{\sum_{k=1}^n e^{x_k}}
\\
&= 1 - \frac{e^{x_i}}{\sum_{k=1}^n e^{x_k}}
\\
&= 1 - f(x_i)
\\
\end{align}
3から4行目の変換は微分の連鎖率を利用しています。
$$
\nabla \log f(x) = \frac{\nabla f(x)}{f(x)}
$$
参考:https://tadaoyamaoka.hatenablog.com/entry/2019/08/13/000546
$$
\nabla \log( \pi_{\theta}(a_i|s) ) = 1 - \pi_{\theta}(a_i|s) \quad(i=1,2,...,n)
$$
ガウス分布(連続値)
ガウス分布は以下で、パラメータは平均 $\mu$ と分散 $\sigma^2$ の2つとなります。
$$
\begin{align}
\pi_{\theta}(a|s) = f(a) &= \frac{1}{\sqrt{2 \pi \sigma^2 } }
\exp(- \frac{(a - \mu)^2}{ 2 \sigma^2} )
\end{align}
$$
対数微分ですが、パラメータ毎の偏微分になるので平均と分散を別々で計算します。
- 平均
計算過程
\begin{align}
\log(f(x)) &= \log (\frac{1}{\sqrt{2 \pi \sigma^2 } }
\exp(- \frac{(x - \mu)^2}{ 2 \sigma^2} ) ) \\
&= \log (\frac{1}{\sqrt{2 \pi \sigma^2 } }) + \log(\exp(- \frac{(x - \mu)^2}{ 2 \sigma^2} ) )
\\
&= log (1) - log(\sqrt{2 \pi \sigma^2 }) - \frac{(x - \mu)^2}{ 2 \sigma^2}
\\
&= 0 - log((2 \pi \sigma^2)^{ \frac{1}{2}}) - \frac{(x - \mu)^2}{ 2 \sigma^2}
\\
&= -\frac{1}{2} log(2 \pi \sigma^2) - \frac{(x - \mu)^2}{ 2 \sigma^2}
\\
\frac{\partial}{\partial \mu} \log(f(x)) &= \frac{\partial}{\partial \mu}(-\frac{1}{2} log(2 \pi \sigma^2)) - \frac{\partial}{\partial \mu}(\frac{(x - \mu)^2}{ 2 \sigma^2}))
\\
&= 0 - \frac{\partial}{\partial \mu}\frac{(x - \mu)^2}{ 2 \sigma^2} \\
&= \frac{2(x - \mu)}{ 2 \sigma^2} \\
&= \frac{x - \mu}{\sigma^2} \\
\end{align}
$$
\frac{\partial}{\partial \mu} \log( \pi_{\theta}(a|s) ) = \frac{a - \mu}{\sigma^2}
$$
- 分散
計算過程
\begin{align}
\frac{\partial}{\partial \sigma} \log(f(x)) &= \frac{\partial}{\partial \sigma} (-\frac{1}{2} log(2 \pi \sigma^2) - \frac{(x - \mu)^2}{ 2 \sigma^2})
\\
&= \frac{\partial}{\partial \sigma}(-\frac{1}{2} ( log(2 \pi) + \log(\sigma^2))) - \frac{\partial}{\partial \sigma}\frac{(x - \mu)^2}{ 2 \sigma^2}
\\
&= \frac{\partial}{\partial \sigma}(-\frac{1}{2} log(2 \pi)) - \frac{\partial}{\partial \sigma}(\frac{1}{2} \log(\sigma^2)) - \frac{\partial}{\partial \sigma}\frac{(x - \mu)^2}{ 2 \sigma^2}
\\
&= 0 - \frac{\partial}{\partial \sigma}(\log(\sigma)) - \frac{\partial}{\partial \sigma}\frac{(x - \mu)^2}{ 2 \sigma^2}
\\
&= -\frac{1}{\sigma} - \frac{-2(x - \mu)^2}{ 2 \sigma^3}
\\
&= -\frac{1}{\sigma} + \frac{(x - \mu)^2}{ \sigma^3}
\end{align}
$$
\frac{\partial}{\partial \sigma} \log( \pi_{\theta}(a|s) ) = \frac{(a - \mu)^2 - \sigma^2}{ \sigma^3}
$$
実装は予測値をモンテカルロ法(エピソード最後まで展開)で求めます。
$$
Q_{\pi_{\theta}}(s_t, a_t) = \sum(r_{t+1} + \gamma r_{t+2} + \gamma^2 r_{t+3} +...)
$$
ベースラインは引いていません。
またサンプリングの回数は1回毎としています。
(バニラな方策勾配法です)
実装(Softmax)
関係ある箇所を抜粋して書いています。
フレームワーク上の実装はgithubを見てください。
Config(ハイパーパラメータ)
ハイパーパラメータは以下です。
@dataclass
class Config(TableConfig):
gamma: float = 0.9 # 割引率
lr: float = 0.1 # 学習率
Parameter
ポリシーパラメータは状態毎・アクション毎に必要です。
状態は文字列にして柔軟に格納できるようにしています。
class Parameter(RLParameter):
def __init__(self, *args):
# ポリシーパラメータ
self.policy = {}
# 行動の各確率を返す
def get_probs(self, state: str):
# 状態がない場合、パラメータを初期化
if state not in self.policy:
self.policy[state] = [0.0 for _ in range(self.config.nb_actions)]
probs = np.array(self.policy[state])
# softmax
probs = np.exp(probs)
probs /= np.sum(probs)
return probs
Trainer
学習部分です。
class Trainer(RLTrainer):
def train(self):
batchs = self.remote_memory.sample()
for batch in batchs:
# 1stepで得た経験
state = batch["state"]
action = batch["action"]
reward = batch["reward"]
prob = self.parameter.get_probs(state)[action]
# ∇logπ = 1 - prob
diff_logpi = 1 - prob
# ∇J = ∇logπ Q
diff_j = diff_logpi * reward
# θ ← θ + α∇J
self.parameter.policy[state][action] += self.config.lr * diff_j
Worker
確率を元にアクションを決めます。
また、モンテカルロ法により価値を求めます。
class Worker(TableWorker):
def call_on_reset(self, state: np.ndarray, invalid_actions: List[int]) -> None:
self.history = [] # 1エピソード分の経験を保存
def call_policy(self, state: np.ndarray, invalid_actions: List[int]) -> int:
# 状態を文字列にする
self.state = str(state.tolist())
# 各アクションの確率
probs = self.parameter.get_probs(self.state)
# 確率を元にアクションを決定
self.action = np.random.choice([a for a in range(self.config.nb_actions)], p=probs)
return self.action
def call_on_step(
self,
next_state: np.ndarray,
reward: float,
done: bool,
next_invalid_actions: List[int],
) -> Dict[str, Union[float, int]]:
if not self.training:
return {}
# 各stepの経験を保存
self.history.append([self.state, self.action, reward])
# エピソード終了時に割引報酬を計算し、メモリに送る
if done:
reward = 0
for h in reversed(self.history):
reward = h[2] + self.config.gamma * reward
batch = {
"state": h[0],
"action": h[1],
"reward": reward,
}
self.remote_memory.add(batch)
return {}
実装(ガウス分布)
関係ある箇所を抜粋して書いています。
フレームワーク上の実装はgithubを見てください。
Config(ハイパーパラメータ)
ハイパーパラメータは以下です。
@dataclass
class Config(TableConfig):
gamma: float = 0.9 # 割引率
lr: float = 0.1 # 学習率
Parameter
パラメータは各状態に対して平均と標準偏差の線形値となります。
標準偏差は正 $\sigma > 0$ の制約を持つので、指数関数を通しています。
class Parameter(RLParameter):
def __init__(self, *args):
# ポリシーパラメータ
self.policy = {}
# パラメータを返す
def get_param(self, state: str):
# 状態がない場合、パラメータを初期化
if state not in self.policy:
self.policy[state] = {
"mean": 0.0,
"stddev_logits": 0.5,
}
mean = self.policy[state]["mean"]
stddev = np.exp(self.policy[state]["stddev_logits"])
return mean, stddev
Trainer
学習部分です。
logpiの計算部分を見ればわかりますが、値によっては更新幅がかなり大きな値になります。
(平均でσ^2、分散でσ^3の計算をしているので、例えばσが0.01の場合、分散はn/(0.01^3)の形になり、n=1の場合で1000000の値を取ります。)
勾配法の特性上、更新幅が大きすぎると学習が安定しません。
ですので、更新幅に制限を設けています。(値はヒューリスティックな値です)
class Trainer(RLTrainer):
def train(self):
batchs = self.remote_memory.sample()
for batch in batchs:
# 1stepで得た経験
state = batch["state"]
action = batch["action"]
reward = batch["reward"]
mean, stddev = self.parameter.get_param(state)
# 平均
mean_diff_logpi = (action - mean) / (stddev**2)
mean_diff_j = mean_diff_logpi * reward
new_mean = self.parameter.policy[state]["mean"] + self.config.lr * mean_diff_j
# 分散
stddev_diff_logpi = (((action - mean) ** 2) - (stddev**2)) / (stddev**3)
stddev_diff_j = stddev_diff_logpi * reward
new_stddev = self.parameter.policy[state]["stddev_logits"] + self.config.lr * stddev_diff_j
# 更新幅が大きすぎる場合は更新しない
if abs(mean_diff_j) < 1 and abs(stddev_diff_j) < 5:
self.parameter.policy[state]["mean"] = new_mean
self.parameter.policy[state]["stddev_logits"] = new_stddev
Worker
平均・分散を元にアクションを決めます。
また、予測値がないのでモンテカルロ法により価値を求めます。
class Worker(ActionContinuousWorker):
def call_on_reset(self, state: np.ndarray) -> None:
self.history = [] # 1エピソード分の経験を保存
def call_policy(self, state: np.ndarray) -> List[float]:
# 状態を文字列にする
self.state = str(state.tolist())
# パラメータ
mean, stddev = self.parameter.get_param(self.state)
# ガウス分布に従った乱数を出す
self.action = env_action = mean + np.random.normal() * stddev
# ガウス分布は-inf~infの範囲を取るので、
# 実際に環境に渡すアクションは、最小と最高で切り取る
# 本当はポリシーが変化しちゃうのでよくない(暫定対処)
env_action = np.clip(env_action, self.config.action_low[0], self.config.action_high[0])
return env_action
def call_on_step(
self,
next_state: np.ndarray,
reward: float,
done: bool,
next_invalid_actions: List[int],
) -> Dict:
if not self.training:
return {}
# 各stepの経験を保存
self.history.append([self.state, self.action, reward])
# エピソード終了時に割引報酬を計算し、メモリに送る
if done:
reward = 0
for h in reversed(self.history):
reward = h[2] + self.config.gamma * reward
batch = {
"state": h[0],
"action": h[1],
"reward": reward,
}
self.remote_memory.add(batch)
return {}
実行結果(softmax)
本フレームワーク内にある Grid という環境を学習した結果は以下です。
※srlが古いバージョンの時に記載しているので動かない可能性があります
import numpy as np
import srl
from srl import runner
# --- env & algorithm load
from srl.envs import grid # isort: skip # noqa F401
from srl.algorithms import vanilla_policy_discrete # isort: skip
def main():
env_config = srl.EnvConfig("Grid")
rl_config = vanilla_policy_discrete.Config()
config = runner.Config(env_config, rl_config)
# --- train
parameter, remote_memory, history = runner.train(config, timeout=10)
# --- evaluate
rewards = runner.evaluate(config, parameter, max_episodes=100)
print(f"Average reward for 100 episodes: {np.mean(rewards)}")
# --- render
rewards, _ = runner.render(config, parameter)
if __name__ == "__main__":
main()
- 学習結果(一部)
### 最初の地点
......
. G.
. . X.
.P .
......
← : 0.1% (-2.72310)
↓ : 1.5% (0.31042)
→ : 0.0% (-4.81747)
*↑ : 98.4% (4.46286)
### ゴール手前
......
. PG.
. . X.
. .
......
← : 0.0% (-0.12817)
↓ : 0.0% (-0.31378)
*→ : 98.6% (24.20097)
↑ : 1.4% (19.93994)
学習結果(全体)
Average reward for 100 episodes: 0.7299999999999999
### 0, action 3, rewards [0.], next 0
env None
work0 None
......
. G.
. . X.
.P .
......
← : 0.1% (-2.72310)
↓ : 1.5% (0.31042)
→ : 0.0% (-4.81747)
*↑ : 98.4% (4.46286)
### 1, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
← : 0.0% (-0.91846)
↓ : 0.0% (-2.24249)
→ : 1.3% (4.22557)
*↑ : 98.7% (8.58121)
### 2, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
.P G.
. . X.
. .
......
← : 0.0% (-0.70717)
↓ : 0.0% (-0.18141)
*→ : 98.7% (13.16927)
↑ : 1.3% (8.80523)
### 3, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
← : 0.0% (-0.91846)
↓ : 0.0% (-2.24249)
→ : 1.3% (4.22557)
*↑ : 98.7% (8.58121)
### 4, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
.P G.
. . X.
. .
......
← : 0.0% (-0.70717)
↓ : 0.0% (-0.18141)
*→ : 98.7% (13.16927)
↑ : 1.3% (8.80523)
### 5, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
← : 0.0% (-0.91846)
↓ : 0.0% (-2.24249)
→ : 1.3% (4.22557)
*↑ : 98.7% (8.58121)
### 6, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
.P G.
. . X.
. .
......
← : 0.0% (-0.70717)
↓ : 0.0% (-0.18141)
*→ : 98.7% (13.16927)
↑ : 1.3% (8.80523)
### 7, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. P G.
. . X.
. .
......
← : 0.0% (-0.62838)
↓ : 0.5% (4.92942)
*→ : 99.5% (10.28625)
↑ : 0.1% (2.98535)
### 8, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. PG.
. . X.
. .
......
← : 0.0% (-0.12817)
↓ : 0.0% (-0.31378)
*→ : 98.6% (24.20097)
↑ : 1.4% (19.93994)
### 9, action 2, rewards [1.], done(env), next 0
env {}
work0 {}
......
. P.
. . X.
. .
......
実行結果(ガウス分布)
本フレームワーク内にある Grid という環境を学習した結果は以下です。
Gridの入力は離散値なのですが、強制的に連続値に変換しています。
なので、
0~0.5 : 0 (LEFT)
0.5~1.5 : 1 (DOWN)
1.5~2.5 : 2 (RIGHT)
2.5~3.0 : 3 (UP)
となります。
※srlが古いバージョンの時に記載しているので動かない可能性があります
import numpy as np
import srl
from srl import runner
# --- env & algorithm load
from srl.envs import grid # isort: skip # noqa F401
from srl.algorithms import vanilla_policy_continuous # isort: skip
def main():
env_config = srl.EnvConfig("Grid")
rl_config = vanilla_policy_continuous.Config()
config = runner.Config(env_config, rl_config)
# --- train
parameter, remote_memory, history = runner.train(config, timeout=30)
# --- evaluate
rewards = runner.evaluate(config, parameter, max_episodes=100)
print(f"Average reward for 100 episodes: {np.mean(rewards)}")
# --- render
rewards, _ = runner.render(config, parameter)
if __name__ == "__main__":
main()
- 学習結果(一部)
### 最初
......
. G.
. . X.
.P .
......
mean 7.02388, stddev 0.00477
### ゴール手前
......
. PG.
. . X.
. .
......
mean 2.01022, stddev 0.02746
学習結果(全体)
Average reward for 100 episodes: 0.6944000054895878
### 0, action 3, rewards [0.], next 0
env None
work0 None
......
. G.
. . X.
.P .
......
mean 7.02388, stddev 0.00477
### 1, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
mean -1.02970, stddev 0.01026
### 2, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
.P .
......
mean 7.02388, stddev 0.00477
### 3, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
mean 9.90385, stddev 0.01101
### 4, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
.P G.
. . X.
. .
......
mean 1.63379, stddev 0.00830
### 5, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
mean 9.90385, stddev 0.01101
### 6, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
mean 9.90385, stddev 0.01101
### 7, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
.P G.
. . X.
. .
......
mean 1.63379, stddev 0.00830
### 8, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. P G.
. . X.
. .
......
mean 1.99982, stddev 0.00659
### 9, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. P G.
. . X.
. .
......
mean 1.99982, stddev 0.00659
### 10, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. P G.
. . X.
. .
......
mean 1.99982, stddev 0.00659
### 11, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. P G.
. . X.
. .
......
mean 1.99982, stddev 0.00659
### 12, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. PG.
. . X.
. .
......
mean 2.01022, stddev 0.02746
### 13, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. PG.
. . X.
. .
......
mean 2.01022, stddev 0.02746
### 14, action 2, rewards [1.], done(env), next 0
env {}
work0 {}
......
. P.
. . X.
. .
......
REINFORCEの実装
試してみましたが、Q値を即時報酬のみで近似するのはさすがに無理があるような…
割引報酬がないので次の状態しか学習できていません。
### 学習後の初期位置
......
. G.
. . X.
.P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 学習後の右下
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
コード
※srlが古いバージョンの時に記載しているので動かない可能性があります
import json
from dataclasses import dataclass
from typing import List, cast
import numpy as np
import srl
from srl import runner
from srl.base.define import RLObservationType
from srl.base.rl.algorithms.discrete_action import (DiscreteActionConfig,
DiscreteActionWorker)
from srl.base.rl.base import RLParameter, RLTrainer
from srl.base.rl.registration import register
from srl.base.rl.remote_memory.sequence_memory import SequenceRemoteMemory
@dataclass
class Config(DiscreteActionConfig):
lr: float = 0.1
train_steps: int = 1
def __post_init__(self):
super().__init__()
@property
def observation_type(self) -> RLObservationType:
return RLObservationType.DISCRETE
@staticmethod
def getName() -> str:
return "REINFORCE_discrete"
class RemoteMemory(SequenceRemoteMemory):
pass
class Parameter(RLParameter):
def __init__(self, *args):
super().__init__(*args)
self.config = cast(Config, self.config)
# パラメータ
self.policy = {}
def call_restore(self, data, **kwargs) -> None:
self.policy = json.loads(data)
def call_backup(self, **kwargs):
return json.dumps(self.policy)
# ---------------------------------
def get_probs(self, state_str: str):
if state_str not in self.policy:
self.policy[state_str] = [0.0 for _ in range(self.config.action_num)]
probs = []
for val in self.policy[state_str]:
probs.append(np.exp(val))
probs /= np.sum(probs)
return probs
class Trainer(RLTrainer):
def __init__(self, *args):
super().__init__(*args)
self.config = cast(Config, self.config)
self.parameter = cast(Parameter, self.parameter)
self.remote_memory = cast(RemoteMemory, self.remote_memory)
self.train_count = 0
def get_train_count(self):
return self.train_count
def train(self):
if self.remote_memory.length() < self.config.train_steps:
return {}
batchs = self.remote_memory.sample()
# ベースラインは報酬の平均
baseline = np.mean([b["reward"] for b in batchs])
# 状態,アクション毎の平均値を計算
target = {}
for batch in batchs:
state = batch["state"]
action = batch["action"]
reward = batch["reward"]
if state not in target:
target[state] = {}
if action not in target[state]:
target[state][action] = []
prob = self.parameter.get_probs(state)[action]
# ∇logπ
diff_logpi = 1 - prob
# ∇J(θ) = ∇logπ (Q - b)
diff_j = diff_logpi * (reward - baseline)
target[state][action].append(diff_j)
# 更新
loss_list = []
for state, v in target.items():
for action, diff_j_list in v.items():
diff_j = np.mean(diff_j_list)
# ポリシー更新
self.parameter.policy[state][action] += self.config.lr * diff_j
loss_list.append(abs(diff_j))
self.train_count += 1
return {"loss": np.mean(loss_list)}
class Worker(DiscreteActionWorker):
def __init__(self, *args):
super().__init__(*args)
self.config = cast(Config, self.config)
self.parameter = cast(Parameter, self.parameter)
self.remote_memory = cast(RemoteMemory, self.remote_memory)
def call_on_reset(self, state: np.ndarray, invalid_actions: List[int]) -> None:
self.state = str(state.tolist())
self.invalid_actions = invalid_actions
self.history = []
def call_policy(self, state: np.ndarray, invalid_actions: List[int]) -> int:
self.state = str(state.tolist())
probs = self.parameter.get_probs(self.state)
action = np.random.choice([a for a in range(self.config.action_num)], p=probs)
self.action = int(action)
return action
def call_on_step(
self,
next_state: np.ndarray,
reward: float,
done: bool,
next_invalid_actions: List[int],
):
if not self.training:
return {}
self.remote_memory.add(
{
"state": self.state,
"action": self.action,
"reward": reward,
}
)
return {}
def render_terminal(self, env, worker, **kwargs) -> None:
probs = self.parameter.get_probs(self.state)
vals = [0 if v is None else v for v in self.parameter.policy[self.state]]
maxa = np.argmax(vals)
for a in range(self.config.action_num):
if maxa == a:
s = "*"
else:
s = " "
s += f"{env.action_to_str(a)}: {probs[a]*100:5.1f}% ({vals[a]:.5f})"
print(s)
register(
Config,
__name__ + ":RemoteMemory",
__name__ + ":Parameter",
__name__ + ":Trainer",
__name__ + ":Worker",
)
# =----------------------------------
def main():
from envs import grid
env_config = srl.EnvConfig("Grid")
rl_config = Config()
config = runner.Config(env_config, rl_config)
# --- train
parameter, remote_memory, history = runner.train(config, max_episodes=500)
# --- eval
rewards = runner.evaluate(config, parameter, max_episodes=100)
print(f"Average reward for 100 episodes: {np.mean(rewards)}")
# --- rendering
runner.render(config, parameter)
if __name__ == "__main__":
main()
実行結果(すべて)
Average reward for 100 episodes: -1.1995999736338854
### 0, action 3, rewards [0.], next 0
env None
work0 None
......
. G.
. . X.
.P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 1, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 2, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 3, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 4, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 5, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
.P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 6, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 7, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
.P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 8, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
.P. X.
. .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 9, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
.P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 10, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 11, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 12, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 13, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
*←: 25.0% (0.00000)
↓: 25.0% (0.00000)
→: 25.0% (0.00000)
↑: 25.0% (0.00000)
### 14, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
←: 9.5% (0.08017)
↓: 9.2% (0.04137)
→: 32.2% (1.29974)
*↑: 49.1% (1.72022)
### 15, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 16, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 17, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 18, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 19, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 20, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
←: 9.5% (0.08017)
↓: 9.2% (0.04137)
→: 32.2% (1.29974)
*↑: 49.1% (1.72022)
### 21, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 22, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
←: 9.5% (0.08017)
↓: 9.2% (0.04137)
→: 32.2% (1.29974)
*↑: 49.1% (1.72022)
### 23, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 24, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 25, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
←: 9.5% (0.08017)
↓: 9.2% (0.04137)
→: 32.2% (1.29974)
*↑: 49.1% (1.72022)
### 26, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 27, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 28, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 29, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 30, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 31, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 32, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 33, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
←: 9.5% (0.08017)
↓: 9.2% (0.04137)
→: 32.2% (1.29974)
*↑: 49.1% (1.72022)
### 34, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 35, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 36, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 37, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 38, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 39, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 40, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
←: 9.5% (0.08017)
↓: 9.2% (0.04137)
→: 32.2% (1.29974)
*↑: 49.1% (1.72022)
### 41, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 42, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 43, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
←: 9.5% (0.08017)
↓: 9.2% (0.04137)
→: 32.2% (1.29974)
*↑: 49.1% (1.72022)
### 44, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 45, action 1, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 46, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P.
......
←: 9.7% (-0.83343)
*↓: 57.7% (0.95076)
→: 30.4% (0.30896)
↑: 2.2% (-2.31935)
### 47, action 3, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. . X.
. P .
......
←: 9.5% (0.08017)
↓: 9.2% (0.04137)
→: 32.2% (1.29974)
*↑: 49.1% (1.72022)
### 48, action 0, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 49, action 2, rewards [-0.04], next 0
env {}
work0 {}
......
. G.
. .PX.
. .
......
*←: 59.8% (-0.09895)
↓: 30.7% (-0.76681)
→: 4.5% (-2.68017)
↑: 5.0% (-2.58166)
### 50, action 2, rewards [-1.], done(env), next 0
env {}
work0 {}
......
. G.
. . P.
. .
......