以前、IBM論文の参考コードでTabBERTモデルの事前学習を行い、Fine-Tuningについては自作コードを実装しました。
自作コードで一応Fine-Tuningをできるようになったのですが、 F1スコアなどのメトリクスを計算するだけでも面倒さを感じていました。
事前学習のときと同様にTransformersのTrainerクラスを使えればメトリクスも簡単に出せるのに...といろいろ調べてみたところ、下流タスク用のヘッドをボディ(事前学習済モデル)に付加したものについても、普通にTrainerクラスが使えることがわかりました。
そんなわけで、今回は以前の自作コードをTrainerクラスを使ってリファクタリングしました。
また、今回のリファクタリングのついでに、WandBの導入やオーバーサンプリング処理の追加も行ったので、最後の方におまけで書いています。
実装
ポイントとなる部分だけ書きますので、コード全文が気になる方はGitHubをご覧ください。
モデルの作成
Trainerクラスでモデルを扱うためには以下がポイントとなります。
-
PreTrainedModel
を継承する - 出力を
loss
とlogits
のタプルで返す
公式Docsに記載されているように、TrainerクラスではPreTrainedModel
で動作するように最適化されるようです。
Trainer is optimized to work with the PreTrainedModel provided by the library. You can still use your own models defined as torch.nn.Module as long as they work the same way as the 🤗 Transformers models.
nn.Module
でも大丈夫とのことでしたが、実際にこちらを継承したらエラーがでました。
また、自作コードではforwardでlogits
(予測結果)のみを返すようにしていたのですが、loss
とlogits
のタプルで返すようにしました。
loss
を返すためにinit
で損失関数loss_fn
を指定し、推論時にもモデルを使用することを想定して、損失関数loss_fn
を指定しない場合はloss=None
を返すようにしました。
なお、下流タスク(分類)のヘッドとして、事前学習済モデルにLSTM層とLinear層を付加しています。
from transformers import BertConfig, BertModel, PreTrainedModel
import torch.nn as nn
import torch
class VisitorReactionModel(PreTrainedModel):
def __init__(self,
config='./output_pretraining/action_history/checkpoint-500/config.json',
num_categories=2,
loss_fn=None,
pretrained_model='./output_pretraining/action_history/checkpoint-500/pytorch_model.bin'):
super().__init__(config=config, num_categories=num_categories, loss_fn=loss_fn, pretrained_model=pretrained_model)
self.model = BertModel.from_pretrained(pretrained_model, config=config)
self.lstm = nn.LSTM(self.config.hidden_size, self.config.hidden_size, batch_first=True)
self.regressor = nn.Linear(self.config.hidden_size, num_categories)
self.loss_fn = loss_fn
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
output_attentions=False,
output_hidden_states=False,
labels=None):
outputs = self.model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
out, _ = self.lstm(outputs[0], None)
sequence_output = out[:, -1, :]
logits = self.regressor(sequence_output)
loss=None
if labels is not None and self.loss_fn is not None:
loss = self.loss_fn(logits, torch.max(labels, 1)[1])
# ModelOutputだとlossがスカラーじゃないというエラーが出るためTupleで返す
return loss, logits
compute_metricsの作成
モデルをTrainerクラスに適用できるようになったのですが、これだけだとprecision, recall, f1といったメトリクスを導出することができません。
そんなときに使用するのがcompute_metrics
となります。
※Transformersの3系バージョンだと使用できなかったので、4.26.0へ事前にバージョンアップしています。
引数の型はEvalPrediction
、戻り値の型はOptional[Dict[str, float]]
となります。
def compute_metrics(res: EvalPrediction):
logits = res.predictions.argmax(axis=1)
labels = res.label_ids.argmax(axis=1)
precision = precision_score(labels, logits, average='macro')
recall = recall_score(labels, logits, average='macro')
f1 = f1_score(labels, logits, average='macro')
return {
'precision': precision,
'recall': recall,
'f1': f1
}
あとは、Trainerの引数にmodel
とcompute_metrics
を与えれば、メトリクスが計算されます。
分類タスクでFine-Tuningを行うため、損失関数にはCrossEntropy
を用いています。
loss_fn = CrossEntropyLoss()
model = VisitorReactionModel(config=config, pretrained_model=pretrained_model, loss_fn=loss_fn)
training_args = TrainingArguments(
output_dir=args.output_dir, # output directory
num_train_epochs=args.num_train_epochs, # total number of training epochs
per_device_train_batch_size=args.num_train_batch_size,
per_device_eval_batch_size=args.num_eval_batch_size,
save_steps=args.save_steps,
do_train=True,
do_eval=True,
evaluation_strategy="epoch", # epochかsteps(デフォルト500)ごとに評価
overwrite_output_dir=True,
save_total_limit=1,
report_to="wandb"
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics
)
その他
今回のリファクタリングでいくつか改善も行ったので、おまけで記載します。
WandBの導入
Trainerの利用でWandBの導入が楽になったので、リファクタリングにあわせて実装しました。
Fine-Tuningのコード内に、WandBのログインと初期化のコードを追加します。
このとき、dotenv
を利用して、ローカル学習の際は.env
から、SageMaker Training Jobsの際はEstimators
のhyperparameters
からAPIキーを読み込むようにします。
load_dotenv()
WANDB_API_KEY = os.getenv('SM_HP_WANDB_API_KEY')
wandb.login(key=WANDB_API_KEY) # Pass your W&B API key here
wandb.init(project="tabformer-opt") # Add your W&B project name
estimator = Estimator(
image_uri="",
role=role,
instance_type="ml.g4dn.2xlarge",
instance_count=1,
base_job_name="tabformer-opt-fine-tuning",
output_path="",
code_location="",
sagemaker_session=session,
entry_point="fine-tuning.sh",
dependencies=["tabformer-opt"],
hyperparameters={
"data_root": "/opt/ml/input/data/input_data/",
"data_fname": "",
"output_dir": "/opt/ml/model/",
"model_path": "/opt/ml/input/data/input_model/",
"wandb_api_key": <APIキー>
}
)
あとは、TrainingArguments
にreport_to="wandb"
を追加するだけで、学習結果が記録されるようになります。
オーバーサンプリング
今回、ポジティブラベルが10 %以下の不均衡データを使用しており、そのままモデルで学習を行ってもprecisionやrecallが低い結果となってしまいます。
そのため、少数派のポジティブラベルデータをオーバーサンプリングで増やすようにしました。
この際、単純なデータ複製で過学習を起こさないために、少数派のデータからランダムでデータを選択し、そのデータからランダムで選択された近傍点を用いて、両者の合成データを作成する、SMOTEという手法を用いました。
SMOTE処理を以下の関数にまとめ、 データ前処理のコードに加えました。
def overSampling(data):
sm = SMOTE(random_state=42)
X = data.drop(columns='reaction', axis=1)
y = data['reaction']
X_sample, Y_sample = sm.fit_resample(X, y)
over_sampling = pd.DataFrame()
over_sampling = X_sample
over_sampling['reaction'] = Y_sample
return over_sampling
参考資料