元ネタはGoogel AI for Developersの記事になります。
この記事を参考にローカルのWindows環境で実行した備忘録です。
Hugging Face Transformers と QloRA を使用して Gemma をファインチューニングする
概要
GoogleのGemmaモデルをQLoRA(Quantized Low-Rank Adaptation)手法を使用してWindowsローカル環境でファインチューニングし、Text-to-SQLタスクに特化させる方法を説明します。また、ファインチューニング済みモデルをOllamaで利用する方法も含まれています。
プロジェクト概要
使用技術
- ベースモデル: Google Gemma-3-1b-pt
- ファインチューニング手法: QLoRA(4bit量子化 + LoRA)
- データセット: philschmid/gretel-synthetic-text-to-sql(10万件超の合成Text-to-SQLデータ)
- フレームワーク: Hugging Face Transformers, TRL, PEFT
QLoRAの利点
- メモリ効率: 4bit量子化により大幅なメモリ削減
- 高速学習: アダプターレイヤーのみ学習で高速化
- 性能維持: ベースモデルの性能を維持しながら効率化
環境セットアップ
ハードウェア要件
- GPU: NVIDIA RTX 3060 (12GB VRAM) 以上推奨(最低8GB)
- RAM: 16GB以上(32GB推奨)
- ストレージ: 20GB以上の空き容量
ソフトウェア要件
- Windows 10/11
- Python 3.8-3.11
- CUDA 11.8以上 + cuDNN
- Git
環境構築手順
1.Python環境の準備
python -m venv gemma-finetune
gemma-finetune\Scripts\activate
2.PyTorchのインストール
pip install torch>=2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install tensorboard
pip install "transformers>=4.51.3"
pip install tokenizers
pip install "datasets==3.3.2" "accelerate==1.4.0" "evaluate==0.4.3" "bitsandbytes==0.45.3" "trl==0.15.2" "peft==0.14.0" protobuf
学習データについて
philschmid/gretel-synthetic-text-to-sql データセット
データセット概要
- 総レコード数: 105,851件(訓練用100,000件、テスト用5,851件)
- 総トークン数: 約2,300万トークン(SQLトークン約1,200万)
- 対象領域: 100の異なる業界ドメイン
- ライセンス: Apache 2.0
データ構造
{
"sql_prompt": "各州の病院ベッド数の合計は?",
"sql_context": "CREATE TABLE Beds (State VARCHAR(50), Beds INT); INSERT INTO...",
"sql": "SELECT State, SUM(Beds) FROM Beds GROUP BY State;",
"sql_explanation": "このクエリは各州の病院ベッド数の合計を計算します...",
"domain": "public health",
"sql_complexity": "aggregation"
}
ファインチューニング実行
train.py
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from huggingface_hub import login
# Hugging Face認証
login("your_token_here")
# GPU設定確認
device = torch.cuda.current_device()
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
# モデル設定
model_id = "google/gemma-3-1b-pt"
torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
# 量子化設定
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_quant_storage=torch_dtype,
)
# モデルとトークナイザー読み込み
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
attn_implementation="eager",
torch_dtype=torch_dtype,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
# データセット準備
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command.
<SCHEMA>
{context}
</SCHEMA>
<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
return {
"messages": [
{"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
{"role": "assistant", "content": sample["sql"]}
]
}
# データセット読み込み
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))
dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)
dataset = dataset.train_test_split(test_size=2500/12500)
# LoRA設定
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=["lm_head", "embed_tokens"]
)
# トレーニング設定
args = SFTConfig(
output_dir="./results/gemma-text-to-sql",
max_seq_length=512,
packing=True,
num_train_epochs=3,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
optim="adamw_torch_fused",
logging_steps=10,
save_strategy="epoch",
learning_rate=2e-4,
fp16=True if torch_dtype == torch.float16 else False,
bf16=True if torch_dtype == torch.bfloat16 else False,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="constant",
push_to_hub=False,
report_to="tensorboard",
dataset_kwargs={
"add_special_tokens": False,
"append_concat_token": True,
}
)
# トレーナー作成・実行
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset["train"],
peft_config=peft_config,
processing_class=tokenizer
)
# トレーニング実行
trainer.train()
trainer.save_model()
実行方法
python train.py
モデル保存場所
- 保存先: ./results/gemma-text-to-sql/
- 内容: LoRAアダプタファイル(数十MB)
- 注意: ベースモデル全体ではなくアダプタのみ保存
推論とテスト
inference.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from random import randint
import re
# モデル読み込み
model_id = "./results/gemma-text-to-sql"
torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch_dtype,
attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# パイプライン作成
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
# テスト例
test_message = {
"messages": [
{
"role": "user",
"content": """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.
<SCHEMA>
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10,2)
);
</SCHEMA>
<USER_QUERY>
Show me all employees in the IT department
</USER_QUERY>
"""
}
]
}
# 推論実行
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_message["messages"], tokenize=False, add_generation_prompt=True)
outputs = pipe(
prompt,
max_new_tokens=256,
do_sample=False,
temperature=0.1,
top_k=50,
top_p=0.1,
eos_token_id=stop_token_ids,
disable_compile=True
)
print("Generated SQL:")
print(outputs[0]['generated_text'][len(prompt):].strip())
生成パラメータ説明
- max_new_tokens=256: 最大256トークン生成
- do_sample=False: 決定論的生成(毎回同じ結果)
- eos_token_id: 適切な停止条件
- disable_compile: PyTorchコンパイル無効化(安定性重視)
期待される出力例
入力: "Show me all employees in the IT department"
出力: SELECT * FROM employees WHERE department = 'IT';
Ollamaでの利用
1.モデルマージ
merge_model.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import os
import shutil
def merge_lora_model():
"""LoRAアダプタをベースモデルにマージし、全てのトークナイザーファイルを保存"""
# パス設定
base_model_id = "google/gemma-3-1b-pt"
adapter_path = "./results/gemma-text-to-sql"
output_dir = "./merged_gemma_text_to_sql"
try:
print("🔄 ベースモデルを読み込み中...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="cpu",
low_cpu_mem_usage=True
)
print("🔄 ベースモデルのトークナイザーを読み込み中...")
base_tokenizer = AutoTokenizer.from_pretrained(base_model_id)
print("🔄 LoRAアダプタを読み込み中...")
model = PeftModel.from_pretrained(base_model, adapter_path)
print("🔄 モデルをマージ中...")
merged_model = model.merge_and_unload()
# 出力ディレクトリ作成
os.makedirs(output_dir, exist_ok=True)
print(f"💾 マージされたモデルを保存中: {output_dir}")
merged_model.save_pretrained(
output_dir,
safe_serialization=True,
max_shard_size="2GB"
)
print(f"💾 ベースモデルのトークナイザーを保存中: {output_dir}")
base_tokenizer.save_pretrained(output_dir)
# tokenizer.modelを手動でダウンロード
print("⬇️ tokenizer.modelを手動でダウンロード中...")
from huggingface_hub import hf_hub_download
tokenizer_model_path = hf_hub_download(
repo_id=base_model_id,
filename="tokenizer.model",
local_dir=output_dir,
local_dir_use_symlinks=False
)
print("✅ マージ完了!")
return True
except Exception as e:
print(f"❌ エラーが発生しました: {e}")
return False
if __name__ == "__main__":
merge_lora_model()
2. GGUF変換(不要)
変換実行
git clone https://github.com/ggerganov/llama.cpp
python C:\path\to\llama.cpp\convert_hf_to_gguf.py ./merged_gemma_text_to_sql --outfile gemma_text_to_sql.gguf
3.Ollamaモデルファイル作成
Modelfile
# Gemma Text-to-SQL用
FROM ./merged_gemma_text_to_sql
# Gemmaモデル専用のテンプレート
TEMPLATE """{{ if .System }}<start_of_turn>system
{{ .System }}<end_of_turn>
{{ end }}{{ if .Prompt }}<start_of_turn>user
{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ end }}"""
# 停止条件の設定
PARAMETER stop "<start_of_turn>"
PARAMETER stop "<end_of_turn>"
# システムプロンプト
SYSTEM """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""
4.Ollamaでの使用
モデル登録
ollama create gemma-text-to-sql -f Modelfile
推論実行
ollama run gemma-text-to-sql "Given schema: CREATE TABLE users (id INT, name VARCHAR(50)); Question: Show all users"
まとめ
WindowsローカルPC環境でGemmaモデルのQLoRAファインチューニングを実行し、Text-to-SQLタスクに特化させる完全な手順を説明しました。