2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

BERTで質疑応答(question answering)タスクのファインチューニングを行う

Posted at

はじめに

BERTquestion answeringタスクのファインチューニンングをColab(とMac)で行う時のメモです。

素のBERTだと事前学習済みで公開されているモデルの仕様が古いのか、最近のバージョンの環境では実行できなかったため、BERTの改良版であるRoBERTaの事前学習済みモデルにて実施します。

事前に必要なものを入れる

!pip install sentencepiece
!pip install datasets
!pip install transformers

Macで動かすときは以下も実施。

$ brew install jumanpp 
$ pip install protobuff 

学習用データを用意する

以下の形式でJSONファイルを作成する。

[
    {
        "context": "aaa is bbb", // 本文
        "question": "what is aaa ? ", // 本文についての質問
        "answers": {
            "text": ["bbb"], // 質問の回答
            "answer_start": [7] // 回答がある本文中の位置
        }
    },
    ...
]

学習用のコードを用意する

トークナイザーを取得する

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("nlp-waseda/roberta-base-japanese")

学習用データを読み込む

# 学習用データの読み込み
from google.colab import drive
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict

drive.mount('/mount')

data = pd.read_json('/mount/MyDrive/foo/bar.json')
train, valid = train_test_split(data, test_size=0.2)
ds_train = Dataset.from_pandas(train)
ds_valid = Dataset.from_pandas(valid)

dataset = DatasetDict({
    "train": ds_train,
    "validation": ds_valid,
})

学習用データを学習するために変換する

文章をトークンに変換する。
回答のトークンの位置を与えてやる必要があるので、検索して設定する。

(変換関数は参考サイトからのコピペ)

# デバイス判定
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"


# 変換関数
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=450,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
        return_tensors="pt").to(device)

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs


# 変換
tokenized_data = dataset.map(preprocess_function, batched=True)

RoBERTaのモデルを取得する

from transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained("nlp-waseda/roberta-base-japanese").to(device)

パラメータを設定する

EarlyStoppingしたいので、load_best_model_at_endをTrueにする。

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=200,
    weight_decay=0.01,
    load_best_model_at_end=True, # 終了時に一番良かったモデルを使う
)

from transformers import default_data_collator
data_collator = default_data_collator

学習を実行する

from transformers import Trainer, EarlyStoppingCallback

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

trainer.train()

学習後の動作確認

from transformers import AutoModelForQuestionAnswering, AutoTokenizer

def predict(context, question):
    inputs = tokenizer(question, add_special_tokens=True, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"].tolist()[0]
    outputs = model(**inputs)

    answer_start = torch.argmax(outputs.start_logits)
    answer_end = torch.argmax(outputs.end_logits) + 1
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

    return [answer_start, answer_end, answer]

result = predict('aaa is bbb', 'what is bbb ?')

print(result)

参考サイト

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?