こんばんは。りーぜんとです。
今回はFizzBuzzのやつに引き続き、強化学習第二弾ということで、Flappy BirdをプレイするAIを作ってみます。
作成したプログラムは全てGitHubにあるので参考にしてください。
目次
Flappy Birdとは
Flappy Birdというゲームを聞いたことがある人は多いんじゃないでしょうか。無料でプレイできるのでぜひ遊んでみてください。
とりあえず僕もやってみます。
いやむっず。
鳥が土管を超えるとポイントが入るのですが、結構頑張ったけど5点までいけません。
今からAIを作って5点を突破できれば僕より強いってことですね。ひとまず10点を目標にしましょう。
ゲームを実装
まずは学習に使うゲーム環境を作っていきます。
今回はpygameというライブラリを使って実際に学習してる様子が見えるように作ります。
pygameについて詳しい解説はしないので、是非公式リファレンス等を読んでみてください。
ライブラリのインポート
最初にライブラリをインポートしたり定数を定義します。
import pygame
import random
import sys
import math
from numpy import array
WIN_WIDTH = 600
WIN_HEIGHT = 800
COLORS = {
'sky': (135, 206, 250),
'bird': (255, 255, 0),
'pipe': (50, 205, 50),
'ground': (160, 82, 45)
}
BIRD_SIZE = 50
PIPE_VEL = 4
PIPE_GAP = 200
PIPE_WIDTH = 100
PIPE_MARGIN = 100
GROUND_HEIGHT = 100
鳥クラスの作成
次に鳥のクラスから作っていきます。今回もFizzBuzzの記事と同様に作っていくので、詳しくはそちらをみてください。
class Bird:
def __init__(self, x=200, y=350):
self.y = y
self.vel = 0
self.rect = pygame.Rect(x, self.y, BIRD_SIZE, BIRD_SIZE)
def jump(self):
self.vel = -6
def move(self):
self.vel += 0.4
self.y += self.vel
self.rect.top = self.y
def get_state(self):
return [self.y, self.vel]
def draw(self, win):
pygame.draw.rect(win, COLORS['bird'], self.rect)
y座標と、加速度、描画用のpygame.Rectオブジェクトを持たせてあります。今回はお絵かきがめんどくさかったので、全ての物体は長方形で構成されます。お許しを。
また、get_state()で現在の鳥の状態(今の位置と加速度)を取得可能にしています。これを強化学習の際に使用します。
土管クラスの作成
次は土管です。
class Pipe:
def __init__(self, x=700):
self.top = random.randrange(PIPE_MARGIN, WIN_HEIGHT - PIPE_GAP - PIPE_MARGIN - GROUND_HEIGHT)
self.bottom = self.top + PIPE_GAP
self.top_rect = pygame.Rect(x, 0, PIPE_WIDTH, self.top)
self.bottom_rect = pygame.Rect(x, self.bottom, PIPE_WIDTH, WIN_HEIGHT - self.bottom)
def move(self):
self.top_rect = self.top_rect.move(-PIPE_VEL, 0)
self.bottom_rect = self.bottom_rect.move(-PIPE_VEL, 0)
if self.top_rect.right < 0:
self.__init__()
return self.top_rect.left == 200
def draw(self, win):
pygame.draw.rect(win, COLORS['pipe'], self.top_rect)
pygame.draw.rect(win, COLORS['pipe'], self.bottom_rect)
常に二本の土管をいい感じの間隔でゲーム内に置いておき、左端まで進んだ土管を右端にテレポートさせることでずっと土管が来るようにしました。
地面クラスの作成
次の地面ですが、これに関して語ることはありません。
class Ground:
def __init__(self):
self.rect = pygame.Rect(0, WIN_HEIGHT - GROUND_HEIGHT, WIN_WIDTH, GROUND_HEIGHT)
def draw(self, win):
pygame.draw.rect(win, COLORS['ground'], self.rect)
ゲームクラスの作成
最後にゲーム本体のクラスです。
class FlappyBird:
def __init__(self, n_bird=1):
pygame.init()
self.win = pygame.display.set_mode((WIN_WIDTH, WIN_HEIGHT))
pygame.display.set_caption('Flappy Bird')
self.n_bird = n_bird
self.birds = [Bird() for _ in range(self.n_bird)]
self.pipes = [Pipe(800), Pipe(1200)]
self.ground = Ground()
self.score = 0
def reset(self):
self.__init__(self.n_bird)
def draw(self):
self.win.fill(COLORS['sky'])
for bird in self.birds:
bird.draw(self.win)
for pipe in self.pipes:
pipe.draw(self.win)
self.ground.draw(self.win)
pygame.display.update()
def check_collide(self, bird):
if bird.y <= -BIRD_SIZE:
return True
for pipe in self.pipes:
if pipe.top_rect.colliderect(bird.rect) or pipe.bottom_rect.colliderect(bird.rect):
return True
if self.ground.rect.colliderect(bird.rect):
return True
return False
def step(self, actions):
passed = False
for pipe in self.pipes:
if pipe.move():
self.score += 1
passed = True
next_birds = []
states = []
rewards = []
for action, bird in zip(actions, self.birds):
if action.argmax():
bird.jump()
last_y = bird.y
bird.move()
pipe_idx = 0
while bird.rect.x > self.pipes[pipe_idx].top_rect.left:
pipe_idx += 1
rewards.append(
1 if abs(
self.pipes[pipe_idx].top_rect.bottom + PIPE_GAP / 2 - last_y
) < abs(
self.pipes[pipe_idx].top_rect.bottom + PIPE_GAP / 2 - bird.y
) else 0
)
pipe_state = [
self.pipes[pipe_idx].top_rect.bottom - bird.y,
self.pipes[pipe_idx].top_rect.left - 200
]
states.append(bird.get_state() + pipe_state)
finished = self.check_collide(bird)
if not finished:
next_birds.append(bird)
self.birds = next_birds
return array(states), array(rewards), finished
def random_step(self):
for pipe in self.pipes:
pipe.move()
state = []
for bird in self.birds:
pipe_state = [
self.pipes[0].top_rect.bottom - bird.y,
self.pipes[0].top_rect.left - 200
]
state.append(bird.get_state() + pipe_state)
return array(state), array([0 for _ in range(self.n_bird)]), False
環境を初期化するreset()、鳥が死んだかどうかを判定するcheck_collide()、次フレームに遷移するstep()、random_step()を実装しました。
step()は今の状況、それに対する報酬、ゲームが終了したかどうかを返します。
DQNで挑戦
さて、前の記事同様、Deep Q Learningという手法を使って学習をしてみます。細かい解説はしないので、いろいろ調べてみてください。
エージェント、メモリ、モデルは前の記事で実装したものをそのまま使います。
import pygame
from flappy_bird import FlappyBird
from model import Model
from memory import Memory
from agent import Agent
def evaluate(env, agent):
env.reset()
state, _, finished = env.random_step()
while not finished:
action = agent.get_action(state, N_EPOCHS, main_model)
next_state, _, finished = env.step(action.argmax(), verbose=True)
state = next_state
def main():
clock = pygame.time.Clock()
N_EPOCHS = 1000
GAMMA = 0.99
N_BIRD = 64
S_BATCH = 256
env = FlappyBird(N_BIRD)
main_model = Model()
target_model = Model()
memory = Memory()
agent = Agent()
for epoch in range(1, N_EPOCHS + 1):
print('Epoch: {}'.format(epoch))
env.reset()
states, rewards, finished = env.random_step()
target_model.model.set_weights(main_model.model.get_weights())
running = True
while running:
clock.tick(60)
actions = []
for state in states:
actions.append(agent.get_action(state, epoch, main_model))
next_states, rewards, finished = env.step(actions)
for state, reward, action, next_state in zip(states, rewards, actions, next_states):
memory.add((state, action, reward, next_state))
states = next_states
if len(memory.buffer) % S_BATCH == 0:
main_model.replay(memory, env.n_bird, GAMMA, target_model)
target_model.model.set_weights(main_model.model.get_weights())
if not len(env.birds):
running = False
break
env.draw()
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
print('\tScore: {}'.format(env.score))
pygame.quit()
if __name__ == '__main__':
main()
さて、早速学習させてみましょう。
DQN使ったやつ
— りーぜんと (@50m_regent) August 10, 2020
全然学習が進まない pic.twitter.com/weuOJx0DmZ
動画を見たらわかるように、全然成長しません。FizzBuzzのときはうまくいったのに、、、
ゲームの内容が複雑になったからかな?
違うアルゴリズムに挑戦してみます。
NEATで挑戦
NEATとは
DQNでは上手くいかなかったので、NEATというアルゴリズムを使ってみます。これは遺伝的アルゴリズムといわれるものです。
簡単に説明をしてみます。
まず、DQNではエポックという単位で学習を進めましたが、NEATでは世代という考え方をします。これは人間でいう世代と同じものだと思ってください。
一世代に100羽の鳥がいるとします。各鳥は自分のニューラルネットワークを持っていて、それをもとに動きます。
一世代が全員死ぬまでゲームを動かすと、その世代の1位から100位まで順位をつけることができます。
この100羽の中で優秀な鳥から次の世代の100羽を生み出します。生み出すとは、各鳥が持ってるニューラルネットワークを少しずつ改変して新しいニューラルネットワークにするという意味です。
Configの設定
今回はneat-pythonというライブラリを使って学習させます。そのためにはconfigファイルを作成して設定をしておかないといけません。今回は公式サイトにあったやつを少しだけいじって使います。
[NEAT]
fitness_criterion = max
fitness_threshold = 500
pop_size = 50
reset_on_extinction = False
[DefaultGenome]
activation_default = sigmoid
activation_mutate_rate = 0.0
activation_options = sigmoid
aggregation_default = sum
aggregation_mutate_rate = 0.0
aggregation_options = sum
bias_init_mean = 0.0
bias_init_stdev = 1.0
bias_max_value = 100
bias_min_value = -100
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient = 0.5
conn_add_prob = 0.5
conn_delete_prob = 0.5
enabled_default = True
enabled_mutate_rate = 0.01
feed_forward = True
initial_connection = full
node_add_prob = 0.2
node_delete_prob = 0.2
num_inputs = 4
num_hidden = 0
num_outputs = 1
response_init_mean = 1.0
response_init_stdev = 0.0
response_max_value = 100
response_min_value = -100
response_mutate_power = 0.0
response_mutate_rate = 0.0
response_replace_rate = 0.0
weight_init_mean = 0.0
weight_init_stdev = 1.0
weight_max_value = 100
weight_min_value = -100
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1
[DefaultSpeciesSet]
compatibility_threshold = 3.0
[DefaultStagnation]
species_fitness_func = max
max_stagnation = 20
species_elitism = 2
[DefaultReproduction]
elitism = 2
survival_threshold = 0.2
学習させてみる
一世代の鳥を50羽として学習させてみます。
import neat
import pygame
from flappy_bird import FlappyBird
CFG_PATH = 'neat_config.txt'
NUM_BIRD = 50
env = FlappyBird(NUM_BIRD)
def gen(genomes, config):
clock = pygame.time.Clock()
env.reset()
nets = []
ge = []
for _, g in genomes:
nets.append(neat.nn.FeedForwardNetwork.create(g, config))
g.fitness = 0
ge.append(g)
while len(env.birds) > 0:
clock.tick(60)
for pipe in env.pipes:
if pipe.move():
env.score += 1
for g in ge:
g.fitness += 3
pipe_idx = 0
while env.birds[0].rect.x > env.pipes[pipe_idx].top_rect.left:
pipe_idx += 1
for i, bird in enumerate(env.birds):
bird_state = bird.get_state()
output = nets[i].activate((
bird_state[0],
bird_state[1],
env.pipes[pipe_idx].top_rect.bottom - bird.rect.y,
env.pipes[pipe_idx].top_rect.left - bird.rect.x))
if output[0] > 0.5:
bird.jump()
bird.move()
ge[i].fitness += 0.1
if env.check_collide(bird):
ge[i].fitness -= 1
env.birds.pop(i)
nets.pop(i)
ge.pop(i)
env.draw()
def train():
config = neat.config.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
CFG_PATH)
p = neat.Population(config)
p.add_reporter(neat.StdOutReporter(True))
p.add_reporter(neat.StatisticsReporter())
winner = p.run(gen, NUM_BIRD)
pygame.quit()
if __name__ == '__main__':
train()
学習の様子です。
Flappy Birdの強化学習できた!いつか記事書きたい pic.twitter.com/6G9XsPOvQ1
— りーぜんと (@50m_regent) August 10, 2020
DQNのときよりもちゃんと成長してるのが分かりますね!
結果
さて、今回はAIにFlappy Birdを学習させてみました。ビジュアライズも可能にすると学習の様子がみれてとても可愛いですね。
数時間学習をさせたら50点に到達しました。目標の10点を軽々達成しました。
よければTwitterフォローしてください。じゃあね。