9
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

三目並べで強化学習を練習してみた

9
Posted at

はじめに

今回は、「ChatGPTにハンズオンを作らせてみた」の第8弾で、三目並べを使って強化学習(Q学習)を勉強しました。

第7弾はこちら↓

使用コード・結果

import numpy as np
import random

class TicTacToeEnv:
    def __init__(self, seed=42):
        self.board = np.zeros((3, 3), dtype=int)  # 0: 空白, 1: X, -1: O
        self.seed = seed
        self.reset()
    
    def set_seed(self):
        """ 乱数のシードを固定 """
        random.seed(self.seed)
        np.random.seed(self.seed)
    
    def reset(self):
        """ 初期状態のボードをセット """
        self.set_seed()
        self.board.fill(0)
        self.current_player = 1  # 先手は X (1)
        return self.board.copy()
    
    def get_valid_moves(self):
        """ 有効な手のリストを取得 """
        return [(x, y) for x in range(3) for y in range(3) if self.board[x, y] == 0]

    def step(self, action):
        """ 手を打つ """
        x, y = action
        if self.board[x, y] != 0:
            return self.board.copy(), -10, True  # 無効な手
        
        self.board[x, y] = self.current_player
        done, winner = self.check_winner()
        
        reward = 1 if winner == self.current_player else 0  # 勝ったら報酬
        if done and winner == 0:
            reward = 0.5  # 引き分けの報酬

        self.current_player *= -1  # 手番交代
        return self.board.copy(), reward, done

    def check_winner(self):
        """ 勝敗判定 """
        for i in range(3):
            if abs(sum(self.board[i, :])) == 3:  # 横ライン
                return True, np.sign(sum(self.board[i, :]))
            if abs(sum(self.board[:, i])) == 3:  # 縦ライン
                return True, np.sign(sum(self.board[:, i]))

        # 斜めチェック
        if abs(self.board[0, 0] + self.board[1, 1] + self.board[2, 2]) == 3:
            return True, np.sign(self.board[0, 0])
        if abs(self.board[0, 2] + self.board[1, 1] + self.board[2, 0]) == 3:
            return True, np.sign(self.board[0, 2])

        # 盤面が埋まったら引き分け
        if not self.get_valid_moves():
            return True, 0

        return False, None

    def render(self):
        """ 盤面の表示 """
        symbols = {0: ".", 1: "X", -1: "O"}
        for row in self.board:
            print(" ".join(symbols[cell] for cell in row))
        print()
import matplotlib.pyplot as plt
import japanize_matplotlib

class QLearningAgent:
    def __init__(self, alpha=0.1, gamma=0.9, epsilon_max=1.0, epsilon_min=0.05, decay_rate=0.01, seed=42):
        self.q_table = {}  # Q値を保存する辞書
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon_max = epsilon_max
        self.epsilon_min = epsilon_min
        self.decay_rate = decay_rate
        self.epsilon = epsilon_max
        self.seed = seed
        self.set_seed()

    def set_seed(self):
        """ 乱数のシードを固定 """
        random.seed(self.seed)
        np.random.seed(self.seed)

    def get_q_value(self, state, action):
        """ 状態と行動の Q 値を取得 """
        state_key = tuple(state.flatten())
        return self.q_table.get((state_key, action), 0)

    def update_q_value(self, state, action, reward, next_state):
        """ Q 値の更新 """
        state_key = tuple(state.flatten())

        # 次の状態での最大Q値を取得
        future_rewards = [self.get_q_value(next_state, a) for a in env.get_valid_moves()]
        best_next_q = max(future_rewards, default=0)

        # Q値の更新
        old_value = self.get_q_value(state, action)
        new_value = old_value + self.alpha * (reward + self.gamma * best_next_q - old_value)
        self.q_table[(state_key, action)] = new_value

    def choose_action(self, state, valid_moves):
        """ ε-greedy 方策に基づく行動選択 """
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(valid_moves)
        q_values = {move: self.get_q_value(state, move) for move in valid_moves}
        return max(q_values, key=q_values.get)

    def decay_epsilon(self, episode):
        """ エピソードごとに探索率 ε を減少 """
        self.epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) * np.exp(-self.decay_rate * episode)

# 学習の実行
agent = QLearningAgent(seed=42)
env = TicTacToeEnv(seed=42)
num_episodes = 1000

reward_history = []
epsilon_history = []

for episode in range(1, num_episodes + 1):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        valid_moves = env.get_valid_moves()
        if not valid_moves:
            break  # すべて埋まったら終了

        action = agent.choose_action(state, valid_moves)
        next_state, reward, done = env.step(action)
        agent.update_q_value(state, action, reward, next_state)

        state = next_state
        total_reward += reward

    reward_history.append(total_reward)
    epsilon_history.append(agent.epsilon)

    agent.decay_epsilon(episode)

    if episode % 100 == 0:
        avg_reward = np.mean(reward_history[-100:])
        print(f"エピソード {episode}/{num_episodes} - 平均報酬: {avg_reward:.2f}, ε: {agent.epsilon:.3f}")

print("学習完了")

# 学習曲線の可視化
plt.figure(figsize=(12, 5))

# 報酬の推移
plt.subplot(1, 2, 1)
plt.plot(reward_history)
plt.xlabel("エピソード")
plt.ylabel("累積報酬")
plt.title("三目並べ 強化学習の進捗")

# εの減衰推移
plt.subplot(1, 2, 2)
plt.plot(epsilon_history)
plt.xlabel("エピソード")
plt.ylabel("ε(探索率)")
plt.title("探索率 (ε) の推移")

plt.show()

image.png

def play_against_agent(agent, env):
    """学習済みのエージェントとプレイする(先攻・後攻を選択)"""
    state = env.reset()
    
    # 先攻・後攻の選択
    while True:
        choice = input("先攻(X)でプレイしますか? (yes/no): ").strip().lower()
        if choice in ["yes", "y"]:
            human_player = 1  # 人間が X (先攻)
            agent_player = -1  # エージェントが O (後攻)
            break
        elif choice in ["no", "n"]:
            human_player = -1  # 人間が O (後攻)
            agent_player = 1  # エージェントが X (先攻)
            break
        else:
            print("無効な入力です。'yes' または 'no' を入力してください。")

    env.render()
    print(f"あなたは {'X' if human_player == 1 else 'O'} です。")
    print("行と列をスペース区切りで入力してください(例: 0 2)\n")

    done = False
    while not done:
        if env.current_player == human_player:  # 人間のターン
            valid_moves = env.get_valid_moves()
            if not valid_moves:
                print("有効な手がないため、エージェントの番です。")
                env.current_player *= -1
                continue
            
            while True:
                try:
                    move = input("あなたの手(行 列): ")
                    x, y = map(int, move.split())
                    if (x, y) in valid_moves:
                        break
                    else:
                        print("無効な手です。もう一度入力してください。")
                except ValueError:
                    print("正しいフォーマットで入力してください(例: 0 2)")

            state, _, done = env.step((x, y))

        else:  # エージェントのターン
            valid_moves = env.get_valid_moves()
            if not valid_moves:
                print("エージェントはパスしました。")
                env.current_player *= -1
                continue

            action = agent.choose_action(state, valid_moves)
            print(f"エージェントの手: {action}")
            state, _, done = env.step(action)

        env.render()
        game_over, winner = env.check_winner()
        if game_over:
            if winner == human_player:
                print("あなたの勝ち! 🎉")
            elif winner == agent_player:
                print("エージェントの勝ち! 🤖")
            else:
                print("引き分け! 😐")
            done = True

# 学習済みのエージェントとプレイ
play_against_agent(agent, env)

おわりに

長くなったので出力は省略しましたが、さすがに人間と勝負すると弱かったです。もっと強くできるはずなので、いろいろ試してみようと思います!

9
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?