はじめに
LLMのファインチューニング時に、デフォルトで提供される損失関数(例えば、CrossEntropyLoss)を使っており、「もっとLLMの性能を直接的に測る指標を損失関数に組み込めないか?」と思い、LLM as a judgeの結果をもとにロス計算するとより簡単にLLMをファインチューニングできるのではないかと思いついたので、やってみた。
損失関数のカスタマイズ
デフォルトの損失関数(例えば、CrossEntropyLoss)は、トークンごとの予測の正確さを測るには優れています。しかし、最終的な目的である「人間にとって自然で質の高いテキストの生成」を直接的に最適化しているわけではありません。そこで、以下のようなメリットを期待してカスタム損失関数を導入しました。
- 人間が判断する品質を損失関数に反映: GPT-4のような強力なモデルを評価者とすることで、人間が感じるテキストの品質をより直接的に損失関数に取り込む
- ファインチューニングの効率化: 最終的なゴールに直結する指標を損失関数に用いることで、学習の効率が向上する可能性
from transformers import TrainerCallback
from transformers import EarlyStoppingCallback
from transformers import EvalPrediction
from typing import Dict
import openai
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def custom_compute_metrics(res: EvalPrediction) -> Dict:
# res.predictionsとres.label_idsはnumpyのarray
pred = res.predictions.argmax(axis=2)
target = res.label_ids # inputsの代わりにlabel_idsを使用
if target is None:
print("警告: label_idsが空です。評価を実行できません。")
return {'gpt4_score': None}
# バッチごとに処理
valid_preds = []
valid_targets = []
for batch_pred, batch_target in zip(pred, target):
# labelsが-100でない部分のみを抽出
valid_indices = batch_target != -100
valid_preds.append(batch_pred[valid_indices])
valid_targets.append(batch_target[valid_indices])
# print("有効な予測:", valid_preds)
# print("有効なターゲット:", valid_targets)
# print("有効な予測の形状:", [p.shape for p in valid_preds])
# print("有効なターゲットの形状:", [t.shape for t in valid_targets])
# GPT-4を使用して評価を行う
gpt4_score = gpt4_evaluate(valid_preds, valid_targets, tokenizer)
return {
'gpt4_eval_score': gpt4_score
}
@weave.op()
def gpt4_evaluate(predictions, labels, tokenizer):
# GPT-4を使用して評価を行う関数
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# print(decoded_preds)
# print(decoded_labels)
results = []
for pred, label in zip(decoded_preds, decoded_labels):
prompt = f"""
生成されたテキスト: {pred}
正解のテキスト: {label}
上記の生成されたテキストと正解のテキストを比較し、以下の基準で0から10の間でスコアを付けてください:
- 内容の一致度
- 文体や表現の類似性
- 全体的な質
- 話者の性格や口調の一致度
10が完全に一致または非常に高品質、0が全く異なるまたは低品質です。スコアのみを返してください。
"""
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "あなたは生成されたテキストの品質を評価するAIアシスタントです。"},
{"role": "user", "content": prompt}
]
)
score = float(response.choices[0].message.content.strip()) / 10 # 0-1のスケールに変換
results.append(score)
return sum(results) / len(results)
class GPT4EvalCallback(TrainerCallback):
def __init__(self, trainer, tokenizer, eval_dataset, num_samples=50):
self.trainer = trainer
self.tokenizer = tokenizer
self.eval_dataset = eval_dataset.select(range(min(num_samples, len(eval_dataset))))
self.best_score = float('-inf')
def on_evaluate(self, args, state, control, **kwargs):
model = self.trainer.model
eval_dataloader = self.trainer.get_eval_dataloader(self.eval_dataset)
predictions = []
labels = []
for batch in eval_dataloader:
with torch.no_grad():
outputs = model.generate(**batch, max_new_tokens=100)
predictions.extend(outputs.tolist())
labels.extend(batch["labels"].tolist())
print(predictions)
print(labels)
gpt4_score = gpt4_evaluate(predictions, labels, self.tokenizer)
wandb.log({"gpt4_eval_score": gpt4_score})
if gpt4_score > self.best_score:
self.best_score = gpt4_score
self.trainer.save_model(f"{args.output_dir}/best_gpt4_model")
# TrainerStateに現在のGPT-4スコアを追加
state.metrics["gpt4_eval_score"] = gpt4_score
import torch.nn.functional as F
class CustomTrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.step = 0
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
# 通常のloss計算(例:CrossEntropyLoss)
loss_fct = F.cross_entropy
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# カスタム評価指標を計算
eval_pred = EvalPrediction(predictions=logits, label_ids=labels)
metrics = custom_compute_metrics(eval_pred)
# カスタム評価指標をlossに加える(例:重み付け)
custom_loss_weight = 1 # この重みは調整が必要かもしれません
custom_loss = (1 - metrics['gpt4_eval_score']) * custom_loss_weight
# 最終的なlossを計算
final_loss = (1 - alpha) * loss + alpha * custom_loss
# メトリクスをログに記録
self.log(metrics, loss, custom_loss, final_loss)
return (final_loss, outputs) if return_outputs else final_loss
def log(self, metrics, loss, custom_loss, final_loss):
# wandbを使用してメトリクスをログに記録
self.step += 1
wandb.log({
**metrics,
"step": self.step,
"loss": loss,
"custom_loss": custom_loss,
"final_loss": final_loss
})
-
custom_compute_metrics
関数:-
EvalPrediction
を受け取り、モデルの予測結果と正解ラベルから有効な部分(-100
以外の部分)を抽出。 - 抽出した予測と正解を
gpt4_evaluate
関数に渡して、GPT-4による評価スコアを取得。 - 評価スコアを辞書形式で返す。
-
-
gpt4_evaluate
関数:- モデルの予測結果と正解ラベルをトークンIDからテキストに変換。
- 各ペアに対して、GPT-4に評価を依頼するプロンプトを作成。
- GPT-4の回答をスコアとして解釈し、平均スコアを返す。
-
GPT4EvalCallback
クラス:-
TrainerCallback
を継承し、評価時にGPT-4による評価を行い、wandb
にログを記録し、最高のGPT-4スコアを達成したモデルを保存する。
-
-
CustomTrainer
クラス:-
SFTTrainer
を継承し、compute_loss
メソッドをオーバーライド。 - 通常の損失関数(
F.cross_entropy
)に加え、custom_compute_metrics
で取得したGPT-4の評価スコアを損失関数に組み込む。 -
wandb
を用いて、通常の損失、カスタム損失、最終的な損失、GPT-4評価スコアをログ記録。
-
最後に
もともとのクラスを継承しているが、wandbで実験管理がうまくできなかったりして、結構苦戦した。
この取り組みをやっている間にTransformersのロス関数にバグがあったという報告があり、今後はロス関数を独自にカスタマイズするのが楽になるらしい...
少し悲しい...