LoginSignup
19
19

More than 5 years have passed since last update.

強化学習 : SarsaとQ学習

Last updated at Posted at 2018-03-25

この記事は、強化学習の代表的なアルゴリズムであるSarsaとQ学習、2つのアルゴリズムの違いについてのまとめです。

SarsaとQ学習の違いは一言で言うと、Q関数の更新にエージェントのポリシー(方策)を用いるかどうか、です。まず、ポリシーとQ関数の関係について確認しておきます。

ポリシーとは、ある状態$s$が与えられたときにとる行動を返す関数$\pi(s)$です。次に得る利得が最も高くなりそうなポリシーは、基本的に状態$s$と行動$a$によって定まるQ関数$Q(s,a)$を$a$について最大化するようなものでしょう (greedy法):

$$\pi (s) := \rm{arg} \max_a Q(s,a)
$$

しかし、このポリシーではQ関数の初期値によってはまだとったことのない行動をとらないまま局所解に陥ってしまうかもしれません。そこで、例えば確率$\varepsilon$ですべての行動から一様にランダムに行動を決める、というポリシーも考えられます (ε-greedy法)。これは、$[0,1)$の一様乱数$x$を使って、

\pi'(s) := 
\left\{
\begin{array}{ll}
{\rm arg}\max_a Q(s,a) & {\rm if} \ x > \varepsilon \\
{\rm select} \ a \ {\rm randomly} & {\rm else}
\end{array}
\right.

ということです。この他にも、様々な観点からみた良いポリシーが多くあります。ポイントは、Q関数にもとづいてどのようなポリシーをとるかはアルゴリズムの設計者が決めるということです。

これを踏まえたうえで、SarsaとQ学習の違いを具体的に見ていきます。Sarsaでは現在の状態$S$とそこでとる行動$A$、行動によって得られた利得$R$、次の状態$S'$が分かったときに、現在の状態と行動についてのQ関数$Q(S,A)$の更新を、次の状態$S'$と$S'$が与えられたときポリシーによって決めたエージェントがとる行動$A'$によって決まるQ関数$Q(S',A')$を使って行います。$(S,A,R,S',A')$の組を使ってQ関数を更新するので、Sarsaと呼ばれます。アルゴリズムの疑似コードは次のようになります( [1] から引用)

スクリーンショット 2018-03-25 17.30.35.jpg

一方、Q学習では$(S,A,R,S')$が分かったときのQ関数の更新に、$S'$が与えられた時のQ関数の最大値を使います。アルゴリズムの疑似コードは次のようになります( [1] から引用)
スクリーンショット 2018-03-25 17.30.52.jpg

疑似コードから、もしポリシーにgreedy法を採用した場合、SarsaとQ学習は一致することが分かります。

実装の例として、[2] の記事のCardPoleの学習をQ学習で行ったコードをSarsaに変更し、jupyter notebookで実行しやすいように改変したものを以下にのせておきます。

import matplotlib.pyplot as plt

# Imports specifically so we can render outputs in Jupyter.
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

import gym
import numpy as np
import random
import math
from time import sleep

def display_frames_as_gif(frames):
    """
    Displays a list of frames as a gif, with controls
    """
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    display(display_animation(anim, default_mode='loop'))



## Initialize the "Cart-Pole" environment
env = gym.make('CartPole-v0')

## Defining the environment related constants

# Number of discrete states (bucket) per state dimension
NUM_BUCKETS = (1, 1, 6, 3)
# Number of discrete actions
NUM_ACTIONS = env.action_space.n
# Bounds for each discrete state
STATE_BOUNDS = list(zip(env.observation_space.low, env.observation_space.high))
STATE_BOUNDS[3] = [-math.radians(50), math.radians(50)]
# Index of the action
ACTION_INDEX = len(NUM_BUCKETS)


## Creating a Q-Table for each state-action pair
q_table = np.zeros(NUM_BUCKETS + (NUM_ACTIONS,))

## Learning related constants
MIN_EXPLORE_RATE = 0.01
MIN_LEARNING_RATE = 0.1

## Defining the simulation related constants
NUM_EPISODES = 1000
MAX_T = 250
STREAK_TO_END = 120
SOLVED_T = 199

def simulate():

    ## Instantiating the learning related parameters
    learning_rate = get_learning_rate(0)
    explore_rate = get_explore_rate(0)
    discount_factor = 0.99  # since the world is unchanging

    num_streaks = 0

    for episode in range(NUM_EPISODES):

        # Reset the environment
        obv = env.reset()

        # the initial state
        state = state_to_bucket(obv)
        action = select_action(state, explore_rate)

        for t in range(MAX_T):
            # Execute the action
            obv, reward, done, _ = env.step(action)

            # Observe the result
            next_state = state_to_bucket(obv)
            next_action = select_action(next_state, explore_rate)

            # Update the Q based on the result
            next_q = q_table[next_state+(next_action,)]
            q_table[state + (action,)] += learning_rate*(reward + discount_factor*(next_q) - q_table[state + (action,)])

            # Setting up for the next iteration
            state = next_state
            action = next_action

            if done:
                print("Episode %d finished after %f time steps" % (episode, t))
                if (t >= SOLVED_T):
                    num_streaks += 1
                else:
                    num_streaks = 0
                break

        # It's considered done when it's solved over 120 times consecutively
        if num_streaks > STREAK_TO_END:
            break

        # Update parameters
        explore_rate = get_explore_rate(episode)
        learning_rate = get_learning_rate(episode)

def select_action(state, explore_rate):
    # Select a random action
    if random.random() < explore_rate:
        action = env.action_space.sample()
    # Select the action with the highest q
    else:
        action = np.argmax(q_table[state])
    return action


def get_explore_rate(t):
    return max(MIN_EXPLORE_RATE, min(1, 1.0 - math.log10((t+1)/25)))

def get_learning_rate(t):
    return max(MIN_LEARNING_RATE, min(0.5, 1.0 - math.log10((t+1)/25)))

def state_to_bucket(state):
    bucket_indice = []
    for i in range(len(state)):
        if state[i] <= STATE_BOUNDS[i][0]:
            bucket_index = 0
        elif state[i] >= STATE_BOUNDS[i][1]:
            bucket_index = NUM_BUCKETS[i] - 1
        else:
            # Mapping the state bounds to the bucket array
            bound_width = STATE_BOUNDS[i][1] - STATE_BOUNDS[i][0]
            offset = (NUM_BUCKETS[i]-1)*STATE_BOUNDS[i][0]/bound_width
            scaling = (NUM_BUCKETS[i]-1)/bound_width
            bucket_index = int(round(scaling*state[i] - offset))
        bucket_indice.append(bucket_index)
    return tuple(bucket_indice)

simulate()

obs_current = env.reset()
state = state_to_bucket(obs_current)
frames = []
for i in range(1000):
    frames.append(env.render(mode = 'rgb_array'))
    action = select_action(state, MIN_EXPLORE_RATE)
    obs_next, reward, done, info = env.step(action) 
    if done:
        print(i)
        break
    state = state_to_bucket(obs_next)
env.render(close=True)
display_frames_as_gif(frames)

参考文献

19
19
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
19