2
2

More than 3 years have passed since last update.

CartPoleを強化学習ライブラリKeras-rlを使って学習させてみた

Last updated at Posted at 2021-08-12

DeepLearningの勉強中なのですが、(全くわからないので)「とりあえず他人の模倣をする」という方針で、多くの皆様方がやっておられるOpenAIのCartPoleを強化学習ライブラリKeras-RLを使って学習させてみました。

CartPoleとは

CartPoleは横に移動できるカートの上にポールが立っており、そのポールをいかに倒さず長く維持できるかを競うゲームになっています。言い換えるとCartPoleでは「4つの入力情報(カートの位置、カートの速度、ポールの角度、ポールの角速度)を与えて、2つの出力(左に押す、右に押す:押す強さは固定)のうちどちらか1つを選ぶ、DeepLearningのモデルを作りなさい」と言うことです。

cartpole.png

動作環境

本投稿では下のライブラリを使っています。Keras-RL2が強化学習用ライブラリで、gymがCartPoleが入った実験環境を提供してくれるライブラリです。

  • Tensorflow 2.5.0
  • Keras-RL2
  • OpenAI gym
pip install tensorflow
pip install keras-rl2
pip install gym

なお、Keras-RL2はKeras-RLが入っていると(私の環境では)動作しませんでした。Keras-RLをuninstallしてからインストールしてください。

プログラム

他のサイトを参考に作りました。深く理解はしていないのですが一応学習は成功しました。

import gym
from tensorflow.keras.models import Sequential,load_model
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization
from tensorflow.keras.optimizers import Adam
from rl.memory import SequentialMemory
from rl.policy import BoltzmannQPolicy
from rl.agents.dqn import DQNAgent

env = gym.make('CartPole-v0') 
env.reset()

model = Sequential([
    Flatten(input_shape=(1,4)),
    Dense(32,activation='relu'),
    BatchNormalization(),
    Dense(32,activation='relu'),
    BatchNormalization(),
    Dense(32,activation='relu'),
    BatchNormalization(),
    Dense(2,activation='linear')
])
#model=load_model('cartpole')     #保存したモデルをロードする時に使う

memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model,nb_actions=2,memory=memory,nb_steps_warmup=10,target_model_update=1e-2,policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
dqn.fit(env,nb_steps=50000,visualize=True,verbose=2)
dqn.model.save('cartpole',overwrite=True)
dqn.test(env,nb_episodes=3,visualize=True)
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2