LoginSignup
0
0

Google ColabでのDQN(Deep Q-learning Network)

Last updated at Posted at 2023-12-02

この動画で使っているプログラムを、こちらの記事に転記しています。

Google ColabとWindowsで実行できるプログラムを作成しています。
この記事ではGoogle Colabで実行できるプログラムを記載しています。
WIndowsで実行できるプログラムは、↓↓↓この記事に転記しています。

必要なライブラリを導入

Google Colab上でCartPoleを画面描画するために、pip installします。
!pip install pyvirtualdisplay

動作確認

CartPoleが画面描画されるかを確認できます。

import numpy as np
import gym
import matplotlib.pyplot as plt
from IPython import display
from pyvirtualdisplay import Display

env = gym.make('CartPole-v1')
state = env.reset()
done = False

# display出力用処理
img = plt.imshow(env.render('rgb_array'))

while not done:
    # display出力用処理
    display.clear_output(wait=True)
    img.set_data(env.render('rgb_array'))
    plt.axis('off')
    display.display(plt.gcf())

    # 行動をランダムで選択
    action = np.random.choice([0, 1])

    # 行動後の状態を変数に代入
    next_state, reward, done, info = env.step(action)
env.close()

image.png

OpenAI GYM(CartPole)

CartPoleを強化学習するためのプログラムです。
このまま実行すると1.5時間程度かかりますが、「# ここをコメントアウトすると学習はすぐ終わる」と記載している行をコメントアウトすると、2分程度で終わります。

import copy
from collections import deque
import random
import numpy as np
import gym
import matplotlib.pyplot as plt
from IPython import display
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F

# GPUが利用可能か確認し、利用可能なら使用する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class ReplayBuffer:
    def __init__(self, buffer_size, batch_size):
        self.buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size

    def add(self, state, action, reward, next_state, done):
        data = (state, action, reward, next_state, done)
        self.buffer.append(data)

    def __len__(self):
        return len(self.buffer)

    def get_batch(self):
        data = random.sample(self.buffer, self.batch_size)
        state, action, reward, next_state, done = zip(*data)
        return (torch.tensor(np.array(state), dtype=torch.float32).to(device),
                torch.tensor(action, dtype=torch.long).to(device),
                torch.tensor(reward, dtype=torch.float32).to(device),
                torch.tensor(np.array(next_state), dtype=torch.float32).to(device),
                torch.tensor(done, dtype=torch.float32).to(device))

class QNet(nn.Module):
    def __init__(self, action_size):
        # super(QNet, self).__init__()
        super().__init__()
        self.fc1 = nn.Linear(4, 128)  # CartPoleの状態は4次元
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DQNAgent:
    def __init__(self):
        self.gamma = 0.98
        self.lr = 0.0005
        self.epsilon = 0.1
        self.buffer_size = 10000
        self.batch_size = 32
        self.action_size = 2

        self.replay_buffer = ReplayBuffer(self.buffer_size, self.batch_size)
        self.qnet = QNet(self.action_size).to(device)
        self.qnet_target = QNet(self.action_size).to(device)
        self.optimizer = optimizers.Adam(self.qnet.parameters(), self.lr)

    def sync_qnet(self):
        # self.qnet_target = copy.deepcopy(self.qnet)
        self.qnet_target.load_state_dict(self.qnet.state_dict())

    def get_action(self, state):
        state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device)  # NumPy配列をTensorに変換
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.action_size)
        else:
            qs = self.qnet(state_tensor)
            return qs.argmax().item()  # .item() を追加してPythonの数値に変換

    def update(self, state, action, reward, next_state, done):
        self.replay_buffer.add(state, action, reward, next_state, done)
        if len(self.replay_buffer) < self.batch_size:
            return

        state, action, reward, next_state, done = self.replay_buffer.get_batch()
        state = state.to(device)
        action = action.to(device)
        reward = reward.to(device)
        next_state = next_state.to(device)
        done = done.to(device)

        # ネットワークの出力と損失の計算
        qs = self.qnet(state)
        q = qs.gather(1, action.unsqueeze(1)).squeeze(1)

        next_qs = self.qnet_target(next_state)
        next_q = next_qs.max(1)[0]
        target = reward + (1 - done) * self.gamma * next_q
        loss = F.mse_loss(q, target)

        # バックプロパゲーション
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss

episodes = 300
sync_interval = 20

env = gym.make('CartPole-v1')
agent = DQNAgent()
reward_history = []

for episode in range(episodes):
    print('episode:', episode)
    state = env.reset()
    done = False
    total_reward = 0

    # display出力用処理
    img = plt.imshow(env.render('rgb_array')) # ここをコメントアウトすると学習はすぐ終わる
    plt.axis('off') # ここをコメントアウトすると学習はすぐ終わる

    while not done:
        # 行動をランダムで選択
        action = agent.get_action(state)

        # 画像データの更新
        img.set_data(env.render('rgb_array')) # ここをコメントアウトすると学習はすぐ終わる
        display.display(plt.gcf()) # ここをコメントアウトすると学習はすぐ終わる
        display.clear_output(wait=True) # ここをコメントアウトすると学習はすぐ終わる

        # 行動後の状態を変数に代入
        next_state, reward, done, info = env.step(action)

        agent.update(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward

    if episode % sync_interval == 0:
        agent.sync_qnet()

    reward_history.append(total_reward)
    print('Total Reward:', total_reward)

env.close()

PytorchによるDQN

8パズルを強化学習するためのプログラムです。

import copy
from collections import deque
import random
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F

# GPUが利用可能か確認し、利用可能なら使用する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class ReplayBuffer:
    def __init__(self, buffer_size, batch_size):
        self.buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size

    def add(self, state, action, reward, next_state, done):
        data = (state, action, reward, next_state, done)
        self.buffer.append(data)

    def __len__(self):
        return len(self.buffer)

    def get_batch(self):
        data = random.sample(self.buffer, self.batch_size)
        state, action, reward, next_state, done = zip(*data)
        return (torch.tensor(np.array(state), dtype=torch.float32).to(device),
                torch.tensor(action, dtype=torch.long).to(device),
                torch.tensor(reward, dtype=torch.float32).to(device),
                torch.tensor(np.array(next_state), dtype=torch.float32).to(device),
                torch.tensor(done, dtype=torch.float32).to(device))

class QNet(nn.Module):
    def __init__(self, action_size):
        super().__init__()
        self.fc1 = nn.Linear(9, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DQNAgent:
    def __init__(self):
        self.gamma = 0.98
        self.lr = 0.0005
        self.epsilon = 0.2 # 0.2にすると10回に2回の確率でランダムに実行する。
        self.buffer_size = 10000
        self.batch_size = 32
        self.action_size = 4 # 'up', 'down', 'left', 'right'

        self.replay_buffer = ReplayBuffer(self.buffer_size, self.batch_size)
        self.qnet = QNet(self.action_size).to(device)
        self.qnet_target = QNet(self.action_size).to(device)
        self.optimizer = optimizers.Adam(self.qnet.parameters(), self.lr)

    def sync_qnet(self):
        self.qnet_target.load_state_dict(self.qnet.state_dict())

    def get_action(self, state):
        state_tensor = torch.from_numpy(state.flatten()).float().unsqueeze(0).to(device)
        print('epsilon:', self.epsilon)
        if np.random.rand() < self.epsilon:
            print('random mode')
            return np.random.choice(self.action_size)
        else:
            print('qnet mode')
            qs = self.qnet(state_tensor)
            return qs.argmax().item()  # .item() を追加してPythonの数値に変換

    def update(self, state, action, reward, next_state, done):
        self.replay_buffer.add(state, action, reward, next_state, done)
        if len(self.replay_buffer) < self.batch_size:
            return

        state, action, reward, next_state, done = self.replay_buffer.get_batch()
        state = state.view(self.batch_size, -1).to(device) # バッチ処理のために、状態をフラット化
        action = action.to(device)
        reward = reward.to(device)
        next_state = next_state.view(self.batch_size, -1).to(device) # バッチ処理のために、状態をフラット化
        done = done.to(device)

        # ネットワークの出力と損失の計算
        qs = self.qnet(state)
        q = qs.gather(1, action.unsqueeze(1)).squeeze(1)

        next_qs = self.qnet_target(next_state)
        next_q = next_qs.max(1)[0]
        target = reward + (1 - done) * self.gamma * next_q
        loss = F.mse_loss(q, target)

        # バックプロパゲーション
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

class TileSlidePuzzleEnv:
    def __init__(self, size=3):
        self.size = size
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.reset()

    def reset(self):
        nums = np.arange(1, self.size**2)
        np.random.shuffle(nums)
        self.board = np.insert(nums, 0, 0).reshape(self.size, self.size)
        return self.board.flatten()

    def step(self, action):
        # タイルをスライドさせるロジックを実装
        # actionはタイルを移動する方向(例: 0=上, 1=下, 2=左, 3=右)
        print('action:', action)
        moved = self.slide_tile(action)
        self.display()

        # 次の状態(フラット化されたボード)
        next_state = self.board.flatten()

        # 報酬の計算(例:正しい位置にあるタイルの数)
        reward = self.calculate_reward()
        print('reward:', reward)

        # ゲーム終了条件のチェック
        done = self.is_solved()

        # 追加情報(空の辞書を返す)
        info = {}

        # next_state, reward, done, infoを返す
        return next_state, reward, done, info

    def slide_tile(self, direction):
        # 空白タイルの位置を見つける
        x, y = np.where(self.board == 0)
        x, y = int(x), int(y)

        if direction == 0 and x < self.size - 1:
            self.board[x, y], self.board[x+1, y] = self.board[x+1, y], self.board[x, y]
        elif direction == 1 and x > 0:
            self.board[x, y], self.board[x-1, y] = self.board[x-1, y], self.board[x, y]
        elif direction == 2 and y < self.size - 1:
            self.board[x, y], self.board[x, y+1] = self.board[x, y+1], self.board[x, y]
        elif direction == 3 and y > 0:
            self.board[x, y], self.board[x, y-1] = self.board[x, y-1], self.board[x, y]

    def calculate_reward(self):
        # 報酬の計算ロジック
        # 例:正しい位置にあるタイルの数を報酬とする
        correct_tiles = np.sum(self.board == np.arange(self.size**2).reshape(self.size, self.size))
        return correct_tiles

    def is_solved(self):
        # パズルが解かれたかどうかをチェック
        return np.array_equal(self.board, np.arange(self.size**2).reshape(self.size, self.size))

    def display(self):
        # パズルの現在の状態を表示
        print(self.board)

    def render(self):
        print(self.board)

episodes = 5000
sync_interval = 20
limit_slide_count = 200

# タイルスライドパズル環境のインスタンス化
env = TileSlidePuzzleEnv()
agent = DQNAgent()
reward_history = []

for episode in range(episodes):
    print('----- ----- episode:', episode, ' ----- -----')
    state = env.reset()
    done = False
    total_reward = 0
    slide_count = 0

    while not done:
        slide_count += 1
        print('----- ----- episode:', episode, ' ----- -----')
        print('slide _count:', slide_count)

        # 行動をランダムで選択
        action = agent.get_action(state)

        # 行動後の状態を変数に代入
        next_state, reward, done, info = env.step(action)

        agent.update(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        if done:
            total_reward = 10000
        print('total_reward:', total_reward)

        if slide_count >= limit_slide_count:
            done = True

    if episode % sync_interval == 0:
        agent.sync_qnet()

    reward_history.append(total_reward)
    print('Total Reward:', total_reward)

print(reward_history)

# エピソード番号のリストを生成
episodes = list(range(len(reward_history)))

# グラフの作成
plt.figure(figsize=(10, 6))
plt.plot(episodes, reward_history, marker='o')
plt.title('Episode vs Total Reward')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.show()
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0