2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ALBERTを用いたテキスト分類

Last updated at Posted at 2023-01-12

ALBERTとは

近年のNLPにおいて、もっとも重要なモデルであるBERTを軽量化したモデルです。
名前の由来は(A Lite BERT)の略です。

詳細に関してはこの記事がわかりやすいのでおすすめです。

利用するデータセット

今回は、Hugging Faceが提供する「ag_news」データセットを利用します。
このデータセットは「ニューストピック」と「ジャンル」のみのシンプルな構成となっております。

ジャンルは「海外ニュース」「スポーツ」「ビジネス」「科学技術」の4つのジャンルからなります。

また、データは「訓練データ」と「テスト」データに分かれており、訓練データが全120000件、テストデータが7600件となっております。

今回作成するテキスト分類モデルの概要

今回は、ALBERTの事前学習済みモデルである「albert-base-v1」を利用し、
ファインチューニングを行っていきます。

最終層の出力に全結合層を組み合わせる形をとります。

全体の概要図を以下に記載します。

AG_NEWS.png

また、学習フェーズの処理の大まか流れは以下です。

1.入力テキストに対して、前処理を行います。
2.トークナイザーを用いてトークン化
3.モデルに入力(ミニバッチ学習)
4.誤差計算(Softmax + クロスエントロピー)
5.最適化処理(Adam)

Pytorchで実装 & 学習

1. データセット処理

まずは、データセットを読み込み、独自のデータセットクラスで扱えるようにします。

python上から元のデータセットを取得できるようにします。
事前にdatasetsモジュールをpipでインストールしておき、以下のコードで元データセットの取得ができます。

ag_news_classificatoin_train.py
from datasets import load_dataset
dataset = load_dataset("ag_news")

pytorchのDatasetクラスを拡張し、AG Newsのデータセット用のクラスを作成します。
ALBERTへの入力は入力文をトークン化したものとなります。
その為、データセットの文をトークンに変換する処理をここで行います。

トークン化は以下の部分となります。
encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)

また、ALBERT(BERT)への入力の最大値は「512」トークンです。
実行環境の制約などにより、メモリを確保できない場合などは、文を中略などして、簡略化することもありますが、今回はこのままいきます。

トークナイズしたデータをpytorchのTensor型にしておきます。

ag_news_classificatoin_train.py
class AgNewsDataSet(Dataset):
    __data = []
    __is_test = False

    def __init__(self, data_dict, is_test=False):
        self.__data = data_dict

    def __getitem__(self, idx):
        text = self.__data[idx]["text"]
        label = self.__data[idx]["label"]
        clean_text = text_cleaning(text)
        encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)

        input_ids = torch.tensor(encoded["input_ids"], dtype=torch.int32)
        attention_mask = torch.tensor(encoded["attention_mask"], dtype=torch.int32)
        token_type_ids = torch.tensor(encoded["token_type_ids"], dtype=torch.int32)

        label_data = None
        if not self.__is_test:
            label_data = torch.tensor(label, dtype=torch.float32)
        return input_ids, attention_mask, token_type_ids, label_data

    def __len__(self):
        return len(self.__data)

AG News用のデータセットクラスを作成できたので、
訓練データを更に、「実際に訓練に利用するデータ」と「検証用のデータ」に分割します。

今回は、適当に12万件のデータセットを1万件は検証用データとしました。

ag_news_classificatoin_train.py

TOTAL_TRAIN_DATA_SIZE = 120000 # 学習データ全件数
SPLIT_SIZE = 110000 # 学習データの内、学習に利用するデータの数(残りは検証用に)

train_data = BaseDataSet(dataset["train"], False)
train_data, valid_data = torch.utils.data.random_split(
    train_data,
    [SPLIT_SIZE, TOTAL_TRAIN_DATA_SIZE - SPLIT_SIZE]
)

以上で、データセット関連は完了です。

2. モデル作成

モデルを作成します。

モデルは、albertの最終層の先頭(CLSトークン)を全結合層の入力として分類モデルとします。
よって、最終層の先頭の次元768を全結合層の入力とし、出力層4のモデルとなります。

CLSトークンは特殊なトークンで、文全体の分散表現として扱うことが利用することができます。

また、後に出てきますが、pytorchのCrossEntropyLossはその中でsoftmaxもやってくれるので、モデルクラスの中には不要です。

ag_news_classificatoin_train.py
class AgNewsClassifyModel(nn.Module):

    def __init__(self):
        super(AgNewsClassifyModel, self).__init__()
        self.bert = AutoModel.from_pretrained(MODEL_NAME)  # MODEL_NAME = "albert-base-v1"
        self.linear = nn.Linear(768, 4)

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
        )
        output = outputs[0]
        output = output[:, 0, :]  # 最終層の先頭だけを取得
        output = self.linear(output)
        return output

3. 学習部分実装

学習部分を実装していきます。

ag_news_classificatoin_train.py
def exec_train(dataloader, model, loss_fn, optimizer):
    train_loss = 0
    size = len(dataloader.dataset)
    for input_ids, attention_mask, token_type_ids, label_data in dataloader:
        mem_input_ids = input_ids.to(device)
        mem_attention_mask = attention_mask.to(device)
        mem_token_type_ids = token_type_ids.to(device)
        mem_label_data = label_data.to(device)

        preds = model(mem_input_ids, mem_attention_mask, mem_token_type_ids)  # 予測値計算
        loss = loss_fn(preds, mem_label_data.long())  # 誤差計算(この中でsoftmaxの処理も)
        optimizer.zero_grad()
        loss.backward()  # 誤差伝播
        optimizer.step()  # パラメータ更新
        train_loss += loss.item()
        # print(loss.item())

    print("loss avg: " + str(train_loss/size))
    return train_loss/size

# -------------------- 中略 --------------------------

model = AgNewsClassifyModel().to(device)
# 最適化 -> Adam
optimizer = torch.optim.Adam(params, lr=LEARNING_RATE)
# 損失関数 -> 交差エントロピー誤差
loss_fn = nn.CrossEntropyLoss()

loss_history = []
correct_rate_history = []
for t in range(EPOCH_NUM):
    print("------------------------------------------------")
    time_sta = time.time()
    print("epoch : " + str(t))
    exec_train(train_dataloader, model, loss_fn, optimizer)
    loss = exec_validation(valid_dataloader, model, loss_fn)
    loss_history.append(loss)
    time_end = time.time()
    # 経過時間(秒)
    tim = time_end - time_sta
    print("elapsed time : " + str(tim))

今回、ハイパーパラメータは以下のように設定しました。

項目
EPOCH数 5
ミニバッチサイズ 64
学習率 1e-5

いざ学習・・・

結果
------------------------------------------------
epoch : 0
loss avg: 0.004363666541129351
correct_count : 9204
validation loss avg : 0.003679560386762023
elapsed time : 1038.508870124817
------------------------------------------------
epoch : 1
loss avg: 0.002958414049006321
correct_count : 9283
validation loss avg : 0.003257966047525406
elapsed time : 1035.072437286377
------------------------------------------------
epoch : 2
loss avg: 0.0022726741611618887
correct_count : 9302
validation loss avg : 0.0031651282478123905
elapsed time : 1035.3001074790955
------------------------------------------------
epoch : 3
loss avg: 0.0016538722487251189
correct_count : 9295
validation loss avg : 0.0033557460997253657
elapsed time : 1035.2703726291656
------------------------------------------------
epoch : 4
loss avg: 0.0010704156845850361
correct_count : 9255
validation loss avg : 0.0037559795673936607
elapsed time : 1035.291580438614

完了したら、torch.save(model.state_dict(), "保存先パス")で保存します。
これで、学習は完了です。

性能評価

用意されているテストデータ全件に対して、上記で作成したモデルを用いて予測していきます。

output = model(mem_input_ids, mem_attention_mask, mem_token_type_ids) # 予測値計算
pred = soft_max_f(output)

この部分に関してですが、outputは作成したモデルを順伝播した値が出力され、それに対してSoftMax関数を適用することで、
[ラベル0の確率, ラベル1の確率, ラベル2の確率, ラベル3の確率](Tensor型)が出力されます。

この中で最も値の大きい(確率の高い)ラベルを、 max, argmax = torch.max(pred, dim=1)で取得してます。

ag_news_classificatoin_predict.py
test_data = dataset["test"]
total_data_num = len(test_data)
correct_count = 0

print("Total Data Num : " + str(total_data_num))

model = AgNewsClassifyModel().to(device)
model.load_state_dict(torch.load(model_path))

for data in test_data:
    correct_label = data["label"]
    clean_text = text_cleaning(data["text"])
    encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)

    input_ids = torch.tensor([encoded["input_ids"]], dtype=torch.int32)
    attention_mask = torch.tensor([encoded["attention_mask"]], dtype=torch.int32)
    token_type_ids = torch.tensor([encoded["token_type_ids"]], dtype=torch.int32)

    model.eval()
    soft_max_f = nn.Softmax(dim=1)
    with torch.no_grad():
        mem_input_ids = input_ids.to(device)
        mem_attention_mask = attention_mask.to(device)
        mem_token_type_ids = token_type_ids.to(device)
        
        output = model(mem_input_ids, mem_attention_mask, mem_token_type_ids)  # 予測値計算
        pred = soft_max_f(output)
        max, argmax = torch.max(pred, dim=1)

        predict_score = pred[0][argmax[0]]

        if correct_label == argmax[0]:
            correct_count += 1

correct_rate = correct_count / total_data_num * 100
print("correct_rate : " + str(correct_rate) + "%")
結果
correct_rate : 93.07894736842105%

未知の7600件の記事データに対して、約93%は正しい、ジャンルへと分類ができました。

まだまだ、改善の余地はありますね。

以上です。

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?