はじめに
最近、多変量の時系列表データの学習に使用する、TabBERT(Hierarchical Tabular BERT)というBERTの応用モデルに関する論文を読み、付属コードで事前学習まで行いました。
ただ、付属コードだと事前学習までしか行えなかったため、さらなる理解のために、Fine-Tuningと分類タスクについては自分で実装してみることにしました。
前回までの記事については、以下をご覧ください。
コード解説
以前の記事から繰り返しとなりますが、事前学習では大量の教師なしデータから各レコードの特性を双方向学習することでトランザクション間の関係を把握しました。
Fine-Tuningでは、少量の教師ありデータから事前学習済モデルのパラメータを微調整することで、個々のタスクへの最適化を行います。
以下がFine-Tuningのコード全文となります。
Model
教師ありデータによる調整のため、BERT層に予測のための層(Prediction Layer)を追加したのが今回作成したCommonModel
です。
まず、事前学習で得られた学習済モデルpytorch_model.bin
とconfig.json
を読み込みます。
論文にならってBERT層の後にLSTM層を追加し、その先にPrediction LayerとしてLinear層を追加します。
今回は不正検知で2クラス(0, 1)の出力が必要となるので、Linear層も2クラスで出力するように設定します(不正じゃない:[1, 0]、不正:[0, 1])。
__init__
で、これらの層を定義を行い、forward
で層の連結および予測値logits
を出力します。
BERTのようなTransformerモデルでは、通常[CLS]トークンの最後の隠れ状態をLinear層に通して予測値を出力するのですが、今回のデータには[CLS]トークンを付与しておらず、代わりに付与した[SEP]トークンの隠れ状態をLinear層に通していますsequence_output = out[:, -1, :]
。
loss_fn
では、PytorchのMSELossクラスを利用してRMSEを損失関数とした損失計算を行います。
class CommonModel(nn.Module):
def __init__(self,
pretrained_config='./output_pretraining/action_history/checkpoint-500/config.json',
pretrained_model='./output_pretraining/action_history/checkpoint-500/pytorch_model.bin'):
super(CommonModel, self).__init__()
self.config = BertConfig.from_pretrained(pretrained_config)
self.model = BertModel.from_pretrained(pretrained_model, config=self.config)
self.lstm = nn.LSTM(self.config.hidden_size, self.config.hidden_size, batch_first=True)
self.regressor = nn.Linear(self.config.hidden_size, 2) # 2クラスで
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
outputs = self.model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
out, _ = self.lstm(outputs[0], None)
sequence_output = out[:, -1, :]
logits = self.regressor(sequence_output)
return logits
def loss_fn(self, logits, label):
loss = torch.sqrt(nn.MSELoss(reduction='mean')(logits, label))
return loss
Dataset
TabBERTでは専用のTokenizerがないため、事前学習ではTransactionDatasetで辞書データの作成とtoken→id変換を行いました。
Fine-Tuningでもこれらを再利用したいので、TransactionDatasetを継承したDataset(FineTuningDataset)を新しく作成します。
Datasetは元々のデータをすべて持ちますが、__getitem__
では、指定したindexの入力データ(平坦化で10連結)と連結データに対応する正解ラベル(Window label)をペアでTensorで返します。
CommonModel
で予測ラベルを2クラスで出力するため、こちらでも正解ラベルを2クラス(one-hot)で出力するようにします。
init_vocab
とformat_trans
では、事前学習のときに作成保存した辞書データを読み込んでtoken→idに変換する処理を行います。
class FineTuningDataset(TransactionDataset):
# 平坦化のためLabelもWindowごとにまとめる
def __getitem__(self, index):
one_hot_window_label = F.one_hot(torch.tensor(self.window_label[index]), num_classes=2)
return_data = (torch.tensor(self.data[index], dtype=torch.long), one_hot_window_label.tolist())
return return_data
# pre-trainingでset_idが完了している
def init_vocab(self):
column_names = list(self.trans_table.columns)
self.vocab.set_field_keys(column_names)
# pre-trainingで保存した辞書でtoken2idをおこなう
def format_trans(self, trans_lst, column_names):
with open('./output_pretraining/credit_card/vocab_token2id.bin', 'rb') as p:
vocab_dic = pickle.load(p)
trans_lst = list(divide_chunks(trans_lst, len(self.vocab.field_keys) - 2)) # 2 to ignore isFraud and SPECIAL
user_vocab_ids = []
sep_id = self.vocab.get_id(self.vocab.sep_token, special_token=True)
for trans in trans_lst:
vocab_ids = []
for jdx, field in enumerate(trans):
vocab_id, _ = vocab_dic[column_names[jdx]][field]
vocab_ids.append(vocab_id)
# TODO : need to handle ncols when sep is not added
if self.mlm: # and self.flatten: # only add [SEP] for BERT + flatten scenario
vocab_ids.append(sep_id)
user_vocab_ids.append(vocab_ids)
return user_vocab_ids
DataCollator
DataCollatorForLanguageModeling
を継承したFineTuningDataCollatorForLanguageModeling
を作成し、input_idsとlabels(正解ラベル)を辞書型で返します。
このようにすることでDataLoaderからもデータを辞書型で取り出すことができるようになります。
class FineTuningDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
input_ids = []
labels = []
for example in examples:
input_ids.append(example[0])
labels.append(example[1])
batch_input_ids = self._tensorize_batch(input_ids)
batch_labels = torch.tensor(labels)
return {"input_ids": batch_input_ids, "label": batch_labels}
DataLoader
DataLoaderは、Datasetのインスタンスを渡すことで、ミニバッチ化した後のデータを返します。
Pytorchで用意されているクラスを使えばよく、実装する必要はありません。
train_loader = DataLoader(
train_dataset,
collate_fn=data_collator,
batch_size=BS,
pin_memory=True,
shuffle=True,
drop_last=True,
num_workers=0)
optimizer
論文のようにBERT層のパラメータのみをフリーズするため、まずはモデルのすべてのパラメータをフリーズした後、LSTM層とLinear層のみフリーズから解放します。
最適化アルゴリズムとしてはAdamWを選択しています。
また、エポックごとに学習率を調整するためのschedulerも用意します。
# set models
model = CommonModel()
model.to(device)
# freeze parameters in all network
for name, param in model.named_parameters():
param.requires_grad = False
# activate parameters in only lstm network
for name, param in model.lstm.named_parameters():
param.requires_grad = True
# activate parameters in only linear network
for name, param in model.regressor.named_parameters():
param.requires_grad = True
# set optimizer
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
max_train_steps = N_EPOCHS * len(train_loader)
warmup_steps = int(max_train_steps * WARM_UP_RATIO)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max_train_steps
)
勾配計算とパラメータ最適化
1エポック内でミニバッチのサイズ分(32)のループをまわします。
ループ内では、まずmodel.train()
でネットワークを学習モードにします。
計算された損失loss
に対してloss.backward()
すると、requires_grad = True
となっているtorch.Tensorについて誤差逆伝播で勾配計算が行われます。
この状態でoptimizer.step()
を実行すると、学習率に応じてパラメータの重みの更新が行われます。
最後に、計算された勾配結果をoptimizer.zero_grad()
で0にリセットします。
for epoch in range(N_EPOCHS):
for d in train_loader:
all_step += 1
model.train()
logits = model(
d["input_ids"].to(device),
attention_mask=None,
token_type_ids=None
)
loss = model.loss_fn(logits, d["label"].float().to(device))
loss = loss / ACCUMULATE
train_iter_loss += loss.item()
loss.backward()
if all_step % ACCUMULATE == 0:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
valid_loss = validation_loop(valid_loader, model)
if valid_best_loss > valid_loss:
valid_best_loss = valid_loss
train_iter_loss = 0
bar.update(1)
検証データについても同様に損失を計算しています。
こちらではmodel.eval()
でネットワークを推論モードに切り替えています。
def validation_loop(valid_loader, model):
model.eval()
preds = []
true = []
for d in valid_loader:
with torch.no_grad():
logits = model(
d["input_ids"].to(device),
attention_mask=None,
token_type_ids=None
)
preds.append(logits)
true.append(d["label"].float().to(device))
y_pred = torch.hstack(preds).cpu().numpy() # tensor連結してndarrayに変換
y_true = torch.hstack(true).cpu().numpy()
return mean_squared_error(y_true, y_pred, squared=False)
パラメータ保存
パラメータ最適化のループが終わったら、最後にパラメータを保存します。
このとき、state_dict()
を使うことでネットワーク構造や各レイヤの引数といったムダな情報を取り除き、必要な情報のみを保存することができます。
torch.save(model.state_dict(), args.output_model_dir)
参考資料