3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

生成AIAdvent Calendar 2024

Day 20

Meta-LLaMA 3.1 8BモデルをUnslothでText-to-SQLに適応

Last updated at Posted at 2024-12-25

Meta LLaMA 3.1 8Bモデルとは?

Metaが開発したLLaMA 3.1 8Bは、80億パラメータを持つ高性能な言語モデルです。このモデルは、計算効率と性能のバランスが取れており、自然言語処理(NLP)の幅広いタスクに対応可能です。特に、Text-to-SQLタスクでは、自然言語入力から正確なSQLクエリを生成する能力を発揮します。

Unslothライブラリとは?

Unslothは、LLMの効率的なファインチューニングをサポートするライブラリです。特に、量子化されたモデルを提供し、計算リソースが限られた環境でも大規模モデルを適応させることができます。

Text-to-SQLタスクへの適応

モデルの初期化

まず、Unslothライブラリを使用して、事前学習済みのLLaMA 3.1 8Bモデルを初期化します。

from unsloth import FastLanguageModel

max_seq_length = 2048  # シーケンス長
load_in_4bit = True    # 4bit量子化でメモリ削減

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Meta-Llama-3.1-8B-bnb-4bit",
    max_seq_length=max_seq_length,
    load_in_4bit=load_in_4bit,
)

パラメータ効率化ファインチューニング(PEFT)

次に、PEFT(Parameter-Efficient Fine-Tuning)を使用してモデルを調整します。LoRA(Low-Rank Adaptation)を適用することで、効率的にモデルを適応させます。

model = FastLanguageModel.get_peft_model(
    model,
    r=16,  # 低ランクの次元数
    target_modules=["..."],  # 調整対象のモジュール
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
)
  1. データセットの準備
    Text-to-SQLタスク用にカスタムデータセットを準備します。

データ生成

ChatGPTや既存のデータセットを利用して、質問(question)、文脈(context)、SQLクエリ(sql_query)の組み合わせを作成します。

データセットのフォーマット

定義したプロンプトテンプレートを用いてデータをフォーマットします。

sql_prompt = """Below is input question that user ask, context is given to help user's question, generate SQL response for user's question.

### Input:
{}

### Context:
{}

### SQL Response:
{}"""

def formatting_prompts_func(examples):
    instructions = examples["context"]
    inputs = examples["question"]
    outputs = examples["sql_query"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        text = sql_prompt.format(instruction, input, output) + "<EOS>"
        texts.append(text)
    return {"text": texts}
データセットのロード
CSVやExcelファイル形式のデータをHugging Face形式に変換します

```python
from datasets import load_dataset, Dataset

df = pd.read_excel('fine-tuning-dataset_latest.xlsx')
dataset = Dataset.from_pandas(df)
train_test_split = dataset.train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
validation_dataset = train_test_split['test']
dataset = dataset.map(formatting_prompts_func, batched=True)

トレーニング設定

以下のようにトレーニング設定を定義します。

from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=5,
        num_train_epochs=1,
        learning_rate=2e-4,
        logging_steps=1,
        output_dir="outputs",
    ),
)

モデルのトレーニングと保存

モデルをトレーニングし、ファインチューニング済みモデルを保存します。

trainer.train()
trainer.model.save_pretrained("fine_tuned_model")
tokenizer.save_pretrained("fine_tuned_model")
print("Training complete.")
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?