7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

BERTを超えた? XLNet を実際に使ってみた

Posted at

この記事について

これまで自然言語処理の様々なタスクにおいて、「BERT」が最も精度が高く有効だとされてきました。しかし、2019年6月にCarnegie Mellon大学とGoogle Brainの研究チームから発表された「XLNet」というモデルが「BERT」を超えたと囁かれています。
今回は実際にSMS Spam Collection Data Setを用いて「XLNet」を試していきたいと思います。:wink:

##BERTとは
現在、既に多くの「BERT」についての解説がされていますので、ここでは簡単に説明します。
単語のone-hotエンコーディングでは、ベクトル間の演算で意味のある結果を得られないという欠点がありました。それを克服したのがword2vecやGloVeといった単語分散表現手法です。これによりベクトル間の演算で意味のある結果を得られるようになりました。例えば、内積を求めることで単語同士の類似度を求めることができます。
しかし、この手法にも課題がありました。それは文脈を考慮できないということです。そのため、「python」がヘビなのかプログラミング言語なのかを見分けることができませんでした。:dizzy_face:
この課題に対し、文中の他の単語の情報も使って各単語の分散表現を生成することで解決を図ったのがBERTをはじめとするモデルです。
BERT以前にも文脈を考慮した分散表現を得られるモデルはGTPやELMo等ありましたが、BERTは双方向かつ深いネットワークによりこれらに比べ良い表現を獲得することに成功しました。
BERTで重要なのは以下の2つです。

  • 事前学習(マスクされた単語の予測、与えた2文が隣接文かの予測)
  • Fine-tuning(ラベル付きデータを使って解きたいタスクを学ばせる)

詳しい説明は下の記事がわかりやすいかと思います。
https://qiita.com/omiita/items/72998858efc19a368e50
ちなみに、BERTとは、Bidirectional Encoder Representations from Transformers の略で、「Transformerによる双方向のエンコード表現」と訳されます。

##XLNetとは
素晴らしい成果を残したBERTですが、その事前学習には2つの問題点があると指摘されています。

  • [mask]という特殊な文字を使用しているので、それらが出現しないfine-tuning時に事前学習時と異なるバイアスを生むこと。
  • 与えられた文章中の[MASK]トークンの箇所は同時に予測されるため、それらの間にある依存関係を学習することはできない。

XLNetは、双方向の情報を同時に扱えるように「単語の予測順序を入れ替える」手法をとることでBERTの良いところを引き継ぎ、予測対象の単語同士の依存関係を学習できる自己回帰言語モデルとしてこれらの問題を解決したモデルとなっています。
詳しい説明はこちらを参考にされるのが良いかと思います。
(論文)https://arxiv.org/abs/1906.08237
https://ai-scholar.tech/articles/treatise/xlnet-ai-228
https://hackmd.io/@V1ia6ZG2RQ-KMNmPo74r5Q/rJxCgPwpr#XLNet
https://mc.ai/bertを超えたxlnetの紹介/

##XLNet 実装

使用データセットと環境

今回、「SMS Spam Collection Dataset」を使用します。
https://www.kaggle.com/uciml/sms-spam-collection-dataset
また、GPUを使用したいため、Google Colaboratoryを使用します。
Colaboratoryからノートブックを開き「ランタイム」→「ランタイムのタイプの変更」からハードウェア アクセラレータをGPUとしてください。

コード

準備

!pip install transformers

必要なライブラリーのインポート

import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split


from transformers import XLNetModel, XLNetTokenizer, XLNetForSequenceClassification
from transformers import AdamW
from tqdm import tqdm, trange
import pandas as pd
import io
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
% matplotlib inline

GPUの識別、指定。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)
実行結果
'Tesla T4'

データの読み込みと加工

# CSV読み込み
df_all = pd.read_csv('spam.csv', encoding='latin-1')
# 未使用列を削除
df_all.drop(['Unnamed: 2', 'Unnamed: 3', 'Unnamed: 4'], axis=1, inplace=True)
# 名前の変更(v1→target, v2→text)
df_all.rename(columns={"v1":"target", "v2":"text"}, inplace=True)
df_all.head()

image.png
データの形状確認、欠損値確認等

# データの形状
print("shape:{}".format(df_all.shape))

# spamとhamの個数
print(df_all.target.value_counts())

# 欠損値の個数
print(df_all.isnull().sum())
実行結果
shape:(5572, 2)
ham     4825
spam     747
Name: target, dtype: int64
target    0
text      0
dtype: int64

データ加工

# targetのham → 0, spam → 1に変更
type = {"ham":0,"spam":1}
df_all.target = df_all.target.map(type)

# testとしてデータの一部を取っておく
test = df_all[5000:]
df = df_all[:5000]

dfとtestでtargetの偏りがないかの確認

print("df")
print(df.target.value_counts())
print("test")
print(test.target.value_counts())
実行結果
df
0    4327
1     673
Name: target, dtype: int64
test
0    498
1     74
Name: target, dtype: int64

データの成形

sentences = df.text.values
sentences = [sentence + " [SEP] [CLS]" for sentence in sentences]
labels = df.target.values

インプット

tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=True)
tokenized_texts = [tokenizer.tokenize(sent) for sent in sentences]
print ("Tokenize the first sentence:")
print (tokenized_texts[0])
実行結果
Downloading: 100%
798k/798k [00:03<00:00, 231kB/s]

Tokenize the first sentence:
['▁go', '▁until', '▁', 'ju', 'rong', '▁point', ',', '▁crazy', '.', '.', '▁available', '▁only', '▁in', '▁bug', 'is', '▁', 'n', '▁great', '▁world', '▁la', '▁', 'e', '▁buffet', '.', '.', '.', '▁', 'cine', '▁there', '▁got', '▁a', 'more', '▁', 'wat', '.', '.', '.', '▁[', 's', 'ep', ']', '▁[', 'cl', 's', ']']

最大単語数の取得

max_len = []
# 1文づつ処理
for sent in sentences:
    # Tokenizeで分割
    token_words = tokenizer.tokenize(sent)
    # 文章数を取得してリストへ格納
    max_len.append(len(token_words))
# 最大の値を確認
print('最大単語数: ', max(max_len))
print('上記の最大単語数にSpecial token([CLS], [SEP])の+2をした値が最大単語数')
実行結果
最大単語数:  293
上記の最大単語数にSpecial token([CLS], [SEP])の+2をした値が最大単語数

つまり、この場合295が最大単語数。
しかし、maximum sequence lengthを295とすると
RuntimeError: CUDA out of memory. Tried to allocate 126.00 MiB (GPU 0; 14.73 GiB total capacity; 13.46 GiB already allocated; 23.88 MiB free; 13.71 GiB reserved in total by PyTorch)
を起こすため、単語数でヒストグラムを作成し、maximum sequence lengthを考える。

plt.title("token_size")
plt.hist(max_len,bins=20)

image.png
この結果よりmaximum sequence lengthを100とする。

データをXLNetが利用できる形にする。

  • input ids: a sequence of integers identifying each input token to its index number in the XLNet tokenizer vocabulary

  • segment mask: (optional) a sequence of 1s and 0s used to identify whether the input is one sentence or two sentences long. For one sentence inputs, this is simply a sequence of 0s. For two sentence inputs, there is a 0 for each token of the first sentence, followed by a 1 for each token of the second sentence

  • attention mask: (optional) a sequence of 1s and 0s, with 1s for all input tokens and 0s for all padding tokens (we’ll detail this in the next paragraph)

  • labels: a single value of 1 or 0. In our task 1 means “grammatical” and 0 means “ungrammatical”

# 最大シーケンス長
MAX_LEN = 100

# XLNetトークナイザーを使用して、トークンをXLNetボキャブラリのインデックス番号に変換
input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]

# 最大長に満たない場合は 0 で埋める
input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

# attention masksの作成
attention_masks = []

for seq in input_ids:
  seq_mask = [float(i>0) for i in seq]
  attention_masks.append(seq_mask)
# train set と validation set に分割

train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, labels, random_state=123, test_size=0.25)
train_masks, validation_masks, _, _ = train_test_split(attention_masks, input_ids,random_state=123, test_size=0.25)
# torch tensorsに変換
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)

# batch size の指定。fine-tuning の batch size は 32, 48, 128が良いらしい。今回はメモリの都合上 32 で行う。
batch_size = 32

# for loopを使わず、torch DataLoaderを使うことで、トレーニング時のメモリを節約できる。
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

モデルのトレーニング

事前学習済みモデルを分類タスクに適用する

# XLNEtForSequenceClassificationをロード。(上部に単一の線形分類レイヤーを持つ事前学習済みXLNetモデル。)

model = XLNetForSequenceClassification.from_pretrained("xlnet-base-cased",num_labels = 2)
model.cuda()
実行結果
Downloading: 100%
760/760 [00:00<00:00, 5.36kB/s]

Downloading: 100%
467M/467M [00:45<00:00, 10.2MB/s]

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
XLNetForSequenceClassification(
  (transformer): XLNetModel(
    (word_embedding): Embedding(32000, 768)
    (layer): ModuleList(
      (0): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (1): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (2): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (3): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (4): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (5): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (6): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (7): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (8): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (9): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (10): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (11): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (sequence_summary): SequenceSummary(
    (summary): Linear(in_features=768, out_features=768, bias=True)
    (first_dropout): Identity()
    (last_dropout): Dropout(p=0.1, inplace=False)
  )
  (logits_proj): Linear(in_features=768, out_features=2, bias=True)
)
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]
# この変数にハイパーパメーター情報がすべて含まれている
optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
# 精度を計算する関数の定義
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)
# プロットのためにlossと精度を保存する
train_loss_set = []

# training epoch数 ( 2~4が良いらしい)
epochs = 3

# trange is a tqdm wrapper around the normal python range
for _ in trange(epochs, desc="Epoch"):
  
  
  # Training
  
  # モデルをトレーニングモードに設定。
  model.train()
  
  # Tracking variables
  tr_loss = 0
  nb_tr_examples, nb_tr_steps = 0, 0
  
  # エポックごとにトレーニング
  for step, batch in enumerate(train_dataloader):
    # GPUにバッチを追加する
    batch = tuple(t.to(device) for t in batch)
    # dataloaderからinputsをUnpack
    b_input_ids, b_input_mask, b_labels = batch
    # Clear out the gradients (defaultでは、accumulate)
    optimizer.zero_grad()
    # Forward pass
    outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
    loss = outputs[0]
    logits = outputs[1]
    train_loss_set.append(loss.item())    
    # Backward pass
    loss.backward()
    # パラメータを更新し、計算された勾配を使用してステップを実行
    optimizer.step()
    
    
    # tracking variablesを更新
    tr_loss += loss.item()
    nb_tr_examples += b_input_ids.size(0)
    nb_tr_steps += 1

  print("Train loss: {}".format(tr_loss/nb_tr_steps))
 
  # Validation

  # モデルをevaluationモードにして、Validationセットのlossを評価
  model.eval()

  # Tracking variables 
  eval_loss, eval_accuracy = 0, 0
  nb_eval_steps, nb_eval_examples = 0, 0

  # エポックごとに評価
  for batch in validation_dataloader:
    # GPUにバッチを追加する
    batch = tuple(t.to(device) for t in batch)
    # dataloaderからinputsをUnpack
    b_input_ids, b_input_mask, b_labels = batch
    # gradientsを計算または保存しないようにモデルに指示し、メモリを節約して高速化する
    with torch.no_grad():
      # Forward pass, calculate logit predictions
      output = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
      logits = output[0]
    
    # logits と labelsをCPUに移動する
    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()

    tmp_eval_accuracy = flat_accuracy(logits, label_ids)
    
    eval_accuracy += tmp_eval_accuracy
    nb_eval_steps += 1

  print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
実行結果
Epoch:   0%|          | 0/3 [00:00<?, ?it/s]Train loss: 0.12498065767517724
Epoch:  33%|███▎      | 1/3 [00:55<01:51, 55.77s/it]Validation Accuracy: 0.98984375
Train loss: 0.030221276571263826
Epoch:  67%|██████▋   | 2/3 [01:51<00:55, 55.71s/it]Validation Accuracy: 0.98984375
Train loss: 0.012542888356650926
Epoch: 100%|██████████| 3/3 [02:46<00:00, 55.63s/it]Validation Accuracy: 0.99375

トレーニングセットの評価

すべてのバッチでのTraining lossの確認。

plt.figure(figsize=(15,10))
plt.title("Training loss")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.plot(train_loss_set)
plt.show()

image.png

Holdout Setの予測と評価

# sentence と label listの作成
sentences = test.text.values

# XLNetを機能させるために、各文に[SEP] [CLS]を追加
sentences = [sentence + " [SEP] [CLS]" for sentence in sentences]
labels = test.target.values

tokenized_texts = [tokenizer.tokenize(sent) for sent in sentences]


MAX_LEN = 100
# XLNet tokenizerを使用して、トークンをXLNet vocabularyのインデックス番号に変換
input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
# Pad our input tokens
input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
# attention masksの作成
attention_masks = []

for seq in input_ids:
  seq_mask = [float(i>0) for i in seq]
  attention_masks.append(seq_mask) 

prediction_inputs = torch.tensor(input_ids)
prediction_masks = torch.tensor(attention_masks)
prediction_labels = torch.tensor(labels)
  
batch_size = 32  


prediction_data = TensorDataset(prediction_inputs, prediction_masks, prediction_labels)
prediction_sampler = SequentialSampler(prediction_data)
prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)
# test setの予測

# モデルを評価モードにする
model.eval()

# Tracking variables 
predictions , true_labels = [], []

# Predict 
for batch in prediction_dataloader:
  # GPUにバッチを追加する
  batch = tuple(t.to(device) for t in batch)
  # dataloaderからinputsをUnpack
  b_input_ids, b_input_mask, b_labels = batch
  # gradientsを計算または保存しないようにモデルに指示し、メモリを節約して高速化する
  with torch.no_grad():
    # Forward pass, calculate logit predictions
    outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
    logits = outputs[0]

  # logits と labelsをCPUに移動する
  logits = logits.detach().cpu().numpy()
  label_ids = b_labels.to('cpu').numpy()
  
  # 予測とラベルを保存する
  predictions.append(logits)
  true_labels.append(label_ids)
flat_predictions = [item for sublist in predictions for item in sublist]
flat_predictions = np.argmax(flat_predictions, axis=1).flatten()
flat_true_labels = [item for sublist in true_labels for item in sublist]

各指標ごとに評価する

print(confusion_matrix(flat_true_labels, flat_predictions,labels=[0,1]))
実行結果
[[496   2]
 [  2  72]]
print("精度:{}".format(accuracy_score(flat_true_labels, flat_predictions)))
print("適合率:{}".format(precision_score(flat_true_labels, flat_predictions)))
print("再現率:{}".format(recall_score(flat_true_labels, flat_predictions)))
print("F1:{}".format(f1_score(flat_true_labels, flat_predictions, average='macro')))
print("MCC:{}".format(matthews_corrcoef(flat_true_labels, flat_predictions)))
実行結果
精度:0.993006993006993
適合率:0.972972972972973
再現率:0.972972972972973
F1:0.9844784543579724
MCC:0.9689569087159449

## まとめ
今回、XLNetを実際に実装し、その性能を確認してみました。
そもそも、今回使用してデータセットは予測が簡単なデータセットでナイーブベイズ分類器ですら0.98程度の精度を出すことができるため、XLNetの性能を確かめるにはもっと予測の難しいデータセットを使用すべきでした。
しかし、Hold outセットにおいて、いずれの評価指標においても極めて良好な結果が得られたことから、非常に協力なモデルであることが確かめられました。
チューニングをしっかりしたり、BERTと比較して、、等を試したかったのですがそれは他のデータセットで試した方が良さそうです。。
今回のコードは、https://mccormickml.com/2019/09/19/XLNet-fine-tuning/
を参考に作成しました。
とても、わかりやすく書かれているので、興味を持った方は参考にされると良いかと思います。

7
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?