LoginSignup
30
22

More than 1 year has passed since last update.

PFRLでスーパーマリオ1-1をクリアするまで

Last updated at Posted at 2021-10-26

以前はStableBaselinesを使っていましたが、ニューラルネットワークの構造をいじりにくいことやアルゴリズムに手を加えにくいと思っていました。ネットワークにAttentionを加えるだけでも一苦労でした。

そもそもTensorFlow向けであった点も使いづらかったので、PyTorch向けの深層強化学習ライブラリを探していたところPFRLというライブラリを見つけました。

Hello World代わりにスーパーマリオブラザーズ1-1をクリアしてみようと思います。

まずネットワークを定義します。

image.png

class Network(nn.Module):
    def __init__(self, hidden_dim, space_shape, n_actions):
        super(Network, self).__init__()
        h, w, c = space_shape
        self.hidden_dim = hidden_dim
        self.feature_extract = nn.Sequential(
            nn.Conv2d(c, hidden_dim, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
            nn.ReLU()
            )
        flatten_dim = hidden_dim * (h // 8) * (w // 8)
        self.flatten_dim = flatten_dim
        self.head = pfrl.nn.Branched(
            nn.Sequential(
                nn.Linear(flatten_dim, n_actions),
                SoftmaxCategoricalHead()
            ),
            nn.Linear(flatten_dim, 1)
        )

    def forward(self, image):
        b, _, _, _ = image.shape
        feature = self.feature_extract(image)
        feature = feature.view(b, -1)
        actions, values = self.head(feature)

        return actions, values

出力用のNNとOptimizerをagents.PPOに渡します。
agentsには様々な強化学習のアルゴリズムが用意されています。
ここではPPO(Proximal Policy Optimization)を使います。

    opt = torch.optim.Adam(network.parameters(), lr=3e-4, eps=1e-8)

    agent = agents.PPO(
        network,
        opt,
        gamma=0.9,
        gpu=0,
        update_interval=update_interval,
        minibatch_size=32,
        epochs=10,
        clip_eps=0.2,
        clip_eps_vf=None,
        standardize_advantages=True,
        entropy_coef=0.01,
        max_grad_norm=0.5,
    )

experimentsにはagentとenvを渡します。
envについてはStable Baselinesのときと同じなので省略します

    experiments.train_agent_batch_with_evaluation(
        agent=agent,
        env=env,
        eval_env=env_eval,
        steps=TOTAL_STEPS,
        eval_n_steps=None,
        eval_n_episodes=10,
        eval_interval=100000,
        outdir=output_dir,
        save_best_so_far_agent=True,
    )

10000ステップごとに16エピソードをプレイした平均報酬和を以下に示します。

image.png

右に進むことができた距離の最大値を更新できた場合に報酬を与えており、報酬和の最大値は316となるようにしています。
2000000ステップでかなり学習できています。

プレイ動画はこんな感じ。

プレイ動画を出力するためのコードを付記しておきます。

from matplotlib import animation

def display_frames_as_movie(frames, fname):
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=66)
    anim.save(fname, writer='ffmpeg')

env = make_env(ENV_NAME)
actions = env.action_space
n_actions = env.action_space.n
space_shape = env.observation_space.shape
network = Network(HIDDEN_DIM, space_shape, n_actions)
network.load_state_dict(torch.load(f"{model_dir}/5000000_finish/model.pt", map_location=torch.device('cpu'))) # 学習済みのweightを読み込む

frames = [] # 各時点でのゲーム画面を格納しておくためのリスト

state = env.reset()
while True:

    image = env.render(mode="rgb_array").copy() # ゲーム画面を取得

    state = torch.FloatTensor(state[np.newaxis, :, :, :])
    action_prob, value = network.forward(state)
    action = action_prob.sample().item()

    frames.append(image)

    state, rewards, done, info = env.step(action)

    if done:
        print(info["x_pos"])
        frames.append(env.render(mode='rgb_array').copy())
        break

display_frames_as_movie(max_frames, output_name)
env.close()

30
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
30
22