11
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Flappy BirdをAIに自動プレイさせる話

Last updated at Posted at 2020-08-10

こんばんは。りーぜんとです。

今回はFizzBuzzのやつに引き続き、強化学習第二弾ということで、Flappy BirdをプレイするAIを作ってみます。
作成したプログラムは全てGitHubにあるので参考にしてください。

目次

  1. Flappy Birdとは
  2. ゲームを実装
  3. DQNで挑戦
  4. NEATで挑戦
  5. 結果

Flappy Birdとは

Flappy Birdというゲームを聞いたことがある人は多いんじゃないでしょうか。無料でプレイできるのでぜひ遊んでみてください。

とりあえず僕もやってみます。

スクリーンショット 2020-08-10 19.46.44.png

いやむっず。
鳥が土管を超えるとポイントが入るのですが、結構頑張ったけど5点までいけません。

今からAIを作って5点を突破できれば僕より強いってことですね。ひとまず10点を目標にしましょう。

ゲームを実装

まずは学習に使うゲーム環境を作っていきます。
今回はpygameというライブラリを使って実際に学習してる様子が見えるように作ります。
pygameについて詳しい解説はしないので、是非公式リファレンス等を読んでみてください。

ライブラリのインポート

最初にライブラリをインポートしたり定数を定義します。

flappy_bird.py
import pygame
import random
import sys
import math
from numpy import array

WIN_WIDTH  = 600
WIN_HEIGHT = 800

COLORS = {
    'sky':    (135, 206, 250),
    'bird':   (255, 255, 0),
    'pipe':   (50,  205, 50),
    'ground': (160, 82,  45)
}

BIRD_SIZE = 50

PIPE_VEL = 4
PIPE_GAP = 200
PIPE_WIDTH  = 100
PIPE_MARGIN = 100

GROUND_HEIGHT = 100

鳥クラスの作成

次に鳥のクラスから作っていきます。今回もFizzBuzzの記事と同様に作っていくので、詳しくはそちらをみてください。

bird.py
class Bird:
    def __init__(self, x=200, y=350):
        self.y = y
        self.vel = 0
        self.rect = pygame.Rect(x, self.y, BIRD_SIZE, BIRD_SIZE)

    def jump(self):
        self.vel = -6

    def move(self):
        self.vel += 0.4

        self.y += self.vel

        self.rect.top = self.y

    def get_state(self):
        return [self.y, self.vel]

    def draw(self, win):
        pygame.draw.rect(win, COLORS['bird'], self.rect)

y座標と、加速度、描画用のpygame.Rectオブジェクトを持たせてあります。今回はお絵かきがめんどくさかったので、全ての物体は長方形で構成されます。お許しを。
また、get_state()で現在の鳥の状態(今の位置と加速度)を取得可能にしています。これを強化学習の際に使用します。

土管クラスの作成

次は土管です。

pipe.py
class Pipe:
    def __init__(self, x=700):
        self.top    = random.randrange(PIPE_MARGIN, WIN_HEIGHT - PIPE_GAP - PIPE_MARGIN - GROUND_HEIGHT)
        self.bottom = self.top + PIPE_GAP

        self.top_rect    = pygame.Rect(x, 0, PIPE_WIDTH, self.top)
        self.bottom_rect = pygame.Rect(x, self.bottom, PIPE_WIDTH, WIN_HEIGHT - self.bottom)

    def move(self):
        self.top_rect    = self.top_rect.move(-PIPE_VEL, 0)
        self.bottom_rect = self.bottom_rect.move(-PIPE_VEL, 0)

        if self.top_rect.right < 0:
            self.__init__()

        return self.top_rect.left == 200

    def draw(self, win):
        pygame.draw.rect(win, COLORS['pipe'], self.top_rect)
        pygame.draw.rect(win, COLORS['pipe'], self.bottom_rect)

常に二本の土管をいい感じの間隔でゲーム内に置いておき、左端まで進んだ土管を右端にテレポートさせることでずっと土管が来るようにしました。

地面クラスの作成

次の地面ですが、これに関して語ることはありません。

ground.py
class Ground:
    def __init__(self):
        self.rect = pygame.Rect(0, WIN_HEIGHT - GROUND_HEIGHT, WIN_WIDTH, GROUND_HEIGHT)

    def draw(self, win):
        pygame.draw.rect(win, COLORS['ground'], self.rect)

ゲームクラスの作成

最後にゲーム本体のクラスです。

flappy_bird.py
class FlappyBird:
    def __init__(self, n_bird=1):
        pygame.init()
        self.win = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))
        pygame.display.set_caption('Flappy Bird')

        self.n_bird = n_bird
        self.birds  = [Bird() for _ in range(self.n_bird)]
        self.pipes  = [Pipe(800), Pipe(1200)]
        self.ground = Ground()

        self.score = 0
        
    def reset(self):
        self.__init__(self.n_bird)

    def draw(self):
        self.win.fill(COLORS['sky'])

        for bird in self.birds:
            bird.draw(self.win)

        for pipe in self.pipes:
            pipe.draw(self.win)

        self.ground.draw(self.win)

        pygame.display.update()

    def check_collide(self, bird):
        if bird.y <= -BIRD_SIZE:
            return True

        for pipe in self.pipes:
            if pipe.top_rect.colliderect(bird.rect) or pipe.bottom_rect.colliderect(bird.rect):
                return True

        if self.ground.rect.colliderect(bird.rect):
            return True

        return False

    def step(self, actions):
        passed = False

        for pipe in self.pipes:
            if pipe.move():
                self.score += 1
                passed = True

        next_birds = []
        states     = []
        rewards    = []

        for action, bird in zip(actions, self.birds):
            if action.argmax():
                bird.jump()

            last_y = bird.y

            bird.move()

            pipe_idx = 0
            while bird.rect.x > self.pipes[pipe_idx].top_rect.left:
                pipe_idx += 1

            rewards.append(
                1 if abs(
                    self.pipes[pipe_idx].top_rect.bottom + PIPE_GAP / 2 - last_y
                ) < abs(
                    self.pipes[pipe_idx].top_rect.bottom + PIPE_GAP / 2 - bird.y
                ) else 0
            )

            pipe_state = [
                self.pipes[pipe_idx].top_rect.bottom - bird.y,
                self.pipes[pipe_idx].top_rect.left - 200
            ]

            states.append(bird.get_state() + pipe_state)

            finished = self.check_collide(bird)
            if not finished:
                next_birds.append(bird)

        self.birds = next_birds

        return array(states), array(rewards), finished

    def random_step(self):
        for pipe in self.pipes:
            pipe.move()

        state = []
        for bird in self.birds:
            pipe_state = [
                self.pipes[0].top_rect.bottom - bird.y,
                self.pipes[0].top_rect.left - 200
            ]

            state.append(bird.get_state() + pipe_state)

        return array(state), array([0 for _ in range(self.n_bird)]), False

環境を初期化するreset()、鳥が死んだかどうかを判定するcheck_collide()、次フレームに遷移するstep()、random_step()を実装しました。
step()は今の状況、それに対する報酬、ゲームが終了したかどうかを返します。

DQNで挑戦

さて、前の記事同様、Deep Q Learningという手法を使って学習をしてみます。細かい解説はしないので、いろいろ調べてみてください。
エージェント、メモリ、モデルは前の記事で実装したものをそのまま使います。

train.py
import pygame

from flappy_bird import FlappyBird
from model import Model
from memory import Memory
from agent import Agent

def evaluate(env, agent):
    env.reset()
    state, _, finished = env.random_step()
    while not finished:
        action = agent.get_action(state, N_EPOCHS, main_model)
        next_state, _, finished = env.step(action.argmax(), verbose=True)
        state = next_state

def main():
    clock = pygame.time.Clock()

    N_EPOCHS = 1000
    GAMMA    = 0.99
    N_BIRD   = 64
    S_BATCH  = 256

    env = FlappyBird(N_BIRD)

    main_model   = Model()
    target_model = Model()

    memory = Memory()
    agent  = Agent()

    for epoch in range(1, N_EPOCHS + 1):
        print('Epoch: {}'.format(epoch))

        env.reset()
        states, rewards, finished = env.random_step() 
        target_model.model.set_weights(main_model.model.get_weights())

        running = True
        while running:
            clock.tick(60)

            actions = []
            for state in states:
                actions.append(agent.get_action(state, epoch, main_model))

            next_states, rewards, finished = env.step(actions)
            for state, reward, action, next_state in zip(states, rewards, actions, next_states):
                memory.add((state, action, reward, next_state))

            states = next_states

            if len(memory.buffer) % S_BATCH == 0:
                main_model.replay(memory, env.n_bird, GAMMA, target_model)

            target_model.model.set_weights(main_model.model.get_weights())

            if not len(env.birds):
                running = False
                break

            env.draw()

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False

        print('\tScore: {}'.format(env.score))
    
    pygame.quit()

if __name__ == '__main__':
    main()

さて、早速学習させてみましょう。

動画を見たらわかるように、全然成長しません。FizzBuzzのときはうまくいったのに、、、
ゲームの内容が複雑になったからかな?

違うアルゴリズムに挑戦してみます。

NEATで挑戦

NEATとは

DQNでは上手くいかなかったので、NEATというアルゴリズムを使ってみます。これは遺伝的アルゴリズムといわれるものです。

簡単に説明をしてみます。

まず、DQNではエポックという単位で学習を進めましたが、NEATでは世代という考え方をします。これは人間でいう世代と同じものだと思ってください。
一世代に100羽の鳥がいるとします。各鳥は自分のニューラルネットワークを持っていて、それをもとに動きます。
一世代が全員死ぬまでゲームを動かすと、その世代の1位から100位まで順位をつけることができます。
この100羽の中で優秀な鳥から次の世代の100羽を生み出します。生み出すとは、各鳥が持ってるニューラルネットワークを少しずつ改変して新しいニューラルネットワークにするという意味です。

※イメージ図
名称未設定のノート (2)-1-min.jpg

Configの設定

今回はneat-pythonというライブラリを使って学習させます。そのためにはconfigファイルを作成して設定をしておかないといけません。今回は公式サイトにあったやつを少しだけいじって使います。

neat_config.txt
[NEAT]
fitness_criterion   = max
fitness_threshold   = 500
pop_size            = 50
reset_on_extinction = False

[DefaultGenome]
activation_default     = sigmoid
activation_mutate_rate = 0.0
activation_options     = sigmoid

aggregation_default     = sum
aggregation_mutate_rate = 0.0
aggregation_options     = sum

bias_init_mean    = 0.0
bias_init_stdev   = 1.0
bias_max_value    = 100
bias_min_value    = -100
bias_mutate_power = 0.5
bias_mutate_rate  = 0.7
bias_replace_rate = 0.1

compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient   = 0.5

conn_add_prob    = 0.5
conn_delete_prob = 0.5

enabled_default     = True
enabled_mutate_rate = 0.01

feed_forward       = True
initial_connection = full

node_add_prob    = 0.2
node_delete_prob = 0.2

num_inputs  = 4
num_hidden  = 0
num_outputs = 1

response_init_mean      = 1.0
response_init_stdev     = 0.0
response_max_value      = 100
response_min_value      = -100
response_mutate_power   = 0.0
response_mutate_rate    = 0.0
response_replace_rate   = 0.0

weight_init_mean        = 0.0
weight_init_stdev       = 1.0
weight_max_value        = 100
weight_min_value        = -100
weight_mutate_power     = 0.5
weight_mutate_rate      = 0.8
weight_replace_rate     = 0.1

[DefaultSpeciesSet]
compatibility_threshold = 3.0

[DefaultStagnation]
species_fitness_func = max
max_stagnation       = 20
species_elitism      = 2

[DefaultReproduction]
elitism            = 2
survival_threshold = 0.2

学習させてみる

一世代の鳥を50羽として学習させてみます。

train.py
import neat
import pygame

from flappy_bird import FlappyBird

CFG_PATH = 'neat_config.txt'
NUM_BIRD = 50

env = FlappyBird(NUM_BIRD)

def gen(genomes, config):
    clock = pygame.time.Clock()
    env.reset()

    nets = []
    ge   = []
    for _, g in genomes:
        nets.append(neat.nn.FeedForwardNetwork.create(g, config))
        g.fitness = 0
        ge.append(g)

    while len(env.birds) > 0:
        clock.tick(60)
        for pipe in env.pipes:
            if pipe.move():
                env.score += 1

                for g in ge:
                    g.fitness += 3

        pipe_idx = 0
        while env.birds[0].rect.x > env.pipes[pipe_idx].top_rect.left:
            pipe_idx += 1

        for i, bird in enumerate(env.birds):
            bird_state = bird.get_state()
            output = nets[i].activate((
                bird_state[0],
                bird_state[1],
                env.pipes[pipe_idx].top_rect.bottom - bird.rect.y,
                env.pipes[pipe_idx].top_rect.left - bird.rect.x))
            if output[0] > 0.5:
                bird.jump()

            bird.move()
            ge[i].fitness += 0.1

            if env.check_collide(bird):
                ge[i].fitness -= 1

                env.birds.pop(i)
                nets.pop(i)
                ge.pop(i)

        env.draw()

def train():
    config = neat.config.Config(
        neat.DefaultGenome,
        neat.DefaultReproduction,
        neat.DefaultSpeciesSet,
        neat.DefaultStagnation,
        CFG_PATH)

    p = neat.Population(config)
    p.add_reporter(neat.StdOutReporter(True))
    p.add_reporter(neat.StatisticsReporter())

    winner = p.run(gen, NUM_BIRD)

    pygame.quit()

if __name__ == '__main__':
    train()

学習の様子です。

DQNのときよりもちゃんと成長してるのが分かりますね!

結果

さて、今回はAIにFlappy Birdを学習させてみました。ビジュアライズも可能にすると学習の様子がみれてとても可愛いですね。

数時間学習をさせたら50点に到達しました。目標の10点を軽々達成しました。

スクリーンショット 2020-08-10 14.49.48.png

よければTwitterフォローしてください。じゃあね。

11
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?