LoginSignup
0
0

More than 5 years have passed since last update.

ChainerRLでBreakout

Posted at

chainerRLでBreakout v0をCNNでやってみました

実行はできるのですが,agentの学習が進まず,Rewardが取れません.
初心者故,なぜ学習が進まないのか教えて頂きたいです.

python2.7で以下のコードを使用しました.
アドバイスお願い致します

import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np

from chainer import cuda

import datetime
from skimage.color import rgb2gray
from skimage.transform import resize

env = gym.make('Breakout-v0')
obs = env.reset()

print("observation space : {}".format(env.observation_space))
print("action space : {}".format(env.action_space))

action = env.action_space.sample()
obs, r, done, info = env.step(action)
class QFunction(chainer.Chain):
def init(self,obs_size, n_action):
super(QFunction, self).init(
l1=L.Convolution2D(obs_size, 4, ksize=2,pad=1),#210x160
bn1=L.BatchNormalization(4),
l2=L.Convolution2D(4, 4, ksize=2,pad=1),#105x80
bn2=L.BatchNormalization(4),
#l3=L.Convolution2D(64, 64, ksize=2, pad=1),#100x100
#bn3=L.BatchNormalization(64),
#l4=L.Convolution2D(64, 3, ksize=2,pad=1),#50x50
# bn4=L.BatchNormalization(3),

    l5=L.Linear(972, 512),
    out=L.Linear(512, n_action, initialW=np.zeros((n_action, 512), dtype=np.float32))
)

def call(self, x, test=False):

h1=F.relu(self.bn1(self.l1(x)))
h2=F.max_pooling_2d(F.relu(self.bn2(self.l2(h1))),2)
#h3=F.relu(self.bn3(self.l3(h2)))
#h4=F.max_pooling_2d(F.relu(self.bn4(self.l4(h3))),2)
#print h4.shape

return chainerrl.action_value.DiscreteActionValue(self.out(self.l5(h2)))

n_action = env.action_space.n
obs_size = env.observation_space.shape[0] #(210,160,3)
q_func = QFunction(obs_size, n_action)

optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)

gamma = 0.99

explorer = chainerrl.explorers.ConstantEpsilonGreedy(
epsilon=0.2, random_action_func=env.action_space.sample)

replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)

phi = lambda x: x.astype(np.float32, copy=False)
agent = chainerrl.agents.DoubleDQN(
q_func, optimizer, replay_buffer, gamma, explorer,
minibatch_size=4, replay_start_size=100, update_interval=10,
target_update_interval=10, phi=phi)

last_time = datetime.datetime.now()
n_episodes = 10000
for i in range(1, n_episodes + 1):
obs = env.reset()

reward = 0
done = False
R = 0

while not done:
env.render()
action = agent.act_and_train(obs, reward)
obs, reward, done, _ = env.step(action)

if reward != 0:
    R += reward

elapsed_time = datetime.datetime.now() - last_time
print('episode:', i,
'reward:', R,
)
last_time = datetime.datetime.now()

if i % 100 == 0:
filename = 'agent_Breakout' + str(i)
agent.save(filename)

agent.stop_episode_and_train(obs, reward, done)
print('Finished.')

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0