LoginSignup
3
1

More than 5 years have passed since last update.

OpenAIGymのCartPoleをCartほむほむ化する

Last updated at Posted at 2019-04-12

痛RL環境を作りたい

強化学習(Reinforcement Learning, RL)を使って研究をしているとき,
研究や仕事のモチベを上げるために痛エディタがあるなら
痛RL環境があってもいいじゃない.
と思いたったので作ってみた.
natm.gif
かわいい.

OpenAIGymのCartPoleのPole部分を画像に差し替える

OpenAIGymのパッケージのフォルダを開いて
gym/envs に MyEnv フォルダを作る.
中身は以下のとおり.

gym/envs/MyEnv
 ├ __init__.py
 ├ cartpole_img.py
 └ assets/
  └ pole_img.png

gym/envs/MyEnv/cartpole_img.py の中身

CartPoleEnv_Img クラスを作り,
既存の gym/envs/classic_control/cartpole.py の中にある CartPoleEnv クラスの
render メソッドをコピペしてきてちょっと変える.
ちなみに画像読み込み周りの処理は gym/envs/classic_control/pendulum.py からのコピペ.

gym/envs/MyEnv/cartpole_img.py
from os import path
from gym.envs.classic_control.cartpole import CartPoleEnv

class CartPoleEnv_Img(CartPoleEnv):     # 既存のCartPoleEnvクラスを継承
    def render(self, mode='human'):
        screen_width = 600
        screen_height = 400

        world_width = self.x_threshold*2
        scale = screen_width/world_width
        carty = 100 # TOP OF CART
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(screen_width, screen_height)
            l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
            axleoffset =cartheight/4.0
            cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
            self.carttrans = rendering.Transform()
            cart.add_attr(self.carttrans)
            self.viewer.add_geom(cart)
            l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2

            # Pole部分を作る元々の処理をコメントアウト
            # pole = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
            # pole.set_color(.8,.6,.4)

            # 改変部分
            fname = path.join(path.dirname(__file__), "assets/pole_pic.png")    # 画像指定
            scale = 200.
            pole = rendering.Image(fname, scale, scale*2.)  # 画像の大きさ指定
            pole.set_color(1, 1, 1)     # 画像の色指定

            self.poletrans = rendering.Transform(translation=(0, axleoffset))
            pole.add_attr(self.poletrans)
            pole.add_attr(self.carttrans)
            self.viewer.add_geom(pole)
            self.axle = rendering.make_circle(polewidth/2)
            self.axle.add_attr(self.poletrans)
            self.axle.add_attr(self.carttrans)
            self.axle.set_color(.5,.5,.8)
            self.viewer.add_geom(self.axle)
            self.track = rendering.Line((0,carty), (screen_width,carty))
            self.track.set_color(0,0,0)
            self.viewer.add_geom(self.track)

            self._pole_geom = pole

        if self.state is None: return None

        # Edit the pole polygon vertex
        pole = self._pole_geom
        l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2
        pole.v = [(l,b), (l,t), (r,t), (r,b)]

        x = self.state
        cartx = x[0]*scale+screen_width/2.0 # MIDDLE OF CART
        self.carttrans.set_translation(cartx, carty)
        self.poletrans.set_rotation(-x[2])

        return self.viewer.render(return_rgb_array = mode=='rgb_array')

gym/envs/MyEnv/__init__.py の中身

作ったクラスをインポート.

gym/envs/MyEnv/__init__.py
from gym.envs.MyEnv.cartpole_img import CartPoleEnv_Img

gym/envs/MyEnv/assets/pole_pic.png

pole_pic.png には好きな画像を用意すればいいが,
棒の支点にあたる部分が中心に来るような画像にしておく.
筆者はAzPainterを使って作った.
キャプチャ.PNG

自作環境を使えるようにする

一つ上の階層 gym/envs/ にある __init__.py に,
自作環境の情報を追加しておく.

gym/envs/__init__.py
~~~省略~~~

# Classic
# ----------------------------------------

register(
    id='CartPole-v0',
    entry_point='gym.envs.classic_control:CartPoleEnv',
    max_episode_steps=200,
    reward_threshold=195.0,
)

register(
    id='CartPole-v1',
    entry_point='gym.envs.classic_control:CartPoleEnv',
    max_episode_steps=500,
    reward_threshold=475.0,
)

# ここに追加
register(
    id='CartPoleImg-v0',
    entry_point='gym.envs.MyEnv:CartPoleEnv_Img',
    max_episode_steps=200,
    reward_threshold=195.0,
)

register(
    id='MountainCar-v0',
    entry_point='gym.envs.classic_control:MountainCarEnv',
    max_episode_steps=200,
    reward_threshold=-110.0,
)

register(
    id='MountainCarContinuous-v0',
    entry_point='gym.envs.classic_control:Continuous_MountainCarEnv',
    max_episode_steps=999,
    reward_threshold=90.0,
)

register(
    id='Pendulum-v0',
    entry_point='gym.envs.classic_control:PendulumEnv',
    max_episode_steps=200,
)

register(
    id='Acrobot-v1',
    entry_point='gym.envs.classic_control:AcrobotEnv',
    reward_threshold=-100.0,
    max_episode_steps=500,
)

~~~省略~~~

これで,env = gym.make("CartPoleImg-v0") できるようになる.

一応

版権絵を使うときは個人利用に留めましょう.
怒られても筆者は責任を負わない.

参考ページ

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1