※本記事は2023.9に執筆された内容です。
概要
OpenAI APIに頼らず、ローカルかつ商用利用可能なLLMにFineTuningを行うことで要約タスクを可能にしたい。
elyza/ELYZA-japanese-Llama-2-7b-instruct
にファインチューニングを行うことで、日本語文章の3行要約ができるようになった。
データセットの収集
このリポジトリを参考に、ライブドアのニュースサイトをスクレイピングしてニュース記事と対応する3行要約を合計3907件集めた。
Few-shot性能の確認
このブログでも紹介されているように、Fine-tuningが有効かどうかはFew-shot promptで性能を確認するとある程度推測できる。
以下はmeta-llama/Llama-2-7b-chat-hf
での結果だが、かなりいい出力なので、Fine-tuningにも期待できそう。
from transformers import AutoTokenizer, AutoConfig
import transformers
import torch
model_name="meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
config.pretraining_tp = 1
model=transformers.AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # Passing device_map = 0 means putting the whole model on GPU 0. Other inputs could be cpu, cuda:1, etc. Setting device_map = auto will let accelerate compute the most optimized device_map automatically.
torch_dtype=torch.float16)
if torch.cuda.is_available():
model = model.to("cuda")
#推論
model.eval()
text = "宮内庁は25日、秋篠宮妃紀子さまが新型コロナウイルスに感染されたと発表した。29日まで宮邸で療養する予定。29、30日に鹿児島県で全国高校総合文化祭( 総文祭)の式典などに出席する予定だったが、取りやめる。24 日夕に発熱の症状があったことから検査をした結果、陽性が確認された。秋像宮さまや次女佳子さま、長男悠仁さまは陰性だった。県は秋篠宮さまと悠仁さまの来鹿について、陰性の状態が続けば、変更はないとしている。"
few_shot_prompt = f"""
### Input:
名店がひしめく「丸の内・日本橋」エリアで<うまい蕎麦ランチ>が食べられるお店を厳選してご紹介。ぴあMOOK『うまい蕎麦の店 首都圏版』が選んだ、とっておきの7店がこちら!ミニかき揚げ丼セット(温冷) 660円関東風の濃いめのつゆと、ふのりを練り込んだ喉ごしの良い自家製のそばが相性抜群。新鮮な油を使ったかき揚げは、サクサク。千代田区丸の内1-9-1 東京駅一番街B1F蕎麦も天ぷらも正統派の味お蕎麦と天丼(二八) 1600円喉ごしや歯応えを楽しむ二八蕎麦のほか、蕎麦の風味を存分に味わえる十割蕎麦と天丼のセットも。天ぷらは鮮度抜群で美味。千代田区丸の内1-6-4 丸の内オアゾ5F信州の郷土料理をアレンジ“信州フランス鴨”セリかも南蛮(温) 1680円和食とフレンチが融合したスタイルの人気店。このメニューは、本格信州そば、信州のフランス鴨、契約農家が育てた野菜など、素材にこだわっている。千代田区丸の内2-7-2 JPタワーKITTE5F納豆と卵白でふんわり食感なっとうそば 1100円取り寄せた蕎麦の実を、職人が毎日石臼で挽き、丹念に手打ちする。たっぷりの納豆と卵白がのって、なめらかな口当たり。千代田区丸の内1-5-1 新丸の内ビルディング5Fやわらかな厚切り鴨肉を堪能鴨せいろ 1945円、さつま揚げ 650円合鴨・ネギ・しめじなどを合わせた濃厚なつけ汁に、程良いコシの蕎麦がよく合う。自家製さつま揚げもぜひ味わいたい一品。千代田区丸の内2-4-1 丸の内ビルディング6Fコシの強い独特な田舎蕎麦が美味とり辛そば(温) 980円信州軽井沢の味噌・醤油屋「酢重正之商店」が手掛ける蕎麦屋の人気メニュー。辛口のつけ汁と太い田舎蕎麦がよく合うと評判。千代田区丸の内1-5-1 新丸の内ビルディングB1F風味豊かな生わさびが決め手ざるそば(生わさび) 710円信州の民家を思わせる店内で打つ自家製麺は歯応え抜群。生わさびを自分でおろして味わう。甘めのつゆがわさびにぴったり。千代田区丸の内1-6-1 丸の内センタービル B1Fうまい蕎麦の店 首都圏版日本人ならいつでも蕎麦が食べたい!
### Summary in three lines:
東京駅周辺の安くて美味しい「蕎麦ランチ」の名店を紹介している
「越後そば 東京店」では、ミニかき揚げ丼セットがおすすめと筆者
その他には、「手打ちそば 石月」「酢重正之 楽」「鎌倉 一茶庵 丸山」など
<end>
### Input:
若い女性からアラフォー世代まで、幅広い世代の女性に大人気のまつ毛エクステですが、 まつ毛エクステのトラブル には要注意です。まつ毛エクステの人気の理由は、「顔の雰囲気が一気に華やかになるから」というのが圧倒的。あれこれ塗らなくてもあっという間に「よそゆき顔」が出来てしまうのです。一方で、「トラブルが多い」というウワサがあるのも事実。ここでは、まつ毛エクステのトラブルについて簡単に解説します。きちんとしたサロンなら、これも説明があるはずです。ちなみに、コンタクトを使用している人は保存液とケース、そして眼鏡は忘れずに持っていきましょう。コンタクトをしたままの施術は厳禁です。いかがでしたでしょうか?
### Summary in three lines:
まつ毛エクステのトラブルについて解説している
接着剤が角膜を傷つけたり、アレルギー症状が起きたりする恐れがある
最悪の場合は視力低下や失明の可能性もあるという
<end>
### Input:
2013年に「一帯一路」は、中国が世界経済の中心的地位を占めていた次代の古代シルクロードの再現を意識したものとされ、陸上と海上の双方において中国と中央アジア、欧州までを結ぶ構想だ。中国メディアの中国網はこのほど、一帯一路構想の実現に向け、中国は日本とどのように対峙すべきかを論じる記事を掲載。中国社科院世界経済政治研究所の研究員の分析として、日本は一帯一路構想の実現における競合として中国の前に立ちはだかると主張した。日本が中国の競合となると主張した1つ目の理由は、「シルクロード文化に最も興味を示しているのは日本である」ことだという。日本にはシルクロードを題材にした小説やドキュメンタリーが多く、シルクロードに対する熱意は中国をも凌ぐゆえだ。確かに日本ではシルクロードを題材とした小説などは多いが、これは納得できない理由だ。記事が挙げた2つ目の理由は「冷戦後、もっとも早くシルクロードに商機を見出したのが日本」であることだという。日本は1997年に当時の橋本総理が「対シルクロード地域外交」を打ち出したが、これはどの国よりも早くシルクロードの重要性に注目した結果であると指摘した。3つ目の点は「中国に対して、もっとも競争力を有しているのが日本である」ことで、日本が中国主導のアジアインフラ投資銀行(AIIB)に対抗して、1100億ドルのインフラ投資をアジアで行う方針を打ち出したことを指摘。また、日本は一帯一路構想に対する「破壊力」も有しているうえ、日本は経済面、政治外交面、軍事面で「もっとも中国に対して懐疑的」であることから、中国の一帯一路構想について、日本が何らかの形で対抗策を打ち出してくるのではないかと警戒感を示した。
### Summary in three lines:
13年に中国が新しい経済構想として提唱する新シルクロード構想「一帯一路」
中国の専門家は、構想の実現に日本が競合相手として立ちはだかると指摘した
日本が何らかの形で対抗策を打ち出してくるのではと、警戒感を示している
<end>
### Input:
{text}
### Summary in three lines:
"""
inputs = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
inputs.to(model.device),
temperature=0.1,
max_new_tokens=256)
print(tokenizer.decode(outputs.tolist()[0][inputs.size(1) :], skip_special_tokens=True))
>>
宮内庁は25日、秋篠宮妃紀子さまが新型コロナウイルスに感染したと発表した。
29日まで宮邸で療養する予定で、29、30日に鹿児島県で全国高校総合文化祭(総文祭)の式典などに出席する予定が取りやめられた。
県は秋篠宮さまと悠仁さまの来鹿について、陰性の状態が続けば、変更はないとしている。
<end>
データセットの成形(プロンプト作成)
ChatベースのモデルをFinetuningする場合、事前学習で使用されていたChat用のプロンプトに倣ってデータを成形すべきである。
meta-llama/Llama-2-7b-hf
など、chatを想定した事前学習がなされていないベースモデルの場合は気にする必要はない。
llama2で使われるフォーマットの参考:https://huggingface.co/datasets/mlabonne/guanaco-llama2-1k
"<s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
クマが海辺に行ってアザラシと友達になり、最終的には家に帰るというプロットの短編小説を書いてください。 [/INST] "
今回はELYZA-japanese-Llama-2-7b-instruct
をファインチューニングするため、上の形式に倣い、ニュース文と要約を埋め込んでデータセットを成形した。
3行要約の開始と終了を示す特殊トークンを追加する ことで、精度が向上した。特殊トークンは、はじめ[INST]
[/INST]
に倣い[RESPONSE]
[/RESPONSE]
としていたが、Fine-tuning後に開始トークンである[RESPONSE]と終了トークンである[/RESPONSE]の区別ができていないような挙動を示した(具体的には、[/RESPONSE]の直後に 「3行要約」 と続き、出力がループする)ことから、開始トークンと終了トークンが文字的に類似することを避けるため[R_START]
[R_END]
に変更した。
※モデル内部ではトークンは文字列ではなくtoken_idに変換されるので、ここを気にする必要性は不明だが、特殊トークンの変更はpad_tokenの変更と同時に入れたためablation studyはできていない
"<s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
以下の入力文を3行で要約しなさい。
入力文:
焼きたてチーズタルト専門店パブロ([チョコレートを練りこんだチーズタルト生地に、ブロックチョコレートを挟み込み、表面にマシュマロをたっぷり敷き詰めて焼き上げた。口の中に広がる焼マシュマロの香ばしい甘さと、チョコの濃厚な味わいが楽しめる。同期間、パブロプレミアムカフェでは「焼きたてミニチーズタルト チョコレート×焼マシュマロ」も登場。バレンタインとホワイトデー期間だけの、ふんわりまろやかな美味しさのチーズタルトを味わって。【商品詳細】販売期間: [/INST] [R_START] 3行要約:
チーズタルト専門店パブロが「焼マシュマロチョコチーズタルト」を発売する
発売期間は2月10〜14日、3月11〜14日で、価格は1350円
同期間中に系列のカフェでは、950円でミニサイズも楽しめるとのこと [R_END] </s> "
Fine-Tuning
Data split
サンプルのように成形したデータセットを、カラム名"text"のcsvファイルとして保存し、DatasetDictとして読み込んだ。(Datasetクラスをオーバーライドして色々工夫する例はこちら)
Splitの比率は train : test : valid = 3516 : 195 : 196
後になってわかったことだが、LLMのFine-tuningにはこんなにたくさんのtrainデータは不要だったっぽい。
LIMAという論文では、「質の良いデータを1000件ほど用意すればよく、それ以上増やしても精度が向上することはなかった」と述べられている(データの質の担保方法には、バリエーション確保にはClusteringなど)。
実際に今回の実験でも、3000件以上のtrainデータを全て使い終わる1 epoch終了前にeval lossのプラトーが始まり、それ以降も学習を続けるとtraining lossが下がりすぎて精度が落ちた(Hallucinationが増加した)。
LoRAチューニングでは、training lossが1を下回る頃から事前学習での記憶の忘却が始まるとされている*。
今回の実験で得られたベストなチェックポイントはcheckpoint-40で、0.36 epochまでしか学習しておらず、これはtrainingデータのうち1265件のみを1度ずつ見ただけにすぎない。
Framework
HuggingfaceのTrainerクラスはPEFTライブラリに対応していないため、SFTTrainer
クラスを使用する。使い方は普通のTrainerクラスとほぼ同じ。
HuggingfaceのPEFTライブラリを利用して、QLoRAでTuningした。
特殊トークンの使い方を工夫する
meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-7b-hf
elyza/ELYZA-japanese-Llama-2-7b-instruct
など、Llama 2系の様々なモデルでFine-tuningを試したが、パラメータサイズに関わらず生成の停止位置を学習させるのに最も手こずった。
3行要約部分を自作の特殊トークンで挟むことを試したが、それだけでは終了位置(EOSトークンの生成)はなかなか学習できなかった。
結論としては、学習時のパティング用トークンにEOSトークンを使わないことが重要だった。パディングトークンでは損失が計算されないため、EOSトークンでパディングするとモデルがEOSトークンの生成を学習できないと考えられる。(ちなみに生成時はパディングトークンをEOSにしてもUNKにしても出力内容に変化はなかった)
今回ベストなチェックポイントが得られた学習では、パディング用トークンにUNKトークンを使用している。
tokenizer.pad_token = tokenizer.unk_token
生成時にEOSトークンがなかなか生成されない問題は、事前学習モデルでも結構話題にされているようだ:https://github.com/huggingface/transformers/issues/24994
特殊トークンの追加方法
学習・推論時にtokenizerの設定を以下のようにする。
# Define special tokens
special_tokens = ["[R_START]", "[R_END]"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_tokens(special_tokens, special_tokens=True)
# This will make new learnable parameters for specialized tokens
model.resize_token_embeddings(len(tokenizer))
QLoRA用の設定
QLoRAで学習を行うため、モデルとトークナイザの定義時に以下の設定が必要となる。
QLoRAは性能を犠牲にしないタイプの4ビットデータ型(nf4; normal float 4)を用いてLoRAチューニングを行うSOTAのファインチューニング手法らしい。
about BitsAndBytesConfig params:
- Loading in 4 bits is enabled via load_in_4bit.
- The datatype utilized by bnb_4bit_compute_dtype for linear layer calculations.
- Nested quantization is enabled via bnb_4bit_use_double_quant.
- bnb_4bit_quant_type specifies the datatype used for quantization. There are two quantization datatypes supported: fp4 (four-bit float) and nf4 (regular four-bit float). We advocate using nf4 since it is theoretically optimum for normally distributed weights.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
model_name = "elyza/ELYZA-japanese-Llama-2-7b-instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True
)
model.config.use_cache = False # デフォルトはTrueなのでFalseを指定する。デコード時に過去に見たkeyとvalueの値を使うかどうかの設定らしい。詳しくはわからないが確かに学習時はFalseにしておいた方が良さそうである。
tokenizer = AutoTokenizer.from_pretrained(model_name)
# This is to avoid overflow in fp16 training
tokenizer.padding_side = "right"
LoRA用の設定
from peft import LoraConfig
lora_alpha = 16 # LoRAのスケーリング係数で、新しい学習データに対してモデルを適応させる度合いを決定する。学習過程における更新行列の寄与を調整するので、小さく設定するほど、事前学習でのデータをより重視し、モデルの既存の知識をより維持することになる。対話形式 (Alpacaなど) を学習する場合は、Rankを低く (32以下)、モデルに質疑応答する想定で文書を理解させたい場合はよりRankを高く設定すると良い。
lora_dropout = 0.1 # LoRAレイヤーのドロップアウト確率。更新行列の一部を無効化することで、LoRAの過学習や不安定さを抑える。QLoRAの論文では13Bまでのモデルでは0.1を使用していた
lora_r = 64 # 更新行列のランク。小さいほど更新行列は小さくなり、学習可能なパラメータは少なくなる。データセットが複雑であればあるほどより大きいランクが必要になる。ちなみにLoRAの原著論文では、最小でも8を推奨している。
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "v_proj"], # QとVのレイヤー名はモデルによって異なるため調べて指定する
modules_to_save = ["lm_head", "embed_tokens"] # when additional tokens
)
自作の特殊トークンを追加する場合、
modules_to_save = ["lm_head", "embed_tokens"]
を指定する必要がある ことに注意。
Training Arguments
training_arguments = TrainingArguments(
do_eval = True,
output_dir = output_dir,
per_device_train_batch_size = 8,
gradient_accumulation_steps = 4, # train_batch_size × accumulation_steps が勾配更新に使われる実質のバッチサイズといえる
optim = "paged_adamw_32bit", # 現状SOTAの方法? todo
save_steps = 5, # 何ステップごとのcheckpointを保存対象にするか
logging_steps = 1, # 何ステップごとにログで監視するか
evaluation_strategy = "steps", # "epoch" or "steps" or "no":トレーニング中にバリデーションしない
eval_steps = 1, # evaluation_strategyが"steps"の場合、何ステップごとにvalidationするかを決める
learning_rate = 5e-4,
fp16 = True,
bf16 = False,
max_grad_norm = 0.3, # This parameter sets the maximum gradient norm for gradient clipping.
max_steps = 600,
warmup_ratio = 0,
group_by_length = True, # データセット全体で見て最大長のシーケンスに合わせてパディングするよりも、バッチごとに行う方が効率が良い(Dynamic Padding)。Dynamic Paddingにはやり方が複数あるが、ほぼ同じ長さのサンプルをグループ化する場合はこれをTrueにすればよい。ただしこの方法を使うとtrain lossが滑らかにならないことがあるらしい
lr_scheduler_type = "cosine",
report_to = "wandb", # ログ可視化ツール
metric_for_best_model = "eval_loss", #ベストモデルを決定するためのメトリクス。ユーザー定義も指定できるが、eval_lossが一番簡単
load_best_model_at_end = True, # ベストcheckpointを常に保存する設定。
save_total_limit = 3, # load_best_model_at_end = Trueの場合はベストのチェックポイント一つと、直近のチェックポイントを保存(Top nのチェックポイントが保存されるわけではない)。これを指定しておかないと、save_stepsごとに作成されるCheckpointが全部保存されてしまう。
run_name = run_name, # ログ可視化ツール(wandbまたはmlflow)に表示する際のrun名
)
高いLearningRate × 小さいEpoch = 高速だが学習の質が落ちやすい
低いLearningRate × 大きいEpoch = 時間はかかるが学習の質は落ちづらい
load_best_model_at_end について
save_total_limit について
Trainer
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset["train"],
eval_dataset = dataset["valid"],
dataset_text_field = "text",
peft_config = peft_config,
max_seq_length = None, # Noneにするとtokenizerの値を自動で使用
packing = False, # todo group_by_lengthとの違いは??
args = training_arguments,
)
学習の開始と中断・再開
学習を開始
trainer.train()
trainer.model.save_pretrained(desired_output_dir)
save_pretrained
について: save strategyを指定して、都度ベストなcheckpointを保存している場合はわざわざ実行しなくてもよい。push_to_hubとかのパラメータを指定できるので、そういう用途には便利なのだろう。
save_strategy = “no”としておいて学習終了後に実行すると余計なファイルを最低限しか残さなくて済むらしい(でもsave_total_limit = 1にした場合も同じでは?)
中断後、output_dirに存在するcheckpointから学習を再開する場合
trainer.train(resume_from_checkpoint=True)
trainer.model.save_pretrained(output_dir)
ちなみに、WandBの方でも再開の設定が必要だったらしい。。。普通に同じWANDB_PROJECTとrun nameを指定するだけで再開していたが、続きが別色のラインでlogされていたので問題ないかと思われる
参考: https://docs.wandb.ai/ja/guides/runs/resuming#%E5%86%8D%E9%96%8B%E3%81%AE%E3%82%AC%E3%82%A4%E3%83%80%E3%83%B3%E3%82%B9
WandBで学習を監視
trainを開始する前に、wandbにログインしておく。以下のコードはjupyter notebook用
import wandb
wandb.login()
%env WANDB_LOG_MODEL = 'end'
%env WANDB_PROJECT = three_line_summarization_7b_elyza_additional_tokens_2_pad_unk # wandb上でのプロジェクト名。一つのプロジェクトページの中に複数のrunを記録することも可能
%env WANDB_SILENT = True
Trainerでload_best_model_at_end=True
となっている場合、W&B は最もパフォーマンスの良いモデルを Artifacts に保存する。WANDB_LOG_MODEL
を"end"とすると最終的なベストモデルを保存するが、"checkpoint"とした場合はsave_stepsごとに保存される。
ref: https://docs.wandb.ai/guides/integrations/huggingface#turn-on-model-versioning
※W&Bの容量は100GB
WANDB_SILENT = True とするとwandbのログステートメントが無効になり、すべてのログがWANDB_DIR/debug.logに書き込まれる。
学習終了後や中断時は、以下の操作でwandbを終了させる
wandb.finish()
学習曲線
Fine-tuning済みのモデルで要約生成
以下では再度ベースモデルをロードして、PeftModel.from_pretrained
メソッドでLoRAアダプターをマージしているが、merge_and_unload
というメソッドを使う方法もあるらしい(未検証)
merged = trained_model.merge_and_unload()
merged.save_pretrained("merged",safe_serialization=True)
tokenizer.save_pretrained("merged")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from peft import PeftModel
# trainの時と同じ設定
model_name = "elyza/ELYZA-japanese-Llama-2-7b-instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True
)
model.config.use_cache = False
# Add special tokens
special_tokens = ["[R_START]", "[R_END]"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_tokens(special_tokens, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"
checkpoint = 'output_dir/your_best_checkpoint'
trainedmodel = PeftModel.from_pretrained(
model,
checkpoint,
torch_dtype=torch.float16,
device_map={'':0}
)
if torch.cuda.is_available():
trainedmodel = trainedmodel.to("cuda")
# trainの時と同じプロンプトフォーマットにすること
def gen(text):
prompt = f"""<s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
以下の入力文を3行で要約しなさい。
入力文:
{text} [/INST] [R_START]
"""
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
token_ids.to(model.device)
with torch.no_grad():
output_ids=trainedmodel.generate(
inputs=token_ids,
do_sample=False, # Greedy decodingしたい場合はFalseにする。
pad_token_id=tokenizer.unk_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=256
)
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True)
print(output)
text = "宮内庁は25日、秋篠宮妃紀子さまが新型コロナウイルスに感染されたと発表した。29日まで宮邸で療養する予定。29、30日に鹿児島県で全国高校総合文化祭( 総文祭)の式典などに出席する予定だったが、取りやめる。24 日夕に発熱の症状があったことから検査をした結果、陽性が確認された。秋像宮さまや次女佳子さま、長男悠仁さまは陰性だった。県は秋篠宮さまと悠仁さまの来鹿について、陰性の状態が続けば、変更はないとしている。"
gen(text)
>>
3行要約:
秋篠宮妃紀子さまが新型コロナウイルスに感染したと宮内庁が発表した
29日まで宮邸で療養する予定で、29、30日に鹿児島県で式典などに出席する予定だったが取りやめる
秋篠宮さまや次女佳子さま、長男悠仁さまは陰性だった
生成について
generationの設定は、model.generation_config
で確認できる(デフォルトと異なる設定のみ表示される)。
model.generate()の引数で生成方式をコントロールできる。
各生成方式の詳細
→GreedySeach, BeamSearch, Sampling
→ContrastiveSearch
→DiverseBeamSearch
→Assisted Decoding
- GreedySearch(デフォルト)
- do_sample=False
- num_beams=1
- BeamSearch
- num_beams > 1
- ContrastiveSearch
- penalty_alphaとtop_kを指定
- QLoRAモデルだと`RuntimeError: "baddbmm_with_gemm" not implemented for 'Half'`と言われてしまう。どうやらfloat16に対応してないようなので`model.float()`でfloat32に変換してみると動いた。モデルの型はmodel.dtypeで確認。
- MultinomialSampling
- do_sample=True
- num_beams=1
- Beam-search multinomial sampling
- do_sample=True
- num_beams > 1
- Diverse beam search decoding
- num_beamsとnum_beam_groupsを指定
- Assisted Decoding
- assistant_modelを指定
- 同じトークナイザ (理想的にはより小さいモデル) を持つアシスタントモデルを使って複数の候補トークンをGreedyに生成する。詳細不明...
- Top-k sampling
- do_sample=True
- top_k > 1
- Top-p sampling
- do_sample=True
- 0 < top_p < 1
- top_k = 0
- top_pで指定した累積確率となるtopをランダムサンプリングの対象とする。top-kより有用に見えるが、実際にはどちらでもうまく機能する。top-kと併用すれば、非常に低いランクの単語を避けながら、ある程度ダイナミックな選択を可能にする(組み合わせた場合どんなアルゴリズムになるか詳細不明)。
- num_return_sequences
- 複数の候補を出しうる生成方法の場合、指定するとその数だけ生成候補を出力
もし特定のgeneration_configをモデルと一緒に保存したい場合、GenerationConfig.save_pretrained()
で保存できる(詳細)。
生成完了後にまとめて表示するのではなく、ChatGPTのように少しずつ生成結果を表示(ストリーミング)したい場合は、TextStreamer
クラスが使える。このクラスのAPIはまだ実装されていないが、以下のようにstreamerを定義して、.generate()
の引数に追加するだけでOK.
from transformers import TextStreamer
def stream_gen(text):
prompt = f"""<s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
以下の入力文を3行で要約しなさい。
入力文:
{text} [/INST] [R_START]
"""
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
token_ids.to(model.device)
streamer = TextStreamer(tokenizer)
_ = trainedmodel.generate(
inputs=token_ids ,
do_sample=False,
streamer=streamer,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=256
)
その他ref
https://note.com/bakushu/n/ne7760c47e39e
https://github.com/huggingface/peft/issues/334
https://colab.research.google.com/drive/16qKy92cGoNPWrlQ4zlvntVGeSgjrknVF?usp=sharing#scrollTo=tpfeUu0NKQRs
https://medium.com/@geronimo7/from-transcripts-to-ai-chat-an-experiment-with-the-lex-fridman-podcast-3248d216ec16
https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md
https://note.com/npaka/n/n82abc49a4c96
https://note.com/npaka/n/na506c63b8cc9