LoginSignup
10
9

More than 3 years have passed since last update.

【OpenAI Gym】CartPole問題の解答例

Posted at

要約

OpenAI Gymの「CartPole」のQ学習での解答例を共有します。強化学習について学ぶ際の理解の一助になれたらと思っています。ある程度のpythonの知識を有している方を対象としています。

コードだけを見たい場合は、こちらをご参照ください。

経緯・動機

強化学習に辿り着くまで

業務においてAIについて学ぶ必要性があるなかで、教師あり機械学習や教師なし機械学習について、インターネット上の情報やソースコードを元に手元でコーディングをして機械学習を試しても、なかなかきっちり「推定できている」、「分類できている」など実感が湧きませんでした。

私の学習不足も要因として多々あるかと思いますが、なかなか機械学習に関する技術が身につかず色々とどうするかを考えていた時に、AIに関する "とある記事" を見つけました。

それは「負けられないオセロAIが発表され話題になっている」という内容のものでした。そのオセロを試した際に、プレイヤーとして負けようと挑むよりも「どうやって作ったのか」ということに興味が湧きました。

合わせて「自分でもオセロくらいならば作れるのではないか」と考え作ってみることにしました。

この時点ではまずは「オセロを作る」ことに主眼を置いているため、AIをどのように作るかまでは考えていませんでしたが、これが強化学習を学んでいくことのきっかけだったと感じています。

オセロゲームの開発

(開発期間: 2019/8/14~8/16)

オセロゲームのUIはまったく凝っておらず、機能も最小限で作ることとし、内部のロジックをきちんと作ることを目的に開発を進めていきました。

開発にはIonic3を利用しました。

開発したオセロゲームの実際の画面

開発したオセロゲームの実際の画面

まずはオセロの仕様を満たすように機能を実装し、次に相手のロジックについて考えるようになりました。
この時に「とりあえずランダムで」と考え実装してプレイをしてみたところ、とてつもなく弱い相手であることを知り、ランダムとは人間が考えて行動するのとは全く違い、プレイヤーとしては弱いことを知りました。

ランダムが弱いことを実感できたことは強化学習について学んでいく上で一つの大きな経験と感じています。

次にもう少し複雑な環境を作りたくなり、「将棋」も作ろうと思えば作れるのではないかと考えました。
余談ですが、ここでプレイロジックの強化に向かわなかったのは自分の志向がアプリ開発にあるからと思っています。

将棋ゲームの開発

(開発期間: 2019/9/16~10/19)※自動プレイ実装は10/19~11/10

将棋ゲームもUIは全く凝らずに必要な項目やロジックを考え実装しました。将棋は駒の種類が分かれていたり、移動できる方向も分かれていたり、持ち駒があったり、二歩、成るなどの特殊ルールもあり、オセロに比べると開発には大分時間を要しました。

開発にはVue.jsを利用しました。

開発した将棋ゲームの実際の画面

開発した将棋ゲームの実際の画面

一通り将棋ロジックの実装が終わって、ランダムのプレイを実装し後に、今度は自動でプレイする機能の掘り下げを考えました。

この中で強いプレイヤーとするためには将棋という強化学習における「環境」を知る必要があるとわかりました。人工無能として自前でロジックを用意するときに、どのようにしたら強くなるか、ということを考え実装するためには環境に対する知識は必須でした。

実装で工夫した点は、指し手ごとにポイントを付けることでした。「王将」を獲得することを最大ポイントとして、駒の序列に合わせてポイント設定をし、駒を奪われる場合にはマイナスポイントを設定するようにして、全パターンに対してポイントが最大になるものを選ぶようにしました。

ここで気づいたのは、毎回最大のポイントを選んでしまうと、次のように持ち駒がなく「歩」が動いたら駒をとられてしまうような状態になった場合に、「歩」が動かなくなり、結果無限ループになってしまうということです。

最大ポイントを選ぶ行為自体が単調で、結果、強くはならないと解釈しました。

だからランダムのような「探索」をする方法が必要なのだと感じました。
これは強化学習でいうε(イプシロン)グリーディ法が有効であることを示しているのではないかと思います。

ランダムがないと無限ループしてしまう状態

ランダムがないと無限ループしてしまう状態

このようにオセロや将棋の開発を通じて強化学習で大切な次の要素について学びました。

  • 環境
  • 状態
  • 行動
  • 報酬

OpenAI Gymを始めたきっかけ

将棋を作った後にAIを実装することを考え、目標として「自分で考えたロジック以上の強さ」を設定ましたが「では実際に強化学習によるAIを実装して」と考えたときに、どのように実装すればいいのかがわかりませんでした。

単純に考えれば、「状態・全行動可能パターンを受け取って最適な行動パターンを返すAPI」を準備すればよいのだと思いますが、全行動可能パターンのデータは多すぎるので、APIのパラメータには設定したくはないですし、学習するタイミングはどうするか、データベースのテーブル定義はどのようにすればよいか、なにをもって強くするか、自分とプレイしたデータがいいのか、AI同士のデータがいいのか、など不明点が色々とありAIの実行環境を構築するまでには至りませんでした。

そこで強化学習において環境を提供してくれるOpenAI Gymを知り、まずはここで学ぼうと考え、OpenAI Gymでは初歩である「CartPole」について試してみることにしました。

その結果、インターネット上のコードを参考にしつつも自分なりのコードが書けたので、記録をしておこうとして記事を書いています。

環境

  • OS : Microsoft Windows 10 Home
  • プロセッサ : Intel(R) Core(TM) i7-7500U CPU @ 2.70GHz、2904 Mhz、2 個のコア、4 個のロジカル プロセッサ
  • メモリ : 16GB
  • プログラミング言語 : Python3.7
  • IDE : PyCharm 2019.2.4 (Community Edition)

実行環境は機械学習向けの高スペックな端末ではなく一般コンシューマ向けの中で少し良さめなノートPCになります。
OpenAIの環境構築方法は本記事の趣旨とははずれるため、割愛させていただきます。

問題

環境、観察値や問題などの詳細は次のURLに書かれています。

CartPole v0
https://github.com/openai/gym/wiki/CartPole-v0

上記ページに書かれている通りクリアの条件として
「100回の連続試行で平均報酬が195.0以上」
が設定されています。

今回はQ学習を用いてこの条件を満たすコードについて説明をします。

Solved Requirements
Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials.

クラス概要

作成したクラスの概要は次の通りです。

  • EnvCartPole : CartPole-v0 強化学習クラス
    • __init__ : コンストラクタ
    • save_q_text : Qテーブルをファイルに保存する
    • digitize_state : 状態のデジタル化をする
    • get_action : 状態から次のアクションを取得する
    • update_q_table : Qテーブルを更新する
    • run : 学習を実行する
    • finish : 学習を終了する

コード解説

コードについてポイントを解説します。

ライブラリ参照

import gym
import numpy as np
import time
import math
from statistics import mean, median
import matplotlib.pyplot as plt

実行結果をグラフとして視覚的にとらえるために「statics」や「matplotlib」もインポートします。

コンストラクタ

    def __init__(self):
        """
        コンストラクタ
        """
        self.env = gym.make('CartPole-v0')
        # 各種観察値の等差数列を生成
        # pole_velocityの範囲は特に重要
        self.bins = {
            'cart_position': np.linspace(-2.4, 2.4, 3)[1:-1],
            'cart_velocity': np.linspace(-3, 3, 5)[1:-1],
            'pole_angle':    np.linspace(-0.5, 0.5, 5)[1:-1],
            'pole_velocity': np.linspace(-2, 2, 5)[1:-1]
        }
        num_state = (
            (len(self.bins['cart_position']) + 1)
            * (len(self.bins['cart_velocity']) + 1)
            * (len(self.bins['pole_angle']) + 1)
            * (len(self.bins['pole_velocity']) + 1)
        )
        self.q_table = np.random.uniform(low=-1, high=1, size=(num_state, self.env.action_space.n))
        self.q_table_update = [[math.floor(time.time()), 0]] * num_state
        self.save_q_text()
        self.episodes = []
        self.episode_rewards = []
        self.episode_rewards_mean = []
        self.episode_rewards_median = []

学習において1度だけ実行する必要がある処理をまとめています。

変数 : self.env

CartPoleの環境オブジェクトです。
self.env = gym.make('CartPole-v0')により取得します。

変数 : self.bins

状態について区分けするための情報(bins)を色々試せるように個々で値設定ができるようにしました。
[1:-1]で両端(最初と最後の要素)を省いています。

変数 : num_state

状態数を表しています。
状態を数値で分けられるように各種観察値を区分けしたものを使って表しており、その種類数です。

len関数に対して+1をしているのは、self.binsでは区切り位置となる数値を格納しており、+1することで区分け数として扱っているためです。

上記コードでは、「2×4×4×4」で 128 になります。

変数 : self.q_table

Qテーブルで状態毎、行動毎の価値情報を保持しています。

変数 : self.q_table_update

Qテーブルの更新履歴で、更新日時と更新回数を保持しています。

変数 : self.episode_rewards、self.episode_rewards_mean、self.episode_rewards_median

各エピソードにおける「報酬」「平均値」「中央値」の情報を格納します。

Qテーブルをファイルに保存する

    def save_q_text(self):
        """
        Qテーブルをファイルに保存する
        ・q_table_cartpole.txt: Qテーブル情報(行: 状態、列: 行動)
        ・q_table_cartpole_update.txt: Qテーブル更新履歴(行 状態、列: 更新日時UNIXタイムスタンプ、更新回数)
        :return: void
        """
        np.savetxt('q_table_cartpole.txt', self.q_table)
        np.savetxt('q_table_cartpole_update.txt', self.q_table_update, '%.0f')

Qテーブルの情報をファイルとして残すための処理です。
Qテーブルのデータがどれだけ更新できるかを確認するため更新履歴のファイルも用意しました。

状態のデジタル化をする

    def digitize_state(self, observation):
        """
        状態のデジタル化をする
        :param observation: 観察値
        :return: デジタル化した数値
        """
        cart_position, cart_velocity, pole_angle, pole_velocity = observation
        state = (
            np.digitize(cart_position, bins=self.bins['cart_position'])
            + (
                np.digitize(cart_velocity, bins=self.bins['cart_velocity'])
                * (len(self.bins['cart_position']) + 1)
            )
            + (
                np.digitize(pole_angle, bins=self.bins['pole_angle'])
                * (len(self.bins['cart_position']) + 1)
                * (len(self.bins['cart_velocity']) + 1)
            )
            + (
                np.digitize(pole_velocity, bins=self.bins['pole_velocity'])
                * (len(self.bins['cart_position']) + 1)
                * (len(self.bins['cart_velocity']) + 1)
                * (len(self.bins['pole_angle']) + 1)
            )
        )
        return state

各種観察値を元に状態を数値で表します。
イメージとしては、

  • 1の位が「cart_position」
  • 10の位が「cart_velocity」
  • 100の位が「pole_angle」
  • 1000の位が「pole_velocity」

という風に桁で分けているイメージです。
実際は観察値の値は10種類以上あるので単純な10進数ではありませんが、
上記処理で表しているのはこのような桁毎に観察値を割りてて一つの数値にしていることです。

状態から次のアクションを取得する

    def get_action(self, state, episode):
        """
        状態から次のアクションを取得する
        (εグリーディ法使用)
        :param state: 状態値
        :param episode: エピソード数
        :return: 次のアクション
        """
        # 重要
        epsilon = 0.5 * (1 / (episode + 1))
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.q_table[state])
        else:
            action = np.random.choice([0, 1])
        return action

εグリーディ法の考え方でエピソード数を引数として「報酬の最大値」をとるか「ランダム」をとるかを判定します。
これがないと単調な処理になり学習が捗りません。

εグリーディ法について自分の理解が足りないのでもっと学ぼうと考えていますが、
今はこの式で認識をしています。

Qテーブルを更新する

    def update_q_table(self, state, action, reward, next_state):
        """
        Qテーブルを更新する
        :param state: 前の状態
        :param action: 前のアクション
        :param reward: 前のアクションにより獲得した報酬
        :param next_state: 次の状態
        :return: void
        """
        alpha = 0.2     # 学習率
        gamma = 0.99    # 割引率
        max_q_value = max(self.q_table[next_state])
        current_q_value = self.q_table[state, action]
        self.q_table[state, action] = (
            current_q_value
            + alpha
            * (reward + (gamma * max_q_value) - current_q_value)
        )
        self.q_table_update[state] = [math.floor(time.time()), self.q_table_update[state][1] + 1]

wikiにも掲載されているよくあるQ学習の式をプログラムにしたものです。

Q(s_{t},a)\leftarrow Q(s_{t},a)+\alpha \left[r_{{t+1}}+\gamma \max _{p}Q(s_{{t+1}},p)-Q(s_{t},a)\right]

Wikipedia: Q学習
https://ja.wikipedia.org/wiki/Q%E5%AD%A6%E7%BF%92

self.q_tableの元の状態の価値を更新します。
合わせてQテーブルの更新情報も更新します。

学習を実行する

    def run(self, num_episode, num_step_max):
        """
        学習を実行する
        :param num_episode: エピソード数
        :param num_step_max: 最大ステップ数
        :return: True: 課題解決成功、false: 課題解決失敗
        """
        num_solved = 0
        num_solved_max = 0
        self.episodes = []
        self.episode_rewards = []
        self.episode_rewards_mean = []
        self.episode_rewards_median = []

        for episode in range(num_episode):
            observation = self.env.reset()
            state = self.digitize_state(observation)
            action = np.argmax(self.q_table[state])
            episode_reward = 0

            for step in range(num_step_max):
                # if episode % 100 == 0:
                #     self.env.render()
                observation, reward, done, _ = self.env.step(action)
                # 重要
                if done and step < num_step_max - 1:
                    reward -= num_step_max
                episode_reward += reward

                next_state = self.digitize_state(observation)
                self.update_q_table(state, action, reward, next_state)
                action = self.get_action(next_state, episode)
                state = next_state
                if done:
                    break
            print(f'episode: {episode}, episode_reward: {episode_reward}')
            if episode_reward >= 195:
                num_solved += 1
            else:
                num_solved = 0
            if num_solved_max < num_solved:
                num_solved_max = num_solved
            self.episodes.append(episode)
            self.episode_rewards.append(episode_reward)
            self.episode_rewards_mean.append(mean(self.episode_rewards))
            self.episode_rewards_median.append(median(self.episode_rewards))

学習の実行処理です。
メインルーチンの考え方は次の通りです。

順番 処理 コード
1 「状態」を取得 observation = self.env.reset()
2 「状態」を数値化 state = self.digitize_state(observation)
3 「行動」を選択 action = np.argmax(self.q_table[state])
4 「行動」を実行し、実行後の「状態」「報酬」「終了か」を取得 observation, reward, done, _ = self.env.step(action)
5 取得した実行後の「状態」「報酬」から、実行前の「状態」のアクションの『価値』を更新 self.update_q_table(state, action, reward, next_state)
6 次の「行動」を選択 action = self.get_action(next_state, episode)
7 4~6を終わるまで繰り返す -

次のコードのコメントアウトを外すと、
100回おきに実行状況を表示します。

                # if episode % 100 == 0:
                #     self.env.render()

指定回数に到達する前に終わってしまった場合は、
次のようにペナルティとして報酬を減らし、
最終行動の価値が下がるようにします。

                if done and step < num_step_max - 1:
                    reward -= num_step_max

変数 : num_solved

連続解決回数です。失敗(報酬が195未満)するとリセットされます。

変数 : num_solved_max

最大連続解決回数です。

学習を終了する

    def finish(self, num_solved_max):
        """
        学習を終了する
        :param num_solved_max: 最大連続解決数
        :return: True: 課題解決成功、false: 課題解決失敗
        """
        self.env.close()
        self.save_q_text()
        plt.plot(self.episodes, self.episode_rewards, label='reward')
        plt.plot(self.episodes, self.episode_rewards_mean, label='mean')
        plt.plot(self.episodes, self.episode_rewards_median, label='median')
        plt.legend()
        plt.show()
        print(f'num_solved_max: {num_solved_max}')
        return num_solved_max > 100

強化学習実行後の後処理をまとめています。
最大連続解決数が 100 を超えた場合、クリアと判定します。

実行完了後に次のような学習結果のグラフを表示します。

学習結果例

実行結果例

上記グラフからは実行を繰り返す度に次のようになったことより学習が成功したことが分かります。

  • 報酬(reward)が200固定になった
  • 平均値(mean)が右肩上がりするようになった
  • 中央値(median)が200固定になった

学習実行呼び出し

if __name__ == '__main__':
    this = EnvCartPole()
    result = this.run(num_episode=500, num_step_max=200)
    if result:
        print('CLEAR!')
    else:
        print('FAILED...')

「CLEAR!」と出力されれば問題をクリアしたことを表します。
実行回数やパターンが少ないため、500回の学習にかかる時間は5秒程度です。(実行状況表示をしない場合)

引数 : num_episode

学習回数です。
500回で十分クリアできますが、初回のQテーブルのランダムでの作成状況によってはクリアできない場合があります。
その時は再度実行するか、回数を増やすかをすればクリアできることが多いです。

解答例

全コードは次の通りです。

env_cart_pole.py
import gym
import numpy as np
import time
import math
from statistics import mean, median
import matplotlib.pyplot as plt


class EnvCartPole:
    """
    CartPole-v0 強化学習クラス
    @see https://github.com/openai/gym/wiki/CartPole-v0
    """
    def __init__(self):
        """
        コンストラクタ
        """
        self.env = gym.make('CartPole-v0')
        # 各種観察値の等差数列を生成
        # pole_velocityの範囲は特に重要
        self.bins = {
            'cart_position': np.linspace(-2.4, 2.4, 3)[1:-1],
            'cart_velocity': np.linspace(-3, 3, 5)[1:-1],
            'pole_angle':    np.linspace(-0.5, 0.5, 5)[1:-1],
            'pole_velocity': np.linspace(-2, 2, 5)[1:-1]
        }
        num_state = (
            (len(self.bins['cart_position']) + 1)
            * (len(self.bins['cart_velocity']) + 1)
            * (len(self.bins['pole_angle']) + 1)
            * (len(self.bins['pole_velocity']) + 1)
        )
        self.q_table = np.random.uniform(low=-1, high=1, size=(num_state, self.env.action_space.n))
        self.q_table_update = [[math.floor(time.time()), 0]] * num_state
        self.save_q_text()
        self.episodes = []
        self.episode_rewards = []
        self.episode_rewards_mean = []
        self.episode_rewards_median = []

    def save_q_text(self):
        """
        Qテーブルをファイルに保存する
        ・q_table_cartpole.txt: Qテーブル情報(行: 状態、列: 行動)
        ・q_table_cartpole_update.txt: Qテーブル更新履歴(行 状態、列: 更新日時UNIXタイムスタンプ、更新回数)
        :return: void
        """
        np.savetxt('q_table_cartpole.txt', self.q_table)
        np.savetxt('q_table_cartpole_update.txt', self.q_table_update, '%.0f')

    def digitize_state(self, observation):
        """
        状態のデジタル化をする
        :param observation: 観察値
        :return: デジタル化した数値
        """
        cart_position, cart_velocity, pole_angle, pole_velocity = observation
        state = (
            np.digitize(cart_position, bins=self.bins['cart_position'])
            + (
                np.digitize(cart_velocity, bins=self.bins['cart_velocity'])
                * (len(self.bins['cart_position']) + 1)
            )
            + (
                np.digitize(pole_angle, bins=self.bins['pole_angle'])
                * (len(self.bins['cart_position']) + 1)
                * (len(self.bins['cart_velocity']) + 1)
            )
            + (
                np.digitize(pole_velocity, bins=self.bins['pole_velocity'])
                * (len(self.bins['cart_position']) + 1)
                * (len(self.bins['cart_velocity']) + 1)
                * (len(self.bins['pole_angle']) + 1)
            )
        )
        return state

    def get_action(self, state, episode):
        """
        状態から次のアクションを取得する
        (εグリーディ法使用)
        :param state: 状態値
        :param episode: エピソード数
        :return: 次のアクション
        """
        # 重要
        epsilon = 0.5 * (1 / (episode + 1))
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.q_table[state])
        else:
            action = np.random.choice([0, 1])
        return action

    def update_q_table(self, state, action, reward, next_state):
        """
        Qテーブルを更新する
        :param state: 前の状態
        :param action: 前のアクション
        :param reward: 前のアクションにより獲得した報酬
        :param next_state: 次の状態
        :return: void
        """
        alpha = 0.2     # 学習率
        gamma = 0.99    # 割引率
        max_q_value = max(self.q_table[next_state])
        current_q_value = self.q_table[state, action]
        self.q_table[state, action] = (
            current_q_value
            + alpha
            * (reward + (gamma * max_q_value) - current_q_value)
        )
        self.q_table_update[state] = [math.floor(time.time()), self.q_table_update[state][1] + 1]

    def run(self, num_episode, num_step_max):
        """
        学習を実行する
        :param num_episode: エピソード数
        :param num_step_max: 最大ステップ数
        :return: True: 課題解決成功、false: 課題解決失敗
        """
        num_solved = 0
        num_solved_max = 0
        self.episodes = []
        self.episode_rewards = []
        self.episode_rewards_mean = []
        self.episode_rewards_median = []

        for episode in range(num_episode):
            observation = self.env.reset()
            state = self.digitize_state(observation)
            action = np.argmax(self.q_table[state])
            episode_reward = 0

            for step in range(num_step_max):
                # if episode % 100 == 0:
                #     self.env.render()
                observation, reward, done, _ = self.env.step(action)
                # 重要
                if done and step < num_step_max - 1:
                    reward -= num_step_max
                episode_reward += reward

                next_state = self.digitize_state(observation)
                self.update_q_table(state, action, reward, next_state)
                action = self.get_action(next_state, episode)
                state = next_state
                if done:
                    break
            print(f'episode: {episode}, episode_reward: {episode_reward}')
            if episode_reward >= 195:
                num_solved += 1
            else:
                num_solved = 0
            if num_solved_max < num_solved:
                num_solved_max = num_solved
            self.episodes.append(episode)
            self.episode_rewards.append(episode_reward)
            self.episode_rewards_mean.append(mean(self.episode_rewards))
            self.episode_rewards_median.append(median(self.episode_rewards))

        return self.finish(num_solved_max)

    def finish(self, num_solved_max):
        """
        学習を終了する
        :param num_solved_max: 最大連続解決数
        :return: True: 課題解決成功、false: 課題解決失敗
        """
        self.env.close()
        self.save_q_text()
        plt.plot(self.episodes, self.episode_rewards, label='reward')
        plt.plot(self.episodes, self.episode_rewards_mean, label='mean')
        plt.plot(self.episodes, self.episode_rewards_median, label='median')
        plt.legend()
        plt.show()
        print(f'num_solved_max: {num_solved_max}')
        return num_solved_max > 100


if __name__ == '__main__':
    this = EnvCartPole()
    result = this.run(num_episode=500, num_step_max=200)
    if result:
        print('CLEAR!')
    else:
        print('FAILED...')

最後に

掲載されている一部素材は、次のサイト様のものを使用しております。

しんえれ外部駒 様
http://shineleckoma.web.fc2.com/page1.htm

掲載したコードは、次のサイト様の情報を参考にさせていただきました。

Pythonを使って強化学習をする方法を徹底解説
https://ai-kenkyujo.com/2019/05/24/python/

今回作成したクラスは他の環境でも同じように使えるものと思いますので、これを活用して他の環境でも同様に強化学習を試してより知識と経験を深めようと思います。

以上です。

この記事がAIや強化学習を学びたいと思っている方にとって有用な情報になれば幸いです。

10
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
9