はじめに
前の続きです。
注意:この記事を書いた時には、LoRAで学習したらベースモデルとマージしてHuggingfaceにアップロードs方法を知らなかったのでこの記事を書きました。今はLoRAモデルとベースのモデルをマージしてアップロードする方法がわかったので、あまり気にすることがないのかもしれません。
講座の参加規約上、書けないこともあるので、少しわかり辛いかもしれません💦
ざっくり前回の振り返りです
- コンペで新しい知識を埋め込み、応答精度を競うコンペが最終課題
- LoRAを使って継続事前学習して新しい知識を埋め込めたが、SFTをすると忘却
- そこで、継続事前学習したLoRAモデルとベースモデルを一つにマージすると出力がおかしい
- 困ったので頑張って解決(本記事)
こまったこと
上の振り返りにも書きましたが、もう少し詳しく書きます。
Unslothを用いて、LoRAでCPT(継続事前学習)をしました。
そのLoRAモデルをSFT(Supervised fine-tune)をして、質問に対する回答をするように出力調整を行いました。そうすると見事に継続事前学習の内容を忘却してしまいました。
そこで、LoRAのモデルとベースのモデルをマージし、CPT後のマージ前と後の出力を比較すると、出力が全然だめになるという現象が発生。
ほんと、困った。実はこの解決には4日ほどの昼休みとプライベートタイムを費やしました。😭
LoRAモデルをベースモデルにマージしよう
CPTをして、Huggingfaceにアップロードしたよと。ここまではできています。
LoRA
この時の出力を確認してみる。ちなみに、学習したのは2024年9月、10月のニュース記事。素のモデルは2024年春に公開されているので、石破茂さんは内閣総理大臣になっていない。
inputs = tokenizer("石破茂氏は", return_tensors="pt").to(peft_model.device)
outputs = peft_model.generate(inputs['input_ids'], max_new_tokens = 1000, use_cache = False , do_sample=False, repetition_penalty=1.2)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 石破茂氏は「自民党総裁選に出馬するにあたり、国民の皆様にお約束したいこと」と題した動画を公開。冒頭で「私は、石破新政権が発足したら、直ちに衆議院解散総選挙を行うべきだと考えています」とし、その理由について説明した。
# まず1つ目は、岸田文雄首相(67)による政治と金問題への対応だ。石破氏は「総理自身が、<以下略>
「自民党総裁になって首相になったら解散総選挙するぜ~」っていったあの発言ですねー。
これを以下のコードでマージすると出力がおかしくなってしまう
merged_model = peft_model.merge_and_unload() # モデルマージ
inputs = tokenizer("石破茂氏は", return_tensors="pt").to(peft_model.device)
outputs = merged_model.generate(inputs['input_ids'], max_new_tokens = 1000, use_cache = False , do_sample=False, repetition_penalty=1.2)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 石破氏は「
あらま、尻切れトンボ
仮説
マージしたらおかしくなるという原因について、いろいろと仮説をたてました。
- そもそも、違う方法があるのではないか
- 複数のLoRAモデルをマージしたほうが良いのではないか
- 引数としてなにかあるのか?
- 浮動小数点のタイプが異なるのではないか
などなど
実は四つ目、仮説を立てたというより「ボーっ」とHuggingFaceのモデルを見ていた時にあれ?って思ったんですよ。
LoRAモデルの浮動小数点の型はft32、llm-jpのモデルはbf16
「そろえないといけないのかも!」
【参考】浮動小数点について
このサイト、コンピュータの基礎を知らない人にとってわかりやすくて嬉しい
検証
そこで、マージする前に一行追加しました。
そうです。bf16に変換したんです。
peft_model = peft_model.to(torch.bfloat16) # bf16に変換
inputs = tokenizer("石破茂氏は", return_tensors="pt").to(peft_model.device)
outputs = peft_model.generate(inputs['input_ids'], max_new_tokens = 1000, use_cache = False, do_sample=False, repetition_penalty=1.2)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 石破茂氏は「自民党総裁選に出馬するにあたり、国民の皆様にお約束したいこと」と題した動画を公開。冒頭で「私は、石破新政権が発足したら、直ちに衆議院解散総選挙を行うべきだと考えています」とし、その理由として(1)政治改革・行政改革への取り組み姿勢に対する評価を、選挙という形でお示しいただきたいから(2)政策論争をしっかり行い、その上で、信任していただけるかどうかをお決めいただくため——などを挙げた。
# また、衆院議員任期満了まで約3カ月あることから、10月15日公示→27日の投開票とする日程案も提示した。こ
ちょっと出力が変化してますが、まぁまぁ内容は合っているので良いでしょう。
そして、マージします。
merged_model = peft_model.merge_and_unload() # ベースモデルとLoRAモデルのマージ
inputs = tokenizer("石破茂氏は", return_tensors="pt").to(peft_model.device)
outputs = merged_model.generate(inputs['input_ids'], max_new_tokens = 1000, use_cache = False, do_sample=False, repetition_penalty=1.2)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 石破茂氏は「自民党総裁選に出馬するにあたり、国民の皆様にお約束したいこと」と題した動画を公開。冒頭で「私は、石破新政権が発足したら、直ちに衆議院解散総選挙を行うべきだと考えています」とし、その理由について説明した。
# まず1つ目は、岸田文雄首相(67)による政治と金をめぐる問題で、国会審議や政策論議へ
2行目の出力は変化していますが、CPTで新しく覚えたことをきちんと出力してくれました。
結論
原因はいくつか考えられますが、浮動小数点のタイプに気を遣う必要がありそうですね。その違いに気づかずマージしてしまうと、僕と同じようなことになってしまいそうです。
実はGPUによっても扱える浮動小数点のタイプが異なります。
アーキテクチャ | FP32 | FP64 | FP16 | TF32 | BF16 | FP8 |
---|---|---|---|---|---|---|
Pascal | ✓ | ✓ | ✓ | ✗ | ✗ | ✗ |
Volta | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
Turing | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
Ampere | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
Hopper | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Ada Lovelace | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Blackwell | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
実は最初がGoogleColaboratryでT4(Turing)で学習していたんです。そう、BF16は使えないんですね。ベースモデルは量子化しないとメモリーに乗りませんが、この辺りも要注意ポイントになりそうです。
気を付けるべきポイント
# モデルのアップロード
merged_model.push_to_hub(
peft_model_id + "-merged",
token=HF_TOKEN,
tokenizer=tokenizer,
private=True,
)
# tokenizerのアップロード
tokenizer.push_to_hub(
peft_model_id + '-merged',
token=HF_TOKEN,
private=True,
)
マージ後、HuggingFaceにアップロードするとモデルしかアップロードされないので、tokenizerも別途アップロードするようにしましょう。
まだ残る困りごと・・・
じつは困りごとはまだ終わっていません。
LoRAモデルとベースモデルのマージはできたものの、マージしたモデルにLoRAでSFTをしていくとやはり継続事前学習で覚えたことを忘れてしまうんです。
きっと、SFT用の学習データにかき消されちゃうんでしょうね。この辺りはまだまだ試行錯誤しています。
13Bのモデルを使うと時間がかかるので3.7Bで試行中。追ってまた公開したいと思います。
終わりに
unslothに関連した記事はあまり見かけませんし、そもそもLLMをファインチューニングしよう!と思う人もまだまだ少ないのではないでしょうか?そのため、なかなか記事が見つからず、苦労しました。unsloth、trlのドキュメント、GitHubまで頑張って読みましたが、解決に至ったのはボーっと見ていた時に浮動小数点のタイプが異なるということを発見したのがきっかけという点が面白かった。
この投稿は・・・
東京大学松尾・岩澤研の大規模言語モデル2024講座を受講に際して自分でトライ&エラーを繰り返して学んだことを公開させてもらいました。(受講規約上、書けないことがあることはご理解ください。)
本当に良い講座で、あれもこれもやりたい!と思うことがいっぱい生まれています。
参考になりましたらうれしいです。
あ、LangChainシリーズも書かなきゃ!