12
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

[翻訳] Hugging faceのPEFTのクイックツアー

Posted at

Quicktourの翻訳です。

本書は抄訳であり内容の正確性を保証するものではありません。正確な内容に関しては原文を参照ください。

🤗 PEFTには、大規模な事前学習済みモデルのトレーニングに対するパラメーター効率の高いファインチューニングのメソッドが含まれています。従来のパラダイムでは、それぞれの後段のタスクごとにモデルのパラメーターすべてをファインチューニングしていましたが、現在のモデルには膨大な数のパラメーターが存在するため過度なコストを必要とし、現実的ではありません。代わりに、少数のプロンプトパラメーターをトレーニングしたり、トレーニング可能なパラメーターの数を削減するためのlow-rank adaptation (LoRA)のような再パラメーター化のような手法を使用する方が効率的です。

このクイックツアーでは、🤗 PEFTの主要な機能を説明し、コンシューマーデバイスでは通常アクセスできない大規模な事前学習済みモデルをトレーニングするお手伝いをします。分類ラベルを生成し、推論に活用できるように、LoRAを用いて1.2Bパラメーターのbigscience/mt0-largeをどのようにトレーニングするのかを説明します。

PeftConfig

🤗 PEFTメソッドのそれぞれは、PeftModelを構築する際に重要なパラメーターすべてを格納するPeftConfigクラスによって定義されます。

LoRAを使用するので、LoraConfigクラスをロード、作成する必要があります。LoraConfigでは、以下のパラメーターを指定します:

  • task_type、この場合はsequence-to-sequence language modelingです。
  • inference_modeではモデルを推論に使用するかどうかを指定します。
  • rはlow-rankマトリクスの次元数です。
  • lora_alphaはlow-rankマトリクスのスケーリングファクターです。
  • lora_dropoutはLoRAレイヤーのドロップアウト確率です。
Python
from peft import LoraConfig, TaskType

peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

調整可能な他のパラメーターの詳細についてはLoraConfigリファレンスをご覧ください。

PeftModel

PeftModelget_peft_model()関数で作成されます。これは🤗 Transformersライブラリからロードできるベースモデルと、固有の🤗 PEFTメソッドにモデルをどのように設定するのかの指示を含むPeftConfigを受け取ります。

ファインチューンしたいベースモデルをロードすることからスタートします。

Python
from transformers import AutoModelForSeq2SeqLM

model_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

PeftModelを作成するには、get_peft_model関数でベースモデルとpeft_configをラップします。モデルでトレーニング可能なパラメーター数の感覚を掴むには、print_trainable_parametersメソッドを使用します。このケースでは、モデルのパラメーターの0.19%のみをトレーニングします!🤏

Python
from peft import get_peft_model

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
"output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282"

これですべてです🎉!これで、🤗 TransformersのTrainer、🤗 Accelerate、その他の任意のカスタムPyTorchトレーニングループを用いてモデルをトレーニングすることができます。

モデルの保存とロード

モデルのトレーニングが完了した後は、save_pretrained関数を用いてディレクトリにモデルを保存することができます。また、push_to_hub関数を用いて、Hubにモデルを保存することができます(最初にHugging Faceアカウントにログインしていることを確認してください)。

Python
model.save_pretrained("output_dir")

# if pushing to Hub
from huggingface_hub import notebook_login

notebook_login()
model.push_to_hub("my_awesome_peft_model")

これは、トレーニングされたインクリメンタルな🤗 PEFTの重みのみを保存し、格納、転送、ロードが非常に効率的であることを意味します。例えば、RAFT datasetのサブセットであるtwitter_complaintsに対してLoRAを用いてこのbigscience/T0_3Bをトレーニングすると、二つのファイルのみが含まれます:adapter_config.jsonadapter_model.binです。後者はたった19MBです!

from_pretrained関数を用いて推論のためにモデルを容易にロードします:

Python
  from transformers import AutoModelForSeq2SeqLM
+ from peft import PeftModel, PeftConfig

+ peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
+ config = PeftConfig.from_pretrained(peft_model_id)
  model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
+ model = PeftModel.from_pretrained(model, peft_model_id)
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

  model = model.to(device)
  model.eval()
  inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")

  with torch.no_grad():
      outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10)
      print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])
  'complaint'

次のステップ

ここまでで🤗 PEFTメソッドを用いてどのようにモデルをトレーニングするのかを見たことになり、プロンプトチューニングのように他のメソッドのいくつかを試してみることをお勧めします。ステップはこのクイックスタートで説明したものと非常に似ています。🤗 PEFTメソッドのPeftConfigを準備し、設定とベースモデルからPeftModelを作成するためにget_peft_modelを使用します。そして、お好きなようにトレーニングすることができます!

また、セマンティックセグメンテーション、多言語自動スピーチ認識、DreamBooth、トークン分類のような固有のタスクに対する🤗 PEFTメソッドを用いたモデルのトレーニングに興味があるのであれば、ご自由にタスクガイドをご覧ください。

12
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?