0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ALBERTとMobileBERTの日本語対応モデルを回帰でfine-tuningして比較してみる

Posted at

背景

  • 以前ALBERTで回帰を行いました

  • 今回はMobileBERTで回帰を行なってみます
  • MobileBERTは日本語の事前学習モデルを @ysakuramoto さんが用意してくれているのでこれを利用させていただきます

  • ALBERTも日本語の事前学習モデルを @ken11_ さんが用意してくれているので、こちらを利用させていただきます

  • 同じデータセットで学習させて簡単な比較をしてみます

ALBERTとMobileBERTの違い

どちらもBERTの軽量モデルです。

ALBERT

以下はbaseモデルのスペックです

  • パラメータ数: 12M
  • レイヤ層数: 12
  • 隠れ層次元数: 768
  • 埋め込み層次元数: 128

MobileBERT

  • パラメータ数: 25.3M
  • レイヤ層数: 24
  • 隠れ層次元数: 512
  • 埋め込み層次元数: 128

違い

ALBERTは主に学習速度の向上を目指している様です。パラメータの共有が特徴らしいです。

MobileBertは主にモバイルのような端末上での推論を効率よく行うことが目的のようです。

MobileBERTの回帰の実装

ほぼ、ALBERTのものと同じです

from transformers import MobileBertModel
import torch.nn as nn
import torch
from typing import Optional, Tuple
from datasets import Dataset
from transformers import AutoTokenizer, TrainingArguments, Trainer
import evaluate
from mobile_bert_for_regression import MobileBertForRegression
from consts import MOBILE_BERT_PRETRAINED_MODEL_NAME, LEARNING_RATE, \
    WEIGHT_DECAY, NUM_TRAIN_EPOCHS, MOBILE_BERT_MODEL_PATH, CSV_PATH, \
    TOHOKU_BERT_TOKENIZER
import torch
import pandas as pd


class MobileBertForRegression(nn.Module):
    def __init__(self, model_name):
        super(MobileBertForRegression, self).__init__()
        self.mobile_bert = MobileBertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.regressor = nn.Linear(self.mobile_bert.config.hidden_size, 1)
        self.loss_fn = nn.MSELoss()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None
    ) -> Tuple:
        outputs = self.mobile_bert(
            input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        regression_output = self.regressor(pooled_output).squeeze(-1)

        loss = None
        if labels is not None:
            # 訓練時には損失を計算
            labels = labels.float()
            loss = self.loss_fn(regression_output, labels)

        return loss, regression_output    return

df = pd.read_csv(CSV_PATH)

# datasetの準備
dataset_dict = Dataset.from_pandas(df)
dataset_dict = dataset_dict.train_test_split(
    test_size=0.1, shuffle=True, seed=42)
model = MobileBertForRegression(MOBILE_BERT_PRETRAINED_MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(TOHOKU_BERT_TOKENIZER)
config = model.mobile_bert.config

def tokenize_function(examples):
    return tokenizer(
        examples["title"], padding="max_length",
        truncation=True, max_length=config.max_position_embeddings)

tokenized_datasets = dataset_dict.map(tokenize_function, batched=True)


training_args = TrainingArguments(
    output_dir="test_trainer",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    load_best_model_at_end=True,
    save_total_limit=3
  )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test']
)

trainer.train()
torch.save(model.state_dict(), MOBILE_BERT_MODEL_PATH)

学習

  • データセット数
    • 個人的に所有している日本語データ1000件程度
    • 学習9:テスト1の割合
  • ハイパーパラメータ
    • LEARNING_RATE = 2.0e-05
    • WEIGHT_DECAY = 0.01
    • NUM_TRAIN_EPOCHS = 20
  • 環境
    • GPU: V100 x 1
    • vCPU: 8
    • メモリ: 52GB
    • CUDA118, torch2.0.0

結果

  • 相関: ALBERT 0.623 vs MobileBERT 0.542
  • 損失: ALBERT 0.674 vs MobileBERT 0.794

となり、今回用意したデータセットではALBERTの方が高い相関が出ました。
MobileBERTの方がパラメータ数が多いのに不思議です。

※ 私のコードやデータセットに問題があった可能性があるので注意してください
※ どちらが優れているかという比較ではありません. ケースやデータセットによると思います

まとめ

ALBERTとMobileBERTの日本語事前学習モデルを使って回帰のFine-tuningをしてみました。
今回のデータセットではALBERTの方が相関が高い結果となりました。
どちらも比較的に簡単にfine-tuningできることがわかりました.

おまけ

ALBERTの結果


{'eval_loss': 1.0931249856948853, 'eval_corr': 0.12595221903204978, 'eval_runtime': 1.0102, 'eval_samples_per_second': 96.022, 'eval_steps_per_second': 12.869, 'epoch': 1.0}
{'eval_loss': 1.0168256759643555, 'eval_corr': 0.33573416678897133, 'eval_runtime': 0.9713, 'eval_samples_per_second': 99.864, 'eval_steps_per_second': 13.384, 'epoch': 2.0}
{'eval_loss': 0.813063383102417, 'eval_corr': 0.5185892558405978, 'eval_runtime': 0.9828, 'eval_samples_per_second': 98.699, 'eval_steps_per_second': 13.228, 'epoch': 3.0}
{'eval_loss': 0.9315071702003479, 'eval_corr': 0.504779982065797, 'eval_runtime': 0.9791, 'eval_samples_per_second': 99.069, 'eval_steps_per_second': 13.277, 'epoch': 4.0}
{'eval_loss': 0.7714371681213379, 'eval_corr': 0.549230228465193, 'eval_runtime': 0.9806, 'eval_samples_per_second': 98.922, 'eval_steps_per_second': 13.258, 'epoch': 5.0}
{'eval_loss': 0.7709570527076721, 'eval_corr': 0.5580444589930664, 'eval_runtime': 0.9774, 'eval_samples_per_second': 99.246, 'eval_steps_per_second': 13.301, 'epoch': 6.0}
{'eval_loss': 0.7781409621238708, 'eval_corr': 0.5543179976101054, 'eval_runtime': 0.9797, 'eval_samples_per_second': 99.011, 'eval_steps_per_second': 13.27, 'epoch': 7.0}
{'eval_loss': 0.7591082453727722, 'eval_corr': 0.5767609914340701, 'eval_runtime': 0.9775, 'eval_samples_per_second': 99.231, 'eval_steps_per_second': 13.299, 'epoch': 8.0}
{'eval_loss': 0.7653213143348694, 'eval_corr': 0.5756049045201133, 'eval_runtime': 0.9812, 'eval_samples_per_second': 98.86, 'eval_steps_per_second': 13.249, 'epoch': 9.0}
{'eval_loss': 0.8040589690208435, 'eval_corr': 0.5674605101439573, 'eval_runtime': 0.9782, 'eval_samples_per_second': 99.163, 'eval_steps_per_second': 13.29, 'epoch': 10.0}
{'eval_loss': 0.7191566228866577, 'eval_corr': 0.5939990546286915, 'eval_runtime': 0.979, 'eval_samples_per_second': 99.082, 'eval_steps_per_second': 13.279, 'epoch': 11.0}
{'eval_loss': 0.7288004159927368, 'eval_corr': 0.6051976472625327, 'eval_runtime': 0.9792, 'eval_samples_per_second': 99.057, 'eval_steps_per_second': 13.276, 'epoch': 12.0}
{'eval_loss': 0.7123575806617737, 'eval_corr': 0.59972310624502, 'eval_runtime': 0.9859, 'eval_samples_per_second': 98.385, 'eval_steps_per_second': 13.186, 'epoch': 13.0}
{'eval_loss': 0.7114874720573425, 'eval_corr': 0.6037608790429076, 'eval_runtime': 0.9791, 'eval_samples_per_second': 99.066, 'eval_steps_per_second': 13.277, 'epoch': 14.0}
{'eval_loss': 0.751545250415802, 'eval_corr': 0.5926516654023529, 'eval_runtime': 0.9889, 'eval_samples_per_second': 98.086, 'eval_steps_per_second': 13.145, 'epoch': 15.0}
{'eval_loss': 0.6738653182983398, 'eval_corr': 0.6229803814128074, 'eval_runtime': 0.9753, 'eval_samples_per_second': 99.452, 'eval_steps_per_second': 13.329, 'epoch': 16.0}
{'eval_loss': 0.6817343235015869, 'eval_corr': 0.61704163366791, 'eval_runtime': 0.9831, 'eval_samples_per_second': 98.667, 'eval_steps_per_second': 13.223, 'epoch': 17.0}
{'eval_loss': 0.7037461400032043, 'eval_corr': 0.6092487558511793, 'eval_runtime': 0.9793, 'eval_samples_per_second': 99.049, 'eval_steps_per_second': 13.275, 'epoch': 18.0}
{'eval_loss': 0.7015594244003296, 'eval_corr': 0.6089663337803164, 'eval_runtime': 0.9852, 'eval_samples_per_second': 98.456, 'eval_steps_per_second': 13.195, 'epoch': 19.0}
{'eval_loss': 0.6958994269371033, 'eval_corr': 0.6104160985318141, 'eval_runtime': 0.985, 'eval_samples_per_second': 98.472, 'eval_steps_per_second': 13.197, 'epoch': 20.0}

MobileBERTの結果

{'eval_loss': 1.0197206735610962, 'eval_corr': 0.32913760439327233, 'eval_runtime': 0.6165, 'eval_samples_per_second': 157.335, 'eval_steps_per_second': 21.086, 'epoch': 1.0}
{'eval_loss': 0.834292471408844, 'eval_corr': 0.4926006530191603, 'eval_runtime': 0.6178, 'eval_samples_per_second': 157.007, 'eval_steps_per_second': 21.042, 'epoch': 2.0}
{'eval_loss': 0.841667115688324, 'eval_corr': 0.48562966284496684, 'eval_runtime': 0.623, 'eval_samples_per_second': 155.707, 'eval_steps_per_second': 20.868, 'epoch': 3.0}
{'eval_loss': 0.8046469688415527, 'eval_corr': 0.5316866216189041, 'eval_runtime': 0.617, 'eval_samples_per_second': 157.208, 'eval_steps_per_second': 21.069, 'epoch': 4.0}
{'eval_loss': 0.7792660593986511, 'eval_corr': 0.539372381148289, 'eval_runtime': 0.6163, 'eval_samples_per_second': 157.379, 'eval_steps_per_second': 21.092, 'epoch': 5.0}
{'eval_loss': 0.816072940826416, 'eval_corr': 0.521739222023989, 'eval_runtime': 0.6456, 'eval_samples_per_second': 150.258, 'eval_steps_per_second': 20.138, 'epoch': 6.0}
{'eval_loss': 0.820040762424469, 'eval_corr': 0.518526031399988, 'eval_runtime': 0.6222, 'eval_samples_per_second': 155.9, 'eval_steps_per_second': 20.894, 'epoch': 7.0}
{'eval_loss': 0.8060498833656311, 'eval_corr': 0.5256717078147494, 'eval_runtime': 0.6206, 'eval_samples_per_second': 156.292, 'eval_steps_per_second': 20.946, 'epoch': 8.0}
{'eval_loss': 0.8206806182861328, 'eval_corr': 0.520293664139267, 'eval_runtime': 0.6175, 'eval_samples_per_second': 157.08, 'eval_steps_per_second': 21.052, 'epoch': 9.0}
{'eval_loss': 0.8192519545555115, 'eval_corr': 0.5357709361612227, 'eval_runtime': 0.6153, 'eval_samples_per_second': 157.636, 'eval_steps_per_second': 21.126, 'epoch': 10.0}
{'eval_loss': 0.8302949070930481, 'eval_corr': 0.5024485634087245, 'eval_runtime': 0.6175, 'eval_samples_per_second': 157.089, 'eval_steps_per_second': 21.053, 'epoch': 11.0}
{'eval_loss': 0.846757709980011, 'eval_corr': 0.5156568099133542, 'eval_runtime': 0.6173, 'eval_samples_per_second': 157.142, 'eval_steps_per_second': 21.06, 'epoch': 12.0}
{'eval_loss': 0.7940130233764648, 'eval_corr': 0.5416293539097846, 'eval_runtime': 0.6203, 'eval_samples_per_second': 156.378, 'eval_steps_per_second': 20.958, 'epoch': 13.0}
{'eval_loss': 0.8173080086708069, 'eval_corr': 0.5157108818748154, 'eval_runtime': 0.6166, 'eval_samples_per_second': 157.314, 'eval_steps_per_second': 21.083, 'epoch': 14.0}
{'eval_loss': 0.7959800958633423, 'eval_corr': 0.5345031465525737, 'eval_runtime': 0.6213, 'eval_samples_per_second': 156.131, 'eval_steps_per_second': 20.925, 'epoch': 15.0}
{'eval_loss': 0.7986265420913696, 'eval_corr': 0.5298356069843573, 'eval_runtime': 0.6143, 'eval_samples_per_second': 157.912, 'eval_steps_per_second': 21.163, 'epoch': 16.0}
{'eval_loss': 0.8147493600845337, 'eval_corr': 0.516939618033049, 'eval_runtime': 0.6119, 'eval_samples_per_second': 158.519, 'eval_steps_per_second': 21.245, 'epoch': 17.0}
{'eval_loss': 0.8108468651771545, 'eval_corr': 0.5198748437385999, 'eval_runtime': 0.6151, 'eval_samples_per_second': 157.709, 'eval_steps_per_second': 21.136, 'epoch': 18.0}
{'eval_loss': 0.811181366443634, 'eval_corr': 0.5201105451211807, 'eval_runtime': 0.615, 'eval_samples_per_second': 157.722, 'eval_steps_per_second': 21.138, 'epoch': 19.0}
{'eval_loss': 0.8112084865570068, 'eval_corr': 0.5206190728457705, 'eval_runtime': 0.6173, 'eval_samples_per_second': 157.134, 'eval_steps_per_second': 21.059, 'epoch': 20.0}
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?