1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ソースコードを使って中央集権型PPO学習を説明してみたい

Posted at

イントロ

最近、マルチエージェント強化学習アルゴリズムの一つである 中央集権型PPO学習 (Centralized PPO Learning) のソースコードを読んでいて、やや難しい部分をだんだん理解してきました。

今回の記事は、自分のノートを兼ねて、実際のソースコードのコアーな部分を抽出して、説明してみたいです。

目次

  • Part 1:学習プロセスの全体像
    • train 関数を通して、PPOエージェントの学習プロセス全体を概観
    • Actor と Critic の役割、データフロー、損失関数の計算など、主要な要素をコードと共に確認
  • Part 2:GAE (Generalized Advantage Estimation) の詳細
    • _compute_returns_advs 関数に注目し、GAE の計算方法をステップごとに解説
    • 割引率 (gamma) や GAEラムダパラメータ (tau) が Advantage 推定に与える影響を理解
  • Part 3:Actor の更新
    • 方策関数(Actor)の更新部分を詳しく見ていく
    • PPO の核心であるクリッピングされた目的関数、KLダイバージェンス、エントロピーボーナスなどの役割を解説
  • Part 4:Critic の更新
    • 価値関数(Critic)の更新部分を解説
    • 価値正規化 (Value Normalization) の有無による処理の違い、損失関数の計算方法などを確認

解説内容

Part 1:学習プロセスの全体像 (train 関数)

def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
	# ... (省略:RMSの更新やマスクの調整など)

    old_action_logits = []
    self.mac.init_hidden(batch.batch_size)
    for t in range(max_t):
        actor_outs = self.mac.forward(batch, t=t, test_mode=False)
        old_action_logits.append(actor_outs)
    old_action_logits = torch.stack(old_action_logits, dim=1)

	old_values_before = self.critic(batch).squeeze(dim=-1).detach()
	if self.is_value_normalized or self.is_popart:
		old_values_before = self.denormalize_value(old_values_before)

	rewards = rewards.squeeze(dim=-1)

    returns, advantages = self._compute_returns_advs(
        old_values_before, rewards, terminated, self.args.gamma, self.args.tau
    )

	if getattr(self.args, "is_advantage_normalized", False):
		# ... (省略:Advantage 正規化)

	old_action_logits = old_action_logits.detach()
	old_meta_data = compute_logp_entropy(old_action_logits, actions, avail_actions)
	old_log_pac = old_meta_data["logp"]
    
	central_old_log_pac = torch.sum(old_log_pac * alive_mask, dim=-1)

	# ... (省略:Actor と Critic の更新処理 - Part 3, Part 4 で解説)

コード解説

1. old_action_logits の計算 (Actor)

  • self.mac.forward(batch, t=t, test_mode=False):マルチエージェントコントローラ (mac) を通して、各エージェントの行動 logits (行動選択前の値) を計算
  • old_action_logits.append(actor_outs):各タイムステップの行動 logits をリストに格納し、torch.stack でテンソルに変換
  • old_action_logits は、更新前の古い方策 に基づく行動 logits を保持

 

2. old_values_before の計算 (Critic)

  • old_values_before = self.critic(batch).squeeze(dim=-1).detach():Critic ネットワーク (self.critic) に batch (状態情報などを含むデータ) を入力し、状態価値を評価します。.squeeze(dim=-1) で次元を調整し、.detach() で計算グラフから切り離す
  • if self.is_value_normalized or self.is_popart::価値正規化 (Value Normalization) または PopArt が使用されている場合、正規化された価値を元のスケールに戻す Denormalization を行う

 

3. GAE による returnsadvantages の計算

  • returns, advantages = self._compute_returns_advs(...)_compute_returns_advs 関数を呼び出し、収益 (returns)アドバンテージ (advantages) を計算
  • GAE は、割引率 (gamma) と GAEラムダパラメータ (tau) によって制御され、時間的割引を考慮した、より安定的な Advantage 推定を実現

 

4. Advantage 正規化 (オプション)

  • if getattr(self.args, "is_advantage_normalized", False)::Advantage 正規化が有効になっている場合、Advantage を正規化

 

5. old_log_paccentral_old_log_pac の計算 (方策確率)

  • old_action_logits = old_action_logits.detach():古い行動 logits を計算グラフから切り離す
  • old_meta_data = compute_logp_entropy(old_action_logits, actions, avail_actions)compute_logp_entropy 関数を用いて、古い方策における行動の対数確率 (logp) とエントロピーを計算
    • compute_logp_entropy 関数は、利用可能な行動 (avail_actions) を考慮し、利用不可能な行動をマスク する処理を含む
  • old_log_pac = old_meta_data["logp"]:各エージェントの行動の対数確率を取得
  • central_old_log_pac = torch.sum(old_log_pac * alive_mask, dim=-1):全てのアクティブなエージェント (alive_mask) の行動の結合対数確率を計算します。alive_mask は、各タイムステップで生存しているエージェントを示すマスク

Part 2:GAE の詳細 (_compute_returns_advs 関数)

    def _compute_returns_advs(self, _values, _rewards, _terminated, gamma, tau):
        returns = torch.zeros_like(_rewards)
        advs = torch.zeros_like(_rewards)
        lastgaelam = torch.zeros_like(_rewards[:, 0])
        ts = _rewards.size(1)

        for t in reversed(range(ts)):
            nextnonterminal = 1.0 - _terminated[:, t]
            nextvalues = _values[:, t + 1]

            reward_t = _rewards[:, t]
            delta = reward_t + gamma * nextvalues * nextnonterminal - _values[:, t]
            advs[:, t] = lastgaelam = delta + gamma * tau * nextnonterminal * lastgaelam

        returns = advs + _values[:, :-1]

        return returns, advs

コード解説

1. 時間反転ループ

  • for t in reversed(range(ts)):タイムステップを逆順に ループします。GAE は、将来の報酬を考慮して現在の Advantage を計算するため、逆順の計算が必要

 

2. 各タイムステップでの計算

  • nextnonterminal = 1.0 - _terminated[:, t]:次の状態が終端状態でない (nextnonterminal = 1.0) か、終端状態 (nextnonterminal = 0.0) かを示すマスクを作成
  • nextvalues = _values[:, t + 1]:次の状態の価値 (nextvalues) を取得
  • reward_t = _rewards[:, t]:現在のタイムステップの報酬 (reward_t) を取得
  • delta = reward_t + gamma * nextvalues * nextnonterminal - _values[:, t]:誤差 (delta) を計算
    • gamma * nextvalues * nextnonterminal:割引率 (gamma) を考慮した、次の状態の価値の期待値です。nextnonterminal を掛けることで、終端状態以降の価値を 0 にする
    • _values[:, t]:現在の状態の価値
    • delta は、期待される価値の変化を表す
  • advs[:, t] = lastgaelam = delta + gamma * tau * nextnonterminal * lastgaelamGAE Advantage (advs[:, t]) を計算
    • lastgaelam は、過去の TD 誤差を割引率 (gamma) と GAEラムダパラメータ (tau) で減衰させながら累積
    • tau は、GAE の時間的視野を調整するパラメータです。tau = 1 の場合は TD(λ) に、tau = 0 の場合は TD(0) に近づく

 

3. 収益 (returns) の計算

  • returns = advs + _values[:, :-1]:収益 (returns) は、Advantage (advs) に現在の状態価値 (_values[:, :-1]) を足し合わせる ことで計算される

Part 3:Actor の更新

	for _ in range(0, self.mini_epochs_actor):
		if self.agent_type == "rnn":
			action_logits = []
			self.mac.init_hidden(batch.batch_size)
			for t in range(max_t):
				actor_outs = self.mac.forward(batch, t=t, test_mode=False)
				action_logits.append(actor_outs)
			action_logits = torch.stack(action_logits, dim=1)

		meta_data = compute_logp_entropy(action_logits, actions, avail_actions)
		log_pac = meta_data["logp"]

		central_log_pac = torch.sum(log_pac * alive_mask, dim=-1)

		witorch.torch.no_grad():
			approxkl = (0.5 * torch.sum((central_log_pac - central_old_log_pac) ** 2) / alive_mask.sum())
			approxkl_lst.append(approxkl)

		entropy = (torch.sum(meta_data["entropy"] * alive_mask) / alive_mask.sum())
		entropy_lst.append(entropy)

		prob_ratio = torch.clamp(torch.exp(central_log_pac - central_old_log_pac), 0.0, 16.0)

		pg_loss_unclipped = -advantages * prob_ratio
		pg_loss_clipped = -advantages * torch.clamp(
			prob_ratio,
			1 - self.args.ppo_policy_clip_param,
			1 + self.args.ppo_policy_clip_param,
		)

		pg_loss = (torch.max(pg_loss_unclipped, pg_loss_clipped).sum() / alive_mask.sum())

		actor_loss = pg_loss - self.args.entropy_loss_coeff * entropy
		actor_loss_lst.append(actor_loss)

		self.optimiser_actor.zero_grad()
		actor_loss.backward()
		self.optimiser_actor.step()

コード解説

1. action_logits の計算 (新しい方策)

  • if self.agent_type == "rnn"::エージェントが RNN を使用する場合の処理
  • self.mac.forward(batch, t=t, test_mode=False):マルチエージェントコントローラ (mac) を通して、更新された新しい方策 に基づく行動 logits (action_logits) を計算

 

2. log_paccentral_log_pac の計算 (新しい方策確率)

  • meta_data = compute_logp_entropy(action_logits, actions, avail_actions)compute_logp_entropy 関数を用いて、新しい方策における行動の対数確率 (logp) とエントロピー を計算
  • central_log_pac = torch.sum(log_pac * alive_mask, dim=-1):新しい方策における結合対数確率 (central_log_pac) を計算

 

3. 近似 KL ダイバージェンスの計算

  • witorch.torch.no_grad()::勾配計算を抑制するコンテキスト内で処理を行う
  • approxkl = (0.5 * torch.sum((central_log_pac - central_old_log_pac) ** 2) / alive_mask.sum())新旧方策の近似 KL ダイバージェンス (approxkl) を計算

 

4. エントロピーボーナスの計算

  • entropy = (torch.sum(meta_data["entropy"] * alive_mask) / alive_mask.sum()):エントロピー (entropy) を計算
  • エントロピーは、方策のランダム性 を表し、探索を促進 するために損失関数に組み込まれる

 

5. 確率比 (prob_ratio) の計算

  • prob_ratio = torch.clamp(torch.exp(central_log_pac - central_old_log_pac), 0.0, 16.0)新旧方策の確率比 (prob_ratio) を計算
  • torch.exp(central_log_pac - central_old_log_pac):対数確率の差を指数関数で戻すことで、確率比を計算
  • torch.clamp(..., 0.0, 16.0):確率比を 0.0 から 16.0 の範囲にクリップします。これは、数値的な安定性を高めるための処理

 

6. クリッピングされた Surrogate 目的関数の計算

  • pg_loss_unclipped = -advantages * prob_ratio:クリッピングなしの方策勾配損失 (pg_loss_unclipped) を計算
  • pg_loss_clipped = -advantages * torch.clamp(...):クリッピングされた方策勾配損失 (pg_loss_clipped) を計算
    • torch.clamp(prob_ratio, 1 - self.args.ppo_policy_clip_param, 1 + self.args.ppo_policy_clip_param):確率比 (prob_ratio) を 1 - clip_param から 1 + clip_param の範囲にクリップします。clip_param は、方策の更新幅を制限するハイパーパラメータ
    • PPO の核心: クリップされた目的関数は、方策が大きく改善する場合のみ更新を許可し、過剰な更新を抑制 することで、学習の安定性を高め
  • pg_loss = (torch.max(pg_loss_unclipped, pg_loss_clipped).sum() / alive_mask.sum()):最終的な方策勾配損失 (pg_loss) は、クリッピングなしの損失とクリッピングされた損失の最大値 を採用

 

7. Actor 損失 (actor_loss) の構築

  • actor_loss = pg_loss - self.args.entropy_loss_coeff * entropy:最終的な Actor 損失 (actor_loss) は、方策勾配損失 (pg_loss) からエントロピーボーナス (- self.args.entropy_loss_coeff * entropy) を引く ことで構築され

Part 4:Critic の更新

	critic_loss_lst = []
	if self.is_value_normalized or self.is_popart:
		returns = self.normalize_value(returns, mask)

	for _ in range(0, self.mini_epochs_critic):
		new_values = self.critic(batch).squeeze()
		new_values = new_values[:, :-1]

		vf_loss = torch.sum((new_values - returns) ** 2 * mask) / mask.sum()

		self.optimiser_critic.zero_grad()
		vf_loss.backward()
		self.optimiser_critic.step()
		critic_loss_lst.append(vf_loss)

コード解説

1. 価値正規化 (Value Normalization) (オプション)

  • if self.is_value_normalized or self.is_popart::価値正規化または PopArt が使用されている場合、収益 (returns) を正規化

 

2. new_values の計算 (新しい Critic)

  • new_values = self.critic(batch).squeeze()更新された新しい Critic (self.critic) に batch を入力し、新しい状態価値 (new_values) を評価
  • new_values = new_values[:, :-1]:最後のタイムステップの価値を除外します。これは、収益 (returns) が _values[:, :-1] を基に計算されているため、形状を合わせるため

 

3. 価値損失 (vf_loss) の計算

  • vf_loss = torch.sum((new_values - returns) ** 2 * mask) / mask.sum():価値損失 (vf_loss) を計算
  • 価値損失は、新しい状態価値 (new_values) と目標収益 (returns) の二乗誤差 です。Critic は、この損失を最小化するように学習される
  • * mask:マスク (mask) を掛けることで、有効な遷移のみを損失計算に含めます。エピソード終了後の遷移などは損失計算から除外される
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?