ALBERTとは
近年のNLPにおいて、もっとも重要なモデルであるBERTを軽量化したモデルです。
名前の由来は(A Lite BERT)の略です。
詳細に関してはこの記事がわかりやすいのでおすすめです。
利用するデータセット
今回は、Hugging Faceが提供する「ag_news」データセットを利用します。
このデータセットは「ニューストピック」と「ジャンル」のみのシンプルな構成となっております。
ジャンルは「海外ニュース」「スポーツ」「ビジネス」「科学技術」の4つのジャンルからなります。
また、データは「訓練データ」と「テスト」データに分かれており、訓練データが全120000件、テストデータが7600件となっております。
今回作成するテキスト分類モデルの概要
今回は、ALBERTの事前学習済みモデルである「albert-base-v1」を利用し、
ファインチューニングを行っていきます。
最終層の出力に全結合層を組み合わせる形をとります。
全体の概要図を以下に記載します。
また、学習フェーズの処理の大まか流れは以下です。
1.入力テキストに対して、前処理を行います。
2.トークナイザーを用いてトークン化
3.モデルに入力(ミニバッチ学習)
4.誤差計算(Softmax + クロスエントロピー)
5.最適化処理(Adam)
Pytorchで実装 & 学習
1. データセット処理
まずは、データセットを読み込み、独自のデータセットクラスで扱えるようにします。
python上から元のデータセットを取得できるようにします。
事前にdatasetsモジュールをpipでインストールしておき、以下のコードで元データセットの取得ができます。
from datasets import load_dataset
dataset = load_dataset("ag_news")
pytorchのDatasetクラスを拡張し、AG Newsのデータセット用のクラスを作成します。
ALBERTへの入力は入力文をトークン化したものとなります。
その為、データセットの文をトークンに変換する処理をここで行います。
トークン化は以下の部分となります。
encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)
また、ALBERT(BERT)への入力の最大値は「512」トークンです。
実行環境の制約などにより、メモリを確保できない場合などは、文を中略などして、簡略化することもありますが、今回はこのままいきます。
トークナイズしたデータをpytorchのTensor型にしておきます。
class AgNewsDataSet(Dataset):
__data = []
__is_test = False
def __init__(self, data_dict, is_test=False):
self.__data = data_dict
def __getitem__(self, idx):
text = self.__data[idx]["text"]
label = self.__data[idx]["label"]
clean_text = text_cleaning(text)
encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)
input_ids = torch.tensor(encoded["input_ids"], dtype=torch.int32)
attention_mask = torch.tensor(encoded["attention_mask"], dtype=torch.int32)
token_type_ids = torch.tensor(encoded["token_type_ids"], dtype=torch.int32)
label_data = None
if not self.__is_test:
label_data = torch.tensor(label, dtype=torch.float32)
return input_ids, attention_mask, token_type_ids, label_data
def __len__(self):
return len(self.__data)
AG News用のデータセットクラスを作成できたので、
訓練データを更に、「実際に訓練に利用するデータ」と「検証用のデータ」に分割します。
今回は、適当に12万件のデータセットを1万件は検証用データとしました。
TOTAL_TRAIN_DATA_SIZE = 120000 # 学習データ全件数
SPLIT_SIZE = 110000 # 学習データの内、学習に利用するデータの数(残りは検証用に)
train_data = BaseDataSet(dataset["train"], False)
train_data, valid_data = torch.utils.data.random_split(
train_data,
[SPLIT_SIZE, TOTAL_TRAIN_DATA_SIZE - SPLIT_SIZE]
)
以上で、データセット関連は完了です。
2. モデル作成
モデルを作成します。
モデルは、albertの最終層の先頭(CLSトークン)を全結合層の入力として分類モデルとします。
よって、最終層の先頭の次元768を全結合層の入力とし、出力層4のモデルとなります。
CLSトークンは特殊なトークンで、文全体の分散表現として扱うことが利用することができます。
また、後に出てきますが、pytorchのCrossEntropyLossはその中でsoftmaxもやってくれるので、モデルクラスの中には不要です。
class AgNewsClassifyModel(nn.Module):
def __init__(self):
super(AgNewsClassifyModel, self).__init__()
self.bert = AutoModel.from_pretrained(MODEL_NAME) # MODEL_NAME = "albert-base-v1"
self.linear = nn.Linear(768, 4)
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.bert(
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
)
output = outputs[0]
output = output[:, 0, :] # 最終層の先頭だけを取得
output = self.linear(output)
return output
3. 学習部分実装
学習部分を実装していきます。
def exec_train(dataloader, model, loss_fn, optimizer):
train_loss = 0
size = len(dataloader.dataset)
for input_ids, attention_mask, token_type_ids, label_data in dataloader:
mem_input_ids = input_ids.to(device)
mem_attention_mask = attention_mask.to(device)
mem_token_type_ids = token_type_ids.to(device)
mem_label_data = label_data.to(device)
preds = model(mem_input_ids, mem_attention_mask, mem_token_type_ids) # 予測値計算
loss = loss_fn(preds, mem_label_data.long()) # 誤差計算(この中でsoftmaxの処理も)
optimizer.zero_grad()
loss.backward() # 誤差伝播
optimizer.step() # パラメータ更新
train_loss += loss.item()
# print(loss.item())
print("loss avg: " + str(train_loss/size))
return train_loss/size
# -------------------- 中略 --------------------------
model = AgNewsClassifyModel().to(device)
# 最適化 -> Adam
optimizer = torch.optim.Adam(params, lr=LEARNING_RATE)
# 損失関数 -> 交差エントロピー誤差
loss_fn = nn.CrossEntropyLoss()
loss_history = []
correct_rate_history = []
for t in range(EPOCH_NUM):
print("------------------------------------------------")
time_sta = time.time()
print("epoch : " + str(t))
exec_train(train_dataloader, model, loss_fn, optimizer)
loss = exec_validation(valid_dataloader, model, loss_fn)
loss_history.append(loss)
time_end = time.time()
# 経過時間(秒)
tim = time_end - time_sta
print("elapsed time : " + str(tim))
今回、ハイパーパラメータは以下のように設定しました。
項目 | 値 |
---|---|
EPOCH数 | 5 |
ミニバッチサイズ | 64 |
学習率 | 1e-5 |
いざ学習・・・
------------------------------------------------
epoch : 0
loss avg: 0.004363666541129351
correct_count : 9204
validation loss avg : 0.003679560386762023
elapsed time : 1038.508870124817
------------------------------------------------
epoch : 1
loss avg: 0.002958414049006321
correct_count : 9283
validation loss avg : 0.003257966047525406
elapsed time : 1035.072437286377
------------------------------------------------
epoch : 2
loss avg: 0.0022726741611618887
correct_count : 9302
validation loss avg : 0.0031651282478123905
elapsed time : 1035.3001074790955
------------------------------------------------
epoch : 3
loss avg: 0.0016538722487251189
correct_count : 9295
validation loss avg : 0.0033557460997253657
elapsed time : 1035.2703726291656
------------------------------------------------
epoch : 4
loss avg: 0.0010704156845850361
correct_count : 9255
validation loss avg : 0.0037559795673936607
elapsed time : 1035.291580438614
完了したら、torch.save(model.state_dict(), "保存先パス")
で保存します。
これで、学習は完了です。
性能評価
用意されているテストデータ全件に対して、上記で作成したモデルを用いて予測していきます。
output = model(mem_input_ids, mem_attention_mask, mem_token_type_ids) # 予測値計算
pred = soft_max_f(output)
この部分に関してですが、outputは作成したモデルを順伝播した値が出力され、それに対してSoftMax関数を適用することで、
[ラベル0の確率, ラベル1の確率, ラベル2の確率, ラベル3の確率](Tensor型)
が出力されます。
この中で最も値の大きい(確率の高い)ラベルを、 max, argmax = torch.max(pred, dim=1)
で取得してます。
test_data = dataset["test"]
total_data_num = len(test_data)
correct_count = 0
print("Total Data Num : " + str(total_data_num))
model = AgNewsClassifyModel().to(device)
model.load_state_dict(torch.load(model_path))
for data in test_data:
correct_label = data["label"]
clean_text = text_cleaning(data["text"])
encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)
input_ids = torch.tensor([encoded["input_ids"]], dtype=torch.int32)
attention_mask = torch.tensor([encoded["attention_mask"]], dtype=torch.int32)
token_type_ids = torch.tensor([encoded["token_type_ids"]], dtype=torch.int32)
model.eval()
soft_max_f = nn.Softmax(dim=1)
with torch.no_grad():
mem_input_ids = input_ids.to(device)
mem_attention_mask = attention_mask.to(device)
mem_token_type_ids = token_type_ids.to(device)
output = model(mem_input_ids, mem_attention_mask, mem_token_type_ids) # 予測値計算
pred = soft_max_f(output)
max, argmax = torch.max(pred, dim=1)
predict_score = pred[0][argmax[0]]
if correct_label == argmax[0]:
correct_count += 1
correct_rate = correct_count / total_data_num * 100
print("correct_rate : " + str(correct_rate) + "%")
correct_rate : 93.07894736842105%
未知の7600件の記事データに対して、約93%は正しい、ジャンルへと分類ができました。
まだまだ、改善の余地はありますね。
以上です。