8
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

BERTで分類(Classification)タスクのファインチューニングを行う

Last updated at Posted at 2023-03-31

はじめに

BERTClassificationタスクのファインチューニンングをColab(とMac)で行う時のメモです。

素のBERTだと事前学習済みで公開されているモデルの仕様が古いのか、最近のバージョンの環境では実行できなかったため、BERTの改良版であるRoBERTaの事前学習済みモデルにて実施します。

事前に必要なものを入れる

!pip install sentencepiece
!pip install datasets
!pip install transformers

Macで動かすときは以下も実施。

$ brew install jumanpp 
$ pip install protobuff 

学習用データを用意する

以下のような形式でJSONファイルを作成する。

[
    {
        "text": "東京は日本の首都です", // 本文
        "label": 2, // 分類番号
    },
    {
        "text": "明日の天気は晴れでしょう",
        "label": 3,
    },
    ...
]

学習用のコードを用意する

トークナイザーを取得する

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("nlp-waseda/roberta-base-japanese")

学習用データを読み込む

# 学習用データの読み込み
from google.colab import drive
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
import pandas as pd

drive.mount('/mount')
data = pd.read_json('/mount/MyDrive/xxxx.json')

train, valid = train_test_split(data, test_size=0.25)

ds_train = Dataset.from_pandas(train)
ds_valid = Dataset.from_pandas(valid)

dataset = DatasetDict({
    "train": ds_train,
    "validation": ds_valid,
})

読み込んだデータを学習用に変換する

文章をトークンに変換する。

import torch

# 変換関数
def preprocess_function(data):
    texts = [q.strip() for q in data["text"]]
    inputs = tokenizer(
        texts,
        max_length=450,
        truncation=True,
        padding=True,
    )

    inputs['labels'] = torch.tensor(data['label'])

    return inputs


# 変換
tokenized_data = dataset.map(preprocess_function, batched=True)

分類用のモデルを取得する

from transformers import AutoModelForSequenceClassification

num_labels = 5 # 分類する種類の数

# デバイス判定
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = AutoModelForSequenceClassification.from_pretrained("nlp-waseda/roberta-base-japanese", num_labels=num_labels).to(device)

パラメータを設定する

EarlyStoppingしたいので、load_best_model_at_endをTrueにする。

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="/mount/MyDrive/yyyy",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=200,
    weight_decay=0.01,
    load_best_model_at_end=True,
)

from transformers import default_data_collator
data_collator = default_data_collator

学習を実行する

from sklearn.model_selection import train_test_split
from transformers import Trainer, EarlyStoppingCallback

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=30)],
)

trainer.train()

学習後の動作確認

from torch import nn

def predict(text):
    inputs = tokenizer(text, add_special_tokens=True, return_tensors="pt").to(device)
    outputs = model(**inputs)
    ps = nn.Softmax(1)(outputs.logits)

    max_p = torch.max(ps)
    result = torch.argmax(ps).item() if max_p > 0.8 else -1
    return result

result = predict('明日の天気は晴れでしょう')

print(result)

3が表示されるはず。

参考サイト

8
11
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?