強化学習といえばゲームですね。
ただゲームを1から作るのはかなり大変です。
そこでエミュレータBizHawkとGymを連携させるフレームワーク GymBizHawk を作成しました。
これによりBizHawkで実行できるゲームはすべて強化学習の環境に出来る可能性があります。
その仕組みについての記事となります。
GitHub
GymBizHawk概要
GymBizHawkですが、以下のような形でBizHawkとエージェントを橋渡しするようなフレームワークとなります。
Gymのカスタム環境
Gym側の仕様です。
公式ドキュメント
・https://gymnasium.farama.org/content/basic_usage/
・https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
Gymの基本的な実行サイクルは以下です。
import gymnasium as gym
# 1. 環境の作成
env = gym.make("FrozenLake-v1", render_mode="ansi")
# 2. エピソードの初期化、
observation, info = env.reset()
print(env.render())
for _ in range(100):
action = env.action_space.sample()
# 3. アクションを元に1step進める
observation, reward, terminated, truncated, info = env.step(action)
print(env.render())
if terminated or truncated:
observation, info = env.reset()
# 4. close
env.close()
重要なのはreset,stepでこの2つの実装がメインとなります。
GymBizHawkの実装
BizHawk側の仕様ですが、comm関数が外部とのやりとり関数として実装されています。
プロトコルとしては、http,mmf(memory mapped file),socketの3種類があるようです。
(ただ使い方のドキュメントは探しても見つかりませんでした…)
Bizhawk/Lua Functions : http://tasvideos.org/Bizhawk/LuaFunctions.html
速度面からsocketを採用しました。
安定性はどうかなと思いましたが、基本ローカル通信なので大丈夫かなと思っています(一応遠隔で起動もできますが試していません)
起動までのフローは以下です。
- Socketサーバを作成
- BizHawkを起動(Lua経由で実行するゲームを指定)
- SocketサーバとBizHawkを接続
- Gym->BizHawk:初期情報を送信(実行モードなど)
- BizHawk->Gym:初期情報を送信(Space情報など)
クラス図は以下です。
黄色部分がフレームワークの実装、緑部分がユーザが実装する必要のあるコード、青が外部ツールです。
この後は黄色部分の説明をしていきます。
また、記事内のコードは要点のみ記載し、例外処理やオプション的な機能は省いています。
コード全体を見たい方はGitHubを見てください。
Socketクラス(Python)
Gym側のSocketサーバです。
BizHawkのSocketクライアントは内部で実装されているので不要です。
import socket
import selectors
class SocketServer:
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.conn = None
self.selector = None
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# アドレスにbindする
# ポートが開いてなかったら次のポートをトライする
for _ in range(100):
try:
self.sock.bind((self.host, self.port))
break
except OSError as e:
if e.errno == errno.EADDRINUSE:
self.port += 1
elif e.errno == errno.WSAEADDRINUSE:
self.port += 1
else:
self.close()
raise
self.sock.listen(1)
self.sock.setblocking(False)
# timeout処理用にselectorを導入
self.selector = selectors.DefaultSelector()
self.selector.register(self.sock, selectors.EVENT_READ)
def connect_wait(self) -> bool:
# クライアントから接続があるまで待機する
while True:
events = self.selector.select(timeout=self.timeout)
for key, _ in events:
if key.fileobj != self.sock:
continue
# クライアントと接続
self.conn, addr = self.sock.accept()
self.conn.setblocking(False)
self.selector.register(self.conn, selectors.EVENT_READ)
return True
# timeoutの場合
return False
def send(self, data: str) -> bool:
self.conn.send(data.encode(encoding="utf-8"))
return True
def recv(self, enable_decode: bool) -> str | bytes | None:
# クライアントから受信があるまで待機する
while True:
events = self.selector.select(timeout=self.timeout)
for key, _ in events:
if key.fileobj != self.conn:
continue
# データ受信、複数回に分けられる場合もあるのでそれも処理
data = b""
while True:
chunk = self.conn.recv(self.buffer_size)
if not chunk:
break
data += chunk
# bytesデータをそのまま利用したい場合はdecodeしない
if enable_decode:
data = data.decode("utf-8", "ignore")
return data
# timeout
return None
BizHawkクラス(Python)
Python側でBizHawkをラップするためのクラスです。
最初はSocketクラスとgym.Envクラスを混ぜて書いていましたが複雑になったので独立しています。
gym.Envには依存しておらず、これ単体で動きます。(gym.spacesを使っているのでgymには依存しています)
主な役割
・Socketインスタンスの管理
・BizHawkプロセスの管理
・通信管理
・スクリーンショット処理
class BizHawk:
def __init__(
self,
bizhawk_dir: str,
lua_file: str,
socket_ip: str = "127.0.0.1",
socket_port: int = 30000,
):
self.server = SocketServer(socket_ip, socket_port)
def boot(self):
# --- run bizhawk
# 引数でsocketサーバ情報とluaファイルを指定
cmd = os.path.join(self.bizhawk_dir, "EmuHawk.exe")
cmd += " --luaconsole"
cmd += " --socket_ip={}".format(self.server.host)
cmd += " --socket_port={}".format(self.server.port)
cmd += " --lua={}".format(self.lua_file)
self.emu = subprocess.Popen(cmd)
# --- connect
self.server.connect_wait()
# --- 1st send
# 識別用のprefixとして先頭にa、区切り文字を|としています
s = "a|{}".format("dummy")
self.send(s)
# --- 1st recv
# 初期情報+画像情報が送られてくる
d = self.recv(enable_split=True)
img = self._recv_image()
self.image_shape = img.shape
# 初期情報を元にspaceを作成(コードは省略)
self.action_space = 初期情報を元にgym.spacesを作成
self.observation_space = 初期情報を元にgym.spacesを作成
def send(self, data: str) -> None:
# BizHawkへの送信データは $"{msg.Length:D} {msg}" のフォーマット
data = f"{len(data)} {data}"
self.server.send(data)
def recv(self, enable_split: bool = False):
data = self.server.recv(enable_decode=True)
data = str(data).strip()
if enable_split:
return [v.strip() for v in data.split("|")]
return data
def _recv_image(self) -> np.ndarray:
img_raw = self.server.recv(enable_decode=False)
# rawデータをcv2形式で扱えるように変換(確かpngフォーマットで来たはず)
img_arr = np.frombuffer(img_raw, dtype=np.uint8)
img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def _decode_observation(self, obs_str: str) -> np.ndarray:
# 状態通信用フォーマット
# "s1 s2 s3 s4" スペース区切り
return np.array([float(o) for o in obs_str.split(" ")])
def _decode_invalid_actions(self, inv_act_str: str):
return inv_act_strを元に無効なアクションリストをデコードする(コードは省略)
# -----------------------------
# gym用の簡易ラッパー
# -----------------------------
def reset(self):
# "reset" を送信し、"invalid_actions|observation"とimageを受信
self.send("reset")
recv_str_list = self.recv(enable_split=True)
self.invalid_actions = self._decode_invalid_actions(recv_str_list[0])
state = self._decode_observation(recv_str_list[1])
self.step_img = self._recv_image()
return state
def step(self, action: list):
# actionのフォーマットは "step act1 act2 act3 ..."
act_str = " ".join([str(a) for a in action])
self.send(f"step {act_str}")
# 受信は、invalid_actions "|" reward "|" terminated "|" truncated "|" observation
recv_str_list = self.recv(enable_split=True)
self.invalid_actions = self._decode_invalid_actions(recv_str_list[0])
reward = float(recv_str_list[1])
terminated = True if recv_str_list[2] == "1" else False
truncated = True if recv_str_list[3] == "1" else False
state = self._decode_observation(recv_str_list[1])
self.step_img = self._recv_image()
return state, reward, terminated, truncated
GymEnvクラス(Lua)
BizHawk側で動作するluaプログラムです。
主な役割
・socketクライアントの管理
・通信管理
・BizHawkの実行の管理
・Processorの管理(Lua側のgym.Envに相当し、具体的な動作を定義するクラス)
基本は無限ループで待ち、コマンドがpythonから送られてきたら実行する形です。
local function startswith(str, start)
return string.sub(str, 1, string.len(start)) == start
end
local GymEnv = {}
GymEnv.new = function()
local this = {}
this.processor = nil
this.run = function(self, processor)
self.processor = processor
---- open rom
client.openrom(self.processor.ROM)
client.pause()
---- 1st recv
local recv = self:recv_wait()
local dummy = string.match(recv, "a|(.+)")
------ setup processor
self.processor:setup(self)
self.ACTION = self.processor.ACTION
self.OBSERVATION = self.processor.OBSERVATION
---- emu setting
client.speedmode(100)
client.unpause()
---- 1st send
local s = 初期情報を作成
self:send(s)
self:sendImage()
---- main loop
while true do
if self:_waitAction() == false then
self:log_info("loop end.")
break
end
end
self:close()
client.speedmode(100)
client.pause()
end
this.recv_wait = function(self)
return comm.socketServerResponse()
end
this.send = function(self, d)
comm.socketServerSend(d)
end
this.sendImage = function(self)
comm.socketServerScreenShot()
end
this._encodeObservation = function(self)
local d = self.processor:getObservation()
return dを送信用文字列に変換
end
this._encodeInvalidActions = function(self)
local d = self.processor:getInvalidActions()
return dを送信用文字列に変換
end
this._waitAction = function(self)
---- commandの受信を待つ
local data = self:recv_wait()
---- reset
if data == "reset" then
self.processor:reset()
-- invalid actionと状態と画像を送信する
local s = self:_encodeInvalidActions() .. "|" .. self.:_encodeObservation()
self:send(s)
self:sendImage()
return true
end
---- step
if startswith(data, "step") then
-- 1. recv action
local acts = dataをアクションの形式にデコード
-- 2. step
local reward, terminated, truncated = self.processor:step(acts)
-- 3. send invalid_actions "|" reward "|" terminated "|" truncated "|" observation
local s = self:_encodeInvalidActions()
s = s .. "|" .. reward
s = s .. "|" .. (terminated and "1" or "0")
s = s .. "|" .. (truncated and "1" or "0")
s = s .. "|" .. self:_encodeObservation()
self:_sendExtendObservtion(s)
self:send(s)
self:sendImage()
return true
end
その他いろいろなコマンド
return true
end
return this
end
GymBizHawkクラス(Python)
Python側のBizHawkをラップし、gym.Env形式を提供するクラスです。
ただ処理として追加されてる項目はrenderぐらいでしょうか。
import gymnasium
import pygame
class BizHawkEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"]}
def __init__(self, render_mode: str | None = None, **kwargs):
self.render_mode = render_mode
self.bizhawk = BizHawk(**kwargs)
self.bizhawk.boot()
self.action_space = self.bizhawk.action_space
self.observation_space = self.bizhawk.observation_space
self.screen = None
def reset(self):
state = self.bizhawk.reset()
return state, {}
def step(self, action: list):
state, reward, terminated, truncated = self.bizhawk.step(action)
return state, reward, terminated, truncated, {}
def render(self):
if self.screen is None:
pygame.init()
self.screen = pygame.Surface((self.bizhawk.image_shape[1], self.bizhawk.image_shape[0]))
img = self.bizhawk.step_img
img = img.swapaxes(0, 1)
img = pygame.surfarray.make_surface(img)
self.screen.blit(img, (0, 0))
return np.transpose(np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2))
ユーザ側の実装コード
基本はLua側のUserProcessorの実装となり、具体的にゲーム毎に処理する内容を実装します。
更にオプションでPython側のBizHawkEnvを継承することでgym.Envの処理を上書きすることも可能です。
具体的な書き方はGitHubに書いたドキュメントやサンプルコードを参考にしてください。
オリジナル環境の作り方:https://pocokhc.github.io/GymBizHawk/pages/custom.html
学習例
Gym環境に対応しているフレームワークならどれでも対応しています。
(もちろんフレームワークを使わずにそのままの利用もできます)
例として私が別途自作している強化学習フレームワークを使って、スーパーマリオブラザーズ1-1とmoonのジンギスカンを学習させてみました。
実際のコードはGitHubのexamplesに置いてあります。
スーパーマリオブラザーズ1-1
愚直に実装すると学習できないので(スタートボタン連打など…)学習しやすいように手を加えています。
具体的には以下です。
- Lua側
- アクションは常に右とBを押しっぱなしでエージェントはAを押すかどうかのみの2通り
- 状態を座標ではなくタイルという状態でまとめる(横0~9、縦0~5)
- 障害物と敵の位置(タイル)はマリオからの相対距離
- 一定時間進まなかったら終了(報酬-1)
- 報酬は、ゴール+100、死亡-1、進まない-0.01、進む+0.1
- 強化学習側
- 1stepを5フレームとし、その間は同じアクションを実行する(frameskip)
- 直近の8stepの状態を1つの状態としてエージェントの入力とする(window_length)
画像はdebugモードで実行した時のものです。
左上がエージェントが見る情報を可視化したものとなり、タイル単位で紫の範囲が見えています。
学習過程
アルゴリズムはRainbowで学習しています。
学習中の報酬の推移です。
学習結果
ちゃんとゴールまで行けました。
右下の0と1はQ値となります。
moonのジンギスカン(PS)
初期値は乱数で変動し(100通り)、LV5までのクリアとなります。
- Lua側
- アクションは変更なし(元々〇ボタンを押すか離すの2択)
- 状態は、y速度、y座標、ゴールのy座標、ゴールとの相対距離(y)
- 報酬は、LVクリア+100、ゲームオーバー-10、y座標がゴールに対して-1~3の間なら+0.1、それ以外は-0.01
- 強化学習側
- 直近の4stepの状態を1つの状態としてエージェントの入力とする(window_length)
学習過程
同じくアルゴリズムはRainbowで学習しています。
学習中の報酬の推移です。
学習結果
長いのでLV1のみです。
かなり安定していますね。
終わりに
BizHawkとGymの連携ですが、ありそうでなかったものかと思います。
できればluaまたはpythonをユーザが書かずに実現したかったのですが、報酬と終了条件を表現するのにどうしても必要になり…、Luaを書く形で実装しました。
Luaを書く必要はありますが、これでBizHawkで動かせるゲームはすべて強化学習できるようになったかと思います。