3
2

エミュレータと連携できるGym環境を作った話(Gym×BizHawk×強化学習)

Posted at

強化学習といえばゲームですね。
ただゲームを1から作るのはかなり大変です。
そこでエミュレータBizHawkとGymを連携させるフレームワーク GymBizHawk を作成しました。
これによりBizHawkで実行できるゲームはすべて強化学習の環境に出来る可能性があります。
その仕組みについての記事となります。

GitHub

GymBizHawk概要

GymBizHawkですが、以下のような形でBizHawkとエージェントを橋渡しするようなフレームワークとなります。

aa-ページ2.drawio.png

Gymのカスタム環境

Gym側の仕様です。

公式ドキュメント
https://gymnasium.farama.org/content/basic_usage/
https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/

Gymの基本的な実行サイクルは以下です。

import gymnasium as gym

# 1. 環境の作成
env = gym.make("FrozenLake-v1", render_mode="ansi")

# 2. エピソードの初期化、
observation, info = env.reset()
print(env.render())

for _ in range(100):
    action = env.action_space.sample()

    # 3. アクションを元に1step進める
    observation, reward, terminated, truncated, info = env.step(action)
    print(env.render())

    if terminated or truncated:
        observation, info = env.reset()

# 4. close
env.close()

重要なのはreset,stepでこの2つの実装がメインとなります。

GymBizHawkの実装

BizHawk側の仕様ですが、comm関数が外部とのやりとり関数として実装されています。
プロトコルとしては、http,mmf(memory mapped file),socketの3種類があるようです。
(ただ使い方のドキュメントは探しても見つかりませんでした…)

Bizhawk/Lua Functions : http://tasvideos.org/Bizhawk/LuaFunctions.html

速度面からsocketを採用しました。
安定性はどうかなと思いましたが、基本ローカル通信なので大丈夫かなと思っています(一応遠隔で起動もできますが試していません)
起動までのフローは以下です。

  1. Socketサーバを作成
  2. BizHawkを起動(Lua経由で実行するゲームを指定)
  3. SocketサーバとBizHawkを接続
  4. Gym->BizHawk:初期情報を送信(実行モードなど)
  5. BizHawk->Gym:初期情報を送信(Space情報など)

クラス図は以下です。

overview.drawio.png

黄色部分がフレームワークの実装、緑部分がユーザが実装する必要のあるコード、青が外部ツールです。
この後は黄色部分の説明をしていきます。

また、記事内のコードは要点のみ記載し、例外処理やオプション的な機能は省いています。
コード全体を見たい方はGitHubを見てください。

Socketクラス(Python)

Gym側のSocketサーバです。
BizHawkのSocketクライアントは内部で実装されているので不要です。

import socket
import selectors

class SocketServer:
    def __init__(self, host: str, port: int):
        self.host = host
        self.port = port
        self.conn = None
        self.selector = None

        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        # アドレスにbindする
        # ポートが開いてなかったら次のポートをトライする
        for _ in range(100):
            try:
                self.sock.bind((self.host, self.port))
                break
            except OSError as e:
                if e.errno == errno.EADDRINUSE:
                    self.port += 1
                elif e.errno == errno.WSAEADDRINUSE:
                    self.port += 1
                else:
                    self.close()
                    raise
        
        self.sock.listen(1)
        self.sock.setblocking(False)

        # timeout処理用にselectorを導入
        self.selector = selectors.DefaultSelector()
        self.selector.register(self.sock, selectors.EVENT_READ)

    def connect_wait(self) -> bool:
        # クライアントから接続があるまで待機する
        while True:
            events = self.selector.select(timeout=self.timeout)
            for key, _ in events:
                if key.fileobj != self.sock:
                    continue
                
                # クライアントと接続
                self.conn, addr = self.sock.accept()
                self.conn.setblocking(False)
                self.selector.register(self.conn, selectors.EVENT_READ)
                return True

            # timeoutの場合
            return False

    def send(self, data: str) -> bool:
        self.conn.send(data.encode(encoding="utf-8"))
        return True

    def recv(self, enable_decode: bool) -> str | bytes | None:
        # クライアントから受信があるまで待機する
        while True:
            events = self.selector.select(timeout=self.timeout)
            for key, _ in events:
                if key.fileobj != self.conn:
                    continue
                
                # データ受信、複数回に分けられる場合もあるのでそれも処理
                data = b""
                while True:
                    chunk = self.conn.recv(self.buffer_size)
                    if not chunk:
                        break
                    data += chunk

                # bytesデータをそのまま利用したい場合はdecodeしない
                if enable_decode:
                    data = data.decode("utf-8", "ignore")
                return data

            # timeout
            return None

BizHawkクラス(Python)

Python側でBizHawkをラップするためのクラスです。
最初はSocketクラスとgym.Envクラスを混ぜて書いていましたが複雑になったので独立しています。
gym.Envには依存しておらず、これ単体で動きます。(gym.spacesを使っているのでgymには依存しています)

主な役割
・Socketインスタンスの管理
・BizHawkプロセスの管理
・通信管理
・スクリーンショット処理

class BizHawk:
    def __init__(
        self,
        bizhawk_dir: str,
        lua_file: str,
        socket_ip: str = "127.0.0.1",
        socket_port: int = 30000,
    ):
        self.server = SocketServer(socket_ip, socket_port)

    def boot(self):
        # --- run bizhawk
        # 引数でsocketサーバ情報とluaファイルを指定
        cmd = os.path.join(self.bizhawk_dir, "EmuHawk.exe")
        cmd += " --luaconsole"
        cmd += " --socket_ip={}".format(self.server.host)
        cmd += " --socket_port={}".format(self.server.port)
        cmd += " --lua={}".format(self.lua_file)
        self.emu = subprocess.Popen(cmd)

        # --- connect
        self.server.connect_wait()

        # --- 1st send
        # 識別用のprefixとして先頭にa、区切り文字を|としています
        s = "a|{}".format("dummy")
        self.send(s)

        # --- 1st recv
        # 初期情報+画像情報が送られてくる
        d = self.recv(enable_split=True)
        img = self._recv_image()
        self.image_shape = img.shape

        # 初期情報を元にspaceを作成(コードは省略)
        self.action_space = 初期情報を元にgym.spacesを作成
        self.observation_space = 初期情報を元にgym.spacesを作成


    def send(self, data: str) -> None:
        # BizHawkへの送信データは $"{msg.Length:D} {msg}" のフォーマット
        data = f"{len(data)} {data}"
        self.server.send(data)

    def recv(self, enable_split: bool = False):
        data = self.server.recv(enable_decode=True)
        data = str(data).strip()
        if enable_split:
            return [v.strip() for v in data.split("|")]
        return data

    def _recv_image(self) -> np.ndarray:
        img_raw = self.server.recv(enable_decode=False)
        # rawデータをcv2形式で扱えるように変換(確かpngフォーマットで来たはず)
        img_arr = np.frombuffer(img_raw, dtype=np.uint8)
        img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img

    def _decode_observation(self, obs_str: str) -> np.ndarray:
        # 状態通信用フォーマット
        # "s1 s2 s3 s4" スペース区切り
        return np.array([float(o) for o in obs_str.split(" ")])

    def _decode_invalid_actions(self, inv_act_str: str):
        return inv_act_strを元に無効なアクションリストをデコードするコードは省略

    # -----------------------------
    # gym用の簡易ラッパー
    # -----------------------------
    def reset(self):
        # "reset" を送信し、"invalid_actions|observation"とimageを受信
        self.send("reset")
        recv_str_list = self.recv(enable_split=True)
        self.invalid_actions = self._decode_invalid_actions(recv_str_list[0])
        state = self._decode_observation(recv_str_list[1])
        self.step_img = self._recv_image()
        return state

    def step(self, action: list):
        # actionのフォーマットは "step act1 act2 act3 ..."
        act_str = " ".join([str(a) for a in action])
        self.send(f"step {act_str}")

        # 受信は、invalid_actions "|" reward "|" terminated "|" truncated "|" observation
        recv_str_list = self.recv(enable_split=True)
        self.invalid_actions = self._decode_invalid_actions(recv_str_list[0])
        reward = float(recv_str_list[1])
        terminated = True if recv_str_list[2] == "1" else False
        truncated = True if recv_str_list[3] == "1" else False
        state = self._decode_observation(recv_str_list[1])
        self.step_img = self._recv_image()
        return state, reward, terminated, truncated

GymEnvクラス(Lua)

BizHawk側で動作するluaプログラムです。

主な役割
・socketクライアントの管理
・通信管理
・BizHawkの実行の管理
・Processorの管理(Lua側のgym.Envに相当し、具体的な動作を定義するクラス)

基本は無限ループで待ち、コマンドがpythonから送られてきたら実行する形です。

local function startswith(str, start)
    return string.sub(str, 1, string.len(start)) == start
end

local GymEnv = {}
GymEnv.new = function()
    local this = {}
    this.processor = nil

    this.run = function(self, processor)
        self.processor = processor

        ---- open rom
        client.openrom(self.processor.ROM)
        client.pause()

        ---- 1st recv
        local recv = self:recv_wait()
        local dummy = string.match(recv, "a|(.+)")

        ------ setup processor
        self.processor:setup(self)
        self.ACTION = self.processor.ACTION
        self.OBSERVATION = self.processor.OBSERVATION

        ---- emu setting
        client.speedmode(100)
        client.unpause()

        ---- 1st send
        local s = 初期情報を作成
        self:send(s)
        self:sendImage()

        ---- main loop
        while true do
            if self:_waitAction() == false then
                self:log_info("loop end.")
                break
            end
        end

        self:close()
        client.speedmode(100)
        client.pause()
    end

    this.recv_wait = function(self)
        return comm.socketServerResponse()
    end

    this.send = function(self, d)
        comm.socketServerSend(d)
    end

    this.sendImage = function(self)
        comm.socketServerScreenShot()
    end

    this._encodeObservation = function(self)
        local d = self.processor:getObservation()
        return dを送信用文字列に変換
    end

    this._encodeInvalidActions = function(self)
        local d = self.processor:getInvalidActions()
        return dを送信用文字列に変換
    end

    this._waitAction = function(self)
        ---- commandの受信を待つ
        local data = self:recv_wait()

        ---- reset
        if data == "reset" then
            self.processor:reset()

            -- invalid actionと状態と画像を送信する
            local s = self:_encodeInvalidActions() .. "|" .. self.:_encodeObservation()
            self:send(s)
            self:sendImage()
            return true
        end

        ---- step
        if startswith(data, "step") then
            -- 1. recv action
            local acts = dataをアクションの形式にデコード

            -- 2. step
            local reward, terminated, truncated = self.processor:step(acts)

            -- 3. send invalid_actions "|" reward "|" terminated "|" truncated "|" observation
            local s = self:_encodeInvalidActions()
            s = s .. "|" .. reward
            s = s .. "|" .. (terminated and "1" or "0")
            s = s .. "|" .. (truncated and "1" or "0")
            s = s .. "|" .. self:_encodeObservation()
            self:_sendExtendObservtion(s)
            self:send(s)
            self:sendImage()
            return true
        end

        その他いろいろなコマンド
        return true
    end

    return this
end

GymBizHawkクラス(Python)

Python側のBizHawkをラップし、gym.Env形式を提供するクラスです。
ただ処理として追加されてる項目はrenderぐらいでしょうか。

import gymnasium
import pygame
class BizHawkEnv(gym.Env):
    metadata = {"render_modes": ["rgb_array"]}

    def __init__(self, render_mode: str | None = None, **kwargs):
        self.render_mode = render_mode
        self.bizhawk = BizHawk(**kwargs)
        self.bizhawk.boot()
        self.action_space = self.bizhawk.action_space
        self.observation_space = self.bizhawk.observation_space

        self.screen = None

    def reset(self):
        state = self.bizhawk.reset()
        return state, {}

    def step(self, action: list):
        state, reward, terminated, truncated = self.bizhawk.step(action)
        return state, reward, terminated, truncated, {}
        
    def render(self):
        if self.screen is None:
            pygame.init()
            self.screen = pygame.Surface((self.bizhawk.image_shape[1], self.bizhawk.image_shape[0]))
        
        img = self.bizhawk.step_img
        img = img.swapaxes(0, 1)
        img = pygame.surfarray.make_surface(img)
        self.screen.blit(img, (0, 0))
        return np.transpose(np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2))

ユーザ側の実装コード

基本はLua側のUserProcessorの実装となり、具体的にゲーム毎に処理する内容を実装します。
更にオプションでPython側のBizHawkEnvを継承することでgym.Envの処理を上書きすることも可能です。

具体的な書き方はGitHubに書いたドキュメントやサンプルコードを参考にしてください。
オリジナル環境の作り方:https://pocokhc.github.io/GymBizHawk/pages/custom.html

学習例

Gym環境に対応しているフレームワークならどれでも対応しています。
(もちろんフレームワークを使わずにそのままの利用もできます)

例として私が別途自作している強化学習フレームワークを使って、スーパーマリオブラザーズ1-1とmoonのジンギスカンを学習させてみました。
実際のコードはGitHubのexamplesに置いてあります。

スーパーマリオブラザーズ1-1

愚直に実装すると学習できないので(スタートボタン連打など…)学習しやすいように手を加えています。
具体的には以下です。

  • Lua側
    • アクションは常に右とBを押しっぱなしでエージェントはAを押すかどうかのみの2通り
    • 状態を座標ではなくタイルという状態でまとめる(横0~9、縦0~5)
    • 障害物と敵の位置(タイル)はマリオからの相対距離
    • 一定時間進まなかったら終了(報酬-1)
    • 報酬は、ゴール+100、死亡-1、進まない-0.01、進む+0.1
  • 強化学習側
    • 1stepを5フレームとし、その間は同じアクションを実行する(frameskip)
    • 直近の8stepの状態を1つの状態としてエージェントの入力とする(window_length)

ss1.png

画像はdebugモードで実行した時のものです。
左上がエージェントが見る情報を可視化したものとなり、タイル単位で紫の範囲が見えています。

学習過程

アルゴリズムはRainbowで学習しています。
学習中の報酬の推移です。

smb.png
(横が学習回数、縦が報酬)

学習結果

ちゃんとゴールまで行けました。
右下の0と1はQ値となります。

smb_zip.gif

moonのジンギスカン(PS)

初期値は乱数で変動し(100通り)、LV5までのクリアとなります。

  • Lua側
    • アクションは変更なし(元々〇ボタンを押すか離すの2択)
    • 状態は、y速度、y座標、ゴールのy座標、ゴールとの相対距離(y)
    • 報酬は、LVクリア+100、ゲームオーバー-10、y座標がゴールに対して-1~3の間なら+0.1、それ以外は-0.01
  • 強化学習側
    • 直近の4stepの状態を1つの状態としてエージェントの入力とする(window_length)

学習過程

同じくアルゴリズムはRainbowで学習しています。
学習中の報酬の推移です。

moon.png
(横が学習回数、縦が報酬)

学習結果

長いのでLV1のみです。
かなり安定していますね。

moon_zip.gif

終わりに

BizHawkとGymの連携ですが、ありそうでなかったものかと思います。
できればluaまたはpythonをユーザが書かずに実現したかったのですが、報酬と終了条件を表現するのにどうしても必要になり…、Luaを書く形で実装しました。
Luaを書く必要はありますが、これでBizHawkで動かせるゲームはすべて強化学習できるようになったかと思います。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2