はじめに
こんにちは!thunderです。
著書 ゲームで学ぶ探索アルゴリズム実践入門 ~木探索とメタヒューリスティクス を2月に販売開始してから約3カ月の月日が流れました。早いものですね。
さて、「ゲームで学ぶ~」の読者からは「サンプルコードがC++なのが難しい」という意見をよく聞きます。
実は、本書の執筆時にはPythonのサンプルコードも公開する構想もありました。
ところが、Pythonで実装したコードはあまりにも遅すぎて役に立たなそうだったためボツにしたという経緯があります。
また、探索アルゴリズムは実社会問題を最適化するのにも強力な武器となります。実社会問題を解く際、pandasやjupyterなど豊富なライブラリ、ツールが手軽に利用できる点からもPythonが有力候補となります。厳密解法を試してうまくいかないことがわかった後に探索アルゴリズムに頼ることも多いため、ソルバのPuLPやGurobiのためにPythonを利用するのも自然な流れでしょう。
要約すると「Pythonで探索アルゴリズムを実装するのは速度面で不利」かつ「Pythonで探索アルゴリズムを利用する需要がある」という相反する事象が発生しています。
そこで、C++で実装した探索アルゴリズムライブラリをPythonから呼び出せばいいのでは?という結論に至りました。
それ、つくりました。やったー!!解決です。
前提知識
以下の内容を概要だけでもいいので知っている方に向けて記事を書いています。
- ビームサーチ
- クラスの継承
- オーバーライド
※記事中に著書「ゲームで学ぶ~」について言及することがしばしばありますが、こちらは必ずしも読んでいる必要はありません。
インストール方法
pipが使える状態で以下のコマンドを実行します。かんたん!やったー!!
pip install thunsearch
警告
先述のgithubのリポジトリはthunとsearchの間に_がある"thun_search"ですが、モジュール名はアンダーバーのない"thunsearch"です。
警告
v0.0.1ではUbuntuでのみ動作確認をしています。
manylinux2014にてビルドしているので、Ubuntu以外でもLinux環境ではPyPIからインストールできるはずです。
pip install thunsearch
でうまくインストールできない場合、https://github.com/thun-c/thun_search からリポジトリをクローンし、ルートフォルダでpip install .
とするとインストールできるかもしれません。
現状、M1 Macにてこの方法でインストールできたという報告があります。
Windowsユーザの方にはWSL上での利用を強く推奨します。
使い方
thunsearchは、問題の種別ごとに用意されたベースクラスと探索アルゴリズム関数をペアで使用します。
現状のv0.0.1では文脈のある最適化問題1に対応するベースクラス BaseContextualState
と探索アルゴリズム beam_search_action
が 実装済みです。
以下に探索アルゴリズム関数を呼び出すまでの流れを説明していきます。
使い方概要
- ベースクラス (v0.0.1はBaseContextualStateのみ) を継承したサブクラスを定義する
- サブクラス側でベースクラスの特定のメソッドをオーバーライドして実装する
- サブクラスのインスタンスを引数に探索アルゴリズム (beam_search_actionなど)を呼ぶと、戻り値として解が得られる。
サンプル問題の紹介
さて、ここからは具体的な問題で説明していきます。
たとえば、以下の「罠あり数字集め迷路」を考えます。
見出し | 説明 |
---|---|
目的 | 終了時点のスコアを高くする。 |
着手タイミング | 1ターンに1回 |
着手するタスク | 各ターン、キャラクター(@)を上下左右の四方向いずれかの場所に1マス移動させる。立ち止まることや、盤面の外に移動させることはできない。 |
終了条件 | 4ターン経過する。またはトラップマス(X) にキャラクターが侵入する。 |
その他 | キャラクターはランダムに初期配置される。 キャラクターが移動した先にポイントがある場合、そのポイントの値をゲームスコアに加算し、床のポイントは消失する。 |
例えば、以下のような初期盤面が与えられたとします。
右、右、下、右に移動すると、4ターン経過したのでタスクを終了します。この場合、床のポイント7,7,3,1を取得してスコアは18になります。
1ターン目で下に移動した場合、キャラクターが罠を踏むので4ターンを待たずにここでタスクを終了します。床のポイントを取得していないのでスコアは0です。
この問題をthunsearchを用いて解いていく過程を説明します。
ベースクラスのメソッド用ラベル
まず、ベースクラスを継承して罠あり数字集め迷路を表現するサブクラスを定義していきます。
ベースクラスのメソッドにはmust
,should
,can
の3種類のラベルが付与されています。
サブクラスではベースクラスのメソッドをオーバーライドして実装する必要がありますが、どのメソッドをオーバーライドすべきかは以下のラベル表を見ながら選択しましょう。
ラベル名 | 説明 | 実装しなかった場合の挙動 |
---|---|---|
must | 実装しなければならないメソッド | インスタンスを生成できない |
should | 実装をしたほうがいいメソッド | インスタンスは生成できるが、一部の関数がエラー停止する |
can | 実装してもよい関数 | ベースクラス側で実装されたメソッドがそのまま実行される |
以下のように、サブクラスを実装している途中でget_not_implemented_must_methods
,get_not_implemented_should_methods
,get_not_implemented_can_methods
を実行するとそのサブクラスで未実装の該当ラベルのメソッドのセットを取得できます。
import thunsearch as thun
class MazeState(thun.BaseContextualState):
pass
print("not_implemented_must",
MazeState.get_not_implemented_must_methods())
print("not_implemented_should",
MazeState.get_not_implemented_should_methods())
print("not_implemented_can",
MazeState.get_not_implemented_can_methods())
結果は以下のようになります。
この場合、少なくともmust
ラベルのついたis_done
,legal_actions
,advance
,evaluate_score
を実装しなければStateAのインスタンスを生成できません。
not_implemented_must {'is_done', 'legal_actions', 'advance', 'evaluate_score'}
not_implemented_should set()
not_implemented_can {'is_dead', '__str__', 'clone'}
試しにこの状態でMazeStateのインスタンスを生成してみましょう。
import thunsearch as thun
class MazeState(thun.BaseContextualState):
pass
state = MazeState()
このように、未実装のmust
メソッドの一覧と共にNotImplementedErrorを出力してプログラムが停止します。
正しく実装したつもりなのにエラーが出た場合、エラーメッセージ内の未実装のmust
メソッド一覧と自分の実装を照らし合わせて足りないメソッドを探しましょう。
Traceback (most recent call last):
~略~
raise NotImplementedError(
NotImplementedError: must functions are not implemented.
[advance , evaluate_score , is_done , legal_actions]
コンストラクタの実装
まずはコンストラクタを実装します。ベースクラス側のコンストラクタsuper().__init__()
を呼ぶのを忘れずにしましょう。
class MazeState(thun.BaseContextualState):
dy = [0, 0, 1, -1] # 右、左、下、上への移動方向のy成分
dx = [1, -1, 0, 0] # 右、左、下、上への移動方向のx成分
H = 3
W = 4
END_TURN = 4
INF = 1000000000
def __init__(self, seed=None) -> None:
super().__init__()
self.turn_ = 0
self.points_ = [[0 for w in range(MazeState.W)]
for h in range(MazeState.H)]
self.character_ = Coord(0, 0)
self.trap_ = Coord(0, 0)
self.task_score_ = 0
if seed is not None:
random.seed(seed)
self.character_.y_ = random.randrange(MazeState.H)
self.character_.x_ = random.randrange(MazeState.W)
while self.character_ == self.trap_:
self.trap_.y_ = random.randrange(MazeState.H)
self.trap_.x_ = random.randrange(MazeState.W)
for y in range(MazeState.H):
for x in range(MazeState.W):
if (y, x) == (self.character_.y_, self.character_.x_):
continue
if (y, x) == (self.trap_.y_, self.trap_.x_):
continue
self.points_[y][x] = random.randrange(10)
"must" メソッドのオーバーライド
BaseContextualState
の"must" メソッドは以下の通りです。
これらのメソッドはどの問題を解く時も必ず実装しなければなりません。
メソッド | 説明 |
---|---|
advance(self, action: int) -> None | 指定したactionでタスクを進める |
legal_actions(self) -> List[int] | 合法手の一覧を取得する |
is_done(self) -> bool | タスクが正常終了したかどうかを取得する |
evaluate_score(self) -> float | 現在の状態を評価して取得する |
まずはadvance
を実装します。
罠あり数字集め迷路ではキャラクターを指定方向に移動させ、床にポイントがあれば取得し、self.task_scoreを更新します。
class MazeState(thun.BaseContextualState):
def advance(self, action):
self.character_.y_ += MazeState.dy[action]
self.character_.x_ += MazeState.dx[action]
if self.points_[self.character_.y_][self.character_.x_] > 0:
self.task_score_ += self.points_[
self.character_.y_][self.character_.x_]
self.points_[self.character_.y_][self.character_.x_] = 0
self.turn_ += 1
return
legal_actions
を実装します。
罠あり数字集め迷路では場外に出ない行動のリストを取得します。
class MazeState(thun.BaseContextualState):
def legal_actions(self):
actions = []
for action in range(4):
ty = self.character_.y_+MazeState.dy[action]
tx = self.character_.x_+MazeState.dx[action]
if ty >= 0 and ty < MazeState.H and tx >= 0 and tx < MazeState.W:
actions.append(action)
return actions
is_done
を実装します。
罠あり数字集め迷路では4ターン経過したかどうかを取得します。
罠あり数字集め迷路ではトラップを踏んだタイミングでもゲームが終了しますが、ここでは判定しません。
これは後程is_dead
の実装で説明します。
class MazeState(thun.BaseContextualState):
def is_done(self):
return self.turn_ == MazeState.END_TURN
evaluate_score
を実装します。
罠あり数字集め迷路ではタスクのスコアをそのまま評価値として取得します。2
class MazeState(thun.BaseContextualState):
def evaluate_score(self) -> float:
return self.task_score_
"can" メソッドのオーバーライド
BaseContextualState
の"can" メソッドは以下の通りです。
これらのメソッドはベースクラス側にデフォルト処理が実装済みのため、自分で実装しなくても実行可能です。
問題の特性や用途に応じて実装すべきかどうか判断しましょう。
メソッド | 説明 | デフォルト処理 |
---|---|---|
is_dead(self) -> bool | タスクが異常終了したかどうかを取得する | 常にFalseを返す |
__str__(self) -> str | 現在の状態を文字列にする | self.__dict__の中身をstrにして連結して返す |
clone(self) | インスタンスをコピーする。 | self.__dict__の中身を全てdeepcopyする |
今回、is_done
の方で全ての終了条件を記述しなかったのでis_dead
も実装します。
is_done
は正常終了、
is_dead
は異常終了を判定します。
罠あり数字集め迷路の場合、ターン経過による終了は正常終了、罠を踏むことによる終了は異常終了とみなせます。
thunsearhの探索アルゴリズムは、探索中で異常終了したノードは切り捨てて探索を続行し、ベストスコアのノードが正常終了まで探索するような実装しています。
そもそも異常終了するような行動を合法手扱いすべきではないので、legal_actions
のほうで罠を踏まない行動のみを返すようにしてもいいです。その場合はis_dead
は実装しなくてもいいです。
class MazeState(thun.BaseContextualState):
def is_dead(self):
return self.trap_ == self.character_
__str__
は探索中では影響がないのですが、タスクの状態を確認するにはデフォルト実装ではやや見にくいです。
罠あり数字集め迷路では盤面を長方形の形で表現できるようにしておくと便利です。
class MazeState(thun.BaseContextualState):
def __str__(self):
ss = ""
ss += f"turn:\t{self.turn_}\n"
ss += f"task_score_:\t{self.task_score_}\n"
for h in range(MazeState.H):
for w in range(MazeState.W):
if (self.character_ == Coord(h, w)):
ss += "@"
elif (self.trap_ == Coord(h, w)):
ss += "X"
elif self.points_[h][w] > 0:
ss += str(self.points_[h][w])
else:
ss += "."
ss += '\n'
return ss
clone
はデフォルト実装で困らないので、今回はオーバーライドしません。
すべてのメンバ変数をdeepcopyするのが速度的に問題だと感じた場合にオーバーライドして独自実装しましょう。
探索アルゴリズムの適用
MazeStateの実装が終わったので、いよいよ探索アルゴリズムを適用できます。
現状実装済みなのはビームサーチだけなので、ビームサーチを実行しましょう。
beam_search_action
は第1引数にBaseContextualStateを継承したサブクラスのインスタンス、第2引数にビーム幅をとり、タスク終了までにとるべき行動リストを返します。
ビーム幅2のビームサーチを適用する場合は以下のようなコードになります。
state = MazeState(0)
actions = thun.beam_search_action(state, 2)
print(actions)
上記のコードを実行すると以下のような結果になります。
シード0で生成した罠あり数字集め迷路の盤面に対し、
左(1)、下(2)、左(1)、左(1)の順で行動するとよいということがわかります。
VectorInt[1, 2, 1, 1]
これで探索アルゴリズムの適用は終わりなんですが、行動リストだけ見せられてもピンときませんね。
行動リストをもとにタスクを進める過程を可視化する関数show_task
を用意しているので、使ってみましょう。
state = MazeState(0)
actions = thun.beam_search_action(state, 2)
thun.show_task(state, actions)
実行結果は以下のようになります。
左、下、左、左に移動してスコア25を達成している様子がよくわかりますね。
##############################
turn: 0
task_score_: 0
X.48
764@
7593
##############################
turn: 1
task_score_: 4
X.48
76@.
7593
##############################
turn: 2
task_score_: 13
X.48
76..
75@3
##############################
turn: 3
task_score_: 18
X.48
76..
7@.3
##############################
turn: 4
task_score_: 25
X.48
76..
@..3
##############################
しかし、いちいち探索アルゴリズムを適用してactionsをとりだしてshow_taskに引数としていれて…とやるのはちょっとめんどうですね(?)
そこで、探索からタスク進行表示を一括で行う関数play_task
も用意しました。
play_task
では第1引数に対象問題のインスタンス、第2引数に探索アルゴリズムの関数、第3引数以降に探索アルゴリズム用引数を与えます。
ビーム幅2の探索アルゴリズムを適用する場合は以下のコードになります。
state = MazeState(0)
thun.play_task(state, thun.beam_search_action, 2)
実行結果は先ほどと同じになるので省略します。
ここまでで実装したサンプルコードの全文は以下で確認できます。
また、以下のサンプルコードを実行すると、未実装のメソッドがある状態で実行した場合のエラーを確認できます。
contextual_not_implemented_sample.py
速度評価
さて、今回ライブラリを開発した目的は「Pythonでも速い探索アルゴリズムを適用したい」ということでした。
ということはPythonのみで実装した探索アルゴリズムよりは速くないと意味がありませんね。
ということで実験しました。
実験コードは以下にあります。
https://github.com/thun-c/thun_search/tree/master/test_speed
先述の罠あり数字集め迷路をシード10種類の盤面にたいして10回ずつビームサーチで解き、1シードあたりの平均経過時間を確認しました。
thunsearchを利用した場合
ビーム幅 | スコア | 経過時間(s) |
---|---|---|
2 | 22.9 | 7 |
4 | 23.3 | 14 |
8 | 23.5 | 21 |
16 | 23.5 | 30 |
32 | 23.5 | 40 |
Pythonのみで実装した場合
ビーム幅 | スコア | 経過時間(s) |
---|---|---|
2 | 22.9 | 14 |
4 | 23.3 | 24 |
8 | 23.5 | 40 |
16 | 23.5 | 62 |
32 | 23.5 | 86 |
thunsearchを利用した場合、Pythonのみで実装する場合と同じスコアを維持しながらほぼ実行時間を約半分にできました。
C++とPythonの速度差は倍どころではない気がしますが、C++とPythonのデータ受け渡しのオーバーヘッドなどもあるため、とりあえず倍程度でも速くなったのでよしとします。
最後に
今回、Pythonで動く高速な探索アルゴリズムライブラリthunsearch
をリリースし、その使い方を説明しました。
説明のために「罠あり数字集め迷路」を実装してthunsearchを適用しましたが、同じ形に落とし込んでmust
ラベルのついたメソッドを実装さえすればあらゆる最適化問題に探索アルゴリズムを適用できます。
読者の皆さんにも、「罠あり数字集め迷路」ではなくいろいろな問題にthunsearchを適用して気軽な気持ちで最適化を楽しんでもらえるとうれしいです。