背景
- 以前ALBERTで回帰を行いました
- 今回はMobileBERTで回帰を行なってみます
- MobileBERTは日本語の事前学習モデルを @ysakuramoto さんが用意してくれているのでこれを利用させていただきます
- ALBERTも日本語の事前学習モデルを @ken11_ さんが用意してくれているので、こちらを利用させていただきます
- 同じデータセットで学習させて簡単な比較をしてみます
ALBERTとMobileBERTの違い
どちらもBERTの軽量モデルです。
ALBERT
以下はbaseモデルのスペックです
- パラメータ数: 12M
- レイヤ層数: 12
- 隠れ層次元数: 768
- 埋め込み層次元数: 128
MobileBERT
- パラメータ数: 25.3M
- レイヤ層数: 24
- 隠れ層次元数: 512
- 埋め込み層次元数: 128
違い
ALBERTは主に学習速度の向上を目指している様です。パラメータの共有が特徴らしいです。
MobileBertは主にモバイルのような端末上での推論を効率よく行うことが目的のようです。
MobileBERTの回帰の実装
ほぼ、ALBERTのものと同じです
from transformers import MobileBertModel
import torch.nn as nn
import torch
from typing import Optional, Tuple
from datasets import Dataset
from transformers import AutoTokenizer, TrainingArguments, Trainer
import evaluate
from mobile_bert_for_regression import MobileBertForRegression
from consts import MOBILE_BERT_PRETRAINED_MODEL_NAME, LEARNING_RATE, \
WEIGHT_DECAY, NUM_TRAIN_EPOCHS, MOBILE_BERT_MODEL_PATH, CSV_PATH, \
TOHOKU_BERT_TOKENIZER
import torch
import pandas as pd
class MobileBertForRegression(nn.Module):
def __init__(self, model_name):
super(MobileBertForRegression, self).__init__()
self.mobile_bert = MobileBertModel.from_pretrained(model_name)
self.dropout = nn.Dropout(0.1)
self.regressor = nn.Linear(self.mobile_bert.config.hidden_size, 1)
self.loss_fn = nn.MSELoss()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None
) -> Tuple:
outputs = self.mobile_bert(
input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
regression_output = self.regressor(pooled_output).squeeze(-1)
loss = None
if labels is not None:
# 訓練時には損失を計算
labels = labels.float()
loss = self.loss_fn(regression_output, labels)
return loss, regression_output return
df = pd.read_csv(CSV_PATH)
# datasetの準備
dataset_dict = Dataset.from_pandas(df)
dataset_dict = dataset_dict.train_test_split(
test_size=0.1, shuffle=True, seed=42)
model = MobileBertForRegression(MOBILE_BERT_PRETRAINED_MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(TOHOKU_BERT_TOKENIZER)
config = model.mobile_bert.config
def tokenize_function(examples):
return tokenizer(
examples["title"], padding="max_length",
truncation=True, max_length=config.max_position_embeddings)
tokenized_datasets = dataset_dict.map(tokenize_function, batched=True)
training_args = TrainingArguments(
output_dir="test_trainer",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
num_train_epochs=NUM_TRAIN_EPOCHS,
load_best_model_at_end=True,
save_total_limit=3
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['test']
)
trainer.train()
torch.save(model.state_dict(), MOBILE_BERT_MODEL_PATH)
学習
- データセット数
- 個人的に所有している日本語データ1000件程度
- 学習9:テスト1の割合
- ハイパーパラメータ
- LEARNING_RATE = 2.0e-05
- WEIGHT_DECAY = 0.01
- NUM_TRAIN_EPOCHS = 20
- 環境
- GPU: V100 x 1
- vCPU: 8
- メモリ: 52GB
- CUDA118, torch2.0.0
結果
- 相関: ALBERT 0.623 vs MobileBERT 0.542
- 損失: ALBERT 0.674 vs MobileBERT 0.794
となり、今回用意したデータセットではALBERTの方が高い相関が出ました。
MobileBERTの方がパラメータ数が多いのに不思議です。
※ 私のコードやデータセットに問題があった可能性があるので注意してください
※ どちらが優れているかという比較ではありません. ケースやデータセットによると思います
まとめ
ALBERTとMobileBERTの日本語事前学習モデルを使って回帰のFine-tuningをしてみました。
今回のデータセットではALBERTの方が相関が高い結果となりました。
どちらも比較的に簡単にfine-tuningできることがわかりました.
おまけ
ALBERTの結果
{'eval_loss': 1.0931249856948853, 'eval_corr': 0.12595221903204978, 'eval_runtime': 1.0102, 'eval_samples_per_second': 96.022, 'eval_steps_per_second': 12.869, 'epoch': 1.0}
{'eval_loss': 1.0168256759643555, 'eval_corr': 0.33573416678897133, 'eval_runtime': 0.9713, 'eval_samples_per_second': 99.864, 'eval_steps_per_second': 13.384, 'epoch': 2.0}
{'eval_loss': 0.813063383102417, 'eval_corr': 0.5185892558405978, 'eval_runtime': 0.9828, 'eval_samples_per_second': 98.699, 'eval_steps_per_second': 13.228, 'epoch': 3.0}
{'eval_loss': 0.9315071702003479, 'eval_corr': 0.504779982065797, 'eval_runtime': 0.9791, 'eval_samples_per_second': 99.069, 'eval_steps_per_second': 13.277, 'epoch': 4.0}
{'eval_loss': 0.7714371681213379, 'eval_corr': 0.549230228465193, 'eval_runtime': 0.9806, 'eval_samples_per_second': 98.922, 'eval_steps_per_second': 13.258, 'epoch': 5.0}
{'eval_loss': 0.7709570527076721, 'eval_corr': 0.5580444589930664, 'eval_runtime': 0.9774, 'eval_samples_per_second': 99.246, 'eval_steps_per_second': 13.301, 'epoch': 6.0}
{'eval_loss': 0.7781409621238708, 'eval_corr': 0.5543179976101054, 'eval_runtime': 0.9797, 'eval_samples_per_second': 99.011, 'eval_steps_per_second': 13.27, 'epoch': 7.0}
{'eval_loss': 0.7591082453727722, 'eval_corr': 0.5767609914340701, 'eval_runtime': 0.9775, 'eval_samples_per_second': 99.231, 'eval_steps_per_second': 13.299, 'epoch': 8.0}
{'eval_loss': 0.7653213143348694, 'eval_corr': 0.5756049045201133, 'eval_runtime': 0.9812, 'eval_samples_per_second': 98.86, 'eval_steps_per_second': 13.249, 'epoch': 9.0}
{'eval_loss': 0.8040589690208435, 'eval_corr': 0.5674605101439573, 'eval_runtime': 0.9782, 'eval_samples_per_second': 99.163, 'eval_steps_per_second': 13.29, 'epoch': 10.0}
{'eval_loss': 0.7191566228866577, 'eval_corr': 0.5939990546286915, 'eval_runtime': 0.979, 'eval_samples_per_second': 99.082, 'eval_steps_per_second': 13.279, 'epoch': 11.0}
{'eval_loss': 0.7288004159927368, 'eval_corr': 0.6051976472625327, 'eval_runtime': 0.9792, 'eval_samples_per_second': 99.057, 'eval_steps_per_second': 13.276, 'epoch': 12.0}
{'eval_loss': 0.7123575806617737, 'eval_corr': 0.59972310624502, 'eval_runtime': 0.9859, 'eval_samples_per_second': 98.385, 'eval_steps_per_second': 13.186, 'epoch': 13.0}
{'eval_loss': 0.7114874720573425, 'eval_corr': 0.6037608790429076, 'eval_runtime': 0.9791, 'eval_samples_per_second': 99.066, 'eval_steps_per_second': 13.277, 'epoch': 14.0}
{'eval_loss': 0.751545250415802, 'eval_corr': 0.5926516654023529, 'eval_runtime': 0.9889, 'eval_samples_per_second': 98.086, 'eval_steps_per_second': 13.145, 'epoch': 15.0}
{'eval_loss': 0.6738653182983398, 'eval_corr': 0.6229803814128074, 'eval_runtime': 0.9753, 'eval_samples_per_second': 99.452, 'eval_steps_per_second': 13.329, 'epoch': 16.0}
{'eval_loss': 0.6817343235015869, 'eval_corr': 0.61704163366791, 'eval_runtime': 0.9831, 'eval_samples_per_second': 98.667, 'eval_steps_per_second': 13.223, 'epoch': 17.0}
{'eval_loss': 0.7037461400032043, 'eval_corr': 0.6092487558511793, 'eval_runtime': 0.9793, 'eval_samples_per_second': 99.049, 'eval_steps_per_second': 13.275, 'epoch': 18.0}
{'eval_loss': 0.7015594244003296, 'eval_corr': 0.6089663337803164, 'eval_runtime': 0.9852, 'eval_samples_per_second': 98.456, 'eval_steps_per_second': 13.195, 'epoch': 19.0}
{'eval_loss': 0.6958994269371033, 'eval_corr': 0.6104160985318141, 'eval_runtime': 0.985, 'eval_samples_per_second': 98.472, 'eval_steps_per_second': 13.197, 'epoch': 20.0}
MobileBERTの結果
{'eval_loss': 1.0197206735610962, 'eval_corr': 0.32913760439327233, 'eval_runtime': 0.6165, 'eval_samples_per_second': 157.335, 'eval_steps_per_second': 21.086, 'epoch': 1.0}
{'eval_loss': 0.834292471408844, 'eval_corr': 0.4926006530191603, 'eval_runtime': 0.6178, 'eval_samples_per_second': 157.007, 'eval_steps_per_second': 21.042, 'epoch': 2.0}
{'eval_loss': 0.841667115688324, 'eval_corr': 0.48562966284496684, 'eval_runtime': 0.623, 'eval_samples_per_second': 155.707, 'eval_steps_per_second': 20.868, 'epoch': 3.0}
{'eval_loss': 0.8046469688415527, 'eval_corr': 0.5316866216189041, 'eval_runtime': 0.617, 'eval_samples_per_second': 157.208, 'eval_steps_per_second': 21.069, 'epoch': 4.0}
{'eval_loss': 0.7792660593986511, 'eval_corr': 0.539372381148289, 'eval_runtime': 0.6163, 'eval_samples_per_second': 157.379, 'eval_steps_per_second': 21.092, 'epoch': 5.0}
{'eval_loss': 0.816072940826416, 'eval_corr': 0.521739222023989, 'eval_runtime': 0.6456, 'eval_samples_per_second': 150.258, 'eval_steps_per_second': 20.138, 'epoch': 6.0}
{'eval_loss': 0.820040762424469, 'eval_corr': 0.518526031399988, 'eval_runtime': 0.6222, 'eval_samples_per_second': 155.9, 'eval_steps_per_second': 20.894, 'epoch': 7.0}
{'eval_loss': 0.8060498833656311, 'eval_corr': 0.5256717078147494, 'eval_runtime': 0.6206, 'eval_samples_per_second': 156.292, 'eval_steps_per_second': 20.946, 'epoch': 8.0}
{'eval_loss': 0.8206806182861328, 'eval_corr': 0.520293664139267, 'eval_runtime': 0.6175, 'eval_samples_per_second': 157.08, 'eval_steps_per_second': 21.052, 'epoch': 9.0}
{'eval_loss': 0.8192519545555115, 'eval_corr': 0.5357709361612227, 'eval_runtime': 0.6153, 'eval_samples_per_second': 157.636, 'eval_steps_per_second': 21.126, 'epoch': 10.0}
{'eval_loss': 0.8302949070930481, 'eval_corr': 0.5024485634087245, 'eval_runtime': 0.6175, 'eval_samples_per_second': 157.089, 'eval_steps_per_second': 21.053, 'epoch': 11.0}
{'eval_loss': 0.846757709980011, 'eval_corr': 0.5156568099133542, 'eval_runtime': 0.6173, 'eval_samples_per_second': 157.142, 'eval_steps_per_second': 21.06, 'epoch': 12.0}
{'eval_loss': 0.7940130233764648, 'eval_corr': 0.5416293539097846, 'eval_runtime': 0.6203, 'eval_samples_per_second': 156.378, 'eval_steps_per_second': 20.958, 'epoch': 13.0}
{'eval_loss': 0.8173080086708069, 'eval_corr': 0.5157108818748154, 'eval_runtime': 0.6166, 'eval_samples_per_second': 157.314, 'eval_steps_per_second': 21.083, 'epoch': 14.0}
{'eval_loss': 0.7959800958633423, 'eval_corr': 0.5345031465525737, 'eval_runtime': 0.6213, 'eval_samples_per_second': 156.131, 'eval_steps_per_second': 20.925, 'epoch': 15.0}
{'eval_loss': 0.7986265420913696, 'eval_corr': 0.5298356069843573, 'eval_runtime': 0.6143, 'eval_samples_per_second': 157.912, 'eval_steps_per_second': 21.163, 'epoch': 16.0}
{'eval_loss': 0.8147493600845337, 'eval_corr': 0.516939618033049, 'eval_runtime': 0.6119, 'eval_samples_per_second': 158.519, 'eval_steps_per_second': 21.245, 'epoch': 17.0}
{'eval_loss': 0.8108468651771545, 'eval_corr': 0.5198748437385999, 'eval_runtime': 0.6151, 'eval_samples_per_second': 157.709, 'eval_steps_per_second': 21.136, 'epoch': 18.0}
{'eval_loss': 0.811181366443634, 'eval_corr': 0.5201105451211807, 'eval_runtime': 0.615, 'eval_samples_per_second': 157.722, 'eval_steps_per_second': 21.138, 'epoch': 19.0}
{'eval_loss': 0.8112084865570068, 'eval_corr': 0.5206190728457705, 'eval_runtime': 0.6173, 'eval_samples_per_second': 157.134, 'eval_steps_per_second': 21.059, 'epoch': 20.0}