8
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

MetaのLlama 2をDatabricksでQLoRAを使ってファインチューニングしてみる

Posted at

こちらの続きです。04_langchainは、サービングしているモデルを呼び出しているだけなのでスキップします。あと、05_fine_tune_deepspeedは途中のエラーで動かず。

QLoRAとは

私は知らなかったです。そもそもLoRAもまだよく分かってないです。こちらのようですね。

完全な16bitのファインチューニングタスクのパフォーマンスを維持しつつも、単一の48GB GPUで65Bのパラメーターモデルをファインチューニングするのに十分なメモリーの使用量を削減する効率的なファインチューニングアプローチであるQLoRAを発表します。QLoRAは凍結された4-bitの量子化された学習済み言語モデルを通じて勾配をLow Rank Adapters(LoRA)に逆伝播します。

ライブラリのインストール

%pip install -U git+https://github.com/huggingface/transformers.git  git+https://github.com/huggingface/accelerate.git git+https://github.com/huggingface/peft.git
%pip install datasets==2.12.0 bitsandbytes==0.40.1 einops==0.6.1 trl==0.4.7
%pip install torch==2.0.1

データセットのロード

日本語版のデータセットtaka-yayoi/databricks-dolly-15k-jaを使います。

from datasets import load_dataset

dataset_name = "taka-yayoi/databricks-dolly-15k-ja"
dataset = load_dataset(dataset_name, split="train")

プロンプトテンプレート

データセットにプロンプトテンプレートを適用します。これも後で日本語にした方が良さそう。

INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"

PROMPT_NO_INPUT_FORMAT = """{intro}

{instruction_key}
{instruction}

{response_key}
{response}

{end_key}""".format(
  intro=INTRO_BLURB,
  instruction_key=INSTRUCTION_KEY,
  instruction="{instruction}",
  response_key=RESPONSE_KEY,
  response="{response}",
  end_key=END_KEY
)

PROMPT_WITH_INPUT_FORMAT = """{intro}

{instruction_key}
{instruction}

{input_key}
{input}

{response_key}
{response}

{end_key}""".format(
  intro=INTRO_BLURB,
  instruction_key=INSTRUCTION_KEY,
  instruction="{instruction}",
  input_key=INPUT_KEY,
  input="{input}",
  response_key=RESPONSE_KEY,
  response="{response}",
  end_key=END_KEY
)

def apply_prompt_template(examples):
  instruction = examples["instruction"]
  response = examples["response"]
  context = examples.get("context")

  if context:
    full_prompt = PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response, input=context)
  else:
    full_prompt = PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response)
  return { "text": full_prompt }

dataset = dataset.map(apply_prompt_template)

結果を確認します。

dataset["text"][0]
Out[6]: 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nヴァージン・オーストラリア航空はいつから運航を開始したのですか?\n\nInput:\nヴァージン・オーストラリア航空(Virgin Australia Airlines Pty Ltd)の商号で、オーストラリアを拠点とする航空会社です。ヴァージン・グループを使用する航空会社の中で、保有機材数では最大の航空会社である。2000年8月31日にヴァージン・ブルー航空として、2機の航空機で単一路線で運航を開始した[3]。2001年9月のアンセット・オーストラリア航空の破綻後、突然オーストラリア国内市場の大手航空会社としての地位を確立した。その後、ブリスベン、メルボルン、シドニーをハブとして、オーストラリア国内の32都市に直接乗り入れるまでに成長した[4]。\n\n### Response:\nヴァージン・オーストラリア航空は、2000年8月31日にヴァージン・ブルー航空として、2機の航空機で単一路線の運航を開始しました。\n\n### End'

モデルのロード

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer

model = "elinas/llama-7b-hf-transformers-4.29"
revision = "d33594ee64ef1b6264543b6a88f60982a55fdb7a"

tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model,
    quantization_config=bnb_config,
    device_map="auto",
    revision=revision,
    trust_remote_code=True
)
model.config.use_cache = False

LoRAモデルを作成するための設定ファイルのロード。

以下はノートブックの説明。まだ良くわからない…。

QLoRAの論文によると、最大のパフォーマンスのためにはトランスフォーマーブロックの全ての線形レイヤーを検討することが重要です。このため、混成のクエリーキーバリューレイヤーに加えて、ターゲットモジュールにdensedense_h_to_4_hdense_4h_to_hレイヤーを追加します。

import bitsandbytes as bnb

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit # if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])


    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)
find_all_linear_names(model)
Out[9]: ['q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'k_proj', 'v_proj']
from peft import LoraConfig

lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'k_proj', 'v_proj']
)

トレーナーのロード

from transformers import TrainingArguments

output_dir = "/dbfs/tmp/takaaki.yayoi@databricks.com/qlora/results"
per_device_train_batch_size = 16
gradient_accumulation_steps = 4
optim = "paged_adamw_32bit"
save_steps = 10
logging_steps = 10
learning_rate = 2e-4
max_grad_norm = 0.3
max_steps = 500
warmup_ratio = 0.03
lr_scheduler_type = "constant"

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
    ddp_find_unused_parameters=False,
)

これまで準備したデータセットや設定をトレーナーに引き渡します。

from trl import SFTTrainer

max_seq_length = 512

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)

トレーニングを安定させるために、レイヤーのノルムをfloat 32にアップキャストすることでモデルの前処理を行います。

for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)

モデルのトレーニング

trainer.train()を呼び出してトレーニングを起動します。

trainer.train()

モデルがMLflowにトラッキングされます。
Screenshot 2023-07-30 at 10.59.15.png
Screenshot 2023-07-30 at 10.59.59.png

4時間程度でファインチューニングは完了しました。使ったインスタンスはg5.8xlarge(メモリー128GB、1GPU)です。
Screenshot 2023-07-30 at 14.33.51.png

ただ、保存しようとしたらエラーに。

model.save_pretrained("/local_disk0/llamav2-7b-lora-fine-tune")
NotImplementedError: You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported

GitHubとか見ると、未サポートらしいです。うーん、今日は一旦ここまでで。残念。

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

8
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?