0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

gemma-2-2b-jpn-itで二値分類をする

Posted at

はじめに

自然言語を対象にする二値分類にはしばしばBERTが使われることが多いと思います。
そこで本記事では今流行りのLLMを上手いことやって二値分類を行うサンプル実装をしたのでそれを紹介します。
QLoRAのチューニングを上手いことすればもっとシンプルに書けるかもしれませんが、それは今後の課題ということで現状できているものを備忘録として残します。

データセット

今回使用するデータセットはchABSA-datasetをもとに二値分類のラベルを独自につけたものを使用します。
まずはこのデータセットのダウンロードを行いましょう。

# ダウンロード
wget https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip

# 解凍
unzip chABSA-dataset.zip

ラベルの付け方を説明します。
このデータセットでは1つの文章に対して、その文章の **どの部分が ** negative/positive/neutralかが含まれています。

今回は1つの文章に対してネガポジのラベルを付与したいので、少し加工をします。

これを実現するPythonスクリプト、dataset.pyを実装します。

"""
dataset.py

chABSA-datasetからネガポジの2値分類のデータセットを作成する

なお事前にデータセットをダウンロードしておく必要がある。
"""

import os
from typing import Tuple, List, Final
import json
import random
from dataclasses import dataclass
from dataclasses_json import dataclass_json

random.seed(42)


@dataclass_json
@dataclass
class BinaryClassificationData:
    text: str
    label: int | None


def create_dataset() -> None:
    dataset_directory: Final[str] = "./chABSA-dataset"

    train_test_ratio: Final[Tuple[float, float]] = (0.7, 0.3)

    json_file_paths: List[str] = list(
        f"{dataset_directory}/{file}"
        for file in os.listdir(dataset_directory)
        if file.endswith(".json")
    )

    dataset: List[BinaryClassificationData] = []

    # 各jsonファイルの処理
    for json_file_path in json_file_paths:
        with open(json_file_path, "r") as f:
            data: dict = json.load(f)

        sentences: list = data["sentences"]

        # 文章ごとに処理
        for s in sentences:
            text: str = s["sentence"]

            # optionsのpolarityの多数決によりラベルを付与.
            # ただし、negativeとpositiveが同数の場合にはラベルを付与せずデータセットに含めないとする.
            negative_count: int = 0
            positive_count: int = 0

            for o in s["opinions"]:
                match o["polarity"]:
                    case "negative":
                        negative_count += 1
                    case "positive":
                        positive_count += 1
                    case "neutral":
                        pass

            if negative_count == positive_count:
                continue

            label: int = 1 if negative_count < positive_count else 0

            dataset.append(BinaryClassificationData(text=text, label=label))

    # train, testに分割
    train_dataset: List[BinaryClassificationData] = []
    test_dataset: List[BinaryClassificationData] = []

    random.shuffle(dataset)

    train_dataset = dataset[: int(len(dataset) * train_test_ratio[0])]
    test_dataset = dataset[int(len(dataset) * train_test_ratio[0]) :]

    print("✅データセットの数")
    print(f"\t{len(train_dataset)=}")
    print(f"\t{len(test_dataset)=}")

    with open("train.json", "w") as f:
        json.dump(
            [item.to_dict() for item in train_dataset],
            f,
            indent=4,
            ensure_ascii=False,
        )
    with open("test.json", "w") as f:
        json.dump(
            [item.to_dict() for item in test_dataset],
            f,
            indent=4,
            ensure_ascii=False,
        )


if __name__ == "__main__":
    create_dataset()


実装

次に gemma-2-2b-jpn-itに対してQLoRAチューニングを行います。

共通

共通部分の実装としてcustom_dataset.pyを作成し、pytorchのDatasetクラスのサブクラスを作成します。

"""
custom_dataset.py
"""

from torch.utils.data import Dataset
import torch


class BinaryClassificationDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        label = item["label"]

        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": torch.tensor(label, dtype=torch.long),
        }

学習

次にモデルをチューニングするコードを実装します。
今回は4bit量子化を行い、LoRAチューニングを行いました。

コード全体を以下に示します。

"""
train.py

google/gemma-2-2b-jpn-itをqloraチューニングする
"""

import torch
import json
import argparse
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from tqdm import tqdm
import os
from custom_dataset import BinaryClassificationDataset

# 引数パーサーの設定
parser: argparse.ArgumentParser = argparse.ArgumentParser(
    description="QLoRA fine-tuning for binary classification"
)
parser.add_argument(
    "--train_file", type=str, default="train.json", help="Path to training data"
)
parser.add_argument(
    "--output_dir", type=str, default="./model_output", help="Directory to save model"
)
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training")
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
parser.add_argument(
    "--max_length", type=int, default=512, help="Maximum sequence length"
)
parser.add_argument("--lora_r", type=int, default=8, help="LoRA r parameter")
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha parameter")
parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout rate")
args: argparse.Namespace = parser.parse_args()


def main():
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    with open(args.train_file, "r", encoding="utf-8") as f:
        train_data = json.load(f)

    print(f"Loaded {len(train_data)} training examples")

    model_name = "google/gemma-2-2b-jpn-it"

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        quantization_config=bnb_config,
        device_map={"": device},
    )

    model = prepare_model_for_kbit_training(model)

    lora_config: LoraConfig = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="SEQ_CLS",
    )

    model = get_peft_model(model, lora_config)

    train_dataset = BinaryClassificationDataset(train_data, tokenizer, args.max_length)
    train_dataloader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    model.train()
    for epoch in range(args.epochs):
        epoch_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{args.epochs}")

        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()

            # 順伝播
            outputs = model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
            loss = outputs.loss

            # 逆伝播
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})

        average_loss = epoch_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{args.epochs}, Average Loss: {average_loss:.4f}")

    os.makedirs(args.output_dir, exist_ok=True)
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    print(f"Model and tokenizer saved to {args.output_dir}")


if __name__ == "__main__":
    main()

またこのコードはコマンドライン引数でパラメータや保存先などを指定するようにしています。
例えば、以下のような実行が可能です。

python train.py --train_file train.json --output_dir ./model_output --batch_size 4 --epochs 3

上記の設定で実行すると、今回のデータセット+私の環境では50分ほどで実行が完了しました。

推論

testデータを用いた評価を行います。
評価にはtest.jsonを使用して、Accuracy,Precision,Recall,F1値を計算します。

コード全文を以下に示します。

"""
test.py

google/gemma-2-2b-jpn-itをqloraチューニングしたモデルをテストする
"""

import torch
import json
import argparse
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
)

from custom_dataset import BinaryClassificationDataset


# 引数パーサーの設定
parser = argparse.ArgumentParser(
    description="Evaluation of fine-tuned model for binary classification"
)
parser.add_argument(
    "--test_file", type=str, default="test.json", help="Path to test data"
)
parser.add_argument(
    "--model_dir", type=str, default="./model_output", help="Directory of saved model"
)
parser.add_argument(
    "--batch_size", type=int, default=8, help="Batch size for evaluation"
)
parser.add_argument(
    "--max_length", type=int, default=512, help="Maximum sequence length"
)
args = parser.parse_args()


def main():
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    with open(args.test_file, "r", encoding="utf-8") as f:
        test_data = json.load(f)

    print(f"Loaded {len(test_data)} test examples")

    model_name = "google/gemma-2-2b-jpn-it"
    tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
    base_model = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=2, device_map={"": device}
    )

    model = PeftModel.from_pretrained(base_model, args.model_dir)
    model.eval()

    test_dataset = BinaryClassificationDataset(test_data, tokenizer, args.max_length)
    test_dataloader = DataLoader(
        test_dataset, batch_size=args.batch_size, shuffle=False
    )

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            predictions = torch.argmax(logits, dim=-1)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # 評価
    accuracy: float = accuracy_score(all_labels, all_predictions)
    precision: float = precision_score(all_labels, all_predictions, average="binary")
    recall: float = recall_score(all_labels, all_predictions, average="binary")
    f1: float = f1_score(all_labels, all_predictions, average="binary")

    print("\n===== 評価結果 =====")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")


if __name__ == "__main__":
    main()

こちらもtrian.pyと同じく、コマンドライン引数を必要とします!
以下のように実行することが可能です。

python test.py --test_file test.json --model_dir ./model_output

実行結果は以下のようになりました。
9割弱とかなり高い精度が出ていることがわかるかと思います。

===== 評価結果 =====
Accuracy: 0.8934
Precision: 0.9591
Recall: 0.8627
F1 Score: 0.9084

さいごに

最後までお読みいただきありがとうございました。

本記事の中で誤字脱字や内容の誤り等がありましたらコメント等で 優しく 指摘していただけますと幸いです。

この記事ではgemma-2-2b-jpn-itに対してQLoRAチューニングを施し二値分類を行う実装をしました。
パラメータのチューニングなどまだまだ不完全な部分はありますが、今回のデータセットにおいて高い精度を示したことから、非常に有効な手法なのかなというふうに感じています☺️
BERTによる分類と比較をしたり、別のモデルとの比較、パラメータ関連についても気が向いたら記事にしようかなと考えております!

参考文献

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?