1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LoRAを使ってSupervised fine-tune

Posted at

ここまでの経緯

前回まではこちら

大規模言語モデル2024という東京大学の松尾・岩澤研の講座に参加して、自分で調べたことをここにも記しています。


大規模言語モデルはこちら

いや、マジで神講座でしたよ。
2025年も可能なら参加したい!


ちなみに、今週末にここで講座に参加した顛末記をお話ししますのでご興味ある方はご参加ください(現地開催のみ。Webはありません)



Supervised fine tuneとは

前回までの継続事前学習後の出力って、だらだらと文章を作るだけです。
例えばこんな感じ

「石破氏は」と入力したときの継続事前学習後の出力
# 石破茂氏は「自民党総裁選に出馬するにあたり、国民の皆様にお約束したいこと」と題した動画を公開。冒頭で「私は、石破新政権が発足したら、直ちに衆議院解散総選挙を行うべきだと考えています」とし、その理由として(1)政治改革・行政改革への取り組み姿勢に対する評価を、選挙という形でお示しいただきたいから(2)政策論争をしっかり行い、その上で、信任していただけるかどうかをお決めいただくため——などを挙げた。

そう、「石破氏は」という言葉(トークン)に続く言葉(トークン)を出力するだけです。
これではChatGPTのようなチャット形式にはなりませんし、1~100の点数をつけるなどのタスクには適応できません。
そこで、行うのがSupervised fine tune、略してSFTです。(かっこいい。w)

Supervised fine tune 後の出力イメージ(QAタスクの場合)
input = '石破茂氏は何をしている人ですか?'
# <<省略>>
# 石破茂氏は内閣総理大臣です。

こんなイメージ



参考ページ

講座では配布されたサンプルコードを使って色々と工夫をしましたが、サンプルコード自体の著作権は当然松尾・岩澤研にあると思いますので、代わりにunslothのサンプルプログラムを例にお話ししていきます。というか、UnslothのHP、継続事前学習はドキュメントページがあるのに、SFTはないんだよね。

とはいえ、unslothさん神ってるよね。

基本は継続事前学習と同じですが、学習うに使うクラスがSFTTrainerだったり、target_modulesが異なっていたりします。


さらに、東京大学の松尾・岩澤研のPaper&Hacksの動画

この動画は10回位見てもいいと思う。

じゃ、コード解説に行きましょう!



コード解説

環境構築

%%capture
!pip install unsloth
# Also get the latest nightly Unsloth!
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

公式もUnslothをインストールしてから再インストールしてますね。
最新を入れるため!なんて書いてますが、それだけではなさそう。www


モデルの設定

ライブラリのインポート
from unsloth import FastLanguageModel
import torch
モデルダウンロードにおける設定
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

トークン数を2024に、dtypeをNoneにしておくと自動でフロートの形式を選んでくれます。Autoとかにすればいいのに。w
もう一つは4ビットで量子化する設定

モデルのダウンロード
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

ここのmodel_nameに継続事前学習したモデルを指定すると新しい知識を持ったモデルがダウンロードできます。モデルがprivate設定の場合はtoken=の次にHuggingFaceのトークンを設定してください。


LoRAモデルを作る
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

このLoRAモデルを学習していきます。ベースモデルは学習しません。
target_modulesに設定した層は学習されます。継続事前学習の時とは違い、embed_tokenslm_headは指定しないみたい。

各モジュールの意味はこちらで。

実は大切なのはここまでです。


データセットの用意

データを処理するための関数を定義
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
pass

どうやら、llama-3.1のフォーマットで成形しなおす事を定義しているっぽいです。
また今度、確認してみよう。


学習データのダウンロード
from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")

この中身をよく見ないとformatting_prompts_funcがよくわかんないですね。


formatting_prompts_funcで定義した処理の実行
from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)

学習データの成形を実行


Trainerの設定
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
    ),
)

ここではtrlのSFTTrainerクラスをつかいます。
別に他のUnslothのクラスも使えるらしいです。


多分、入出力の定義(僕自身は使っていない)
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)

学習の実行

trainer_stats = trainer.train()


課題

上記でチャット形式で回答できるようなデータセットで学習すれば、ちゃんとQ&A回答してくれます。

ですが、課題もあります。継続事前学習で覚えた内容をSFTによって簡単に忘却してしまいます。

考えられる原因は

  • そもそも継続事前学習のデータがきれいではなく、学習が精度よく行われていない可能性
  • SFTのデータセットによって継続事前学習で入れた知識が上書きされてしまっている可能性

論文も探していますが、なかなかそういう論文はないみたい。
だれか研究してください。お願い!



終わりに

東京大学 松尾・岩澤研の大規模言語モデル2024の講座に参加できたことは非常に有益でした。いや、有益すぎる。
講座はめちゃくちゃわかりやすかったです。僕が大学時代に受けた授業はいったいなんだったんだ?ってマジで思ってます。
社会人にも一部講座は受講可能なので、興味があれば迷わず申し込んだ方が良いと思います。

講師の皆さん、ともに参加した皆さんに心の底から感謝しまくりです。

皆さんもぜひ!

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?