2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLMのファインチューニングに使うロス関数を自作してみた

Posted at

はじめに

LLMのファインチューニング時に、デフォルトで提供される損失関数(例えば、CrossEntropyLoss)を使っており、「もっとLLMの性能を直接的に測る指標を損失関数に組み込めないか?」と思い、LLM as a judgeの結果をもとにロス計算するとより簡単にLLMをファインチューニングできるのではないかと思いついたので、やってみた。

損失関数のカスタマイズ

デフォルトの損失関数(例えば、CrossEntropyLoss)は、トークンごとの予測の正確さを測るには優れています。しかし、最終的な目的である「人間にとって自然で質の高いテキストの生成」を直接的に最適化しているわけではありません。そこで、以下のようなメリットを期待してカスタム損失関数を導入しました。

  • 人間が判断する品質を損失関数に反映: GPT-4のような強力なモデルを評価者とすることで、人間が感じるテキストの品質をより直接的に損失関数に取り込む
  • ファインチューニングの効率化: 最終的なゴールに直結する指標を損失関数に用いることで、学習の効率が向上する可能性
from transformers import TrainerCallback
from transformers import EarlyStoppingCallback
from transformers import EvalPrediction
from typing import Dict
import openai

client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def custom_compute_metrics(res: EvalPrediction) -> Dict:
    # res.predictionsとres.label_idsはnumpyのarray
    pred = res.predictions.argmax(axis=2)
    target = res.label_ids  # inputsの代わりにlabel_idsを使用
    
    if target is None:
        print("警告: label_idsが空です。評価を実行できません。")
        return {'gpt4_score': None}
    
    # バッチごとに処理
    valid_preds = []
    valid_targets = []
    for batch_pred, batch_target in zip(pred, target):
        # labelsが-100でない部分のみを抽出
        valid_indices = batch_target != -100
        valid_preds.append(batch_pred[valid_indices])
        valid_targets.append(batch_target[valid_indices])
    
    # print("有効な予測:", valid_preds)
    # print("有効なターゲット:", valid_targets)
    # print("有効な予測の形状:", [p.shape for p in valid_preds])
    # print("有効なターゲットの形状:", [t.shape for t in valid_targets])
    
    # GPT-4を使用して評価を行う
    gpt4_score = gpt4_evaluate(valid_preds, valid_targets, tokenizer)
    
    return {
        'gpt4_eval_score': gpt4_score
    }

@weave.op()
def gpt4_evaluate(predictions, labels, tokenizer):
    # GPT-4を使用して評価を行う関数
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # print(decoded_preds)
    # print(decoded_labels)
    
    results = []
    for pred, label in zip(decoded_preds, decoded_labels):
        prompt = f"""
        生成されたテキスト: {pred}
        正解のテキスト: {label}

        上記の生成されたテキストと正解のテキストを比較し、以下の基準で0から10の間でスコアを付けてください:
        - 内容の一致度
        - 文体や表現の類似性
        - 全体的な質
        - 話者の性格や口調の一致度

        10が完全に一致または非常に高品質、0が全く異なるまたは低品質です。スコアのみを返してください。
        """
        
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "あなたは生成されたテキストの品質を評価するAIアシスタントです。"},
                {"role": "user", "content": prompt}
            ]
        )
        
        score = float(response.choices[0].message.content.strip()) / 10  # 0-1のスケールに変換
        results.append(score)
    
    return sum(results) / len(results)

class GPT4EvalCallback(TrainerCallback):
    def __init__(self, trainer, tokenizer, eval_dataset, num_samples=50):
        self.trainer = trainer
        self.tokenizer = tokenizer
        self.eval_dataset = eval_dataset.select(range(min(num_samples, len(eval_dataset))))
        self.best_score = float('-inf')
    
    def on_evaluate(self, args, state, control, **kwargs):
        model = self.trainer.model
        eval_dataloader = self.trainer.get_eval_dataloader(self.eval_dataset)
        
        predictions = []
        labels = []
        for batch in eval_dataloader:
            with torch.no_grad():
                outputs = model.generate(**batch, max_new_tokens=100)
            predictions.extend(outputs.tolist())
            labels.extend(batch["labels"].tolist())
        
        print(predictions)
        print(labels)
        gpt4_score = gpt4_evaluate(predictions, labels, self.tokenizer)
        wandb.log({"gpt4_eval_score": gpt4_score})
        
        if gpt4_score > self.best_score:
            self.best_score = gpt4_score
            self.trainer.save_model(f"{args.output_dir}/best_gpt4_model")
        
        # TrainerStateに現在のGPT-4スコアを追加
        state.metrics["gpt4_eval_score"] = gpt4_score

import torch.nn.functional as F

class CustomTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.step = 0

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        # 通常のloss計算(例:CrossEntropyLoss)
        loss_fct = F.cross_entropy
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        
        # カスタム評価指標を計算
        eval_pred = EvalPrediction(predictions=logits, label_ids=labels)
        metrics = custom_compute_metrics(eval_pred)
        
        # カスタム評価指標をlossに加える(例:重み付け)
        custom_loss_weight = 1  # この重みは調整が必要かもしれません
        custom_loss = (1 - metrics['gpt4_eval_score']) * custom_loss_weight
        
        # 最終的なlossを計算
        final_loss = (1 - alpha) * loss + alpha * custom_loss
        
        # メトリクスをログに記録
        self.log(metrics, loss, custom_loss, final_loss)
        
        return (final_loss, outputs) if return_outputs else final_loss

    def log(self, metrics, loss, custom_loss, final_loss):
        # wandbを使用してメトリクスをログに記録
        self.step += 1
        wandb.log({
            **metrics,
            "step": self.step,
            "loss": loss,
            "custom_loss": custom_loss,
            "final_loss": final_loss
        })
  1. custom_compute_metrics 関数:
    • EvalPrediction を受け取り、モデルの予測結果と正解ラベルから有効な部分(-100以外の部分)を抽出。
    • 抽出した予測と正解を gpt4_evaluate 関数に渡して、GPT-4による評価スコアを取得。
    • 評価スコアを辞書形式で返す。
  2. gpt4_evaluate 関数:
    • モデルの予測結果と正解ラベルをトークンIDからテキストに変換。
    • 各ペアに対して、GPT-4に評価を依頼するプロンプトを作成。
    • GPT-4の回答をスコアとして解釈し、平均スコアを返す。
  3. GPT4EvalCallback クラス:
    • TrainerCallback を継承し、評価時にGPT-4による評価を行い、wandb にログを記録し、最高のGPT-4スコアを達成したモデルを保存する。
  4. CustomTrainer クラス:
    • SFTTrainer を継承し、compute_loss メソッドをオーバーライド。
    • 通常の損失関数(F.cross_entropy)に加え、custom_compute_metrics で取得したGPT-4の評価スコアを損失関数に組み込む。
    • wandb を用いて、通常の損失、カスタム損失、最終的な損失、GPT-4評価スコアをログ記録。

最後に

もともとのクラスを継承しているが、wandbで実験管理がうまくできなかったりして、結構苦戦した。

この取り組みをやっている間にTransformersのロス関数にバグがあったという報告があり、今後はロス関数を独自にカスタマイズするのが楽になるらしい...
少し悲しい...

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?