Help us understand the problem. What is going on with this article?

Kerasで最短で強化学習(reinforcement learning)する with OpenAI Gym

More than 3 years have passed since last update.

はじめに

強化学習よくわからないけど,理論そっちのけでとりあえずパパッと動かして見たいせっかちな人向けです.つまり僕のような人間です.

OpenAI Gymで,強化学習の環境を提供してくれているので,それを用います.
OpenAI Gymはあくまでも環境だけで,実際に学習させるには他のものが必要です.
調べて見ると,Kerasで強化学習をやるkeras-rlを書いた人がいて,これを使うと簡単に試せそうだったので使います.先人に感謝.

環境の準備

今回の環境

  • Python 3.6.0 :: Anaconda 4.3.1 (x86_64)
  • Mac OS Sierra 10.12.5
  • keras 2.0.5 (backend tensorflow)
  • tensorflow 1.2.0

最初はディスプレイのないサーバーでやっていましたが,めんどくさかったので,ローカル環境でやりました.
ちなみにディスプレイのないサーバーでもXvfbで頑張ればいけそうです.仮想メモリー上で,ディスプレイを再現してくれるやつらしいです.

インストール

pip install gym
pip install keras-rl

インストールは両方ともpipでいけます.
kerasは入っているものとします.

CartPole

CartPoleとは

CartPoleとは,カートの上にポールが乗っていて,それを倒さないようにカートを動かしてバランスをとるゲーム(?)です.

これです.

Screen Shot 2017-07-23 at 1.44.51.png

カートは左右にしか動けません.なので,カートが取れる行動としては,右と左の2値です.
現在の環境に応じて,右か左か選択して,いい感じにバランスをとります.これは以下のように確認できます.

import gym
env = gym.make('CartPole-v0')
env.action_space
# Discrete(2)

env.action_space.sample()
# 0

また,カートが得ることができる環境についての情報は,

env.observation_space
# Box(4,)

env.observation_space.sample()
# array([  4.68609638e-01, 1.46450285e+38, 8.60908446e-02, 3.05459097e+37])

この4つの値です.それぞれ順番に,カートの場所,カートの速度,ポールの角度,ポールが回転する速さだそうです.(カートとポール早すぎじゃね?)
sample()メソッドは適当に行動や環境をサンプリングするためのメソッドです.

DQN example

keras-rlにこれをDQNでやるexampleがあるので,それをそのまま使います.
この記事を書くにあたって図が欲しかったので2行だけ追加してます.(追加と書いてあるところ)

DQNについては
[Python]強化学習(DQN)を実装しながらKerasに慣れる
ゼロからDeepまで学ぶ強化学習
あたりが参考になります.

行動価値関数をディープニューラルネットにしたものだそうです.ここでいうと,ポールが右に倒れているときは,カートを右に動かす行動の方が価値が高い,というようなことを表す関数の部分です.

import numpy as np
import gym
from gym import wrappers # 追加

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory

ENV_NAME = 'CartPole-v0'

# Get the environment and extract the number of actions.
env = gym.make(ENV_NAME)
env = wrappers.Monitor(env, './CartPole') # 追加
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

# Next, we build a very simple model.
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
               target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

# Okay, now it's time to learn something! We visualize the training here for show, but this
# slows down training quite a lot. You can always safely abort the training prematurely using
# Ctrl + C.
dqn.fit(env, nb_steps=50000, visualize=True, verbose=2)

# After training is done, we save the final weights.
dqn.save_weights('dqn_{}_weights.h5f'.format(ENV_NAME), overwrite=True)

# Finally, evaluate our algorithm for 5 episodes.
dqn.test(env, nb_episodes=5, visualize=True)

このexampleでは,BoltzmannQPolicy()という方策を使っていますが,これはこれからの強化学習によると,行動を選択するときに行動価値関数の値のソフトマックス関数で決めるものだそうです.行動価値があるほどよく選ぶということですね.

結果

1エピソード目

openaigym.video.0.43046.video000001.gif

エピソードとは,強化学習の学習単位で,ゲームの勝敗が明らかになるまでが1エピソードです.そしてこれは1エピソード目の結果なので,まだ何も学習しておらず,完全にランダムです.

ポールが右に倒れそうなのに,カートは左に動いていますね.

なんかカクカクしているのは,CartPoleが15度以上傾くとゲーム終了となるので,そこから先は描画されていないためです.あと,左右に大きく動きすぎても終了します.

216エピソード目

openaigym.video.0.43046.video000216.gif

おぉ...なかなか持ちこたえている...

終わりに

kazetof
データ分析とかやってます.
https://kazetof.github.io/blog/
emcjpn
バイタルセンシング、IoT、データ分析、A.I.などを用いて、ヘルスケアにイノベーションを起こすことを目的とするスタートアップ
https://www.emcjpn.com/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away