7
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Phi-3-miniをFine-Tuningしてみる。

Last updated at Posted at 2024-06-10

はじめに

以前から気になっていたFine-Tuningをやってみます。
色々ありますが、今回はReFTという手法で簡単にPhi-3-miniをFine-Tuningしていきます。

環境はGoogle ColaboratoryのT4で実行しています。

ReFT(Representation Finetuning)とは?

ReFTは、事前学習されたモデルの表現力を最大限に活かしながら、特定のタスクに適応させるための微調整技術で、以下のような特徴を持っています。

  • 柔軟性: 様々なモデルやデータセットに適用可能
  • 効率性: 微調整のための計算コストが低い
  • 性能向上: 特定のタスクに対して顕著な性能向上を実現

LoRAに近い手法のようです。

Fine-Tuningの実行

step-by-step guideの通りに進めていきます。

pyreftのインストール

!pip install git+https://github.com/stanfordnlp/pyreft.git

※colabだとここで再起動しないといけませんでした。

次にモデルとトークナイザを読み込みます。

import torch, transformers, pyreft

prompt_no_input_template = """<|user|>
%s
<|end|>
<|assistant|>
"""
device = "cuda"

model_name_or_path = "microsoft/Phi-3-mini-4k-instruct"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=2048, 
    padding_side="right",
    use_fast=False
)
tokenizer.pad_token = tokenizer.unk_token

ReFTモデルの作成

reft_config = pyreft.ReftConfig(representations={
    "layer": 15,
    "component": "model.layers[15].output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(
        embed_dim=model.config.hidden_size,
        low_rank_dimension=4
    )
})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device(device)
reft_model.print_trainable_parameters()

データセットの用意
今回はChatGPTに会社のHP(https://www.haw.co.jp )を渡して作ってもらいました。

training_examples = [
    ['会社名は?', '私たちは株式会社ハウインターナショナルです。'],
    ['本社の所在地は?', '福岡県飯塚市幸袋576-14 e-ZUKAトライバレーセンターB211号室です。'],
    ...
]

data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model, [prompt_no_input_template % e[0] for e in training_examples], 
    [e[1] for e in training_examples]
)

学習の実行

training_args = transformers.TrainingArguments(
    num_train_epochs=100.0,
    output_dir="./tmp",
    per_device_train_batch_size=10, 
    learning_rate=4e-3,
    logging_steps=20,
    report_to=[]
)
trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model,
    tokenizer=tokenizer,
    args=training_args,
    **data_module
)
_ = trainer.train()

データが少ないので10分ぐらいで終わりました。
スクリーンショット 2024-05-26 6.26.11.png

結果の確認

instruction = "会社名は?"

prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

base_unit_location = prompt["input_ids"].shape[-1] - 1
_, reft_response = reft_model.generate(
    prompt,
    unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512,
    do_sample=True, 
    eos_token_id=tokenizer.eos_token_id,
    early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))
<|user|> 会社名は?
 <|end|> <|assistant|> 私たちは株式会社ハウインターナショナルです。

うまく学習できてますね。

ちなみに学習前は以下のように訳のわからない出力がされました。

私の名前は田中太郎で、東京に本社を置く電子部品メーカーです。私たちは...

他の質問でもちゃんとできているようです。

 <|user|> 会社の所在地は?
 <|end|> <|assistant|> 福岡県飯塚市幸袋576-14 e-ZUKAトライバレーセンターです。 <|end|>
<|user|> どんなサービスを提供していますか?
 <|end|> <|assistant|> 各種システム開発、ブロックチェーン技術の開発、クラウド関係サービスを提供しています。
<|user|> 会社の設立日は?
 <|end|> <|assistant|> 1999年7月です。

まとめ

簡単にですがFine-Tuningをやってみました。
最新情報に対応できない、データセットを用意するのが大変といったデメリットがありますが、狙った回答を出させるような事が容易にできるのが良いと感じました。
基本的にRAGを使った方が良さそうですが、前提知識としたい場合はこちらが良さそうです。

7
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?