3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

強化学習:価値反復法(Value Iteration)を高速化する

Posted at

はじめに

近年、AlphaGoやDQNの成功からモデルフリーの深層強化学習の分野が盛んに研究されています。これらのアルゴリズムは、状態行動空間が事情に大きい場合やダイナミクスの数理モデル化が困難である場合に有効なアプローチの一つです。しかし、現実に遭遇する問題の中には、環境の数理モデル化が比較的容易で、状態行動空間も工夫次第で削減できるような場合も多いのではないでしょうか。このような問題に対しては、モデルベースのテーブル強化学習を用いる方が、開発・運用コストの面からもメリットが大きいと思います。

ただし、テーブル強化学習では扱える状態行動空間の大きさはプログラムの速度に大きく依存しており、高速化は非常に重要です。そこで、本稿では、強化学習の基礎的なアルゴリズムである 価値反復法 を高速に実行するためのノウハウを紹介します。最終的にはナイーブな実装と比べて、500倍程度の高速化 が達成できました。

背景

マルコフ決定過程

マルコフ決定過程(Markov Decision Processes : MDP)とは、強化学習の問題設定で用いられるフレームワークです。「環境」は各時刻においてある状態$s$を取り、意思決定を行う「エージェント」はその状態において利用可能な行動$a$を任意に選択します。 その後環境はランダムに新しい状態へと遷移し、その際にエージェントは状態遷移に対応した報酬$r$を受け取ります。
MDPにおける基本的な問題設定は、ある状態においてエージェントが取る最適な行動へのマッピング(確率分布)である方策を求めることです。そして、目的関数を割引累計報酬とすると、次のような最適価値関数を求める問題に帰着します。

\pi^* = \text{arg}\max_{\pi}  \text{E}_{s \sim \mu, \tau \sim \pi}[\sum _t \gamma^t r(s_t, a_t)|s_0=s] = \text{arg}\max_{\pi}  \text{E}_{s \sim \mu} [V_{\pi}(s)]

状態価値関数$V(s)$および行動価値関数$Q(s, a)$をそれぞれ次のように定義します。

V_{\pi}(s) = \text{E}_{\tau \sim \pi}[\sum _t \gamma^t r(s_t, a_t)|s_0=s] \\
Q_{\pi}(s, a) = \text{E}_{\tau \sim \pi}[\sum _t \gamma^t r(s_t, a_t)|s_0=s, a_0=a]

価値反復法

最適方策$\pi^$における状態価値関数を $V^(s) = V_{\pi^}(s)$、行動価値関数を $Q^(s,a) = Q_{\pi^*}(s,a)$と定義します。価値関数の定義から、最適価値関数は次のベルマン方程式を満たすことがわかります。

Q^*(s, a) = \sum_{s', r} p(s', r|s, a) [r + \gamma V^*(s')] \\
V^*(s) = \max _a Q^*(s, a)

価値反復法(Value Iteration)は、適当な初期値から始め、ベルマン方程式を繰り返し適応して$V$と$Q$を交互に更新していくアルゴリズムです。最悪計算量は状態数と行動数に対して多項式となっています。
行動価値関数$Q^*(s,a)$を求めることができれば、状態$s$において取り得る行動$a$の内、最も高い行動価値のものを選ぶことで最適な方策を求めることができます。

\pi(a|s) = \text{arg}\max _a Q^*(s,a)

アルゴリズム

実験準備

価値反復法はアイデアとしては非常にシンプルですが、実装としては様々な方式が考えられます。ここではそのうち、いくつかの実装とその処理速度を具体的に紹介したいと思います。
各実装の比較のため、実験用のMDPを作ります。実装の容易さから、決定論的なMDP、つまりある行動$a$に対して、報酬$r$と次状態$s'$が確定的に決定するような世界を考えます。このときベルマン方程式は単純化され、次のようになります。

Q^*(s, a) = r + \gamma V^*(s')

決定論的なMDPは状態をノード、行動をエッジ、報酬をエッジの重み(属性)とするグラフで表すことができます。次の関数で、実験で使用するMDPを作成します。

import networkx as nx
import random

def create_mdp(num_states, num_actions, reward_ratio=0.01, neighbor_count=30):
    get_reward = lambda: 1.0 if random.random() < reward_ratio else 0.0
    get_neighbor = lambda u: random.randint(u - neighbor_count, u + neighbor_count) % (num_states - 1)
    edges = [
        (i, (i + 1) % (num_states - 1), get_reward())
        for i in range(num_states)
    ]
    for _ in range(num_states * (num_actions - 1)):
        u = random.randint(0, num_states - 1)
        v = get_neighbor(u)
        r = get_reward()
        edges.append((u, v, r))
    G = nx.DiGraph()
    G.add_weighted_edges_from(edges)
    return G

状態数と行動数(各状態の平均的な行動数)を指定し、ランダムな強連結グラフとスパースな報酬を生成しています。networkxのnodeが状態、edgeが行動、edgeのweight属性が報酬ととするDiGraphで表現されています。

今後の実験では、状態数が10000で各状態数の行動数が平均3程度のMDPを用いていきます。

num_states = 10000
num_actions = 3
G = create_mdp(num_states, num_actions)

ナイーブな価値反復法の実装

最も単純なアルゴリズムは、「全状態の状態価値の更新」と「全行動の行動価値の更新」を繰り返す方法で、同期的動的計画法(Synchronous Dynamic Programming)とも呼ばれることがあります。

class NonConvergenceError(Exception):
    pass

class SyncDP:
    
    def __init__(self, G, gamma, max_sweeps, threshold):
        self.G = G
        self.gamma = gamma
        self.max_sweeps = max_sweeps
        self.threshold = threshold
        self.V = {state : 0 for state in G.nodes}
        self.TD = {state : 0 for state in G.nodes}
        self.Q = {(state, action) : 0 for state, action in G.edges}

    def get_reward(self, s, a):
        return self.G.edges[s, a]['weight']

    def sweep(self):
        for state in self.G.nodes:
            for action in self.G.successors(state):
                self.Q[state, action] = self.get_reward(state, action) + self.gamma * self.V[action]
        for state in self.G.nodes:
            v_new = max([self.Q[state, action] for action in self.G.successors(state)])
            self.TD[state] = abs(self.V[state] - v_new)
            self.V[state] = v_new

    def run(self):
        for _ in range(self.max_sweeps):
            self.sweep()
            if (np.array(list(self.TD.values())) < self.threshold).all():
                return self.V
        raise NonConvergenceError

パラメータである割引率gamma、収束閾値threshold、最大スイープ数は、本来、アプリケーションの要件などから決まってくるものですが、ここでは適当な値を設定します。

gamma = 0.95
threshold = 0.01
max_sweeps = 1000

この条件で実行すると処理時間は次のようになりました。

%timeit V = SyncDP(G, gamma, max_sweeps, threshold).run()
8.83 s ± 273 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

状態数10000でこれだけの時間がかかってしまうと、利用できるアプリケーションはかなり限られてしまいそうです。この処理時間が今後の改善のベースラインとなります。

Asynchronous Dynamic Programming (Async DP)

同期的な価値更新アルゴリズムは一回のスイープで全ての状態の更新が発生するため、状態数が非常に多い場合には一回のスイープだけでもかなりの時間を要してしまいます。非同期DPはインプレースで反復的に状態価値を更新します。つまり、同期DPのように毎回更新済みの値を格納するための新しい配列を用意し、全ての状態価値を更新して新しい配列に格納するのではなく、更新のたびにそれまでに計算して利用できる別の状態価値を活用して、価値更新を繰り返します。収束を担保するためには全ての状態を更新し続ける必要がありますが、更新順序は自由に選ぶことが可能です。例えば、最適な方策と関連性が低い状態をスキップするなどにより価値更新を高速化することができます。

Prioritized Sweeping

非同期DPの場合、価値更新の順序は任意に決めることができます。価値更新の際、全ての状態が等しく他の状態の価値更新で有用であるわけではなく、いくつかの状態は他の状態価値に大きな影響を与えるということが予想されます。例えば、今考えているようなスパースな報酬が得られるMDPでは、報酬が得られる状態を効率良く他の状態へ伝搬させていくことが重要だと言えます。そこで、優先度付きキューを使った次のようなアルゴリズムが考えられます。

  1. 優先度付きキューで全ての状態において価値更新による変化量を管理する
  2. キューのトップにある状態の価値を更新する
  3. 前回の価値更新からの変化量が閾値を超えていたらキューに状態と変化量のペアをプッシュする

このアルゴリズムはPrioritized Sweepingと呼ばれています。実装は次のようになります。

import heapq

class PrioritizedDP(SyncDP):
    def run(self):
        self.sweep()
        pq = [
            (-abs(td_error), state)
            for state, td_error in self.TD.items()
            if abs(td_error) > self.threshold
            ]
        heapq.heapify(pq)
        while pq:
            _, action = heapq.heappop(pq)
            if self.TD[action] < self.threshold:
                continue
            self.TD[action] = 0
            for state in self.G.predecessors(action):
                self.Q[state, action] = self.get_reward(state, action) + self.gamma * self.V[action]
                v_new = max([self.Q[state, action] for action in self.G.successors(state)])
                td_error = abs(v_new - self.V[state])
                self.TD[state] += td_error
                if td_error > self.threshold:
                    heapq.heappush(pq, (-td_error, state))
                self.V[state] = v_new
        return self.V

はじめにsweep関数で全状態の価値更新を実行して、その内TD誤差が閾値を超えた状態をもとにヒープを構成します。あとは、キューが無くなるまで更新を繰り返します。まず、キューから取り出した状態を次の状態(=action)とするような行動価値$Q$を更新します。そして、更新した行動価値に依存する状態価値を$V$を更新します。更新前と更新後の差(td_error)が閾値を超えていればキューにプッシュします。処理時間は以下のようになり、2倍程度高速化が達成できました。

%timeit V = PrioritizedDP(G, gamma, max_sweeps, threshold).run()
4.06 s ± 115 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

ベクトル演算化

これまでnetworkxのオブジェクトを主としたアルゴリズムを実装してきましたが、successorsなどの頻繁に呼ばれるメソッドが大部分の時間を要しています。データ構造をより効率化しつつ、numpyによるベクトル演算を活用する方法を考えます。グラフをnumpy配列で表現することで、次のように行動価値の演算をベクトル演算を用いることができるようになります。

class ArraySyncDP:
    
    def __init__(self, A : ArrayGraph, gamma, max_sweeps, threshold):
        self.A = A
        self.gamma = gamma
        self.max_sweeps = max_sweeps
        self.threshold = threshold
        self.V = np.zeros(A.num_states)
        self.TD = np.full(A.num_states, np.inf)
        self.Q = np.zeros(A.num_actions)

    def run(self):
        for _ in range(self.max_sweeps):
            self.Q[:] = self.A.reward + self.gamma * self.V[self.A.action2next_state]
            for state_id in range(self.A.num_states):
                start, end = self.A.state2action_start[state_id], self.A.state2action_start[state_id + 1]
                v_new = self.Q[start : end].max()
                self.TD[state_id] = abs(self.V[state_id] - v_new)
                self.V[state_id] = v_new

            if (self.TD < self.threshold).all():
                return self.V
        raise NonConvergenceError
%timeit V = ArraySyncDP(A, gamma, max_sweeps, threshold).run()
3.5 s ± 99.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Prioritized Sweepingに比べて僅かに高速化されましたが、それほど効果はありませんでした。全てをベクトル演算化できておらず、まだ状態価値の更新でfor分を使っていることが主な原因のようです。

Cython

本来、メモリ上に散らばるオブジェクトを参照するリストに比べて、配列は基本データ型をメモリ上の連続領域に保持するので、高速にアクセスすることができるはずです。しかし、Pythonでは、配列の個々の要素を参照するためにPythonオブジェクトへの変換が行われ、オーバーヘッドがかかってしまうため、リストや辞書よりも要素アクセスが遅くなってしまうようです。

そこでCythonを使って、配列アクセスを高速化することを考えます。Cythonは型注釈のあるPythonをコンパイルされた拡張モジュールに変換するコンパイラです。変換した拡張モジュールは、通常のPythonモジュールと同様にimportで読み込むことができます。Cythonを使えば、以下の理由から、numpy配列をインターフェースとして、要素アクセスが発生する処理の高速化が行えそうです。

  • 配列を順番に処理する場合には、メモリ領域が連続しているので、いちいちPythonに次の要素のアドレスを尋ねずに、次の要素のアドレスを直接求めるようにコンパイラに指示することができる
  • numpy配列はmemoryviewという汎用のバッファインターフェースを実装しているオブジェクトとして低水準でアクセスすることができる
  • 従って、メモリ領域を簡単にCのライブラリと共有することができ、Pythonオブジェクトを他の形式に変換する必要がなくなる

Cythonで非同期DPを実装します。優先度付きキューを使いたいところですが、Pythonのオブジェクトの処理が発生してしまうため、TD誤差が小さい状態価値更新のスキップのみを実行しています。

%%cython
import numpy as np
cimport numpy as np
cimport cython

ctypedef np.float64_t FLOAT_t
ctypedef np.int64_t INT_t

@cython.boundscheck(False)
@cython.wraparound(False)
def cythonic_backup(
    FLOAT_t[:] V, FLOAT_t[:] Q, FLOAT_t[:] TD, FLOAT_t[:] reward,
    INT_t[:] state2action_start, INT_t[:] action2next_state,
    INT_t[:] next_state2inv_action, INT_t[:, :] inv_action2state_action,
    FLOAT_t gamma, FLOAT_t threshold
):
    cdef INT_t num_updates, state, action, next_state, inv_action, start_inv_action, end_inv_action, start_action, end_action
    cdef FLOAT_t v
    num_updates = 0
    for next_state in range(len(V)):
        if TD[next_state] < threshold:
            continue
            
        num_updates += 1
        TD[next_state] = 0
        start_inv_action = next_state2inv_action[next_state]
        end_inv_action = next_state2inv_action[next_state + 1]
        for inv_action in range(start_inv_action, end_inv_action):
            state = inv_action2state_action[inv_action][0]
            action = inv_action2state_action[inv_action][1]
            Q[action] = reward[action] + gamma * V[next_state]
            start_action = state2action_start[state]
            end_action = state2action_start[state + 1]
            v = -1e9
            for action in range(start_action, end_action):
                if v < Q[action]:
                    v = Q[action]
            if v > V[state]:
                TD[state] += v - V[state]
            else:
                TD[state] += V[state] - v
            V[state] = v
    return num_updates
class CythonicAsyncDP(ArraySyncDP):
    def run(self):
        A = self.A
        for i in range(self.max_sweeps):
            num_updates = cythonic_backup(
                self.V, self.Q, self.TD, A.reward, A.state2action_start, A.action2next_state,
                A.next_state2inv_action, A.inv_action2state_action, self.gamma, self.threshold
            )
            if num_updates == 0:
                return self.V
        raise NonConvergenceError
%timeit V = CythonicAsyncDP(A, gamma, max_sweeps, threshold).run()
18.6 ms ± 947 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

18.6ms と劇的な高速化が達成できました。この処理速度であれば、利用用途にもよりますが、状態数が100万程度の問題でも十分対応可能です。

参考文献

[1] ハイパフォーマンスPython, Micha Gorelick and Ian Ozsvald, オライリージャパン, 2015.
[2] Reinforcement Learning, R. S. Sutton and A. G. Barto, The MIT Press, 2018.

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?