LoginSignup
2
2

More than 5 years have passed since last update.

OpenAI GymのCartPole-v0をDQNで解く

Last updated at Posted at 2017-08-15

OpenAI GymのCartPole-v0をKeras-RLのサンプルDQN1で解こうとしてみました.

  • DQN版2とDuel-DQN版3があり,DQNAgentのコンストラクタから設定可能

cartpole.gif

コード

DQN版

import numpy as np
import gym
from gym import wrappers
from keras.layers import Flatten, Dense, Input
from keras.models import Model
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory

def build_model(input_dim, output_dim):
    x_input = Input(shape=(1, input_dim))
    x = Flatten()(x_input)
    x = Dense(16, activation="relu")(x)
    x = Dense(16, activation="relu")(x)
    x = Dense(16, activation="relu")(x)
    x = Dense(output_dim, activation="linear")(x)
    return Model(inputs=x_input, outputs=x)

def run():
    env = gym.make("CartPole-v0")
    env = wrappers.Monitor(env, '/tmp/cartpole-v0-dqn', force=True)
    model = build_model(env.observation_space.shape[0], env.action_space.n)
    memory = SequentialMemory(limit=50000, window_length=1)
    policy = BoltzmannQPolicy()
    dqn = DQNAgent(model=model, nb_actions=env.action_space.n, memory=memory, policy=policy,)
    dqn.compile("adam", metrics=["mae"])
    dqn.fit(env, nb_steps=50000, visualize=False, verbose=2)

if __name__ == "__main__":
    run()

Duel-DQN版

import numpy as np
import gym
from gym import wrappers
from keras.layers import Flatten, Dense, Input
from keras.models import Model
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory

def build_model(input_dim, output_dim):
    x_input = Input(shape=(1, input_dim))
    x = Flatten()(x_input)
    x = Dense(16, activation="relu")(x)
    x = Dense(16, activation="relu")(x)
    x = Dense(16, activation="relu")(x)
    x = Dense(output_dim, activation="linear")(x)
    return Model(inputs=x_input, outputs=x)

def run():
    env = gym.make("CartPole-v0")
    env = wrappers.Monitor(env, '/tmp/cartpole-v0-duel-dqn', force=True)
    model = build_model(env.observation_space.shape[0], env.action_space.n)
    memory = SequentialMemory(limit=50000, window_length=1)
    policy = BoltzmannQPolicy()
    dqn = DQNAgent(model=model, nb_actions=env.action_space.n, memory=memory, policy=policy, enable_dueling_network=True, dueling_type="avg")
    dqn.compile("adam", metrics=["mae"])
    dqn.fit(env, nb_steps=50000, visualize=False, verbose=2)

if __name__ == "__main__":
    run()

スコア

  • 50,000ステップほど試してみましたが,両方とも解けずじまいでした.
  • コンストラクタ時のパラメータを設定すれば解けるかもしれないが, Keras==2.0.6 だと,まともに動かなくなっている模様.
DQN: 32.98 ± 2.91
Duel-DQN: 42.46 ± 3.83

References

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2