DeepLearningの勉強中なのですが、(全くわからないので)「とりあえず他人の模倣をする」という方針で、多くの皆様方がやっておられるOpenAIのCartPoleを強化学習ライブラリKeras-RLを使って学習させてみました。
#CartPoleとは
CartPoleは横に移動できるカートの上にポールが立っており、そのポールをいかに倒さず長く維持できるかを競うゲームになっています。言い換えるとCartPoleでは「4つの入力情報(カートの位置、カートの速度、ポールの角度、ポールの角速度)を与えて、2つの出力(左に押す、右に押す:押す強さは固定)のうちどちらか1つを選ぶ、DeepLearningのモデルを作りなさい」と言うことです。
#動作環境
本投稿では下のライブラリを使っています。Keras-RL2が強化学習用ライブラリで、gymがCartPoleが入った実験環境を提供してくれるライブラリです。
- Tensorflow 2.5.0
- Keras-RL2
- OpenAI gym
pip install tensorflow
pip install keras-rl2
pip install gym
なお、Keras-RL2はKeras-RLが入っていると(私の環境では)動作しませんでした。Keras-RLをuninstallしてからインストールしてください。
#プログラム
他のサイトを参考に作りました。深く理解はしていないのですが一応学習は成功しました。
import gym
from tensorflow.keras.models import Sequential,load_model
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization
from tensorflow.keras.optimizers import Adam
from rl.memory import SequentialMemory
from rl.policy import BoltzmannQPolicy
from rl.agents.dqn import DQNAgent
env = gym.make('CartPole-v0')
env.reset()
model = Sequential([
Flatten(input_shape=(1,4)),
Dense(32,activation='relu'),
BatchNormalization(),
Dense(32,activation='relu'),
BatchNormalization(),
Dense(32,activation='relu'),
BatchNormalization(),
Dense(2,activation='linear')
])
#model=load_model('cartpole') #保存したモデルをロードする時に使う
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model,nb_actions=2,memory=memory,nb_steps_warmup=10,target_model_update=1e-2,policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
dqn.fit(env,nb_steps=50000,visualize=True,verbose=2)
dqn.model.save('cartpole',overwrite=True)
dqn.test(env,nb_episodes=3,visualize=True)