0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

GoogleのGemmaモデルをQLoRAファインチューニングしてみた

Last updated at Posted at 2025-07-05

元ネタはGoogel AI for Developersの記事になります。
この記事を参考にローカルのWindows環境で実行した備忘録です。
Hugging Face Transformers と QloRA を使用して Gemma をファインチューニングする

概要

GoogleのGemmaモデルをQLoRA(Quantized Low-Rank Adaptation)手法を使用してWindowsローカル環境でファインチューニングし、Text-to-SQLタスクに特化させる方法を説明します。また、ファインチューニング済みモデルをOllamaで利用する方法も含まれています。

プロジェクト概要

使用技術

  • ベースモデル: Google Gemma-3-1b-pt
  • ファインチューニング手法: QLoRA(4bit量子化 + LoRA)
  • データセット: philschmid/gretel-synthetic-text-to-sql(10万件超の合成Text-to-SQLデータ)
  • フレームワーク: Hugging Face Transformers, TRL, PEFT

QLoRAの利点

  • メモリ効率: 4bit量子化により大幅なメモリ削減
  • 高速学習: アダプターレイヤーのみ学習で高速化
  • 性能維持: ベースモデルの性能を維持しながら効率化

環境セットアップ

ハードウェア要件

  • GPU: NVIDIA RTX 3060 (12GB VRAM) 以上推奨(最低8GB)
  • RAM: 16GB以上(32GB推奨)
  • ストレージ: 20GB以上の空き容量

ソフトウェア要件

  • Windows 10/11
  • Python 3.8-3.11
  • CUDA 11.8以上 + cuDNN
  • Git

環境構築手順

1.Python環境の準備

python -m venv gemma-finetune
gemma-finetune\Scripts\activate 

2.PyTorchのインストール

pip install torch>=2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install tensorboard
pip install "transformers>=4.51.3"
pip install tokenizers
pip install "datasets==3.3.2" "accelerate==1.4.0" "evaluate==0.4.3" "bitsandbytes==0.45.3" "trl==0.15.2" "peft==0.14.0" protobuf

学習データについて

philschmid/gretel-synthetic-text-to-sql データセット

データセット概要

  • 総レコード数: 105,851件(訓練用100,000件、テスト用5,851件)
  • 総トークン数: 約2,300万トークン(SQLトークン約1,200万)
  • 対象領域: 100の異なる業界ドメイン
  • ライセンス: Apache 2.0

データ構造

{
  "sql_prompt": "各州の病院ベッド数の合計は?",
  "sql_context": "CREATE TABLE Beds (State VARCHAR(50), Beds INT); INSERT INTO...",
  "sql": "SELECT State, SUM(Beds) FROM Beds GROUP BY State;",
  "sql_explanation": "このクエリは各州の病院ベッド数の合計を計算します...",
  "domain": "public health",
  "sql_complexity": "aggregation"
}

ファインチューニング実行

train.py
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from huggingface_hub import login

# Hugging Face認証
login("your_token_here")

# GPU設定確認
device = torch.cuda.current_device()
print(f"Using GPU: {torch.cuda.get_device_name(device)}")

# モデル設定
model_id = "google/gemma-3-1b-pt"
torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

# 量子化設定
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_quant_storage=torch_dtype,
)

# モデルとトークナイザー読み込み
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    attn_implementation="eager",
    torch_dtype=torch_dtype,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

# データセット準備
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""

def create_conversation(sample):
    return {
        "messages": [
            {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
            {"role": "assistant", "content": sample["sql"]}
        ]
    }

# データセット読み込み
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))
dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)
dataset = dataset.train_test_split(test_size=2500/12500)

# LoRA設定
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"]
)

# トレーニング設定
args = SFTConfig(
    output_dir="./results/gemma-text-to-sql",
    max_seq_length=512,
    packing=True,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-4,
    fp16=True if torch_dtype == torch.float16 else False,
    bf16=True if torch_dtype == torch.bfloat16 else False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=False,
    report_to="tensorboard",
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True,
    }
)

# トレーナー作成・実行
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer
)

# トレーニング実行
trainer.train()
trainer.save_model()

実行方法

python train.py

モデル保存場所

  • 保存先: ./results/gemma-text-to-sql/
  • 内容: LoRAアダプタファイル(数十MB)
  • 注意: ベースモデル全体ではなくアダプタのみ保存

推論とテスト

inference.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from random import randint
import re

# モデル読み込み
model_id = "./results/gemma-text-to-sql"
torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch_dtype,
    attn_implementation="eager",
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

# パイプライン作成
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# テスト例
test_message = {
    "messages": [
        {
            "role": "user", 
            "content": """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
CREATE TABLE employees (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    department VARCHAR(50),
    salary DECIMAL(10,2)
);
</SCHEMA>

<USER_QUERY>
Show me all employees in the IT department
</USER_QUERY>
"""
        }
    ]
}

# 推論実行
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_message["messages"], tokenize=False, add_generation_prompt=True)

outputs = pipe(
    prompt, 
    max_new_tokens=256, 
    do_sample=False, 
    temperature=0.1, 
    top_k=50, 
    top_p=0.1, 
    eos_token_id=stop_token_ids,
    disable_compile=True
)

print("Generated SQL:")
print(outputs[0]['generated_text'][len(prompt):].strip())

生成パラメータ説明

  • max_new_tokens=256: 最大256トークン生成
  • do_sample=False: 決定論的生成(毎回同じ結果)
  • eos_token_id: 適切な停止条件
  • disable_compile: PyTorchコンパイル無効化(安定性重視)

期待される出力例

入力: "Show me all employees in the IT department"
出力: SELECT * FROM employees WHERE department = 'IT';

Ollamaでの利用

1.モデルマージ

merge_model.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import os
import shutil

def merge_lora_model():
    """LoRAアダプタをベースモデルにマージし、全てのトークナイザーファイルを保存"""
    
    # パス設定
    base_model_id = "google/gemma-3-1b-pt"
    adapter_path = "./results/gemma-text-to-sql"
    output_dir = "./merged_gemma_text_to_sql"
    
    try:
        print("🔄 ベースモデルを読み込み中...")
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype=torch.float16,
            device_map="cpu",
            low_cpu_mem_usage=True
        )
        
        print("🔄 ベースモデルのトークナイザーを読み込み中...")
        base_tokenizer = AutoTokenizer.from_pretrained(base_model_id)
        
        print("🔄 LoRAアダプタを読み込み中...")
        model = PeftModel.from_pretrained(base_model, adapter_path)
        
        print("🔄 モデルをマージ中...")
        merged_model = model.merge_and_unload()
        
        # 出力ディレクトリ作成
        os.makedirs(output_dir, exist_ok=True)
        
        print(f"💾 マージされたモデルを保存中: {output_dir}")
        merged_model.save_pretrained(
            output_dir,
            safe_serialization=True,
            max_shard_size="2GB"
        )
        
        print(f"💾 ベースモデルのトークナイザーを保存中: {output_dir}")
        base_tokenizer.save_pretrained(output_dir)
        
        # tokenizer.modelを手動でダウンロード
        print("⬇️ tokenizer.modelを手動でダウンロード中...")
        from huggingface_hub import hf_hub_download
        tokenizer_model_path = hf_hub_download(
            repo_id=base_model_id,
            filename="tokenizer.model",
            local_dir=output_dir,
            local_dir_use_symlinks=False
        )
        
        print("✅ マージ完了!")
        return True
        
    except Exception as e:
        print(f"❌ エラーが発生しました: {e}")
        return False

if __name__ == "__main__":
    merge_lora_model()

2. GGUF変換(不要)

変換実行

git clone https://github.com/ggerganov/llama.cpp
python C:\path\to\llama.cpp\convert_hf_to_gguf.py ./merged_gemma_text_to_sql --outfile gemma_text_to_sql.gguf

3.Ollamaモデルファイル作成

Modelfile
# Gemma Text-to-SQL用
FROM ./merged_gemma_text_to_sql

# Gemmaモデル専用のテンプレート
TEMPLATE """{{ if .System }}<start_of_turn>system
{{ .System }}<end_of_turn>
{{ end }}{{ if .Prompt }}<start_of_turn>user
{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ end }}"""

# 停止条件の設定
PARAMETER stop "<start_of_turn>"
PARAMETER stop "<end_of_turn>"

# システムプロンプト
SYSTEM """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

4.Ollamaでの使用

モデル登録

ollama create gemma-text-to-sql -f Modelfile

推論実行

ollama run gemma-text-to-sql "Given schema: CREATE TABLE users (id INT, name VARCHAR(50)); Question: Show all users"

まとめ

WindowsローカルPC環境でGemmaモデルのQLoRAファインチューニングを実行し、Text-to-SQLタスクに特化させる完全な手順を説明しました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?