イントロ
最近、マルチエージェント強化学習アルゴリズムの一つである 中央集権型PPO学習 (Centralized PPO Learning) のソースコードを読んでいて、やや難しい部分をだんだん理解してきました。
今回の記事は、自分のノートを兼ねて、実際のソースコードのコアーな部分を抽出して、説明してみたいです。
目次
-
Part 1:学習プロセスの全体像
-
train
関数を通して、PPOエージェントの学習プロセス全体を概観 - Actor と Critic の役割、データフロー、損失関数の計算など、主要な要素をコードと共に確認
-
-
Part 2:GAE (Generalized Advantage Estimation) の詳細
-
_compute_returns_advs
関数に注目し、GAE の計算方法をステップごとに解説 - 割引率 (
gamma
) や GAEラムダパラメータ (tau
) が Advantage 推定に与える影響を理解
-
-
Part 3:Actor の更新
- 方策関数(Actor)の更新部分を詳しく見ていく
- PPO の核心であるクリッピングされた目的関数、KLダイバージェンス、エントロピーボーナスなどの役割を解説
-
Part 4:Critic の更新
- 価値関数(Critic)の更新部分を解説
- 価値正規化 (Value Normalization) の有無による処理の違い、損失関数の計算方法などを確認
解説内容
Part 1:学習プロセスの全体像 (train
関数)
def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
# ... (省略:RMSの更新やマスクの調整など)
old_action_logits = []
self.mac.init_hidden(batch.batch_size)
for t in range(max_t):
actor_outs = self.mac.forward(batch, t=t, test_mode=False)
old_action_logits.append(actor_outs)
old_action_logits = torch.stack(old_action_logits, dim=1)
old_values_before = self.critic(batch).squeeze(dim=-1).detach()
if self.is_value_normalized or self.is_popart:
old_values_before = self.denormalize_value(old_values_before)
rewards = rewards.squeeze(dim=-1)
returns, advantages = self._compute_returns_advs(
old_values_before, rewards, terminated, self.args.gamma, self.args.tau
)
if getattr(self.args, "is_advantage_normalized", False):
# ... (省略:Advantage 正規化)
old_action_logits = old_action_logits.detach()
old_meta_data = compute_logp_entropy(old_action_logits, actions, avail_actions)
old_log_pac = old_meta_data["logp"]
central_old_log_pac = torch.sum(old_log_pac * alive_mask, dim=-1)
# ... (省略:Actor と Critic の更新処理 - Part 3, Part 4 で解説)
コード解説
1. old_action_logits
の計算 (Actor)
-
self.mac.forward(batch, t=t, test_mode=False)
:マルチエージェントコントローラ (mac
) を通して、各エージェントの行動 logits (行動選択前の値) を計算 -
old_action_logits.append(actor_outs)
:各タイムステップの行動 logits をリストに格納し、torch.stack
でテンソルに変換 -
old_action_logits
は、更新前の古い方策 に基づく行動 logits を保持
2. old_values_before
の計算 (Critic)
-
old_values_before = self.critic(batch).squeeze(dim=-1).detach()
:Critic ネットワーク (self.critic
) にbatch
(状態情報などを含むデータ) を入力し、状態価値を評価します。.squeeze(dim=-1)
で次元を調整し、.detach()
で計算グラフから切り離す -
if self.is_value_normalized or self.is_popart:
:価値正規化 (Value Normalization) または PopArt が使用されている場合、正規化された価値を元のスケールに戻す Denormalization を行う
3. GAE による returns
と advantages
の計算
-
returns, advantages = self._compute_returns_advs(...)
:_compute_returns_advs
関数を呼び出し、収益 (returns
) と アドバンテージ (advantages
) を計算 - GAE は、割引率 (
gamma
) と GAEラムダパラメータ (tau
) によって制御され、時間的割引を考慮した、より安定的な Advantage 推定を実現
4. Advantage 正規化 (オプション)
-
if getattr(self.args, "is_advantage_normalized", False):
:Advantage 正規化が有効になっている場合、Advantage を正規化
5. old_log_pac
と central_old_log_pac
の計算 (方策確率)
-
old_action_logits = old_action_logits.detach()
:古い行動 logits を計算グラフから切り離す -
old_meta_data = compute_logp_entropy(old_action_logits, actions, avail_actions)
:compute_logp_entropy
関数を用いて、古い方策における行動の対数確率 (logp
) とエントロピーを計算-
compute_logp_entropy
関数は、利用可能な行動 (avail_actions
) を考慮し、利用不可能な行動をマスク する処理を含む
-
-
old_log_pac = old_meta_data["logp"]
:各エージェントの行動の対数確率を取得 -
central_old_log_pac = torch.sum(old_log_pac * alive_mask, dim=-1)
:全てのアクティブなエージェント (alive_mask
) の行動の結合対数確率を計算します。alive_mask
は、各タイムステップで生存しているエージェントを示すマスク
Part 2:GAE の詳細 (_compute_returns_advs
関数)
def _compute_returns_advs(self, _values, _rewards, _terminated, gamma, tau):
returns = torch.zeros_like(_rewards)
advs = torch.zeros_like(_rewards)
lastgaelam = torch.zeros_like(_rewards[:, 0])
ts = _rewards.size(1)
for t in reversed(range(ts)):
nextnonterminal = 1.0 - _terminated[:, t]
nextvalues = _values[:, t + 1]
reward_t = _rewards[:, t]
delta = reward_t + gamma * nextvalues * nextnonterminal - _values[:, t]
advs[:, t] = lastgaelam = delta + gamma * tau * nextnonterminal * lastgaelam
returns = advs + _values[:, :-1]
return returns, advs
コード解説
1. 時間反転ループ
-
for t in reversed(range(ts)):
:タイムステップを逆順に ループします。GAE は、将来の報酬を考慮して現在の Advantage を計算するため、逆順の計算が必要
2. 各タイムステップでの計算
-
nextnonterminal = 1.0 - _terminated[:, t]
:次の状態が終端状態でない (nextnonterminal = 1.0
) か、終端状態 (nextnonterminal = 0.0
) かを示すマスクを作成 -
nextvalues = _values[:, t + 1]
:次の状態の価値 (nextvalues
) を取得 -
reward_t = _rewards[:, t]
:現在のタイムステップの報酬 (reward_t
) を取得 -
delta = reward_t + gamma * nextvalues * nextnonterminal - _values[:, t]
:誤差 (delta
) を計算-
gamma * nextvalues * nextnonterminal
:割引率 (gamma
) を考慮した、次の状態の価値の期待値です。nextnonterminal
を掛けることで、終端状態以降の価値を 0 にする -
_values[:, t]
:現在の状態の価値 -
delta
は、期待される価値の変化を表す
-
-
advs[:, t] = lastgaelam = delta + gamma * tau * nextnonterminal * lastgaelam
:GAE Advantage (advs[:, t]
) を計算-
lastgaelam
は、過去の TD 誤差を割引率 (gamma
) と GAEラムダパラメータ (tau
) で減衰させながら累積 -
tau
は、GAE の時間的視野を調整するパラメータです。tau = 1
の場合は TD(λ) に、tau = 0
の場合は TD(0) に近づく
-
3. 収益 (returns
) の計算
-
returns = advs + _values[:, :-1]
:収益 (returns
) は、Advantage (advs
) に現在の状態価値 (_values[:, :-1]
) を足し合わせる ことで計算される
Part 3:Actor の更新
for _ in range(0, self.mini_epochs_actor):
if self.agent_type == "rnn":
action_logits = []
self.mac.init_hidden(batch.batch_size)
for t in range(max_t):
actor_outs = self.mac.forward(batch, t=t, test_mode=False)
action_logits.append(actor_outs)
action_logits = torch.stack(action_logits, dim=1)
meta_data = compute_logp_entropy(action_logits, actions, avail_actions)
log_pac = meta_data["logp"]
central_log_pac = torch.sum(log_pac * alive_mask, dim=-1)
witorch.torch.no_grad():
approxkl = (0.5 * torch.sum((central_log_pac - central_old_log_pac) ** 2) / alive_mask.sum())
approxkl_lst.append(approxkl)
entropy = (torch.sum(meta_data["entropy"] * alive_mask) / alive_mask.sum())
entropy_lst.append(entropy)
prob_ratio = torch.clamp(torch.exp(central_log_pac - central_old_log_pac), 0.0, 16.0)
pg_loss_unclipped = -advantages * prob_ratio
pg_loss_clipped = -advantages * torch.clamp(
prob_ratio,
1 - self.args.ppo_policy_clip_param,
1 + self.args.ppo_policy_clip_param,
)
pg_loss = (torch.max(pg_loss_unclipped, pg_loss_clipped).sum() / alive_mask.sum())
actor_loss = pg_loss - self.args.entropy_loss_coeff * entropy
actor_loss_lst.append(actor_loss)
self.optimiser_actor.zero_grad()
actor_loss.backward()
self.optimiser_actor.step()
コード解説
1. action_logits
の計算 (新しい方策)
-
if self.agent_type == "rnn":
:エージェントが RNN を使用する場合の処理 -
self.mac.forward(batch, t=t, test_mode=False)
:マルチエージェントコントローラ (mac
) を通して、更新された新しい方策 に基づく行動 logits (action_logits
) を計算
2. log_pac
と central_log_pac
の計算 (新しい方策確率)
-
meta_data = compute_logp_entropy(action_logits, actions, avail_actions)
:compute_logp_entropy
関数を用いて、新しい方策における行動の対数確率 (logp
) とエントロピー を計算 -
central_log_pac = torch.sum(log_pac * alive_mask, dim=-1)
:新しい方策における結合対数確率 (central_log_pac
) を計算
3. 近似 KL ダイバージェンスの計算
-
witorch.torch.no_grad():
:勾配計算を抑制するコンテキスト内で処理を行う -
approxkl = (0.5 * torch.sum((central_log_pac - central_old_log_pac) ** 2) / alive_mask.sum())
:新旧方策の近似 KL ダイバージェンス (approxkl
) を計算
4. エントロピーボーナスの計算
-
entropy = (torch.sum(meta_data["entropy"] * alive_mask) / alive_mask.sum())
:エントロピー (entropy
) を計算 - エントロピーは、方策のランダム性 を表し、探索を促進 するために損失関数に組み込まれる
5. 確率比 (prob_ratio
) の計算
-
prob_ratio = torch.clamp(torch.exp(central_log_pac - central_old_log_pac), 0.0, 16.0)
:新旧方策の確率比 (prob_ratio
) を計算 -
torch.exp(central_log_pac - central_old_log_pac)
:対数確率の差を指数関数で戻すことで、確率比を計算 -
torch.clamp(..., 0.0, 16.0)
:確率比を0.0
から16.0
の範囲にクリップします。これは、数値的な安定性を高めるための処理
6. クリッピングされた Surrogate 目的関数の計算
-
pg_loss_unclipped = -advantages * prob_ratio
:クリッピングなしの方策勾配損失 (pg_loss_unclipped
) を計算 -
pg_loss_clipped = -advantages * torch.clamp(...)
:クリッピングされた方策勾配損失 (pg_loss_clipped
) を計算-
torch.clamp(prob_ratio, 1 - self.args.ppo_policy_clip_param, 1 + self.args.ppo_policy_clip_param)
:確率比 (prob_ratio
) を1 - clip_param
から1 + clip_param
の範囲にクリップします。clip_param
は、方策の更新幅を制限するハイパーパラメータ - PPO の核心: クリップされた目的関数は、方策が大きく改善する場合のみ更新を許可し、過剰な更新を抑制 することで、学習の安定性を高め
-
-
pg_loss = (torch.max(pg_loss_unclipped, pg_loss_clipped).sum() / alive_mask.sum())
:最終的な方策勾配損失 (pg_loss
) は、クリッピングなしの損失とクリッピングされた損失の最大値 を採用
7. Actor 損失 (actor_loss
) の構築
-
actor_loss = pg_loss - self.args.entropy_loss_coeff * entropy
:最終的な Actor 損失 (actor_loss
) は、方策勾配損失 (pg_loss
) からエントロピーボーナス (- self.args.entropy_loss_coeff * entropy
) を引く ことで構築され
Part 4:Critic の更新
critic_loss_lst = []
if self.is_value_normalized or self.is_popart:
returns = self.normalize_value(returns, mask)
for _ in range(0, self.mini_epochs_critic):
new_values = self.critic(batch).squeeze()
new_values = new_values[:, :-1]
vf_loss = torch.sum((new_values - returns) ** 2 * mask) / mask.sum()
self.optimiser_critic.zero_grad()
vf_loss.backward()
self.optimiser_critic.step()
critic_loss_lst.append(vf_loss)
コード解説
1. 価値正規化 (Value Normalization) (オプション)
-
if self.is_value_normalized or self.is_popart:
:価値正規化または PopArt が使用されている場合、収益 (returns
) を正規化
2. new_values
の計算 (新しい Critic)
-
new_values = self.critic(batch).squeeze()
:更新された新しい Critic (self.critic
) にbatch
を入力し、新しい状態価値 (new_values
) を評価 -
new_values = new_values[:, :-1]
:最後のタイムステップの価値を除外します。これは、収益 (returns
) が_values[:, :-1]
を基に計算されているため、形状を合わせるため
3. 価値損失 (vf_loss
) の計算
-
vf_loss = torch.sum((new_values - returns) ** 2 * mask) / mask.sum()
:価値損失 (vf_loss
) を計算 - 価値損失は、新しい状態価値 (
new_values
) と目標収益 (returns
) の二乗誤差 です。Critic は、この損失を最小化するように学習される -
* mask
:マスク (mask
) を掛けることで、有効な遷移のみを損失計算に含めます。エピソード終了後の遷移などは損失計算から除外される