1. ベルマン方程式
任意の時刻 $t$ において選択可能な行動 $a_t$ は、その時の状態 $x_t$ の関数であるため、$a_t \in \Gamma(x_t)$と書ける。 時刻 $t$ で行動 $a_t$をとったとき、時刻 $t$+1 における状態への遷移確率は $\pi(a_t \mid x_{t}) $で与えられるとする。この時、得られる利潤を $r_t = r(x_t, a_t)$ とする。割引率を $\gamma \in (0,1)$ とすると、時刻 $t$ における報酬関数 $R_t$ は、以下のように書ける。
R_t = \sum_{j=0}^{T} \gamma^{j} r_{t+j+1}
簡単のため、以下のように記載することにする。
\begin{eqnarray}
r(x, a, x') &=& \Bbb{E} \left[ r_{t+1} \mid x_t=x, a_t =a, x_{t+1}=x' \right]
\\\
\mathcal{p}(x'\mid x, a)
&=& \Bbb{P} \left[ x_{t+1}=x'\mid x_t=x, a_t =a \right]
\end{eqnarray}
価値関数 $V=V(x)$、及び行動価値関数 $Q=Q(x,a)$ を報酬関数の期待値として定義する。
\begin{eqnarray}
V_{\pi}(x) &=& \Bbb{E}_{\pi} \left[ R_t \mid x=x_t \right]
\\\
&=& \Bbb{E}_{\pi} \left[ \sum_{j=0}^{T} \gamma^{j} r_{t+j+1}\mid x_t=x \right]
\\\
Q_{\pi}(x, a) &=& \mathbb{E}_{\pi} \left[ R_t \mid x_t=x ,~ a_t = a \right]
\\\
&=& \Bbb{E}_{\pi} \left[ \sum_{j=0}^{T} \gamma^{j} r_{t+j+1}\mid x_t=x ,~ a_t = a \right]
\end{eqnarray}
ここで、価値関数 $V=V(x)$ の定義から以下のように書き下すことができる。
\begin{eqnarray}
V_{\pi}(x) &=& \Bbb{E}_{\pi} \left[ R_t \mid x=x_t \right]
\\\
&=&
\Bbb{E}_{\pi} \left[ r_{t+1} + \gamma \,R_{t+1} \mid x=x_t \right]
\\\
&=&
\sum_{a \in \Gamma(x)} \pi(a \mid x)
\sum_{x'} \mathcal{p}(x' \,|\, x, ~a )\, r(x, a, x')
\\\
&+&
\gamma \, \sum_{a \in \Gamma(x)} \pi(a \mid x)
\sum_{x'} \mathcal{p}(x' \,|\, x, ~a )\, V_{\pi}(x')
\end{eqnarray}
同様にして、行動価値関数 $Q=Q(x,a)$ の定義から
\begin{eqnarray}
Q_{\pi}(x,a) &=& \Bbb{E}_{\pi} \left[ R_t \mid x=x_t , a=a_t \right]
\\\
&=&
\Bbb{E}_{\pi} \left[ r_{t+1} + \gamma \,R_{t+1} \mid
x=x_t, a=a_t \right]
\\\
&=&
\sum_{x'} \mathcal{p}(x' \,|\, x, ~a )\, r(x, a, x')
+
\gamma \,\sum_{x'} \mathcal{p}(x' \,|\, x, ~a )\, V_{\pi}(x')
\end{eqnarray}
を得る。
ここで、最適解を求めるため、価値関数 $V$ と $Q$ の最大値をとる。
\begin{eqnarray}
V_*(x) &=& \max_{\pi} V_{\pi}(x)
\\\
Q_*(x, a) &=& \max_{\pi} Q_{\pi}(x, a)
\end{eqnarray}
このとき、最適化された価値関数 $V$ と $Q$ は、以下の式で与えられる。
\begin{eqnarray}
V_{*}(x) &=& \max_{a\in \Gamma(x)} \sum_{x'} \mathcal{P}(x'\,|\, x, \,a) \left[ r(x,a) + \gamma V_{*}(x')
\right]
\\\
Q_{*}(x, a) &=& \sum_{x'} \mathcal{P}(x'\,|\, x, ~a )
\left[ \, r(x,a)
+ \gamma \max_{a'\in \Gamma(x')} Q_{*}(x',a') \, \right]
\end{eqnarray}
以下、簡単のため、マルコフ決定過程に従う状況、すなわち、状態の遷移確率 $\mathcal{p}(x'\mid x,~a)$ がマルコフ性を持つ場合を考える。このとき、価値関数は行動価値関数を最大化する行動を選択することで得られる。
V_{*}(x) = \max_{a\in \Gamma(x)} Q_{*}(x, a)
この式を上記の行動価値関数が従う方程式に代入すると、以下のベルマン方程式を得る。
Q_{*}(x,a) = r(x, a) + \gamma \max_{a'\in \Gamma(x')} Q_{*}(x',a')
2. Q学習
ベルマン方程式の解法の一つは、与えられた初期条件について、繰り返し、Q値を更新していく方法であり、この解法はQ学習と呼ばれている。具体的には、Q学習は下記のように書き下すことができる。
Q(x_t, a_t) = Q(x_t, a_t) + \alpha \left[ r_{t+1} + \gamma \max_{a\in \Gamma(x_{t+1})} Q(x_{t+1},a)
- Q(x_{t},a_{t}) \right]
ここで、$\alpha$ は学習率と呼ばれるパラメータで、以下の条件を満たすものである。
\sum_{t=0}^{\infty} Q(t) \rightarrow \infty
\\\
\sum_{t=0}^{\infty} Q(t)^2 < \infty
この条件を満たすとき、Q値は確率 1 で最適な評価値に収束することが示されている。
import gym
import collections
from torch.utils.tensorboard import SummaryWriter
ENV_NAME = "FrozenLake-v1"
GAMMA = 0.9
TEST_EPISODES = 20
REWARD_GOAL = 0.8
N =100
class Agent:
def __init__(self):
self.env = gym.make(ENV_NAME)
self.state = self.env.reset()
self.rewards = collections.defaultdict(float)
self.transits = collections.defaultdict(
collections.Counter)
self.values = collections.defaultdict(float)
def play_n_random_steps(self, count):
for _ in range(count):
action = self.env.action_space.sample()
new_state, reward, is_done, _ = self.env.step(action)
self.rewards[(self.state, action, new_state)] = reward
self.transits[(self.state, action)][new_state] += 1
if is_done:
self.state = self.env.reset()
else:
self.state = new_state
def calc_action_value(self, state, action):
target_counts = self.transits[(state, action)]
total = sum(target_counts.values())
action_value = 0.0
for tgt_state, count in target_counts.items():
reward = self.rewards[(state, action, tgt_state)]
val = reward + GAMMA * self.values[tgt_state]
action_value += (count / total) * val
return action_value
def select_action(self, state):
best_action, best_value = None, None
for action in range(self.env.action_space.n):
action_value = self.values[(state, action)]
if best_value is None or best_value < action_value:
best_value = action_value
best_action = action
return best_action
def value_iteration_for_Q(self):
for state in range(self.env.observation_space.n):
for action in range(self.env.action_space.n):
action_value = 0.0
target_counts = self.transits[(state, action)]
total = sum(target_counts.values())
for tgt_state, count in target_counts.items():
key = (state, action, tgt_state)
reward = self.rewards[key]
best_action = self.select_action(tgt_state)
val = reward + GAMMA * \
self.values[(tgt_state, best_action)]
action_value += (count / total) * val
self.values[(state, action)] = action_value
test_env = gym.make(ENV_NAME)
agent = Agent()
writer = SummaryWriter()
iter_no = 0
best_reward = 0.0
while best_reward < REWARD_GOAL:
agent.play_n_random_steps(N)
agent.value_iteration_for_Q()
iter_no += 1
reward_test = 0.0
for _ in range(TEST_EPISODES):
total_reward = 0.0
state = test_env.reset()
while True:
action = agent.select_action(state)
new_state, new_reward, is_done, _ = test_env.step(action)
total_reward += new_reward
if is_done: break
state = new_state
reward_test += total_reward
reward_test /= TEST_EPISODES
writer.add_scalar("reward", reward_test, iter_no)
if reward_test > best_reward:
print("Best reward updated %.2f at iteration %d " % (reward_test ,iter_no) )
best_reward = reward_test
writer.close()
3. モンテカルロ法による解法
import sys
import gym
import numpy as np
from collections import defaultdict
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
env = gym.make('Blackjack-v1')
state = env.reset()
while True:
print('state:', state)
action = env.action_space.sample()
print('action: ', action)
state, reward, done, info = env.step(action)
if done:
if reward > 0:
print('Reward: ', reward)
else:
print('Reward: ', reward)
break
def generate_episode(env):
episode = []
state = env.reset()
while True:
probs = [0.75, 0.25] if state[0] > 18 else [0.25, 0.75]
action = np.random.choice(np.arange(2), p=probs)
next_state, reward, done, info = env.step(action)
episode.append((state, action, reward))
state = next_state
if done:
break
return episode
episode = generate_episode(env)
states, actions, rewards = zip(*episode)
def mc_prediction(env, num_episodes, generate_episode, gamma=1.0):
returns_sum = defaultdict(lambda: np.zeros(env.action_space.n))
N = defaultdict(lambda: np.zeros(env.action_space.n))
Q = defaultdict(lambda: np.zeros(env.action_space.n))
for episode in range(1, num_episodes+1):
if episode % 10000 == 0: # monitor progress
print("\repisode {}/{}.".format(episode, num_episodes), end="")
sys.stdout.flush()
episode = generate_episode(env)
states, actions, rewards = zip(*episode)
discounts = np.array([gamma**i for i in range(len(rewards)+1)])
for i, state in enumerate(states):
returns_sum[state][actions[i]] += sum(rewards[i:]*discounts[:-(1+i)])
N[state][actions[i]] += 1.0
Q[state][actions[i]] = returns_sum[state][actions[i]] / N[state][actions[i]]
return Q
num_episodes=10000
Q = mc_prediction(env, num_episodes, generate_episode)
State_Value_table={}
for state, actions in Q.items():
State_Value_table[state]= (state[0]>18)*(np.dot([0.75, 0.25],actions)) + (state[0]<=18)*(np.dot([0.75, 0.25],actions))
def plot_blackjack_values(V):
def get_Z(x, y, usable_ace):
if (x,y,usable_ace) in V:
return V[x,y,usable_ace]
else:
return 0
def get_figure(usable_ace, ax):
x_range = np.arange(11, 22)
y_range = np.arange(1, 11)
X, Y = np.meshgrid(x_range, y_range)
Z = np.array([get_Z(x,y,usable_ace) for x,y in zip(np.ravel(X), np.ravel(Y))]).reshape(X.shape)
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.coolwarm, vmin=-1.0, vmax=1.0)
ax.set_xlabel('Player\'s Current Sum')
ax.set_ylabel('Dealer\'s Showing Card')
ax.set_zlabel('State Value')
ax.view_init(ax.elev, -120)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(211, projection='3d')
ax.set_title('Usable Ace')
get_figure(True, ax)
ax = fig.add_subplot(212, projection='3d')
ax.set_title('No Usable Ace')
get_figure(False, ax)
plt.show()
plot_blackjack_values(State_Value_table)