3
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

強化学習をしてみた

Last updated at Posted at 2018-01-14

#はじめに
前回は強化学習っぽいことをしたので、今回は強化学習をしてみたいと思います。そもそもなぜ前回したことが強化学習っぽいかというと、強化学習は各行動ではなく連続した行動に対して報酬が与えられ適切な学習をしていきます。なので、前回のような行動に対して即報酬が与えらるようなものを強化学習と呼んでいいのかと思い強化学習っぽいと表現しました。

#問題設定
今回は以下のような環境でQ学習を使って適切な学習をするプログラムを書いてみます。
state.PNG

  • 状態2がスタート地点
  • 状態0と5がゴール
  • 状態5は天国、正の報酬が得られる。
  • 状態0が地獄、負の報酬が得られる。
  • 選択できる行動は右か左に動くことだけ。

すなわち、学習するプログラムが天国である状態5に行くための行動を適切に学習することが今回の目的です。

#Q学習のアルゴリズム

  1. Q値を乱数を用いて初期化
  2. 初期状態(状態2)に移る。
  3. Q値に基づいて次の行動を選択
  4. Q値の更新
  5. 次の状態に遷移
  6. 目標状態のとき、報酬が与えられ 2.に戻る。
  7. 目標状態ではないが、決められた回数の行動選択に達したとき 2.に戻る。

#プログラム
アルゴリズムに基づいてプログラムを書いていきます。
実装はpythonです。

##Q値の初期化

# Initial Q values
self._q = [[random.random() for _ in range(2)] \
                            for _ in range(6)]

ランダムモジュールのrandom関数を用いてQ値を初期化する。
今回の問題で存在する状態は6つです。それぞれの状態で選択できる行動は右か左に移動することだけなので2つになります。

##初期状態に移る

# Initial state
self._state = [2, 2]

状態を初期化する。ここで初期状態を状態2にセットしています。
1番目の要素は現在の状態、2番目の要素に次の状態を入れるようにします。

##Q値に基づいて次の行動を選択

def _select_action(self, state):
    if random.random() < EPSILON:
        return random.randrange(2)
    Q = self._q
    return 1 if Q[state][1] > Q[state][0] else 0

def _transition(self, action):
    state = self._state[0]
    if action:
        return state + 1
    return state - 1

# Decide action
state = self._state[0]
action = self._select_action(state)
# Get next state
self._state[1] = self._transition(action)

現在の状態から行動を選択します。行動選択にはεグリーディ法を用いる。
_select_actionで行動を選択、戻り値の1は右、0は左へ行くことを表しています。
行動を選択後、次の状態に移ります。
状態が2のとき右に行くと状態は3になるので+1、左に行くと状態は1になるので-1しています。

##Q値の更新

def _update(self, action):
    state, next_state = self._state
    Q = self._q
    if next_state in [0, 5]:
        reward = self._reward[next_state]
        return Q[state][action] + ALPHA * (reward - Q[state][action])
    
    next_action = 1 if Q[next_state][1] > Q[next_state][0] else 0
    return Q[state][action] + ALPHA * (GAMMA * Q[next_state][next_action] - Q[state][action])

# update q values
self._q[state][action] = self._update(action)

状態と選択した行動よりQ値を更新していきます。Q値の更新式は次のようになります。
$Q(s,a) \gets Q(s,a) + \alpha(r + \gamma maxQ(s_{next}, a_{next}) - Q(s,a))$

  • s : 状態
  • a : 状態sで選択した行動
  • r : 報酬(得られなければ0)
  • $\alpha$ : 学習係数
  • $\gamma$ : 割引率
  • $maxQ(s_{next}, a_{next})$ : 次の状態で選択できる行動に対するQ値のうちの最大の値

報酬が与えられるときと、与えられないときとでは上記のQ値の更新式は異なります。
報酬を与えられるときのQ値の更新式は次のようになります。
$Q(s,a) \gets Q(s,a) + \alpha(r - Q(s,a))$
報酬が与えられるとき次の状態に遷移しないので$maxQ(s_{next}, a_{next})$は計算しません。
今回の問題では状態が状態0か5ならば報酬が与えられます。今回は状態0ならー10、状態5に行くことができれば+10の報酬を設定しています。

報酬が与えられないとき報酬は0なので更新式は次のようになります。
$Q(s,a) \gets Q(s,a) + \alpha(\gamma maxQ(s_{next}, a_{next}) - Q(s,a))$

#学習の様子
学習を始めた直後はマイナスの報酬をもらってしまうことがありますが、学習が進むにつれだんだんプラスの報酬をもらうように行動を選択していくようになります。
turn1.PNG

turn2.PNG

#結果
下の2つは初期状態と学習終了後のQ値です。初期状態での状態2を見ると左を選択するときのQ値が高いことがわかります。しかし学習するにつれてプログラムが適切なQ値を学習していき学習終了後には状態2のQ値は右を選択するときのほうが高くなっています。ほかの状態でのQ値を見ても正しく学習していることがわかります。状態1~4では右を選択するときのQ値が高くなっています。

初期状態
  状態0      状態1      状態2     状態3      状態4      状態5
0.836 0.992     0.662 0.390     0.618 0.108     0.969 0.519     0.051 0.881     0.217 0.791
学習終了後
  状態0      状態1      状態2     状態3      状態4      状態5
0.836 0.992     -10.000 5.909   3.319 7.566     6.509 8.745     7.676 10.000    0.217 0.791

#全体のコード
Windowsユーザの方はbash on Windowsで実行してください。

import argparse
import random
import sys
import time

random.seed()
out = sys.stdout

ALPHA = 0.1
EPSILON = 0.3
GAMMA = 0.9


class Agent:
    def __init__(self, num_train, limit):
        self._num_train = num_train
        self._limit = limit
        # Initial Q values
        self._q = [[random.random() for _ in range(2)] for _ in range(6)]
        # state
        self._state = [2, 2]
        # reward
        self._reward = {0: -10, 5: 10}
        # Training process
        self._console_format = ["\r\033[5A\033[KTraining {}\n",
                                "Turn {}\n"
                                "{:^3d}{:^2d}{:^2d}{:^2d}{:^2d}{:^2d}\033[10D\n",
                                "-------------\n",
                                "|{}|{}|{}|{}|{}|{}|\n",
                                "-------------"]

    def _update(self, action):
        state, next_state = self._state
        Q = self._q
        if next_state in [0, 5]:
            reward = self._reward[next_state]
            return Q[state][action] + ALPHA * (reward - Q[state][action])
        next_action = 1 if Q[next_state][1] > Q[next_state][0] else 0
        return Q[state][action] + ALPHA * (GAMMA * Q[next_state][next_action] - Q[state][action])

    def _select_action(self, state):
        if random.random() < EPSILON:
            return random.randrange(2)
        Q = self._q
        return 1 if Q[state][1] > Q[state][0] else 0

    def _transition(self, action):
        state = self._state[0]
        if action:
            return state + 1
        return state - 1

    def _get_symbol(self):
        symbol = self._state[0]
        items = []
        for i in range(6):
            if symbol == i:
                if i not in [0, 5]:
                    items.append("o")
                elif i == 0:
                    items.append("x")
                else:
                    items.append("O")
            else:
                items.append(" ")
        return items

    def _console(self, i, t):
        symbols = self._get_symbol()
        console = "".join(self._console_format)
        out.write(console.format(i, t, *list(range(6)), *symbols))
        time.sleep(0.5)

    def _print_q(self):
        for i in range(6):
            print("{:.3f} {:.3f}\t".format(*self._q[i]), end='')
        print("\n"+"#"*30)

    def train(self):
        #self._print_q()
        for i in range(self._num_train):
            self._state = [2, 2]
            print("\n"*5)
            for t in range(self._limit):
                # Decide action
                state = self._state[0]
                action = self._select_action(state)
                # Get next state
                self._state[1] = self._transition(action)
                # Training process
                self._console(i+1, t)
                # update q values
                self._q[state][action] = self._update(action)
                # update state
                self._state[0] = self._state[1]

                if self._state[0] in [0, 5]:
                    self._console(i+1, t+1)
                    break
            #self._print_q()
            out.write("\033[6A\033[K\033[10D")
        out.write("\033[10D\033[J")
        out.flush()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_train', '-n', type=int, default=10,
                        help="Number of training.")
    parser.add_argument('--limit', '-l', type=int, default=10,
                        help="Limit.")
    args = parser.parse_args()
    agent = Agent(args.num_train, args.limit)
    agent.train()
3
9
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?