LoginSignup
15
13

More than 5 years have passed since last update.

Open AI GymのCartPoleコードをいじりながら仕組みを学ぶ(2)

Last updated at Posted at 2016-09-07

過去6回で、Ubuntu14.04、CUDA、chainer、dqn、LIS、Tensorflow、Open AI Gymを順次インストールし、最後にOpen AI Gymのサンプルコードをちょっといじった。
http://qiita.com/masataka46/items/cc37d36137a4a162c04a

今回も前回と同様、Open AI GymのHPに載ってるCartPoleゲームのサンプルコードをいじりながら、仕組みを学んでいく。公式HPはこちら。
https://gym.openai.com/docs

CartPole-v0のサンプルコード(test06.py)はこちら。

import gym
env = gym.make('CartPole-v0')
for i_episode in range(20):
    observation = env.reset()
    for t in range(100):
        env.render()
        print(observation)
        action = env.action_space.sample()
        observation, reward, done, info = env.step(action)
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            break

step関数

前回8行目まで見たので、今回は9行目。env.step()にactionを放り込むと、戻り値としていろいろ返ってきている。actionは7行目でランダムな値を生成しているので、ランダムに選択されたacttion(例えば右)を放り込んでいるわけだ。

このstep関数はcore.pyのEnvクラスに記載されている。
https://github.com/openai/gym/blob/master/gym/core.py

def step(self, action):
    self.monitor._before_step(action)
    observation, reward, done, info = self._step(action)
    done = self.monitor._after_step(observation, reward, done, info)
    return observation, reward, done, info

この関数内に書かれている解説をまとめると以下。
(1)環境を1step進める。もしepisodeの終わりに達すれば、reset()を自動的に呼び出す。
(2)引数としてactionオブジェクトをとり、戻り値としてobservation、reward、done、infoを含むタプルを返す。
(3)doneはboolean型でepisodeが終わったか否かを保持。
(4)infoはdictionary型でデバッグ情報など予備の診断情報を保持。

次にstep関数の中身を見ていく。2行目のmonitorはよくわからんが、画面への出力関係か?

3行目は_stepにactionを放り込んで、4つの情報が返ってきている。4行目も画面への出力関係か?5行目で4つの値を返す。

重要なのは3行目の_step関数。これに関しては

 # Override in ALL subclasses
def _step(self, action): raise NotImplementedError

とある。今回はどこかで実装されているのだろうか?

if doneの部分

次に9行目以降のif文。先ほど見たようにdoneはepisodeが終了すればTRUEを返す。よって、このif文はepisodeが終了した時に「〜timestepsで終了しましたよ」と出力するもの。

rewardも出力する

次にrewardがどんな値なのか、出力して確かめてみる。test06.pyの9行目と10行目の間に以下を追加する。

print("reward is "),
print(reward)

出力はこんな感じ。

[ 0.01509218 -0.01367975 -0.00316408  0.01795877]
reward is  1.0
[ 0.01481859 -0.20875618 -0.00280491  0.30964172]
reward is  1.0
[ 0.01064347 -0.40383806  0.00338793  0.60143874]
reward is  1.0
[ 0.0025667  -0.20876366  0.0154167   0.30982487]
reward is  1.0
[-0.00160857 -0.01386472  0.0216132   0.02204353]
reward is  1.0

報酬はずっと1.0となっている。基準がよくわからんが、poleが倒れなかったら毎時刻1.0もらえるのだろうか?

observationにもとづいてactionを操作する

実装の中心は、我々が受け取ったobservationに対してどのようなactionを返すか、という部分だろう。しかしtest06.pyのコードでは、9行目でobservationを受け取っているにも関わらず、それとは関係なく8行目でランダムなactionを渡している。これは面白くない。

そこで、以下のようにobsevationの値を処理し、それをactionに反映させてみる。

import gym
env = gym.make('CartPole-v0')
total_reward_sum = [0] * 20
for i_episode in range(20):
    observation = env.reset()
    total_reward = 0
    for t in range(100):
        env.render()
        print(observation)
        #action = env.action_space.sample()
        sum_obs = observation[0] + observation[1] + observation[2] + observation[3]
        if sum_obs > 0:
            action = 1
        else:
            action = 0
        observation, reward, done, info = env.step(action)
        total_reward += 1
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            print("total reward is"),
            print(total_reward)
            total_reward_sum[i_episode] = total_reward
            if i_episode == 19:
                print("all total reward is"),
                print(total_reward_sum)
            break
        if t == 99:
            print("Episode finished after {} timesteps".format(t+1))
            print("total reward is"),
            print(total_reward)
            total_reward_sum[i_episode] = total_reward
            if i_episode == 19:
                print("all total reward is"),
                print(total_reward_sum)

トータルな報酬の出力をepisode終了時のみでなく、maxの100timestep終了時にも設定したのは、このアルゴリズムが良すぎて、棒が倒れること無く毎回ほぼmaxまで達するため。実際、episode20回分のトータル報酬は以下のようになった。

all total reward is [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 97, 100, 100, 100, 100, 100, 100, 74, 100]

何も考えずに作ったのに、素晴らしいね!ちなみにactionを右と左入れ替えた結果が以下。

all total reward is [8, 9, 9, 9, 10, 9, 9, 10, 9, 9, 10, 9, 9, 9, 8, 10, 10, 9, 9, 9]

めっちゃ早く倒れている。というか倒れるのを助長させてるね。

15
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
13