4
2
生成AIに関する記事を書こう!
Qiita Engineer Festa20242024年7月17日まで開催中!

[Dify] Difyで貯めたナレッジのQAを使って、Llama3-8BをLoRAでファインチューニングする。

Posted at

はじめに

DifyのナレッジにQA機能があります。これ、LLMのQAタスクの学習データに使えそう!とのことで一回試してみました。
評価うんぬんまではあまりやらず実装に重きをおいた記事として書いていきます。
ではいきましょう。

LoRAの説明

ファインチューニング(Fine-Tuning)はLLMの回答の精度を高めるための手法の一つです。
今回の例はその中でも有名なLoRAという手法を使いますが、めちゃくちゃ簡単にいうと元のモデルの一部の層を再度学習させることで、一部を微調整する方法です。
図を見るとめちゃクチャわかりやすいです。GPTのアーキテクチャを例に説明します。
横に全結合層を足すことでチューニングする方法です。

image.png

説明はこのぐらいにして次から実際の方法を見ていきましょう。

Difyのナレッジを用意する。

データソースはなんでもいいです。私はFirecrawlを使って、pyscfのドキュメントをデータソースにします。
スクリーンショット 2024-07-13 15.59.10.png

ここでQAセットを作成するようにしてください。
スクリーンショット 2024-07-13 15.59.28.png

以下のようにQAセットを作成されます。
スクリーンショット 2024-07-13 15.58.42.png

これでDify側での準備は完了です。

LoRAでファインチューニングする

こちらはGoogleColabでA100を使用します。

事前準備

Llama3-8Bを使用しますので、事前に使用申請を出しておいてください。大体数分で申請許可がきます。

wandbを使って実行管理しているので登録してプロジェクトを作成しておいて下さい。

学習準備

必要なライブラリをinstallしておいてください。

!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.0

huggingfaceを使うのでトークンの発行もしておいてください。

import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

wandb使う場合は以下を設定してください。プロジェクト名は自分でつけたやつに書き換えてください。

!pip install wandb
!wandb login
import wandb
wandb.init(project="your project")

Meta-Llama-3-8Bモデルを4ビット量子化を使って効率的にロードし、メモリ使用量を削減しています。また、モデルとそれに対応するトークナイザーを初期化しています。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 
)
model_id = "meta-llama/Meta-Llama-3-8B"
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

Difyのナレッジからデータを取得します。QAデータセットになっているのでそれに応じてデータを作っています。
dataset_id、api_key、base_urlをご自身の環境によって書き換えてください。
ローカルのDifyを使用している人はngrokを使用するといいでしょう。


import requests
import pandas as pd
def fetch_document_ids(dataset_id, api_key, base_url):
    url = f"{base_url}/v1/datasets/{dataset_id}/documents"
    headers = {
        "Authorization": f"Bearer {api_key}"
    }
    
    all_documents = []
    page = 1
    
    while True:
        response = requests.get(url, headers=headers, params={"page": page})
        
        if response.status_code == 200:
            data = response.json()
            documents = data.get('data', [])
            
            if not documents:
                break
            
            all_documents.extend(documents)
            
            if not data.get('has_more', False):
                break
            
            page += 1
        else:
            print(f"Error fetching documents: {response.status_code}, {response.text}")
            break
    
    return all_documents

def fetch_document_segments(dataset_id, document_id, api_key, base_url):
    url = f"{base_url}/v1/datasets/{dataset_id}/documents/{document_id}/segments"
    headers = {
        "Authorization": f"Bearer {api_key}"
    }
    
    response = requests.get(url, headers=headers)
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Error fetching segments for document {document_id}: {response.status_code}, {response.text}")
        return None

def fetch_all_segments(dataset_id, api_key, base_url):
    documents = fetch_document_ids(dataset_id, api_key, base_url)
    all_segments = {}
    
    for doc in documents:
        doc_id = doc['id']
        segments = fetch_document_segments(dataset_id, doc_id, api_key, base_url)
        if segments:
            all_segments[doc_id] = segments
    
    return all_segments

def extract_to_dataframe(data_items):
    extracted_data = []
    
    for doc_id, doc_data in data_items:
        for item in doc_data.get('data', []):
            extracted_data.append({
                'id': item.get('id'),
                'question': item.get('content'),
                'answer': item.get('answer', '')  # answerフィールドがない場合は空文字を設定
            })
    
    # pandas DataFrameに変換
    df = pd.DataFrame(extracted_data)
    
    return df
# 使用例
dataset_id = "dataset_id"
api_key = "api_key"
base_url = "dify-api-url"

all_segments = fetch_all_segments(dataset_id, api_key, base_url)

df = extract_to_dataframe(all_segments.items())

Pandas DataFrameからHugging Face Datasetsオブジェクトで質問応答データセットを準備しています。
学習用データをプロンプトに整形する関数を定義し、システムプロンプト、ユーザーの質問、アシスタントの回答を含むメッセージ構造を作成しています。データセットの各例にプロンプトを追加し、不要なカラムを削除しています。
最後に、データセットをトレーニングセットと評価セットに分割しています。


import datasets
dataset = datasets.Dataset.from_pandas(df)
# プロンプトの生成
def generate_prompt(qadataset):
    messages = [
        {
            'role': "system",
            'content': "あなたは優秀な量子化学計算の専門家です。与えられたpyscfの内容から、日本語で回答してください。" # シンプルな指示で試してみます
        },
        {
            'role': "user",
            'content': qadataset["question"] 
        },
        {
            'role': "assistant",
            'content': qadataset["answer"]
        }
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False)

def add_text(example):
    example["prompt"] = generate_prompt(example)
    return example

dataset = dataset.map(add_text)
dataset = dataset.remove_columns(["id","question","answer"])


train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

4bitの線形層を見つけてリストアップしておきます。

import bitsandbytes as bnb
def find_all_linear_names(model):
  cls = bnb.nn.Linear4bit 
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split('.')
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:
      lora_module_names.remove('lm_head')
  return list(lora_module_names)

modules = find_all_linear_names(model)
print(modules)

find_all_linear_namesでリストアップした層に対してLoRAを適用します。

from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

print(model)

lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

# 訓練可能なパラメータ数、総パラメータ数、および訓練可能なパラメータの割合(パーセンテージ)を表示

trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

学習

ここまでLoRAの学習設定が完了したので実際に学習を始めていきましょう。
以下を実行して学習を始めてください。レポート出力先はWandBにしておきましょう。


import transformers

from trl import SFTTrainer

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side='right'
torch.cuda.empty_cache()

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="prompt",
    peft_config=lora_config,
    max_seq_length=2500,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,# 一回の計算におけるバッチサイズ
        gradient_accumulation_steps=4, # 勾配計算累積回数
        warmup_steps=0.03,
        max_steps=100,
        learning_rate=2e-4,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit",
        save_strategy="epoch",
        report_to="wandb"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
#
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

学習はまあまあうまく行っているように見えます。

スクリーンショット 2024-07-13 23.10.14.png

ここまでを保存し、HuggingFaceに挙げておきましょう!

from huggingface_hub import notebook_login
notebook_login()

new_model = "pyscf-llama-8B"
#
trainer.model.save_pretrained(new_model)
#
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
merged_model= PeftModel.from_pretrained(base_model, new_model)
merged_model= merged_model.merge_and_unload()

# Save the merged model
#save_adapter=True, save_config=True
merged_model.save_pretrained("merged_model",safe_serialization=True)
tokenizer.save_pretrained("merged_model")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
merged_model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

実行・結果

では最後に実行しておきましょう。

# プロンプトの準備
prompt = "What are MP2, CISD, and FCI methods mentioned in the text?"

# 推論の実行
with torch.no_grad():
    token_ids = tokenizer.encode(prompt, return_tensors="pt")
    output_ids = merged_model.generate(
        token_ids.to(model.device),
        temperature=0.1, 
        do_sample=True, 
        top_p=0.95, 
        top_k=40, 
        max_new_tokens=256,
    )
output = tokenizer.decode(output_ids[0][token_ids.size(1) :])
print(output)

結果

FT済みLLM回答
MP2, CISD, and FCI are all methods mentioned in the text. 
MP2 is an acronym for second-order Møller–Plesset perturbation theory, 
CISD stands for configuration interaction singles and doubles, and FCI is an acronym for full configuration interaction.
FTなしLLM回答
MP2 is a second-order Møller-Plesset perturbation theory. 
CISD is a configuration interaction singles and doubles method. 
FCI is a full configuration interaction method. 
The differences between them are that MP2 is a perturbation theory, 
CISD is a configuration interaction method, and FCI is a full configuration interaction method.
想定回答
MP2, CISD, and FCI are other quantum mechanical methods mentioned in the text.
MP2 stands for Second-order Møller–Plesset perturbation theory,
while CISD and FCI stand for Configuration interaction methods. 
However, these methods are not used in the code provided.

んーー質問が悪かった。そもそも学習されてるっぽい質問でした。
今回は実装に重きをおいて説明していきましたが、どこかで評価してみた系を書きたい。

最後に

Xやってるので気になる方はフォローお願いします。

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2