1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Transformers を使って、BERT のファインチューニングをする

Posted at

Transformers を使って、BERT のファインチューニングをする

はじめに

はじめまして。私は大学院生です。

この記事では、Transformers を使って、BERT のファインチューニングをします。(主に、研究室の後輩向けに書いていますので、極力簡素化しています。ご了承ください。)

この記事が、「Transformers を使ったファインチューニングの方法の参考になった」と思っていただけるように書かさせていただきます。


より簡素化できる案などがあればコメントをしていただけると幸いです。


類似記事もありますので、是非こちらもご覧になってください。

事前準備

用語解説

基本的な用語はこちらで解説しています。

環境

データセット

事前学習済みモデル

実験コード

ライブラリのインポート

import re
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

デバイスの設定

Google Colab では、「ランタイム」 -> 「ランタイムのタイプを変更」 -> 「T4 GPU」 に設定してください。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

上記の設定ができていると、print(device) の出力は、device(type='cuda') になるはずです。

データセットのロード

今回使用するデータセットは、SST-2 1 を使用します。

train_dataset = load_dataset("stanfordnlp/sst2", split="train")
val_dataset = load_dataset("stanfordnlp/sst2", split="validation")

学習データからラベル数を取得します。

num_classes = train_df['label'].nunique()

データセットの分析

分析がしやすいように、DataFrame に変換します。

train_df = train_dataset.to_pandas()

head() メソッドを使うことで、先頭 5 行を見ることができます。

train_df.head()

tail() メソッドを使うことで、末尾 5 行を見ることができます。

train_df.tail()

describe() メソッドを使うことで、記述統計量を見ることができます。

train_df.describe()

学習データのラベルの割合を以下で確認することができます。確認する意図としては、ラベルに偏りがある場合、ダウンサンプリング(あるいはアップサンプリング)が必要になるからです。

train_df['label'].value_counts(normalize=True)

単語の出現頻度を可視化します。ただし、上位 50 件のみを表示します。

all_words = []
for sentence in train_df["sentence"]:
    words = re.findall(r'\b\w+\b', sentence.lower())
    all_words.extend(words)

word_counts = Counter(all_words)

top_50_words = word_counts.most_common(50)
words, counts = zip(*top_50_words)

plt.figure(figsize=(12, 8))
plt.bar(words, counts)
plt.xticks(rotation=45)
plt.xlabel("Words")
plt.ylabel("Frequency")
plt.show()

word frequency histogram

テキストの長さを可視化します。

train_df["text-length"] = train_df["sentence"].apply(lambda x: len(x.split()))

train_df["text-length"].hist(bins=50)
plt.xlabel("text length")
plt.ylabel("count")
plt.show()

text length histogram

事前学習済みトークナイザーとモデルのロード

AutoModelForSequenceClassification.from_pretrained() を分類タスクで使うときは、num_labels を指定するべきです。num_labels を指定しないと、デフォルトの出力サイズが使用され、想定しているクラス数と合わなくなることがあります。num_labels を指定することで、適切な損失関数がモデル内部で選ばれます。これにより、タスクに最適化された損失計算が可能になります。num_labels に合わせてモデルの出力層が再構成されるため、事前学習済みモデルに対して新しいタスク用のファインチューニングを効率的に行えます。

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-uncased", num_labels=num_classes).to(device)

事前学習済みトークナイザーとモデルの確認

BERT 2 などの多くのモデルは、事前学習の際に、「Special Token」という特殊なトークンを使うことが一般的です。

tokenizer

モデルの構造は以下で確認することができます。

model

モデルの設定は、以下で確認することができます。

model.config

モデルの学習可能なパタメータ数の確認は以下のコードで見ることができます。if p.requires_grad の部分は、学習可能かどうかを判定するために入れています。

sum(p.numel() for p in model.parameters() if p.requires_grad)

英文のトークン化

tokenizer を使って、データセットをトークン化します。padding="longest" は、各バッチ内の最長のシーケンスに合わせてパディングされ、余分なパディングが省かれます。

def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding="longest", truncation=True)

train_dataset = train_dataset.map(tokenize_function)
val_dataset = val_dataset.map(tokenize_function)

トークン化された文字を元の数字に戻してみます。

id = 0
print(tokenizer.decode(train_dataset[id]["input_ids"]))

学習時の設定

より細かい学習の設定は、TransformersTrainingArguments をご覧になってください。

最適化関数は、AdamW 3 がデフォルトで設定されています。

  • output_dir:チェックポイントなどの保存先のパス
  • eval_strategy:評価戦略
  • per_device_train_batch_size:学習時のバッチサイズ
  • per_device_eval_batch_size:テスト時のバッチサイズ
  • learning_rate:初期の学習率
  • logging_strategy:ロギング戦略
  • save_strategy:保存戦略

TensorBoardWandB などの設定ができていれば、損失などの記録を確認することができます。

args = TrainingArguments(
    output_dir="./outputs",
    eval_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-5,
    logging_strategy="epoch",
    save_strategy="epoch",
)
print(args)

評価指標の定義

正解率、F 値、適合率、再現率を求めます。

def compute_metrics(p):
    
    preds = p.predictions.argmax(-1)
    
    accuracy = accuracy_score(p.label_ids, preds)
    f1 = f1_score(p.label_ids, preds, average="weighted")
    precision = precision_score(p.label_ids, preds, average="weighted")
    recall = recall_score(p.label_ids, preds, average="weighted")
    
    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall
    }

Trainer の定義と実行

TransformersTrainer を用いて、学習を行います。

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
trainer.train()

おわりに

この記事では、Transformers を使って、BERT のファインチューニングをしました。

この記事が、「Transformers を使ったファインチューニングの方法の参考になった」と思っていただけたら幸いです。

参考文献

  1. Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D. Manning, Andrew Y. Ng, and Christopher Potts. Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank. In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, 2013.

  2. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics, 2019.

  3. Ilya Loshchilov and Frank Hutter. Decoupled Weight Decay Regularization. In 7th International Conference on Learning Representations, 2019.

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?