はじめに
Llama 系モデルを Hugging Face Transformers でファインチューニングするとき、Trainer をそのまま使うだけでも学習は回せますが、もう少し詳しく
- 損失以外の指標を見たい
- 各エポックごとに生成サンプルを残したい
- ちょっとした条件で学習を止めたい
みたいな場合がよくあります。
こういうときに効いてくるのが TrainerCallback です。TrainerCallback は PyTorch 版 Trainer の学習ループにフックを差し込むための仕組みで、ログの加工や早期終了などを柔軟に追加できます。
ここでは
- Llama を Trainer で学習する最小スクリプト
- そこに TrainerCallback を挟んで「Llama っぽい」ログや挙動を足す具体例
という流れで見ていきます。
なお、本記事では Llama 3 を例にしていますが、TrainerCallback の使い方自体は将来のバージョンや別の CausalLM でもほぼ同様です。
🐣 『結果だけ』出したら怒られた経験、私はよくあります
TrainerCallback の概要
TrainerCallback は「学習ループの特定イベントのたびに呼ばれるクラス」です。
代表的なイベントには以下のようなものがあります。
-
on_train_begin/on_train_end -
on_epoch_begin/on_epoch_end -
on_step_begin/on_step_end on_log-
on_evaluate/on_prediction_stepなど
各メソッドには
-
args: TrainingArguments -
state: 現在のステップ数やエポック数などを持つ TrainerState -
control: 学習の停止や保存タイミングを指示する TrainerControl
が渡されます。
Trainer に登録する方法はシンプルです。
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
callbacks=[MyCallback], # クラス or インスタンス
)
のように callbacks 引数にクラスかインスタンスを渡します。
途中で追加・削除したいときは
trainer.add_callback(MyCallback)
trainer.remove_callback(MyCallback)
のようにメソッドを使います。
🐣 デザインパターンの Observer パターンですね。
Llama を Trainer で学習する最小スクリプト
例として meta-llama/Meta-Llama-3-8B-Instruct を日本語 SFT すると仮定します。実際に使うときは Llama ライセンスへの同意や GPU メモリ要件に注意が必要です。
ここでは構造が分かればいいので
- データセットは datasets のダミー
- プロンプト整形も最小限
という前提のサンプルにします。
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
TrainingArguments,
Trainer,
)
import torch
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
raw_dataset = load_dataset("tatsu-lab/alpaca", split="train[:2000]")
def format_example(example):
prompt = f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n"
return {"text": prompt + example["output"]}
processed = raw_dataset.map(format_example, remove_columns=raw_dataset.column_names)
def tokenize(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=1024,
)
tokenized = processed.map(
tokenize,
batched=True,
remove_columns=["text"],
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
training_args = TrainingArguments(
output_dir="outputs/llama3-sft",
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_train_epochs=3,
logging_steps=10,
eval_strategy="no", # 最新のTransformersではevaluation_strategyよりこちらを推奨
save_strategy="epoch",
bf16=True,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
data_collator=data_collator,
)
trainer.train()
ここまでは普通の Trainer ベースの SFT です。この状態だと
- ログは loss 中心
- 生成サンプルが自動では残らない
- 早期終了やカスタム条件は自前で書かないといけない
という素朴な挙動になります。
例 1: Llama の生成サンプルを各エポックで自動保存する
LLM の学習では
- 損失は下がっているのに生成が微妙
- 逆に損失があまり変わらないのに生成品質はかなり改善している
みたいなことがよくあります。そこで TrainerCallback で各エポック終了時に固定プロンプトへの生成結果をファイルに追記していく例を置いておきます。
from pathlib import Path
from typing import List, Dict, Any
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
import torch
class SampleGenerationCallback(TrainerCallback):
def __init__(
self,
tokenizer,
sample_prompts: List[str],
out_path: str = "outputs/samples.txt",
max_new_tokens: int = 128,
) -> None:
self.tokenizer = tokenizer
self.sample_prompts = sample_prompts
self.out_path = Path(out_path)
self.max_new_tokens = max_new_tokens
self.out_path.parent.mkdir(parents=True, exist_ok=True)
self.out_path.write_text("")
def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs: Dict[str, Any],
) -> TrainerControl:
model = kwargs["model"]
model.eval()
lines: List[str] = []
lines.append(f"===== epoch {state.epoch:.1f}, step {state.global_step} =====\n")
for i, prompt in enumerate(self.sample_prompts, start=1):
inputs = self.tokenizer(
prompt,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7,
)[0]
text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
lines.append(f"[prompt {i}]\n{prompt}\n")
lines.append(f"[output {i}]\n{text}\n\n")
with self.out_path.open("a", encoding="utf-8") as f:
f.writelines(lines)
return control
Trainer 側でこのコールバックを追加します。
sample_prompts = [
"日本語で LLM のファインチューニング手法を初心者向けに説明してください",
"インターンの学生に Transformer の仕組みを雑に説明するときのトークを書いてください",
]
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
data_collator=data_collator,
callbacks=[
SampleGenerationCallback(
tokenizer=tokenizer,
sample_prompts=sample_prompts,
out_path="outputs/samples.txt",
max_new_tokens=128,
),
],
)
trainer.train()
これで
- 各エポック終了時に outputs/samples.txt に Llama のサンプル生成結果が追記される
- 学習条件を変えたときに生成の変化をテキストとして比較しやすくなる
という状態になります。
TrainerCallback 側は
- on_epoch_end で state.epoch / state.global_step を参照
- kwargs["model"] から生のモデルを触る
だけなので、Llama 以外の CausalLM にも同じパターンで使えます。
例 2: Llama の学習を EarlyStoppingCallback で安全に止める
TrainerCallback は自作だけでなく、Transformers が用意している標準の Early Stopping もコールバックとして提供されています。EarlyStoppingCallback を使うと、評価指標が悪化し続けたときに自動で学習を止められます。
🐣 これ以上学習しても将来性がない実験は早めに終えましょう、ということです。
Llama の SFT だと
- バッチサイズも小さく
- 1 エポックがかなり長くなりがち
なので、無駄なエポックを回さないためにも Early Stopping を入れておくと楽です。
設定例は以下の通りです。
from transformers import EarlyStoppingCallback
training_args = TrainingArguments(
output_dir="outputs/llama3-sft",
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_train_epochs=10,
logging_steps=10,
eval_strategy="epoch", # epochごとに評価
save_strategy="epoch", # epochごとに保存
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
bf16=True,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=data_collator,
callbacks=[
EarlyStoppingCallback(
early_stopping_patience=2,
early_stopping_threshold=0.0,
),
],
)
ポイントは
-
eval_strategy="epoch"とsave_strategy="epoch"をそろえる -
metric_for_best_modelでどの指標を見るか決める -
early_stopping_patience回連続で悪化したら停止
Llama 系モデルのように 1 エポックの計算コストが重いモデルほど、こうしたコールバックで「頑張りすぎない」仕組みを入れておくと扱いやすくなります。
🐣 でもこういう損切りの基準って実は難しいんですよね
例 3: Llama 向けのモデル内フックを TrainerCallback から呼ぶ
最後に、TrainerCallback とモデル側のフックを併用するパターンを軽く触れておきます。
- モデルクラスに on_step_end という任意メソッドを生やす
- TrainerCallback の on_step_end から hasattr(model, "on_step_end") で呼ぶ
という形にしておくと
- 一部のモデルだけ追加のロギングや正則化を入れる
- 既存のモデルには影響を与えない
という構成を取りやすくなります。
雰囲気だけコードにするとこんな感じです。実際に使う場合は AutoModelForCausalLM をラップするか、ロード後のモデルインスタンスにメソッドを追加する形になります。
from typing import Dict, Any
import torch
class GradNormMixin:
"""モデルに grad_norm 記録機能を追加する Mixin"""
def init_grad_norm_history(self):
self.grad_norm_history = []
@torch.no_grad()
def on_step_end(self, step: int, logs: Dict[str, Any]):
total = 0.0
for p in self.parameters():
if p.grad is None:
continue
n = p.grad.data.norm(2).item()
total += n * n
grad_norm = total ** 0.5
self.grad_norm_history.append((step, grad_norm))
logs["grad_norm"] = grad_norm
class ModelHookCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
model = kwargs.get("model")
logs = kwargs.get("logs") or {}
if model is not None and hasattr(model, "on_step_end"):
model.on_step_end(step=state.global_step, logs=logs)
return control
このように
- TrainerCallback は「フレームワーク側に近いフック」
- モデルのメソッドは「モデル特有の追加挙動」
という役割分担にしておくと、Llama 系モデル向けの実験コードを増やしても Trainer 本体のコードがあまり汚れません。
おわりに
Llama のような大きめの言語モデルを Hugging Face Trainer で学習するとき、TrainerCallback をうまく使うと
- 生成サンプルの自動記録
- 早期終了の制御
- モデル固有のフック呼び出し
といった「LLM っぽい」運用を後から足しやすくなります。
Trainer 自体はブラックボックスにせず、どのイベントでどんなコールバックが呼ばれているかを一度追いかけておくと、Llama 系モデルの学習スクリプトを育てるときの自由度がかなり上がります。
🐣 最近ちょっと元気がない?ので Llama には頑張ってもらいたいですね!