LLMのファインチューニングを効率的に行うことのできるReFTという手法を、Pythonのpyreft
ライブラリでやってみました。
ファインチューニングしたモデルを保存するところまでは難なくできたのですが、その保存したモデルをロードするところでエラーが発生し、解決までにしばらく時間を要してしまったので、その時の備忘録です。
起きていたエラー
公式の方法と同様にファインチューニングしたモデルを以下のように読み込もうとしました。ちなみに環境としてはGoogle Colabを使っていました。インストールしたpyreft
のバージョンとしてはpyreft==0.0.6
でした。
import torch, transformers, pyreft
device = "cuda"
model_name_or_path = "microsoft/Phi-3-mini-4k-instruct"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True
)
model_path = "./finetuned-model"
reft_model = pyreft.ReftModel.load(model_path, model)
そうしたらなぜか以下のようにキーエラーが出てしまった。
KeyError: <class 'transformers_modules.microsoft.Phi-3-mini-4k-instruct.5a516f86087853f9d560c95eb9209c1d4ed9ff69.modeling_phi3.Phi3ForCausalLM'>
解決方法
まず、ファインチューニング元となった事前学習モデルを再び用意し、ファインチューニング時のReFTの設定とともに、ReFTモデルを作成します。
そして、torch.loadやtorch.saveを使って、ファインチューニングで保存した重みを手動で読み込ませます。
ReFTはニューラルネットワーク構造の中間層に介入(intervention)していますが、最後にその介入情報をkey-valueペアとしてreft_model.interventionsに保存してあげます。
具体的なコードは以下のようになります。
import os
import torch, transformers, pyreft
device = "cuda"
#ReFTを利用してファインチューニングしたときに使った、元々の事前学習モデルをロードする。
model_name_or_path = "microsoft/Phi-3-mini-4k-instruct"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True
)
#ファインチューニングで使ったPyReFTの設定を定義する。
layers = range(model.config.num_hidden_layers)
representations = [{
"component": f"model.layers[{l}].output",
"intervention": pyreft.LoreftIntervention(
embed_dim=model.config.hidden_size,
low_rank_dimension=4
)
} for l in layers]
reft_config = pyreft.ReftConfig(representations=representations)
# 元々の事前学習モデルとPyReFTの設定からReFTモデルを作る。
reft_model = pyreft.get_reft_model(model, reft_config)
##### ここまではファインチューニング前と共通の下準備。 #####
##### ここからがファインチューニングして保存したモデルをロードする部分。 #####
# 保存したReFTモデルを読み込む。
save_directory = f"./finetuned-model"
interventions = {}
for l in layers:
component = f"model.layers[{l}].output"
file_path = os.path.join(save_directory, f"intkey_comp.{component}.unit.pos.nunit.1#0.bin")
if os.path.exists(file_path):
with open(file_path, "rb") as f:
adjusted_key = f"comp.{component}.unit.pos.nunit.1#0"
interventions[adjusted_key] = torch.load(f)
# torchで手動で読み込んだ重みをReFTモデルに適用する。
for component, state_dict in interventions.items():
if component in reft_model.interventions:
reft_model.interventions[component][0].load_state_dict(state_dict)
else:
print(f"キーが見つかりません: {component} は reft_model.interventions 内にありませんでした。")
なぜ公式のようにすんなりロードができないのかは謎です。
まだ新しいライブラリだから変な挙動を起こしているだけ??
まあ個人的にはReFTの仕組みの理解の手助けにもなって良かったかもです〜