LoginSignup
0
1

More than 5 years have passed since last update.

chainer3で足し算ゲームを、強化学習させるための初心者備忘録

Last updated at Posted at 2018-05-11

Chainerで機械学習と戯れる: 足し算ゲームをChainerを使って強化学習できるか? を読ませて頂き、
自分でも chainer3 で実装してみました。その時の自分用備忘録を残すために記述してみました。

お題: 足し算ゲーム by Chainer3

アクションを 1〜5 に設定することにしました。
アクションが 1〜4 の場合、7 から次の 7 まで 最善でも3手掛かりますが、
アクションが 1〜5 の場合、2 -> 7 -> 2 -> 7 -> 2 -> ・・・ の2手が最善手であるためです。
学習が進んだ後、最善手である2手の戦略を選ぶようになれば、OKだと思いました。

  • 状態S: 0~9 の整数
  • アクションA: 1~5の整数
  • 次状態S': (S + A) % 10
  • 報酬R:
    • +5 : 次状態S'が7の場合
    • -1 : 次状態S'が7以外だった場合

実験

以下のような方針で実験しました。

・ ニューラルネットのモデル以外の class は定義せず、一本調子なコードにしました。(自分のコーディング力の問題)
・ 学習 -> テスト -> 学習 -> テスト -> ・・・ とはせず、学習 -> テスト のように、ループは一周にしました。
・ ニューラルネットのモデルは何が良いかとか分かってないので、エイヤで作りました。
・ optimizers.SGD() を選択した根拠も特にありません。
・ Ubuntu 16.04 上で実行しました。
・ chainer のバージョンは、3.5.0 です。

test.py
import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import Variable, optimizers

ALPHA   = 0.1
GAMMA   = 0.99
EPSILON = 0.3
actions = [1, 2, 3, 4, 5]

class MyModel(chainer.Chain):

    def __init__(self):
        super(MyModel, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(10, 10)
            self.l2 = L.Linear(10, 10)
            self.l3 = L.Linear(10, len(actions))

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def select_action(model, state, is_training):
    if is_training and np.random.random() < EPSILON:
        return actions[np.random.randint(0, len(actions))]
    else:
        return actions[np.argmax(model(state).data)]

def act(state, action):
    state_num = np.argmax(state.data[0])
    new_state_num = int(state_num + action) % 10
    tmp_list = np.zeros(10)
    tmp_list[new_state_num] = 1
    new_state = Variable(np.array(tmp_list, dtype=np.float32).reshape(1, -1))
    return new_state_num, new_state

if __name__ == '__main__':
    model = MyModel()
    optimizer = optimizers.SGD()
    optimizer.setup(model)

    # 初期状態Sは 0 にした
    tmp_list = np.zeros(10)
    tmp_list[0] = 1
    state = Variable(np.array(tmp_list, dtype=np.float32).reshape(1, -1))

    # これから学習
    is_training = True
    for i in range(100000):
        action = select_action(model, state, is_training)
        new_state_num, new_state = act(state, action)

        reward = -1
        if new_state_num == 7:
            reward = 5

        target     = np.copy(model(state).data)
        target_ofs = ALPHA * (reward + GAMMA * np.argmax(model(new_state).data) - target[0][action-1])
        target[0][action-1] += target_ofs

        t = Variable(target)

        if i % 1000 == 0:
            print("t      ", t.data)
            print("i ", i, " state ", np.argmax(state.data[0]), " action ", action, " new_state_num ", new_state_num, " target_ofs ", target_ofs)

        model.zerograds()
        loss = F.mean_squared_error(model(state), t)
        loss.backward()
        optimizer.update()

        state = new_state

    # 学習後に実戦 (状態はそのまま引き継ぐ)
    is_training = False
    for i in range(100):
        action = select_action(model, state, is_training)
        new_state_num, new_state = act(state, action)

        print("*state ", np.argmax(state.data[0]), " action ", action, " new_state_num ", new_state_num)

        state = new_state

結果

結果の抜粋ですが、2 -> 7 -> 2 -> 7 -> 2 -> ・・・ というように2手の戦略を採用できているので、
概ね学習が出来たのかなと思いました。

ただ、手探りで作成してきたので、自分のコードが正しいのかよく分かっていません…
コードはそれほど長くないのですが、自分としては相当苦労しました…

*state  7  action  5  new_state_num  2
*state  2  action  5  new_state_num  7
*state  7  action  5  new_state_num  2
*state  2  action  5  new_state_num  7
*state  7  action  5  new_state_num  2
*state  2  action  5  new_state_num  7
*state  7  action  5  new_state_num  2
*state  2  action  5  new_state_num  7
*state  7  action  5  new_state_num  2
*state  2  action  5  new_state_num  7
*state  7  action  5  new_state_num  2
*state  2  action  5  new_state_num  7
*state  7  action  5  new_state_num  2
*state  2  action  5  new_state_num  7
*state  7  action  5  new_state_num  2
*state  2  action  5  new_state_num  7
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1