0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

強化学習(PolicyGradient)で棒を立てておく

Last updated at Posted at 2022-02-17

概要

強化学習の環境のド基礎CartPoleをやっていきます.
強化学習のひな型コードのつもりなので,出来るだけシンプルに実装していきます.

準備

pip install gym

環境

CartPole

状態:4次元の連続値
行動:0 or 1

コード

  • インポート類
import gym
import math
import random
from copy import deepcopy
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • 設定
STATE_DIM = 4
ACTIVE_DIM = 2
EPOCH_NUM = 10000000
SAMPLE_NUM = 100  #100エピソードで1回パラメータ更新
ACT_NUM = 10000  #1エピソードの限界行動数
  • モデル(入力:4次元連続値ベクトル 出力:2次元のlogit)
class MLP(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(input_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,output_size))

    def forward(self,x):
        return self.neural(x.float())
  • エージェント
class Agent(nn.Module):
    def __init__(self,device):
        super(Agent,self).__init__()
        self.device = device
        self.action_space = [action for action in F.one_hot(torch.tensor(range(ACTIVE_DIM)),num_classes=ACTIVE_DIM).to(self.device)]
        self.neural = MLP(input_size=STATE_DIM,hidden_size=100,output_size=ACTIVE_DIM)

    def act(self,state,p=0.7): #1state[4] -> 1action(one-hot)[2], 1action(0or1)
        if random.random()<p:
            with torch.no_grad():
                act_prob = F.softmax(self.neural(state.unsqueeze(0)).squeeze())
            act_index = torch.multinomial(act_prob,num_samples=1).item()
        else:
            act_index = random.choice(range(ACTIVE_DIM))
        return self.action_space[act_index],act_index

    def determine_act(self,state): #1state[4] -> 1action(one-hot)[2], 1action(0or1)
        with torch.no_grad():
            act_prob = F.softmax(self.neural(state.unsqueeze(0)).squeeze())
        act_index = torch.argmax(act_prob).item()
        return self.action_space[act_index],act_index

    def forward(self,states): #batch_states[batch,4] -> batch_action_logit[batch,2]
        act_probs = F.softmax(self.neural(states),dim=-1)
        return act_probs
  • 損失関数
class PGLoss(nn.Module):
    def __init__(self):
        super(PGLoss,self).__init__()

    def forward(self,acts,preds,values):
        loss = -(torch.log(preds)*acts).sum(-1)*values
        return loss.mean()
  • 報酬と割引率から価値計算(...+ γ^2R_t+2 + γR_t+1 + R_t)
def reward_to_value(episode,gamma_rate=0.9):
    act_value = []
    for i,s in enumerate(episode):
        gamma=1.0
        value = 0.0
        for j in range(1,len(episode)-i+1):
            value += episode[i+j-1][2]*gamma
            gamma *= gamma_rate
        act_value.append((s[0],s[1],value))
    return act_value
  • 訓練ループ
agent = Agent(device).to(device)
env = gym.make("CartPole-v0") 
pgloss = PGLoss().to(device)
optimizer = optim.SGD(agent.parameters(),lr=0.0001)

for epoch in range(EPOCH_NUM):
    samples = []
    for j in tqdm(range(SAMPLE_NUM)):
        state = env.reset()
        add_reward = 0
        episode = []
        for i in (range(ACT_NUM)):
            action_onehot,action_index = agent.act(torch.tensor(state).to(device),p=0.95)
            next_state, reward, done, info = env.step(action_index)
            add_reward += reward #毎行動の報酬を足し合わせていく
            episode.append((deepcopy(action_onehot),deepcopy(state),deepcopy(add_reward)))
            state = deepcopy(next_state)
            if done: break
        samples += reward_to_value(episode)

    # パラメータ更新
    dataloader = DataLoader(samples,batch_size=200,shuffle=True)
    for batch in dataloader:
        actions,states,values = batch
        actions = torch.tensor(actions).to(device)
        states = torch.tensor(states).to(device)
        values = torch.tensor(values).to(device)
        optimizer.zero_grad()
        act_probs = agent(states)
        loss = pgloss(actions,act_probs,values)
        loss.backward()
        optimizer.step()
    torch.save(agent.state_dict(),"./model.bin")
  • テスト
env = gym.make("CartPole-v0") 
agent = Agent(device).to(device)
agent.load_state_dict(torch.load("./model.bin"))

state = env.reset()
for i in (range(ACT_NUM)):
    env.render()
    action_onehot,action_index = agent.determine_act(torch.tensor(state).to(device))
    next_state, reward, done, info = env.step(action_index)
    state = deepcopy(next_state)
    if done: break

結果

学習前

ezgif-4-e6b744bbf6.gif

学習後

ezgif-4-5a2bb7e9a8.gif

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?