1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AI・機械学習関連論文Advent Calendar 2024

Day 25

論文紹介: Discovering Preference Optimization Algorithms with and for Large Language Models ~その 2 付録編~

Last updated at Posted at 2024-12-24

はじめに

こんにちは、ひよこです。
今回は、前回紹介した「Discovering Preference Optimization Algorithms with and for Large Language Models」その 2 です。前回は LLM が見つけた DiscoPOP の性能に関する内容でしたが、今回は付録の内容をセクション A ~ G に分けて詳細に解説します。この付録部分には、こちらの方が論文の核心といえるのではないか、という LLM による損失関数生成の実装、評価、発見された関数の特性 が詳述されています。

参考文献:
C. Lu, S. Holt, C. Fanconi, A. J. Chan, J. Foerster, M. van der Schaar, & R. T. Lange. "Discovering Preference Optimization Algorithms with and for Large Language Models", NeurIPS 2024.

DiscoPOP 紹介 その 1

🐣 DiscoPOP って LLM が見つけた損失関数の名前で LLM で発見する提案手法の名前じゃないんですよね


A. LLM 駆動型目的関数探索の実装詳細

プロンプト設計

LLM を使用して目的関数を探索する際の実装詳細について説明します。システムプロンプトは、機械学習研究者として以下のような JSON 形式での出力を要求します。
まずは、LLM へ以下のような指示が出されています。(原論文では英語です) また原論文ではプロンプトは英語ですがこれ以降理解に重要な thought の部分は日本語に訳しています。

あなたは、異なる RLHF 損失関数をテストしている機械学習研究者です。応答する際には、次の形式で JSON を出力してください。最初のキー ("thought") は、次の関数を設計する際の思考プロセスに対応します。2 番目のキー ("name") は次の関数の名前に対応します。最後のキー ("code") は、試したい正確な Python コードに対応します。以下に例を示します:

{
  "thought": "過去の出力を基に、直接選好最適化アルゴリズムを試すべき",
  "name": "dpo",
  "code": "def sigmoid_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
  ) -> torch.FloatTensor:
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    losses = -F.logsigmoid(self.beta * logits)
    return losses"
}

プロンプトでは、以下の点を明確に指示します

  • バイナリ分類の損失関数に関する深い知識の活用を要求
  • 関数インターフェースの厳密な指定
  • self.beta = 0.05 の固定値の使用
  • 選好ランク付きのデータセットでの訓練を想定

LLM からの応答を処理する際は以下のステップを踏みます

  1. 関数の有効性を確認するユニットテスト実行
  2. テスト失敗時は、エラーメッセージを LLM にフィードバック
  3. テスト成功時は、性能評価値を返却し次の提案を要求

このプロセスを通じて、LLM は既存の知識を活用しながら、新しい目的関数を段階的に改良していきます。エラー処理や性能フィードバックを含めた反復的な対話により、実用的な損失関数の発見を目指します。
ここで一番興味があるのは LLM がどのように思考し、優れた損失関数の発見に至ったかという過程でしょう。
そこで、DiscoPOP を発見したときの thought の変化を以下に示します。提案システムは上の初期プロンプトから 6 回の思考を経て DiscoPOP を発見しています。以下に、DiscoPOP までの**thought と損失関数の式**をまとめてみました。

DiscoPOP 発見までの過程

🐣 元の thought はすべて英語なのでご注意下さい

0. 初期の損失関数 (Direct Preference Optimization: DPO)

Thought:
"過去の出力を基に、直接選好最適化アルゴリズムを試すべき"
:

f_{\mathrm{DPO}}(\beta\rho) = -\log(\sigma(\beta\rho)) = \log(1 + \exp(-\beta\rho))

ポイント: Logistic loss を用いたシンプルな直接最適化。


1. Logistic Margin Loss

Thought:
"Logistic log loss を元に、SVM のヒンジ損失から着想を得た margin を導入することで、よりロバストな学習を実現できる可能性があります"
:

f_{\mathrm{LML}}(\beta\rho) = -\log(\sigma(\beta(\rho - m)))

ポイント: Logistic loss に静的なマージン $m$ を導入することで、ロバスト性を向上。


2. Adaptive Margin Logistic Loss

Thought:
"固定 margin をやめ、信頼度に応じた adaptive margin を導入することで、より柔軟なペナルティ設計が可能かもしれません"
:

f_{\mathrm{AML}}(\beta\rho) = -\log(\sigma(\beta(\rho - m(\rho))))

ポイント: マージン $m(\rho)$ を信頼度に応じて動的に設定。


3. Combined Logistic Hinge Loss

Thought:
"Logistic loss の確率的な性質とヒンジ損失のマージン特性を組み合わせることで、双方のメリットを引き出すアプローチを試みます"
:

f_{\mathrm{CLHL}}(\beta\rho) = -\log(\sigma(\beta\rho)) + \max(0, 1 - \beta\rho)

ポイント: Logistic loss と Hinge loss を加算し、それぞれの特性を活用。


4. Scaled Combined Logistic Hinge Loss

Thought:
"Logistic と Hinge のバランスを調整するためにスケーリング係数を導入し、それぞれの寄与を調整します"
:

f_{\mathrm{SCLHL}}(\beta\rho) = \alpha(-\log(\sigma(\beta\rho))) + (1-\alpha)\max(0, 1 - \beta\rho)

ポイント: Logistic loss と Hinge loss のスケーリングを調整することでバランスを最適化。$\alpha$ はスケール係数。


5. Logistic Exponential Hybrid Loss

Thought:
"Logistic loss に Exponential loss を加えることで、過剰な信頼を防ぎながらも極値での挙動を調整するハイブリッドアプローチを試します"
:

f_{\mathrm{LEHL}}(\beta\rho) = 0.5(-\log(\sigma(\beta\rho))) + 0.5\exp(-\beta\rho)

ポイント: Logistic loss と Exponential loss を同程度の重みでハイブリッド化。


6. Log Ratio Modulated Loss (DiscoPOP)

Thought:
"ロジスティック損失と指数損失をロジットの大きさに基づいて動的に切り替えることで、柔軟な損失設計を目指します"
:

f_{\mathrm{LRML}}(\beta\rho) = (1 - \sigma(\beta\rho/\tau))(-\log(\sigma(\beta\rho))) + \sigma(\beta\rho/\tau)\exp(-\beta\rho)

ポイント: ロジットのスケール $\sigma(\beta\rho/\tau)$ を基に、Logistic loss と Exponential loss を動的に切り替え。

🐣 これが論文本体で示していた DiscoPOP です!


B. 学習の詳細

項目 シングルターンダイアログ TL;DR 要約 IMDb 肯定的テキスト生成
ベースモデル zephyr-7b-gemma-sft zephyr-7b-gemma-sft GPT-2(IMDb で事前学習済)
データセット Argilla DPO Mix 7K Reddit TL;DR の 10% (約 8K サンプル) IMDb センチメント
$\beta$ 固定値 0.05 固定値 0.05 {0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5}
学習設定
- エポック数 2 2 3
- 学習率 5e-7 5e-7 5.0e-5
- バッチサイズ/デバイス 2 2 2
- 勾配累積ステップ 8 16 8
- 浮動小数点形式 bfloat16 bfloat16 bfloat16
- 最適化手法 AdamW + コサイン学習率 AdamW + コサイン学習率 AdamW + コサイン学習率
計算資源
- GPU A100 × 8 A100 × 4 A100 × 4
- 実行時間/回 約 30 分 約 1.5 時間 約 30 分

IMDb タスクでは広範な $\beta$ 値が探索される一方、他のタスクは固定値 0.05 を採用しており、タスクごとに異なる最適化戦略が見られます。また、TL;DR 要約では勾配累積ステップを 16 に設定し、長いテキスト処理に対応している点が特徴的です。


C. 評価指標

評価指標 概要 評価基準 評価手法
MT-Bench マルチターンの対話生成能力測定 0 ~ 10 のスコア GPT-4 による評価
Alpaca Eval 2.0 単一ターン対話の性能測定 勝率(通常・長さ制御) GPT-4 Turbo による評価
TL;DR Win-Rate 要約の正確性評価 カスタム評価データセット(694 サンプル) カスタム GPT-4 テンプレート
IMDb Rewards 損失関数のバランス評価 事前学習済み感情分類器スコア vs KL 発散値 生成テキスト 10 回の平均

MT-Bench では GPT-4 が対話生成能力を 0-10 でスコア化する一方、Alpaca Eval 2.0 では勝率を基準に評価し、通常と長さ制御の結果が比較されます。また、TL;DR の要約タスクでは GPT-4 用のカスタムテンプレートを使用して評価を特化させており、IMDb タスクでは感情分類器スコアと KL 発散値のバランスを計測している点が特徴です。

🐣 この辺はタスクごとの作り込みが必要なのかもしれません


D. 追加実験結果

D.1 Frontiers of Reward vs KL-Divergence

  • 図示: 報酬と KL 発散のトレードオフを視覚化。
  • DiscoPOP が報酬と KL-Divergence の両面で最適化されていることを示唆。

D.2 Loss Sweeps for Different β Parameters

  • 観察: $\beta$ が小さい場合はロジスティック損失、大きい場合は指数損失が支配。
  • DiscoPOP は $\beta \in [0.025, 0.1]$ の範囲で最良の性能。

D.3 Discovery Robustness

  • 結果: 提案手法は LLM のサンプリング温度やプロンプト構造に対してロバスト性がある。
実験 観察結果 知見
報酬-KL発散のトレードオフ DiscoPOP が最適なバランスを達成 報酬と汎化性能の両立を実現
$\beta$ パラメータの影響 $\beta \in [0.025, 0.1]$ で最良の性能 小:ロジスティック損失が支配的
大:指数損失が支配的
ロバスト性評価 サンプリング温度やプロンプト構造に対して安定 手法の信頼性を確認

E. 発見された目的関数群

提案手法で LLM が発見した主要な損失関数を示します。

  1. DBAQL (Dynamic Blended Adaptive Quantile Loss)
f_{\mathrm{DBAQL}}(\beta\rho) = \sigma(\mathrm{Var}[\beta\rho/\tau]) \cdot f_{\mathrm{DPO}}(\beta\rho/0.9) + (1-\sigma(\mathrm{Var}[\beta\rho/\tau])) \cdot f_{\mathrm{exp}}(\beta\rho \cdot 0.9)

🐣 分散に基づき DPO と指数損失を動的にブレンドしてます

  1. PADLL (Performance Adaptive Decay Logistic Loss)
f_{\mathrm{PADLL}}(\beta\rho) = 0.9 \cdot (1 - 0.5 \cdot 1_{[\rho < 0]}) \cdot f_{\mathrm{DPO}}(\beta\rho)

🐣 性能に応じて重みを減衰させる DPO の拡張ですね

  1. AQFL (Adaptive Quantile Feedback Loss)
f_{\mathrm{AQFL}}(\beta\rho) = r \cdot f_{\mathrm{DPO}}(\beta\rho) + (1-r) \cdot f_{\mathrm{SLIC}}(\beta\rho)

🐣 DPO と SLiC を適応的に組み合わせてますね

  1. LRML (DiscoPOP)
f_{\mathrm{LRML}}(\beta\rho) = (1-\sigma(\beta\rho/\tau)) \cdot f_{\mathrm{DPO}}(\beta\rho) + \sigma(\beta\rho/\tau) \cdot f_{\mathrm{exp}}(\beta\rho)

🐣 ロジスティック損失と指数損失を対数比に基づいて動的に制御しています


F. 関連研究と考察

LLM を活用した最適化研究の全体像

LLM(大規模言語モデル)は、コード生成やアーキテクチャ探索、報酬関数設計など、多様な分野でその応用が進んでいます。本研究はその中でも、損失関数の生成に特化している点が特徴で、従来の最適化手法を補完する新しいアプローチを提案しています。

自動機械学習との違い

他の自動機械学習(AutoML)アプローチでは、遺伝的アルゴリズムやドメイン特化言語(DSL)、ニューラルネットによるパラメタライズ、シンボリック回帰などが主流です。一方、本研究は LLM を用いて損失関数そのものを生成するという直接的なアプローチを取っており、これが本研究の独自性と言えます。

補足: LLM による損失関数生成の課題と考察

LLM が損失関数を生成できる理由やその限界について、以下の疑問が挙げられます

  • 数学的理解の限界: LLM が損失関数の構造を本質的に「理解」しているのか、あるいは学習データ中のパターンを再現しているに過ぎないのか
  • 既存研究の影響: 生成された関数が既存の損失関数とどの程度類似しているか、また学習データ中の論文や数理最適化に関する資料からどの程度影響を受けているか
  • 汎用性とスケーラビリティ: 特定タスクに過度に最適化されていないか、異なるドメインや大規模モデルへの適用が可能か

これらの疑問を解消することで、LLM による損失関数生成の能力と限界がより深く理解されるでしょう。特に、LLM が汎用的な能力を示す場合、損失関数生成だけでなく、例えば NAS(Neural Architecture Search)といった構造最適化への応用も考えられます。その他にも本研究の手法は以下の方向に発展する可能性があります。

  • 変分法: 理論的に最適な関数形状を探索する新しいアプローチ
  • 遺伝的プログラミング: 損失関数を進化的に探索する柔軟な手法

これらを組み込むことで、より高度で汎用的な損失関数生成や、抽象的な概念の最適化が実現できるかもしれません。


G. 実行ログ

G.1 実験結果の完全ログ

非常に長いので興味のある部分だけ確認してみて下さい。原論文ではすべて英語ですが、理解に重要な Thought の部分は日本語に訳してあります。

1. Logistic Margin Loss

Thought
"Logistic log loss を元に、SVM のヒンジ損失から着想を得た margin を導入することで、よりロバストな学習を実現できる可能性があります"

Code

def logistic_margin_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    margin = 0.1
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits_with_margin = pi_logratios - ref_logratios - margin
    losses = -F.logsigmoid(self.beta * logits_with_margin)
    return losses

Fitness
7.709375

2. Adaptive Margin Logistic Loss

Thought
"固定 margin をやめ、信頼度に応じた adaptive margin を導入することで、より柔軟なペナルティ設計が可能かもしれません"

Code

def adaptive_margin_logistic_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    confidence_margin = torch.abs(policy_chosen_logps - policy_rejected_logps) - \
                        torch.abs(reference_chosen_logps - reference_rejected_logps)
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits_with_margin = pi_logratios - ref_logratios - confidence_margin
    losses = -F.logsigmoid(self.beta * logits_with_margin)
    return losses

Fitness
7.51875

3. Combined Logistic Hinge Loss

Thought
"Logistic loss の確率的な性質とヒンジ損失のマージン特性を組み合わせることで、双方のメリットを引き出すアプローチを試みます"

Code

def combined_logistic_hinge_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    logistic_losses = -F.logsigmoid(self.beta * logits)
    hinge_losses = torch.relu(1 - self.beta * logits)
    combined_losses = logistic_losses + hinge_losses
    return combined_losses

Fitness
7.7375

4. Scaled Combined Logistic Hinge Loss

Thought
"Logistic と Hinge のバランスを調整するためにスケーリング係数を導入し、それぞれの寄与を調整します"

Code

def scaled_combined_logistic_hinge_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    logistic_scale = 2.0  # Logistic loss を強調
    hinge_scale = 0.5  # Hinge loss を抑える
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    logistic_losses = logistic_scale * (-F.logsigmoid(self.beta * logits))
    hinge_losses = hinge_scale * (torch.relu(1 - self.beta * logits))
    combined_losses = logistic_losses + hinge_losses
    return combined_losses

Fitness
7.85625

5. Logistic Exponential Hybrid Loss

Thought
"Logistic loss に Exponential loss を加えることで、過剰な信頼を防ぎながらも極値での挙動を調整するハイブリッドアプローチを試します"

Code

def logistic_exponential_hybrid_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    logistic_losses = -F.logsigmoid(self.beta * logits)
    exp_losses = torch.exp(-self.beta * logits)
    combined_losses = 0.5 * logistic_losses + 0.5 * exp_losses
    return combined_losses

Fitness
7.925

6. Log Ratio Modulated Loss (DiscoPOP)

Thought
"ロジスティック損失と指数損失をロジットの大きさに基づいて動的に切り替えることで、柔軟な損失設計を目指します"

Code

def log_ratio_modulated_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    modulation_factor = torch.sigmoid(logits / 0.05)
    logistic_component = -F.logsigmoid(self.beta * logits)
    exp_component = torch.exp(-self.beta * logits)
    losses = (1 - modulation_factor) * logistic_component + modulation_factor * exp_component
    return losses

Fitness
7.916

7. Performance Adaptive Decay Logistic Loss

Thought
"不正確な予測にはペナルティを加え、正確な予測にはロジスティック損失を維持するような適応的な減衰を試します"

Code

def performance_adaptive_decay_logistic_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    base_decay = 0.9
    penalty_factor = 0.5
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    mismatches = (logits < 0).float()
    decay_factor = base_decay * (1 - mismatches * penalty_factor)
    losses = decay_factor * -F.logsigmoid(self.beta * logits)
    return losses

Fitness
7.941

8. Adaptive Quantile Feedback Loss

Thought
"損失関数に適応的な分位点を導入し、データの分散を考慮した損失の調整を行います"

Code

def adaptive_quantile_feedback_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    quantile_update_rate = 0.05
    distance_scale = 0.1
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    logits_std = logits.std()
    adaptive_quantile = logits_std * torch.sigmoid(-logits).mean()
    adaptive_quantile += quantile_update_rate * (torch.sigmoid(logits.mean()) - adaptive_quantile)
    distance_from_quantile = (logits - adaptive_quantile).abs()
    blend_rate = torch.sigmoid(distance_scale * distance_from_quantile)
    logistic_losses = -F.logsigmoid(self.beta * logits)
    hinge_losses = torch.relu(1 - self.beta * logits)
    losses = blend_rate * logistic_losses + (1 - blend_rate) * hinge_losses
    return losses

Fitness
7.931

9. Dynamic Blended Adaptive Quantile Loss

Thought
"ロジスティック損失と指数損失を動的にブレンドし、ログ確率比の分散に基づいて適応的な量的制御を試します"

Code

def dynamic_blended_adaptive_quantile_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    starting_quantile = 0.5
    quantile_adapt_rate = 0.01
    temperature = 0.9
    dynamic_blend_rate = 1.0
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    logits_variability = logits.var()
    moving_quantile = starting_quantile + quantile_adapt_rate * (torch.sigmoid(logits.mean()) - starting_quantile)
    dynamic_blend_coeff = torch.sigmoid(logits_variability) * dynamic_blend_rate
    logistic_loss = -F.logsigmoid(self.beta * logits / temperature)
    exp_loss = torch.exp(-self.beta * logits * temperature)
    losses = dynamic_blend_coeff * logistic_loss + (1 - dynamic_blend_coeff) * exp_loss
    return losses

Fitness
7.978

10. Policy Focused Loss

Thought
"正確な予測にはペナルティを減少させ、誤った予測にはより強いヒンジ損失を適用することで、ポリシーの焦点を強調します"

Code

def policy_focused_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    focus_scale = 2.0  # 正確な予測の強調度
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    is_correct = policy_chosen_logps > policy_rejected_logps
    logistic_losses = -F.logsigmoid(self.beta * logits)
    hinge_losses = torch.relu(1 - logits)
    focused_loss = torch.where(
        is_correct,
        logistic_losses / focus_scale,  # 正確な予測のペナルティを軽減
        hinge_losses * focus_scale  # 誤った予測により強いペナルティ
    )
    return focused_loss

Fitness
7.900

11. Annealed Logistic Hinge Loss

Thought
"時間経過に伴い損失関数をスケールさせるアニーリングを適用し、初期はロジスティック成分を強調し、後半はヒンジ成分を優先するアプローチを試します"

Code

def annealed_logistic_hinge_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
    annealing_factor = torch.sigmoid(self.beta)  # 時間経過に応じたアニーリング
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    logits = pi_logratios - ref_logratios
    logistic_losses = annealing_factor * (-F.logsigmoid(self.beta * logits))
    hinge_losses = (1 - annealing_factor) * torch.relu(1 - self.beta * logits)
    losses = logistic_losses + hinge_losses
    return losses

Fitness
7.850


まとめ

今回は、DiscoPOP の論文の付録部分について示しました。この部分が付録ということはまだまだ粗削りで改善の余地がありということかもしれませんが、「LLM がアルゴリズムや数式を独自に発見するフレームワーク」は、今後の AI 研究に大きな影響を与えそうですね。是非とも動かして雰囲気だけでも体感してみたいと思います。ではまた次の記事でお会いしましょう。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?