5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

SFTTrainer を使って、簡単に CausalLM をファインチューニングをしよう

Last updated at Posted at 2024-07-27

SFTTrainer を使って、簡単に CausalLM をファインチューニングをしよう

はじめに

はじめまして。現在、私は大学院生(修士課程)です。

この記事では、TRL というライブラリを使って、簡単に Causal Language Model (CausalLM) をファインチューニングしようと思います。

この記事が「CausalLM の勉強を始めました」や「コードを書くのが楽になった」などの貢献ができるように書かせていただきます。


類似記事もありますので、是非こちらもご覧になってください。

環境構築

Python バージョン

  • Python >= 3.0.0

ライブラリ

requirements.txt
datasets
peft
torch>=2.0.0
transformers
trl

使用するモデルとデータセット

モデル

今回使用するモデルは、rinna GPT を使用します。

データセット

今回使用するデータセットは、SciQ 1 を日本語に翻訳したデータセットを使用します。

コード

ライブラリのインポート

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from datasets import load_dataset
from peft import LoraConfig
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

トークナイザとモデル

ロード

cuda が使えれば、cuda の方にモデルをロードします。

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-xsmall")
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-xsmall").to(device)

モデルの構造を確認します。

model

LoRA の設定

LoRA architecture

LoRA 2 を適応したモデルの入出力は、以下の式で表せます。ここで、$W_0 \in \mathbb{R}^{d \times k}$ は凍結された事前学習済み重み行列、$B \in \mathbb{R}^{d \times r}$、$A \in \mathbb{R}^{r \times k}$とし、$r \ll \min(d, k)$です。初期化方法は、$B$ を零行列に従い初期化し、$A$ を正規分布に従い初期化します。

h = W_0 x + \Delta W x = W_0 x + B A x

LoRA を適応したとき、学習可能なパラメータ数の割合は、以下の式のように表せます。

\frac{d \times r + r \times k}{d \times k + d \times r + r \times k} = \frac{r (d + k)}{d k + r (d + k)}

LoRA を使うことで、ファインチューニングの時のコストを大幅に削減することができます。

この記事では、モデルサイズが $43.7 \mathrm{M}$ のモデルを使っているので、あまり恩恵は感じられないかもしれません。
数 $\mathrm{B}$ 程度のモデルでは、LoRA の恩恵を受けられると思います。

今回は、全ての線形層に LoRA を適応したいため、target_modules="all-linear" としました。

peft_config = LoraConfig(
    peft_type="LORA",
    task_type="CAUSAL_LM",
    r=8,
    target_modules="all-linear",
    lora_alpha=8,
    lora_dropout=0.0
)

データセット

ロード

train_dataset = load_dataset("izumi-lab/sciq-ja-mbartm2m", split="train")
eval_dataset = load_dataset("izumi-lab/sciq-ja-mbartm2m", split="validation")

学習に使用するフォーマットの作成とデータコレーター

formatting_func に、学習の時に使うプロンプトを書きます。

def formatting_func(example):
    output_texts = [
        f"# 問題: {example['question'][i]}\n#ヒント: {example['support'][i]}\n# 答え: {example['correct_answer'][i]}" for i in range(len(example))
    ]
    return output_texts
data_collator = DataCollatorForCompletionOnlyLM(
    response_template=tokenizer.encode("# 答え: ", add_special_tokens=False),
    instruction_template=tokenizer.encode("# 問題: ", add_special_tokens=False),
    tokenizer=tokenizer
)

LLaMA 2 のようないくつかのトークナイザでは、他とは異なるトークン化戦略のため、上記のような書き方の data_collator では学習ができません。

LLaMA 2 のようないくつかのトークナイザの場合は、以下のようなコードに変更して下さい。

data_collator = DataCollatorForCompletionOnlyLM(
    response_template=tokenizer.encode("\n# 答え: ", add_special_tokens=False)[2:],
    instruction_template=tokenizer.encode("# 問題: ", add_special_tokens=False),
    tokenizer=tokenizer
)

あるいは、TransformersDataCollatorForLanguageModeling を使って、以下のようなコードに変更して下さい。

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    return_tensors="pt"
)

学習

設定

SFTConfig は、TransformersTrainingArguments を継承しています。
今回は、基本的なパラメータのみをここに書いていますが、より細かく設定したい人は、調べてみて下さい。

最適化関数は、AdamW 3 がデフォルトで設定されています。

  • output_dir:チェックポイントなどの保存先のパス
  • evaluation_strategy:評価戦略
  • per_device_train_batch_size:訓練時のバッチサイズ
  • per_device_eval_batch_size:評価時のバッチサイズ
  • learning_rate:初期の学習率
  • num_train_epochs:エポック数
  • lr_scheduler_type:スケジューラーのタイプ
  • warmup_ratio:学習率を $0$ から learning_rate で指定した値までの線形ウォームアップの割合
  • logging_strategy:ロギング戦略
  • save_strategy:保存戦略
  • report_to:結果やログを記録する外部サービス

TensorBoardWandB などの設定ができていれば、損失などの記録を確認することができます。

args = SFTConfig(
    output_dir="./outputs",
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=5e-5,
    num_train_epochs=3.0,
    lr_scheduler_type="linear",
    warmup_ratio=0.0,
    logging_strategy="epoch",
    save_strategy="epoch",
    report_to="all"
)

Trainer の用意

trainer = SFTTrainer(
    model=model,
    args=args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
    formatting_func=formatting_func
)

Trainer の確認

この時点で、LoRA 2 を適応済みです。

trainer.model
trainer.model.print_trainable_parameters()

学習の設定は、SFTConfig で指定した値以外はデフォルトに設定されています。

trainer.args

データセットは、formatting_func で指定したフォーマットに従い、トークン化されています。

trainer.tokenizer.decode(trainer.train_dataset[0]["input_ids"])

学習開始

trainer.train()

推論

学習済みのモデルをロード

outputs/checkpoint-xxxxxx は、outputs フォルダ内の存在する整数に書き換えて下さい。

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("outputs/checkpoint-xxx")
model = AutoModelForCausalLM.from_pretrained("outputs/checkpoint-xxx").to(device)

評価用のデータセット

test_dataset = load_dataset("izumi-lab/sciq-ja-mbartm2m", split="test")

生成

with torch.no_grad():

    id = 0
    input_text = f"# 問題: {test_dataset['question'][id]}\n#ヒント: {test_dataset['support'][id]}\n# 答え: "
    
    tokenized_input_text = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    tokenized_output_text_list = model.generate(**tokenized_input_text, max_new_tokens=128)
    
    output_text_list = [
        tokenizer.decode(tokenized_output_text, skip_special_tokens=True) for tokenized_output_text in tokenized_output_text_list
    ]
    
    print(output_text_list)

おわりに

この記事では、ファインチューニングを簡単にできるようにしました。また、学習したモデルを使って、推論をするため方法について書きました。

この記事が CausalLM を学習・研究する人に貢献できたら、幸いです。

参考文献

  1. Johannes Welbl, Nelson F. Liu and Matt Gardner. Crowdsourcing Multiple Choice Science Questions. In Proceedings of the 3rd Workshop on Noisy User-generated Text, 2017.

  2. Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang and Weizhu Chen. LoRA: Low-Rank Adaptation of Large Language Models. In The Tenth International Conference on Learning Representations, 2022. 2

  3. Ilya Loshchilov and Frank Hutter. Decoupled Weight Decay Regularization. In 7th International Conference on Learning Representations, 2019.

5
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?