0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLMファインチューニング(SFT)

0
Posted at

軽量LLMを教師ありファインチューニング(SFT)する流れを整理しました。
GPUなしローカルマシンのVSCodeでデバッグ実行(ステップイン実行)して動作を追えるように、CPU環境を前提として作成しました。

目次

  1. SFTデータセット構築
  2. SFT実行(SFT済モデル構築)
  3. SFT済モデルの利用

環境構築

仮想環境を構築

uv init
uv run main.py

依存関係をインストール

uv add ipykernel transformers datasets trl peft

私の環境での依存関係(project.toml

[project]
name = "sft"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "datasets>=4.5.0",
    "ipykernel>=7.2.0",
    "peft>=0.18.1",
    "transformers>=5.1.0",
    "trl>=0.28.0",
]

1. SFTデータセット構築

1.1. 方針

  • SFT実行のやり方のお勉強が目的なので、データセットはごく小規模のものを構築する
    • 今回は、10件とする
  • 同様の理由で、データセットの中身も適当に作る
    • 今回は、語尾を「だぬ」に変える
  • データセットは、HuggingFaceのDataset形式で保存する
    • 今回は、ローカル(ディスク)に保存する

1.2. 設計

①CSVで入力(input)と正解回答(output)のペアを10件つくる
②CSVをJSONLに変換する
③JSONLをHuggingFaceHubのDataset形式でローカルディスクに保存する
④保存したデータセットがロードできるか検証する

1.3. 実装例

①CSV

input,output
日本の首都は?,東京だぬ。
1+1は?,2だぬ。
富士山は何県にある?,静岡県と山梨県にまたがっているだぬ。
水の化学式は?,H2Oだぬ。
日本の通貨は?,円だぬ。
太陽は何の星?,恒星だぬ。
寿司の主な材料は?,米と魚介類だぬ。
Pythonは何のための言語?,プログラミングのための言語だぬ。
春夏秋冬のうち、冬の次は?,春だぬ。
地球は何番目の惑星?(太陽から),3番目だぬ。

②JSONL(sft_10.jsonl

{"input": "日本の首都は?", "output": "東京だぬ。"}
{"input": "1+1は?", "output": "2だぬ。"}
{"input": "富士山は何県にある?", "output": "静岡県と山梨県にまたがっているだぬ。"}
{"input": "水の化学式は?", "output": "H2Oだぬ。"}
{"input": "日本の通貨は?", "output": "円だぬ。"}
{"input": "太陽は何の星?", "output": "恒星だぬ。"}
{"input": "寿司の主な材料は?", "output": "米と魚介類だぬ。"}
{"input": "Pythonは何のための言語?", "output": "プログラミングのための言語だぬ。"}
{"input": "春夏秋冬のうち、冬の次は?", "output": "春だぬ。"}
{"input": "地球は何番目の惑星?(太陽から)", "output": "3番目だぬ。"}

③JSONLをHuggingFaceHubのDataset形式でローカルディスクに保存する

from pathlib import Path

import datasets
from datasets import Dataset, DatasetDict

# 出力Dir作成
dataset_dir = Path("sft_dataset")
dataset_dir.parent.mkdir(parents=True, exist_ok=True)

# HF形式データセット作成
dataset: Dataset = datasets.load_dataset(
    path="json",
    data_files="sft_10.jsonl",
    split="train",
)
# ローカルディスクに保存
dataset.save_to_disk(str(dataset_dir))

データセットが.arrow形式で保存されたことを確認する。
VSCodeでは.arrowの中身を表示できないので、以下のようにロードしてプリントする。

import datasets
from datasets import Dataset, DatasetDict

# ロード
_ds: Dataset | DatasetDict = datasets.load_from_disk(str(dataset_dir))

# データ型が元形式(Dataset)と一致するか確認
print(f"データ型: {type(_ds)}")
# -> <class 'datasets.arrow_dataset.Dataset'>
# と出力されることを確認

# データ数を表示
print(f"データ件数: {_ds.num_rows}")
# -> 10
# と出力されることを確認

# 先頭行を表示
print(f"先頭行: {_ds[0]}")
# -> {"input": "日本の首都は?", "output": "東京だぬ。"}
# と出力されることを確認

2. SFT実行(SFT済モデル構築)

2.1. 方針

  • お勉強が目的なので、ローカルのVSCodeでデバッグ実行可能としたい
    • 今回は、ローカルにGPUがないので、CPUで実行可能な手法を採用する
  • CPUでロード・学習可能な軽量モデル・量子化・LoRAを採用する
    • 今回は、"Qwen/Qwen2.5-0.5B-Instruct" を使用する
    • 今回は、4bit量子化をする
    • 今回は、LoRAアダプタのみをローカルに保存する
  • SFT用データセットは、先ほど作成したローカル保存済のものを使用する

2.2. 設計

流れ:

  1. データセットを読込
  2. トークナイザを読込
  3. トークナイザを用いてデータセットの入出力をチャットテンプレート化
    1. よくあるmessages形式(roleとcontentをキーに持つ辞書のリスト)に変換
    2. applay_caht_templateで境界トークンや生成トークンなどの特殊トークンを自動追加
  4. ベースモデル読込 + LoRA設定 + 引数設定 → トレーナー初期化
    1. base = transformers.AutoModelForCausalLM.from_pretrained(モデルパス, ...)
    2. lora_cfg = peft.LoraConfig(r=8, α=16, ...)
    3. args = transformers.TrainingArguments(学習率, バッチサイズ, ...)
    4. trainer = trl.SFTTrainer(base, lora_cfg, args)
  5. 学習
    1. trainer.train()
  6. 学習結果を保存(LoRAアダプタを保存
    1. trainer.model.save_pretrained(...)
    2. tokenizer.save_pretrained(...)

2.3. 実装例

import os
from typing import Any

import datasets
import peft
import torch
import transformers
from datasets import Dataset, DatasetDict
from datasets.formatting.formatting import LazyRow
from transformers import AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from trl.trainer.sft_trainer import SFTTrainer

MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"


### 1. データセット
ds: Dataset
_ds: Dataset | DatasetDict = datasets.load_from_disk("sft_dataset")
if isinstance(_ds, DatasetDict):
    ds = _ds["train"]
elif isinstance(_ds, Dataset):
    ds = _ds
else:
    raise ValueError("データセットの型が期待と異なります")


### 2. トークナイザー
tokenizer: Qwen2Tokenizer
_tok: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if isinstance(_tok, Qwen2Tokenizer):
    tokenizer = _tok
    print("期待通り、Qwen2のトークナイザーがロードされました")
else:
    raise ValueError("期待したトークナイザーの型ではありません")


### 3. チャットテンプレート化
def _format_with_chat_template(example: LazyRow) -> dict[str, str]:
    """
    データセットの1行(Row)を受け取って、チャットテンプレート形式のテキストに変換する関数

    - 入力: {"input": "日本の首都は?", "output": "東京だぬ。"}

    - 出力: {"text": "<|im_start|>system
                       You are Qwen2 ~~ .
                       <|im_end|>

                       <|im_start|>user
                       日本の首都は?
                       <|im_end|>

                       <|im_start|>assistant
                       東京だぬ。
                       <|im_end|>"
            }
    """
    # デバッグ用に中身を表示
    print(example)
    # => {'input': '日本の首都は?', 'output': '東京だぬ。'}
    # のように、データセットの1行(Row)が表示されることを確認

    # よくある role と content からなるチャット形式に変換
    messages = [
        {"role": "user", "content": example["input"]},
        {"role": "assistant", "content": example["output"]},
    ]

    # Qwen系はchat template対応している
    # 境界トークン(<|im_start|>...<|im_end|>)を含む形に整形してくれる
    text: str | Any = tokenizer.apply_chat_template(
        messages,
        tokenize=False,  # トークンID化するか
        add_generation_prompt=False,  # 生成開始トークンを末尾に追加するか
    )
    if not isinstance(text, str):
        raise ValueError("apply_chat_templateの出力が期待と異なります")

    return {"text": text}


# チャットテンプレート形式のテキストに変換
# dataset.mapを使うと裏で並列化してくれて高速らしい
ds_text: Dataset = ds.map(_format_with_chat_template, remove_columns=ds.column_names)
print(ds_text[0]["text"])
# "<|im_start|>system" や "<|im_end|>" などの境界トークンが追加されていることを確認
# あとは勝手にシステムプロンプト("You are Qwen ~~")も追加されていることを確認

# CPUスレッド数
torch.set_num_threads(os.cpu_count() or 4)
print(torch.get_num_threads())  # -> 16

### 4. モデル準備とトレーニング設定
# ベースモデル
model: PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)
if isinstance(model, Qwen2ForCausalLM):
    print("期待通り、Qwen2のモデルがロードされました")
else:
    print("想定外のモデルがロードされました")

# LoRA設定
lora_config = peft.LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

# 引数
args = transformers.TrainingArguments(
    output_dir="./sft_qwen25_05b",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    num_train_epochs=2,
    logging_steps=1,
    save_steps=10,
    save_total_limit=2,
    fp16=False,
    bf16=False,
)

# トレーナー
trainer = SFTTrainer(
    model=model,
    train_dataset=ds_text,
    peft_config=lora_config,
    args=args,
)

### 5. 学習
# CPUだと80分かかった
trainer.train()

### 6. 保存
# LoRAアダプタのみ保存(+tokenizer)
adapter_dir = "qwen2p5_05b_danu_lora"
trainer.model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
print("saved:", adapter_dir)

3. SFT済モデルの利用

from peft.peft_model import PeftModelForCausalLM
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer

## 1. トークナイザとベースモデルを読み込む
base = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)
tok: Qwen2Tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model_sft: PeftModelForCausalLM = peft.PeftModel.from_pretrained(base, adapter_dir)
model_sft.eval()  # 推論モード

## 2. チャットテンプレート形式のプロンプトを作成
query = "日本の首都は?"
messages = [{"role": "user", "content": query}]
prompt: str = tok.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)

## 3. プロンプトを文字列からトークンID列に変換(トークナイズ)
inputs = tok(prompt, return_tensors="pt")

## 4. 推論
with torch.no_grad():
    out = model_sft.generate(
        **inputs,
        max_new_tokens=32,  # 小トークンで高速化
        do_sample=False,  # 再現性向上
        temperature=0.0,  # 再現性向上
    )
answer: str = tok.decode(out[0], skip_special_tokens=True)

print("質問:", query)
print("回答:", answer)

print結果:

質問: 日本の首都は?
回答: system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
user
日本の首都は?
assistant
日本の首都は東京です。

→「だぬ」で終わっておらず、SFTの効果はほとんどなかった。
データ数を10件から1000件くらいに増やせば語尾が「だぬ」に変わるかも。

今回は、CPU環境を前提として、必要最小限のデータセット、軽量なモデル、シンプルなコード(型ヒントは多め)でSFTの実行方法を理解した。
特に、ds_text = dataset.map(format_with_chat_template) のように、mapメソッドを使ってデータセット全件をチャットテンプレート形式に変換する必要があることをよく理解した。
また、transformers, peft, trlなど各種ライブラリの役割分担・使いどころについても理解することができた。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?