本記事では、Googleの軽量LLMであるGemma-3-270m-itをベースモデルとし、特定タスク(ここでは「かな漢字変換」)に特化させるためのポストトレーニングパイプラインを解説します。
単なる教師あり微調整(SFT)にとどまらず、以下の手法を段階的に適用することで、モデルの性能を最大限に引き出す手法を実装します。
- SFT (Supervised Fine-Tuning): 基礎的な応答能力の学習
- DPO (Direct Preference Optimization): 自己生成データを用いた選好学習
- GRPO (Group Relative Policy Optimization): ルールベース報酬による強化学習
- GGUF変換: エッジデバイスでの利用を想定した量子化
第1章:環境構築と教師あり微調整 (SFT)
まずはライブラリのインストールと、ベースとなるデータセットの準備を行います。ここではtrlライブラリのSFTTrainerを使用し、効率的な学習を行います。
1.1 ライブラリ導入とデータ前処理
データセットにはHugging Face上のkatsukiono/kana-kanji-pairsを使用します。出現頻度(count)が低いノイズを除去し、学習用と評価用に分割します。※データはアップデート前のold/mozc_n10_20260102.jsonlを使用。
# -*- coding: utf-8 -*-
"""
Gemma-3-270m かな漢字変換モデル学習パイプライン
Chapter 1: 環境構築とSFT
"""
!pip install trl
from google.colab import userdata
from huggingface_hub import login
from datasets import load_dataset
# Hugging Faceへのログイン
hf_token = userdata.get("HF_TOKEN")
login(hf_token)
# データセットのロードとフィルタリング
ds = load_dataset("katsukiono/kana-kanji-pairs", split="train")
# 出現頻度が5以上のデータのみを利用し、ノイズを低減
ds = ds.filter(lambda x: int(x["count"]) >= 5)
ds = ds.train_test_split(test_size=2000, seed=42)
def make_user_prompt(kana: str, n: int) -> str:
"""
ユーザープロンプトを作成する関数
"""
return (
"キーボードの予測変換として以下のかなを"
f"{n}"
"個の単語に予測変換してください。必ず単語のみを予測してlist形式で返してください。\n\n"
"出力形式\n"
"[候補1, 候補2, 候補3...候補10]\n\n"
"ーーーー以下が予測変換対象ーーーー\n".replace(" \ n", "\n") + # Colab上のコピー事故防止策
f"{kana}"
)
def preprocess_function(example):
"""
データセットをPrompt-Completion形式に変換する前処理関数
"""
kana = example["input"]
n = min(int(example["count"]), 10)
cands = example["output"][:n]
# 出力をPythonリスト形式の文字列として整形
assistant_response = "[" + ", ".join([str(x).strip() for x in cands]) + "]"
# TRLが推奨する会話形式(prompt/completion)に整形
return {
"prompt": [{"role": "user", "content": make_user_prompt(kana, n)}],
"completion": [{"role": "assistant", "content": assistant_response}],
}
# データセットへの適用
train_ds = ds["train"].map(preprocess_function, remove_columns=ds["train"].column_names)
eval_ds = ds["test"].map(preprocess_function, remove_columns=ds["test"].column_names)
# 前処理結果の確認
print(train_ds[0])
1.2 モデルのロードとSFTの実行
Gemma-3-270m-itをロードし、学習を実行します。モデルサイズが小さいため、packing=Trueを有効にして学習効率を高めています。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
BASE_MODEL_ID = "google/gemma-3-270m-it"
# トークナイザーのロード
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# モデルのロード
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype="auto",
device_map="auto",
)
model.config.use_cache = False
model.gradient_checkpointing_enable()
# 学習設定
sft_config = SFTConfig(
output_dir="gemma3_270m_kana_fullft",
seed=42,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
gradient_accumulation_steps=2,
num_train_epochs=3,
learning_rate=5e-5,
warmup_ratio=0.03,
weight_decay=0.1,
lr_scheduler_type="cosine",
logging_steps=50,
eval_steps=500,
save_steps=500,
save_total_limit=2,
bf16=True, # エラーが発生する場合は fp16=True, bf16=False に変更
max_length=256,
packing=True,
report_to="none",
)
# Trainerの初期化と実行
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
)
trainer.train()
# 学習済みモデルとトークナイザーの保存
trainer.save_model(sft_config.output_dir + "/merged_hf_fullft")
tokenizer.save_pretrained(sft_config.output_dir + "/merged_hf_fullft")
第2章:DPO用データの自己生成 (Rejection Sampling)
DPO(Direct Preference Optimization)には、「正解(Chosen)」と「不正解(Rejected)」のペアデータが必要です。ここでは外部データセットに頼らず、SFT済みモデル自身を使用してRejectedデータを生成します。これにより、モデルが犯しやすい間違いを効率的に学習させることが可能になります。
2.1 データ生成ロジックの実装
高温度(Temperature)でのサンプリングを行い、正解データ(Chosen)と異なる出力が得られるまで再試行するロジックを実装します。
"""
Chapter 2: DPO用データの生成
SFT済みモデルを用いて、Chosen(正解)に対するRejected(不正解)を生成する
"""
import os, re, json, torch
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
StopStringCriteria, StoppingCriteriaList, set_seed
)
set_seed(42)
BASE_MODEL_ID = "google/gemma-3-270m-it"
SFT_MODEL_DIR = "gemma3_270m_kana_fullft/merged_hf_fullft"
OUTPUT_JSONL = "gemma3_270m_kana_dpo_pref.jsonl"
BATCH_SIZE = 64
MAX_NEW_TOKENS = 96
SAMPLE_LIMIT = None # 全件生成の場合はNone、テスト時は数値を指定
# トークナイザーの準備(ベースモデルからロード)
tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
# SFT済みモデルのロード
model = AutoModelForCausalLM.from_pretrained(
SFT_MODEL_DIR,
torch_dtype="auto",
device_map="auto",
)
model.eval()
# 生成停止条件の設定(リストの閉じ括弧を検知したら停止)
stopper = StopStringCriteria(tokenizer=tok, stop_strings=["]"])
stopping = StoppingCriteriaList([stopper])
def make_user_prompt(kana: str, n: int) -> str:
"""Prompt作成関数(Chapter 1と同等)"""
return (
"キーボードの予測変換として以下のかなを"
f"{n}"
"個の単語に予測変換してください。必ず単語のみを予測してlist形式で返してください。\n\n"
"出力形式\n"
"[候補1, 候補2, 候補3...候補10]\n\n"
"ーーーー以下が予測変換対象ーーーー\n"
f"{kana}"
)
def format_gold_list(cands) -> str:
"""正解リストを文字列化"""
return "[" + ", ".join([str(x).strip() for x in cands]) + "]"
def extract_bracket_content(text: str):
"""生成テキストからリスト部分を抽出"""
m = re.search(r"\[[^\[\]\n]*\]", text)
return m.group(0).strip() if m else None
@torch.inference_mode()
def generate_responses(chat_prompts, do_sample: bool, temperature=0.9, top_p=0.95, top_k=50):
"""推論実行関数"""
enc = tok(chat_prompts, return_tensors="pt", padding=True).to(model.device)
prompt_lens = enc["attention_mask"].sum(dim=1).tolist()
gen_kwargs = dict(
**enc,
max_new_tokens=MAX_NEW_TOKENS,
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id,
stopping_criteria=stopping,
)
if do_sample:
gen_kwargs.update(dict(do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k))
else:
gen_kwargs.update(dict(do_sample=False))
out = model.generate(**gen_kwargs)
results = []
for j in range(out.shape[0]):
gen_ids = out[j, int(prompt_lens[j]):]
gen_text = tok.decode(gen_ids, skip_special_tokens=True)
results.append(extract_bracket_content(gen_text))
return results
# データセットのロードとフィルタリング
ds = load_dataset("katsukiono/kana-kanji-pairs", split="train")
ds = ds.filter(lambda x: int(x["count"]) >= 5)
if SAMPLE_LIMIT is not None:
ds = ds.select(range(min(SAMPLE_LIMIT, len(ds))))
total_rows = len(ds)
print(f"処理対象データ数: {total_rows}")
# 途中再開の判定
start_index = 0
if os.path.exists(OUTPUT_JSONL):
with open(OUTPUT_JSONL, "r", encoding="utf-8") as rf:
start_index = sum(1 for _ in rf)
print(f"再開位置: {start_index} 行目から")
# 統計情報の初期化
stats = {
"saved": 0,
"dropped_none": 0, # 抽出失敗
"dropped_same": 0, # Chosenと一致
"retry_1": 0, # 1回目のリトライ数
"retry_2": 0, # 2回目のリトライ(Greedy)数
}
# データ生成ループ
with open(OUTPUT_JSONL, "a", encoding="utf-8") as wf:
pbar = tqdm(range(start_index, total_rows, BATCH_SIZE), desc="Generating Rejected Data", unit="batch")
for i in pbar:
batch = ds[i : min(i + BATCH_SIZE, total_rows)]
kana_list = batch["input"]
count_list = batch["count"]
out_list = batch["output"]
ns = [min(int(c), 10) for c in count_list]
chosens = [format_gold_list(o[:n]) for o, n in zip(out_list, ns)]
# プロンプト作成
chat_prompts = [
tok.apply_chat_template(
[{"role":"user","content":make_user_prompt(k, n)}],
tokenize=False,
add_generation_prompt=True
)
for k, n in zip(kana_list, ns)
]
# ---- Pass 1: 通常サンプリング ----
rejected = generate_responses(chat_prompts, do_sample=True, temperature=0.9)
# 失敗データの特定
retry_indices = [j for j, (r, c) in enumerate(zip(rejected, chosens)) if (r is None) or (r == c)]
stats["retry_1"] += len(retry_indices)
# ---- Pass 2: 高温度サンプリング(リトライ) ----
if retry_indices:
prompts2 = [chat_prompts[j] for j in retry_indices]
r2 = generate_responses(prompts2, do_sample=True, temperature=1.1, top_p=0.98)
for j, val in zip(retry_indices, r2):
rejected[j] = val
retry_indices_2 = [j for j, (r, c) in enumerate(zip(rejected, chosens)) if (r is None) or (r == c)]
stats["retry_2"] += len(retry_indices_2)
# ---- Pass 3: Greedy生成(最終手段) ----
if retry_indices_2:
prompts3 = [chat_prompts[j] for j in retry_indices_2]
r3 = generate_responses(prompts3, do_sample=False)
for j, val in zip(retry_indices_2, r3):
rejected[j] = val
# ---- ファイルへの書き込み ----
for k, n, c, r in zip(kana_list, ns, chosens, rejected):
if r is None:
stats["dropped_none"] += 1
continue
if r == c:
stats["dropped_same"] += 1
continue
record = {
"prompt": [{"role":"user", "content": make_user_prompt(k, n)}],
"chosen": [{"role":"assistant", "content": c}],
"rejected": [{"role":"assistant", "content": r}],
"kana": k,
"n": n,
}
wf.write(json.dumps(record, ensure_ascii=False) + "\n")
stats["saved"] += 1
pbar.set_postfix({k: v for k, v in stats.items() if v > 0})
print("生成完了")
print(f"保存件数: {stats['saved']}")
print(f"除外件数: {stats['dropped_none'] + stats['dropped_same']} (None: {stats['dropped_none']}, Same: {stats['dropped_same']})")
第3章:DPOによる選好最適化
生成したペアデータセットを用いてDPOを実行します。ここではDPOTrainerを使用します。SFTモデルを参照モデル(Reference Model)として固定し、そこから過度に逸脱しない範囲で、Rejectedを避けChosenを生成するよう重みを更新します。
"""
Chapter 3: DPO (Direct Preference Optimization) の実行
"""
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer, DPOConfig
import torch
PREF_JSONL = "gemma3_270m_kana_dpo_pref.jsonl"
BASE_MODEL_ID = "google/gemma-3-270m-it"
SFT_MODEL_DIR = "gemma3_270m_kana_fullft/merged_hf_fullft"
DPO_OUTPUT_DIR = "gemma3_270m_kana_dpo"
# データセットの準備
ds = load_dataset("json", data_files=PREF_JSONL, split="train")
required_columns = ["prompt", "chosen", "rejected"]
ds = ds.remove_columns([c for c in ds.column_names if c not in required_columns])
ds = ds.train_test_split(test_size=0.05, seed=42)
train_ds, eval_ds = ds["train"], ds["test"]
# トークナイザー設定
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Policyモデル(学習対象)
model = AutoModelForCausalLM.from_pretrained(SFT_MODEL_DIR, torch_dtype="auto", device_map="auto")
model.config.use_cache = False
model.gradient_checkpointing_enable()
# Referenceモデル(参照用・凍結)
ref_model = AutoModelForCausalLM.from_pretrained(SFT_MODEL_DIR, torch_dtype="auto", device_map="auto")
ref_model.requires_grad_(False)
ref_model.eval()
# DPO学習設定
dpo_config = DPOConfig(
output_dir=DPO_OUTPUT_DIR,
bf16=True,
learning_rate=2e-6, # 安定性を重視し低めの学習率を設定
beta=0.1, # Referenceモデルとの乖離許容度
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
gradient_accumulation_steps=2,
max_length=256,
max_prompt_length=192,
logging_steps=50,
eval_strategy="steps",
eval_steps=500,
save_steps=500,
save_total_limit=2,
report_to="none",
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=dpo_config,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
)
trainer.train()
# 保存
trainer.save_model(DPO_OUTPUT_DIR + "/merged_hf_dpo")
tokenizer.save_pretrained(DPO_OUTPUT_DIR + "/merged_hf_dpo")
第4章:GRPOによるルールベース強化学習
DPOに加え、さらにGRPO (Group Relative Policy Optimization) を適用します。今回は外部の報酬モデルを使用せず、ルールベースの関数(Reward Function)を用いて生成結果を評価・学習します。
ここでは以下のルールを報酬として定義します。
-
形式遵守:
[ ]の形式であるか。 - 正解含有: リスト内に正解データが含まれているか。
- 幻覚抑制: 正解セットに含まれない単語を出力した場合のペナルティ。
- 重複抑制: 同一単語の重複に対するペナルティ。
"""
Chapter 4: GRPO (Group Relative Policy Optimization) の実行
"""
import re
from collections import Counter
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
from transformers import AutoTokenizer
BASE_MODEL_ID = "google/gemma-3-270m-it"
START_MODEL_DIR = "gemma3_270m_kana_dpo/merged_hf_dpo" # DPOモデルを開始地点とする
GRPO_OUTPUT_DIR = "gemma3_270m_kana_grpo"
# トークナイザー設定
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
tokenizer.padding_side = "left" # GRPOでは左パディングが推奨される
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def make_user_prompt(kana: str, n: int) -> str:
"""Prompt作成関数"""
return (
"キーボードの予測変換として以下のかなを"
f"{n}"
"個の単語に予測変換してください。必ず単語のみを予測してlist形式で返してください。\n\n"
"出力形式\n"
"[候補1, 候補2, 候補3...候補10]\n\n"
"ーーーー以下が予測変換対象ーーーー\n"
f"{kana}"
)
def build_grpo_dataset(min_count=5, limit=None):
"""GRPO用データセット構築(PromptとGoldデータを保持)"""
ds = load_dataset("katsukiono/kana-kanji-pairs", split="train")
ds = ds.filter(lambda x: int(x["count"]) >= min_count)
if limit is not None:
ds = ds.select(range(min(limit, len(ds))))
def _map(ex):
n = min(int(ex["count"]), 10)
gold = ex["output"][:n]
return {
"prompt": [{"role": "user", "content": make_user_prompt(ex["input"], n)}],
"gold": gold,
"n": n,
}
return ds.map(_map, remove_columns=ds.column_names)
# 学習用データセット(試行用に30,000件)
train_ds = build_grpo_dataset(min_count=5, limit=30000)
def parse_items(text: str):
"""文字列からリスト要素をパース"""
m = re.search(r"\[(.*?)\]", text, flags=re.DOTALL)
if not m:
return []
items = [p.strip().strip('"').strip("'") for p in m.group(1).split(",")]
return [x for x in items if x]
def reward_function(prompts, completions, gold, n, **kwargs):
"""
報酬関数
prompts: プロンプトのリスト
completions: 生成された回答のリスト
gold: 正解データのリスト
n: 要求された単語数
"""
contents = [c[0]["content"] for c in completions]
rewards = []
for content, g, nn in zip(contents, gold, n):
gset = set(g)
items = parse_items(content)[:nn]
# 1. 形式評価
fmt_ok = 1.0 if (content.strip().startswith("[") and "]" in content) else -2.0
# 2. 正解含有数とOOD(分布外)ペナルティ
in_gold = sum(1 for x in items if x in gset)
ood = sum(1 for x in items if x not in gset)
set_score = 1.0 * in_gold - 2.5 * ood # 間違いを強く罰する設定
# 3. 重複ペナルティ
dup = len(items) - len(set(items))
dup_score = -1.5 * dup
# 4. 先頭文字の偏りペナルティ(多様性確保)
firsts = [x[0] for x in items if x]
c = Counter(firsts)
head_pen = sum(max(0, v - 2) for v in c.values()) # 同一開始文字が3つ以上で減点
head_score = -0.3 * head_pen
rewards.append(fmt_ok + set_score + dup_score + head_score)
return rewards
# GRPO設定
grpo_config = GRPOConfig(
output_dir=GRPO_OUTPUT_DIR,
num_train_epochs=1,
learning_rate=5e-6,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
max_prompt_length=192,
max_completion_length=64,
num_generations=4, # グループサイズ(1プロンプトに対する生成数)
beta=0.05,
logging_steps=25,
save_steps=500,
save_total_limit=2,
report_to="none",
scale_rewards="batch", # バッチ内正規化
)
trainer = GRPOTrainer(
model=START_MODEL_DIR,
args=grpo_config,
train_dataset=train_ds,
reward_funcs=reward_function,
processing_class=tokenizer,
)
trainer.train()
trainer.save_model(GRPO_OUTPUT_DIR + "/merged_hf_grpo")
tokenizer.save_pretrained(GRPO_OUTPUT_DIR + "/merged_hf_grpo")
第5章:推論と比較検証
作成した3つのモデル(SFT, DPO, GRPO)を比較し、性能の変遷を確認するためのコードです。
"""
Chapter 5: 推論と比較検証
"""
import re, torch
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
StopStringCriteria, StoppingCriteriaList
)
MODEL_REPOS = {
"SFT": "katsukiono/gemma3-270m-kana-fullft", # 必要に応じてローカルパスに変更
"DPO": "katsukiono/gemma3-270m-kana-dpo",
"GRPO": "katsukiono/gemma3-270m-kana-grpo",
}
# モデルキャッシュ
_MODEL_CACHE = {}
def get_model_and_tok(repo_id: str):
if repo_id in _MODEL_CACHE:
return _MODEL_CACHE[repo_id]
tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
# メモリ効率のため評価モードでロード
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype="auto", device_map="auto").eval()
_MODEL_CACHE[repo_id] = (model, tok)
return model, tok
def run_inference(repo_id: str, kana: str, n: int = 10):
model, tok = get_model_and_tok(repo_id)
prompt_content = (
"キーボードの予測変換として以下のかなを"
f"{n}"
"個の単語に予測変換してください。必ず単語のみを予測してlist形式で返してください。\n\n"
"出力形式\n"
"[候補1, 候補2, 候補3...候補10]\n\n"
"ーーーー以下が予測変換対象ーーーー\n"
f"{kana}"
)
prompt = tok.apply_chat_template(
[{"role":"user","content":prompt_content}],
tokenize=False,
add_generation_prompt=True
)
inputs = tok(prompt, return_tensors="pt").to(model.device)
stopper = StopStringCriteria(tokenizer=tok, stop_strings=["]"])
with torch.inference_mode():
out = model.generate(
**inputs,
do_sample=False, # 比較のため決定論的に生成
max_new_tokens=96,
stopping_criteria=StoppingCriteriaList([stopper]),
)
gen_text = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
return gen_text
def compare_models(kana: str, n: int = 10):
print(f"=== 入力: {kana} (n={n}) ===")
for name, repo in MODEL_REPOS.items():
try:
result = run_inference(repo, kana, n)
print(f"[{name}]: {result}")
except Exception as e:
print(f"[{name}]: Error - {e}")
# 実行例
compare_models("かんじ", 10)
compare_models("きかいがくしゅう", 10)
第6章:GGUF変換と量子化
最終生成物(ここでは例としてDPOモデル)をllama.cppを用いて量子化し、GGUF形式に変換します。これにより、スマートフォンやCPUのみの環境でも高速に動作可能になります。
"""
Chapter 6: GGUF変換と量子化
"""
# llama.cppのビルド
!rm -rf llama.cpp
!git clone https://github.com/ggml-org/llama.cpp.git
!git -C llama.cpp pull
!cmake -S llama.cpp -B llama.cpp/build
!cmake --build llama.cpp/build -j
!pip -q install -r llama.cpp/requirements.txt
from huggingface_hub import snapshot_download, HfApi, create_repo, upload_file
import os
# 変換対象モデル(例: DPOモデル)
TARGET_MODEL_ID = "katsukiono/gemma3-270m-kana-dpo"
HF_TOKEN = os.getenv("HF_TOKEN")
print("モデルのダウンロード中...")
model_path = snapshot_download(
repo_id=TARGET_MODEL_ID,
token=HF_TOKEN,
)
GGUF_F16 = "gemma3-270m-kana-dpo-f16.gguf"
GGUF_Q4 = "gemma3-270m-kana-dpo-q4_k_m.gguf"
# 1. HF形式 -> GGUF (FP16) への変換
!python3 llama.cpp/convert_hf_to_gguf.py \
"{model_path}" \
--outfile "{GGUF_F16}" \
--outtype f16
# 2. FP16 -> Q4_K_M (4bit量子化)
!./llama.cpp/build/bin/llama-quantize \
"{GGUF_F16}" \
"{GGUF_Q4}" \
Q4_K_M
# 3. Hugging Face Hubへのアップロード
api = HfApi()
user = api.whoami()["name"]
repo_id = f"{user}/gemma3-270m-kana-dpo-gguf"
print(f"リポジトリ作成: {repo_id}")
create_repo(repo_id=repo_id, repo_type="model", private=True, exist_ok=True)
print("アップロード中...")
upload_file(
repo_id=repo_id,
repo_type="model",
path_or_fileobj=GGUF_Q4,
path_in_repo=GGUF_Q4,
commit_message="Add GGUF (Q4_K_M)"
)
print("完了")