Meta LLaMA 3.1 8Bモデルとは?
Metaが開発したLLaMA 3.1 8Bは、80億パラメータを持つ高性能な言語モデルです。このモデルは、計算効率と性能のバランスが取れており、自然言語処理(NLP)の幅広いタスクに対応可能です。特に、Text-to-SQLタスクでは、自然言語入力から正確なSQLクエリを生成する能力を発揮します。
Unslothライブラリとは?
Unslothは、LLMの効率的なファインチューニングをサポートするライブラリです。特に、量子化されたモデルを提供し、計算リソースが限られた環境でも大規模モデルを適応させることができます。
Text-to-SQLタスクへの適応
モデルの初期化
まず、Unslothライブラリを使用して、事前学習済みのLLaMA 3.1 8Bモデルを初期化します。
from unsloth import FastLanguageModel
max_seq_length = 2048 # シーケンス長
load_in_4bit = True # 4bit量子化でメモリ削減
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Meta-Llama-3.1-8B-bnb-4bit",
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
)
パラメータ効率化ファインチューニング(PEFT)
次に、PEFT(Parameter-Efficient Fine-Tuning)を使用してモデルを調整します。LoRA(Low-Rank Adaptation)を適用することで、効率的にモデルを適応させます。
model = FastLanguageModel.get_peft_model(
model,
r=16, # 低ランクの次元数
target_modules=["..."], # 調整対象のモジュール
lora_alpha=16,
lora_dropout=0,
bias="none",
)
- データセットの準備
Text-to-SQLタスク用にカスタムデータセットを準備します。
データ生成
ChatGPTや既存のデータセットを利用して、質問(question)、文脈(context)、SQLクエリ(sql_query)の組み合わせを作成します。
データセットのフォーマット
定義したプロンプトテンプレートを用いてデータをフォーマットします。
sql_prompt = """Below is input question that user ask, context is given to help user's question, generate SQL response for user's question.
### Input:
{}
### Context:
{}
### SQL Response:
{}"""
def formatting_prompts_func(examples):
instructions = examples["context"]
inputs = examples["question"]
outputs = examples["sql_query"]
texts = []
for instruction, input, output in zip(instructions, inputs, outputs):
text = sql_prompt.format(instruction, input, output) + "<EOS>"
texts.append(text)
return {"text": texts}
データセットのロード
CSVやExcelファイル形式のデータをHugging Face形式に変換します。
```python
from datasets import load_dataset, Dataset
df = pd.read_excel('fine-tuning-dataset_latest.xlsx')
dataset = Dataset.from_pandas(df)
train_test_split = dataset.train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
validation_dataset = train_test_split['test']
dataset = dataset.map(formatting_prompts_func, batched=True)
トレーニング設定
以下のようにトレーニング設定を定義します。
from trl import SFTTrainer
from transformers import TrainingArguments
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
num_train_epochs=1,
learning_rate=2e-4,
logging_steps=1,
output_dir="outputs",
),
)
モデルのトレーニングと保存
モデルをトレーニングし、ファインチューニング済みモデルを保存します。
trainer.train()
trainer.model.save_pretrained("fine_tuned_model")
tokenizer.save_pretrained("fine_tuned_model")
print("Training complete.")