0
2

ReFTでファインチューニングして保存したモデルがロードできない時の解決方法

Last updated at Posted at 2024-08-31

LLMのファインチューニングを効率的に行うことのできるReFTという手法を、Pythonのpyreftライブラリでやってみました。
ファインチューニングしたモデルを保存するところまでは難なくできたのですが、その保存したモデルをロードするところでエラーが発生し、解決までにしばらく時間を要してしまったので、その時の備忘録です。

起きていたエラー

公式の方法と同様にファインチューニングしたモデルを以下のように読み込もうとしました。ちなみに環境としてはGoogle Colabを使っていました。インストールしたpyreftのバージョンとしてはpyreft==0.0.6でした。

import torch, transformers, pyreft
device = "cuda"

model_name_or_path = "microsoft/Phi-3-mini-4k-instruct"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True
)

model_path = "./finetuned-model"
reft_model = pyreft.ReftModel.load(model_path, model)

そうしたらなぜか以下のようにキーエラーが出てしまった。

KeyError: <class 'transformers_modules.microsoft.Phi-3-mini-4k-instruct.5a516f86087853f9d560c95eb9209c1d4ed9ff69.modeling_phi3.Phi3ForCausalLM'>

解決方法

まず、ファインチューニング元となった事前学習モデルを再び用意し、ファインチューニング時のReFTの設定とともに、ReFTモデルを作成します。
そして、torch.loadやtorch.saveを使って、ファインチューニングで保存した重みを手動で読み込ませます。
ReFTはニューラルネットワーク構造の中間層に介入(intervention)していますが、最後にその介入情報をkey-valueペアとしてreft_model.interventionsに保存してあげます。

具体的なコードは以下のようになります。

import os
import torch, transformers, pyreft
device = "cuda"

#ReFTを利用してファインチューニングしたときに使った、元々の事前学習モデルをロードする。
model_name_or_path = "microsoft/Phi-3-mini-4k-instruct" 
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True
)

#ファインチューニングで使ったPyReFTの設定を定義する。
layers = range(model.config.num_hidden_layers)
representations = [{
    "component": f"model.layers[{l}].output",
    "intervention": pyreft.LoreftIntervention(
        embed_dim=model.config.hidden_size,
        low_rank_dimension=4
    )
} for l in layers]

reft_config = pyreft.ReftConfig(representations=representations)

# 元々の事前学習モデルとPyReFTの設定からReFTモデルを作る。
reft_model = pyreft.get_reft_model(model, reft_config)

##### ここまではファインチューニング前と共通の下準備。 #####
##### ここからがファインチューニングして保存したモデルをロードする部分。 #####

# 保存したReFTモデルを読み込む。
save_directory = f"./finetuned-model"
interventions = {}
for l in layers:
    component = f"model.layers[{l}].output"
    file_path = os.path.join(save_directory, f"intkey_comp.{component}.unit.pos.nunit.1#0.bin")
    if os.path.exists(file_path):
        with open(file_path, "rb") as f:
            adjusted_key = f"comp.{component}.unit.pos.nunit.1#0"
            interventions[adjusted_key] = torch.load(f)

# torchで手動で読み込んだ重みをReFTモデルに適用する。
for component, state_dict in interventions.items():
    if component in reft_model.interventions:
        reft_model.interventions[component][0].load_state_dict(state_dict)
    else:
        print(f"キーが見つかりません: {component} は reft_model.interventions 内にありませんでした。")

なぜ公式のようにすんなりロードができないのかは謎です。
まだ新しいライブラリだから変な挙動を起こしているだけ??
まあ個人的にはReFTの仕組みの理解の手助けにもなって良かったかもです〜

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2