LoginSignup
2

ALBERTを用いたテキスト分類

Posted at

ALBERTとは

近年のNLPにおいて、もっとも重要なモデルであるBERTを軽量化したモデルです。
名前の由来は(A Lite BERT)の略です。

詳細に関してはこの記事がわかりやすいのでおすすめです。

利用するデータセット

今回は、Hugging Faceが提供する「ag_news」データセットを利用します。
このデータセットは「ニューストピック」と「ジャンル」のみのシンプルな構成となっております。

ジャンルは「海外ニュース」「スポーツ」「ビジネス」「科学技術」の4つのジャンルからなります。

また、データは「訓練データ」と「テスト」データに分かれており、訓練データが全120000件、テストデータが7600件となっております。

今回作成するテキスト分類モデルの概要

今回は、ALBERTの事前学習済みモデルである「albert-base-v1」を利用し、
ファインチューニングを行っていきます。

最終層の出力に全結合層を組み合わせる形をとります。

全体の概要図を以下に記載します。

AG_NEWS.png

また、学習フェーズの処理の大まか流れは以下です。

1.入力テキストに対して、前処理を行います。
2.トークナイザーを用いてトークン化
3.モデルに入力(ミニバッチ学習)
4.誤差計算(Softmax + 交差クロスエントロピー)
5.最適化処理(Adam)

Pytorchで実装 & 学習

1. データセット処理

まずは、データセットを読み込み、独自のデータセットクラスで扱えるようにします。

python上から元のデータセットを取得できるようにします。
事前にdatasetsモジュールをpipでインストールしておき、以下のコードで元データセットの取得ができます。

ag_news_classificatoin_train.py
from datasets import load_dataset
dataset = load_dataset("ag_news")

pytorchのDatasetクラスを拡張し、AG Newsのデータセット用のクラスを作成します。
ALBERTへの入力は入力文をトークン化したものとなります。
その為、データセットの文をトークンに変換する処理をここで行います。

トークン化は以下の部分となります。
encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)

また、ALBERT(BERT)への入力の最大値は「512」トークンです。
実行環境の制約などにより、メモリを確保できない場合などは、文を中略などして、簡略化することもありますが、今回はこのままいきます。

トークナイズしたデータをpytorchのTensor型にしておきます。

ag_news_classificatoin_train.py
class AgNewsDataSet(Dataset):
    __data = []
    __is_test = False

    def __init__(self, data_dict, is_test=False):
        self.__data = data_dict

    def __getitem__(self, idx):
        text = self.__data[idx]["text"]
        label = self.__data[idx]["label"]
        clean_text = text_cleaning(text)
        encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)

        input_ids = torch.tensor(encoded["input_ids"], dtype=torch.int32)
        attention_mask = torch.tensor(encoded["attention_mask"], dtype=torch.int32)
        token_type_ids = torch.tensor(encoded["token_type_ids"], dtype=torch.int32)

        label_data = None
        if not self.__is_test:
            label_data = torch.tensor(label, dtype=torch.float32)
        return input_ids, attention_mask, token_type_ids, label_data

    def __len__(self):
        return len(self.__data)

AG News用のデータセットクラスを作成できたので、
訓練データを更に、「実際に訓練に利用するデータ」と「検証用のデータ」に分割します。

今回は、適当に12万件のデータセットを1万件は検証用データとしました。

ag_news_classificatoin_train.py

TOTAL_TRAIN_DATA_SIZE = 120000 # 学習データ全件数
SPLIT_SIZE = 110000 # 学習データの内、学習に利用するデータの数(残りは検証用に)

train_data = BaseDataSet(dataset["train"], False)
train_data, valid_data = torch.utils.data.random_split(
    train_data,
    [SPLIT_SIZE, TOTAL_TRAIN_DATA_SIZE - SPLIT_SIZE]
)

以上で、データセット関連は完了です。

2. モデル作成

モデルを作成します。

モデルは、albertの最終層の先頭(CLSトークン)を全結合層の入力として分類モデルとします。
よって、最終層の先頭の次元768を全結合層の入力とし、出力層4のモデルとなります。

CLSトークンは特殊なトークンで、文全体の分散表現として扱うことが利用することができます。

また、後に出てきますが、pytorchのCrossEntropyLossはその中でsoftmaxもやってくれるので、モデルクラスの中には不要です。

ag_news_classificatoin_train.py
class AgNewsClassifyModel(nn.Module):

    def __init__(self):
        super(AgNewsClassifyModel, self).__init__()
        self.bert = AutoModel.from_pretrained(MODEL_NAME)  # MODEL_NAME = "albert-base-v1"
        self.linear = nn.Linear(768, 4)

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
        )
        output = outputs[0]
        output = output[:, 0, :]  # 最終層の先頭だけを取得
        output = self.linear(output)
        return output

3. 学習部分実装

学習部分を実装していきます。

ag_news_classificatoin_train.py
def exec_train(dataloader, model, loss_fn, optimizer):
    train_loss = 0
    size = len(dataloader.dataset)
    for input_ids, attention_mask, token_type_ids, label_data in dataloader:
        mem_input_ids = input_ids.to(device)
        mem_attention_mask = attention_mask.to(device)
        mem_token_type_ids = token_type_ids.to(device)
        mem_label_data = label_data.to(device)

        preds = model(mem_input_ids, mem_attention_mask, mem_token_type_ids)  # 予測値計算
        loss = loss_fn(preds, mem_label_data.long())  # 誤差計算(この中でsoftmaxの処理も)
        optimizer.zero_grad()
        loss.backward()  # 誤差伝播
        optimizer.step()  # パラメータ更新
        train_loss += loss.item()
        # print(loss.item())

    print("loss avg: " + str(train_loss/size))
    return train_loss/size

# -------------------- 中略 --------------------------

model = AgNewsClassifyModel().to(device)
# 最適化 -> Adam
optimizer = torch.optim.Adam(params, lr=LEARNING_RATE)
# 損失関数 -> 交差エントロピー誤差
loss_fn = nn.CrossEntropyLoss()

loss_history = []
correct_rate_history = []
for t in range(EPOCH_NUM):
    print("------------------------------------------------")
    time_sta = time.time()
    print("epoch : " + str(t))
    exec_train(train_dataloader, model, loss_fn, optimizer)
    loss = exec_validation(valid_dataloader, model, loss_fn)
    loss_history.append(loss)
    time_end = time.time()
    # 経過時間(秒)
    tim = time_end - time_sta
    print("elapsed time : " + str(tim))

今回、ハイパーパラメータは以下のように設定しました。

項目
EPOCH数 5
ミニバッチサイズ 64
学習率 1e-5

いざ学習・・・

結果
------------------------------------------------
epoch : 0
loss avg: 0.004363666541129351
correct_count : 9204
validation loss avg : 0.003679560386762023
elapsed time : 1038.508870124817
------------------------------------------------
epoch : 1
loss avg: 0.002958414049006321
correct_count : 9283
validation loss avg : 0.003257966047525406
elapsed time : 1035.072437286377
------------------------------------------------
epoch : 2
loss avg: 0.0022726741611618887
correct_count : 9302
validation loss avg : 0.0031651282478123905
elapsed time : 1035.3001074790955
------------------------------------------------
epoch : 3
loss avg: 0.0016538722487251189
correct_count : 9295
validation loss avg : 0.0033557460997253657
elapsed time : 1035.2703726291656
------------------------------------------------
epoch : 4
loss avg: 0.0010704156845850361
correct_count : 9255
validation loss avg : 0.0037559795673936607
elapsed time : 1035.291580438614

完了したら、torch.save(model.state_dict(), "保存先パス")で保存します。
これで、学習は完了です。

性能評価

用意されているテストデータ全件に対して、上記で作成したモデルを用いて予測していきます。

output = model(mem_input_ids, mem_attention_mask, mem_token_type_ids) # 予測値計算
pred = soft_max_f(output)

この部分に関してですが、outputは作成したモデルを順伝播した値が出力され、それに対してSoftMax関数を適用することで、
[ラベル0の確率, ラベル1の確率, ラベル2の確率, ラベル3の確率](Tensor型)が出力されます。

この中で最も値の大きい(確率の高い)ラベルを、 max, argmax = torch.max(pred, dim=1)で取得してます。

ag_news_classificatoin_predict.py
test_data = dataset["test"]
total_data_num = len(test_data)
correct_count = 0

print("Total Data Num : " + str(total_data_num))

model = AgNewsClassifyModel().to(device)
model.load_state_dict(torch.load(model_path))

for data in test_data:
    correct_label = data["label"]
    clean_text = text_cleaning(data["text"])
    encoded = tokenizer(clean_text, padding="max_length", max_length=512, truncation=True)

    input_ids = torch.tensor([encoded["input_ids"]], dtype=torch.int32)
    attention_mask = torch.tensor([encoded["attention_mask"]], dtype=torch.int32)
    token_type_ids = torch.tensor([encoded["token_type_ids"]], dtype=torch.int32)

    model.eval()
    soft_max_f = nn.Softmax(dim=1)
    with torch.no_grad():
        mem_input_ids = input_ids.to(device)
        mem_attention_mask = attention_mask.to(device)
        mem_token_type_ids = token_type_ids.to(device)
        
        output = model(mem_input_ids, mem_attention_mask, mem_token_type_ids)  # 予測値計算
        pred = soft_max_f(output)
        max, argmax = torch.max(pred, dim=1)

        predict_score = pred[0][argmax[0]]

        if correct_label == argmax[0]:
            correct_count += 1

correct_rate = correct_count / total_data_num * 100
print("correct_rate : " + str(correct_rate) + "%")
結果
correct_rate : 93.07894736842105%

未知の7600件の記事データに対して、約93%は正しい、ジャンルへと分類ができました。

まだまだ、改善の余地はありますね。

以上です。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
2