0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

BLOOM-3BモデルとLoRAを用いたファインチューニング

Posted at

BLOOM-3BモデルとLoRAを用いたファインチューニングチュートリアル

このチュートリアルでは、BLOOM-3Bモデルと**LoRA(Low-Rank Adaptation)**を使用して、JCommonsenseQAデータセットを用いたファインチューニングの手順を詳しく説明します。初心者のエンジニアの方でも理解しやすいよう、ステップバイステップで進めていきます。


目次

  1. 前提知識
  2. 環境の準備
  3. データセットの準備
  4. モデルの準備
  5. データの前処理
  6. LoRAを用いたファインチューニング
  7. モデルの保存と評価
  8. まとめ
  9. 参考リンク

1. 前提知識

1.1 LoRAとは

**LoRA(Low-Rank Adaptation)**は、大規模言語モデルの一部のパラメータのみを更新することで、メモリ使用量と計算コストを削減しつつ効率的にファインチューニングを行う手法です。

1.2 BLOOMモデルとは

BLOOMは、多言語対応の大規模言語モデルで、日本語を含む複数の言語に対応しています。BLOOM-3Bは約30億のパラメータを持ち、RTX 3090などのGPUで扱いやすいサイズです。


2. 環境の準備

2.1 必要なソフトウェアとライブラリ

  • Python 3.8以上
  • PyTorch
  • Transformers(Hugging Face)
  • Datasets(Hugging Face)
  • Accelerate
  • PEFT(Parameter-Efficient Fine-Tuning)
  • bitsandbytes(8bit量子化のため)

2.2 仮想環境の設定

  1. 仮想環境の作成とアクティベート

    python -m venv bloom_lora_env
    
    • Linux/Macの場合:

      source bloom_lora_env/bin/activate
      
    • Windowsの場合:

      bloom_lora_env\Scripts\activate
      
  2. ライブラリのインストール

    pip install --upgrade pip
    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    pip install transformers datasets accelerate sentencepiece peft bitsandbytes
    
    • 注意: bitsandbytesは8bit量子化を行うために必要です。

3. データセットの準備

3.1 JCommonsenseQAデータセットの取得

from datasets import load_dataset

dataset = load_dataset("leemeng/jcommonsenseqa-v1.1")
  • データセットにはtrainvalidationtestのスプリットがあります。

3.2 データの確認

print(dataset['train'][0])

出力例:

{
    'q_id': 0,
    'question': '主に子ども向けのもので、イラストのついた物語が書かれているものはどれ?',
    'choice0': '世界',
    'choice1': '写真集',
    'choice2': '絵本',
    'choice3': '論文',
    'choice4': '図鑑',
    'label': 2
}

4. モデルの準備

4.1 Hugging Faceへのログイン

セキュリティ上の注意: トークンは環境変数や.envファイルを使用して安全に管理してください。

import os
from huggingface_hub import login

HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
login(token=HUGGINGFACE_TOKEN)
  • トークンの取得方法: Hugging Faceのアカウントページからアクセストークンを取得します。

4.2 モデルとトークナイザーのロード

from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "bigscience/bloom-3b"
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,        # 8bit量子化でメモリを節約
    device_map='auto'         # デバイスを自動設定
)
  • 注意: load_in_8bit=Trueを使用するため、bitsandbytesライブラリが必要です。

5. データの前処理

5.1 データの整形

データをモデルが理解できる形式に整形します。

def format_example(example):
    question = example['question']
    choices = [example[f'choice{i}'] for i in range(5)]  # choice0からchoice4まで取得
    label_index = int(example['label'])  # 正解のインデックス(0から4)

    # 選択肢を整形
    options = '\n'.join([f"{chr(65 + i)}: {choice}" for i, choice in enumerate(choices)])
    prompt = f"質問: {question}\n{options}\n答えの選択肢を一つ選んでください。"

    # 正解のテキストを取得
    answer_text = choices[label_index]

    return {'prompt': prompt, 'label': answer_text}

formatted_dataset = dataset.map(format_example)

5.2 トークナイズ

def tokenize_function(example):
    # 質問と回答を結合
    full_text = example['prompt'] + "\n答え: " + example['label']
    
    # 全体をトークナイズ
    tokenized = tokenizer(
        full_text,
        max_length=512,
        truncation=True,
        padding='max_length',
    )
    
    # ラベルの作成
    labels = tokenized['input_ids'].copy()
    
    # プロンプト部分をマスク
    prompt_length = len(tokenizer(example['prompt'])['input_ids'])
    labels[:prompt_length] = [-100] * prompt_length  # -100は損失計算から除外されるトークンを示す

    tokenized['labels'] = labels
    return tokenized

tokenized_dataset = formatted_dataset.map(tokenize_function, batched=False)
  • 説明:
    • -100を使用して、モデルが回答部分のみを学習するように設定します。

6. LoRAを用いたファインチューニング

6.1 LoRAの設定

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["query_key_value"],  # BLOOMモデルの適切なモジュール
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)
  • target_modulesの確認: モデル内のモジュール名を確認することで、適切なターゲットを指定できます。

    for name, module in model.named_modules():
        if 'query_key_value' in name:
            print(name)
    

6.2 モデルにLoRAを適用

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
  • 出力例:

    trainable params: 2,457,600 || all params: 3,005,015,040 || trainable%: 0.0818
    

6.3 トレーニングパラメータの設定

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./bloom-lora-finetuned",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    num_train_epochs=3,
    learning_rate=1e-4,
    fp16=True,
    save_steps=500,
    logging_steps=100,
    report_to="none"
)

6.4 データコラトの定義

import torch

def data_collator(features):
    return {
        'input_ids': torch.tensor([f['input_ids'] for f in features], dtype=torch.long),
        'attention_mask': torch.tensor([f['attention_mask'] for f in features], dtype=torch.long),
        'labels': torch.tensor([f['labels'] for f in features], dtype=torch.long),
    }

6.5 トレーナーの作成

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    data_collator=data_collator,
)

6.6 トレーニングの開始

trainer.train()

7. モデルの保存と評価

7.1 モデルの保存

model.save_pretrained("./bloom-lora-finetuned")
tokenizer.save_pretrained("./bloom-lora-finetuned")

7.2 モデルの評価

# テストデータで評価
test_results = trainer.evaluate(eval_dataset=tokenized_dataset['validation'])
print(test_results)
  • 評価指標: lossperplexityなど。

8. まとめ

このチュートリアルでは、以下の手順を踏んでBLOOM-3BモデルをLoRAを用いてファインチューニングしました。

  1. 環境の準備: 必要なライブラリと仮想環境を設定。
  2. データセットの準備: JCommonsenseQAデータセットを取得し、データを確認。
  3. モデルの準備: BLOOM-3Bモデルとトークナイザーをロード。
  4. データの前処理: データをモデルが理解できる形式に整形し、トークナイズ。
  5. LoRAを用いたファインチューニング: LoRAの設定を行い、モデルに適用。トレーナーを作成してトレーニングを実行。
  6. モデルの保存と評価: ファインチューニング済みモデルを保存し、評価データセットで性能を確認。

9. 参考リンク


補足情報

  • APIトークンの管理: セキュリティのため、APIトークンはコード内に直接記載せず、環境変数や.envファイルを使用して安全に管理してください。
  • メモリエラーの対処: メモリエラーが発生した場合、per_device_train_batch_sizeを減らすか、gradient_accumulation_stepsを増やすことで対処できます。
  • モデルの評価: 必要に応じて、カスタムの評価関数を実装してモデルの性能を詳細に分析してください。
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?