27
18

More than 3 years have passed since last update.

PyTorchで日本語BERTと日本語DistilBERTの文章分類の精度比較をしてみた&BERTの精度向上テクニックの紹介

Last updated at Posted at 2020-08-21

はじめに

前回の記事でhuggingface/transformersを使って日本語BERTを使ってみましたが、huggingface/transformersを使えば、他の事前学習済のBERTモデルも簡単に扱えます。

使えるモデルの一覧のうち、日本語のものと思われるモデルは他にもDistilBERTとかALBERTとかがあるようです。どちらも軽量版BERTって位置づけですかね。

今回はhuggingfaceからも使えるバンダイナムコさんが提供しているDistilBERTを簡単に紹介しつつ、通常のBERTとの精度比較を行ってみました。最後にBERTで文章分類をする際の精度を向上させるテクニックの1つも紹介してみます。

DistilBERTとは?

バンダイナムコさんのGithubのREADMEをそのまま拝借いたします。

DistilBERTはHuggingface が NeurIPS 2019 に公開したモデルで、名前は「Distilated-BERT」の略となります。投稿された論文はこちらをご参考ください。

DistilBERTはBERTアーキテクチャをベースにした、小さくて、速くて、軽いTransformerモデルです。DistilBERTは、BERT-baseよりもパラメータが40%少なく、60%高速に動作し、GLUE Benchmarkで測定されたBERTの97%の性能を維持できると言われています。

DistilBERTは、教師と呼ばれる大きなモデルを生徒と呼ばれる小さなモデルに圧縮する技術である知識蒸留を用いて訓練されます。BERTを蒸留することで、元のBERTモデルと多くの類似点を持ちながら、より軽量で実行速度が速いTransformerモデルを得ることができます。

要はBERT-baseの軽量版で高速化を実現したモデル(それ故に精度はほんのちょっとBERT-baseに比べて劣る)って感じでしょうか。

実際にどれくらい高速化なのか、どれくらい精度が少々劣ってしまうのかを実際に使って確認してみます。

DistilBERTの使い方

公式Github通りにhuggingface/transformersから簡単に呼び出せます。

  • tokenizerのほうは以下のように名前にcl-tohoku/をつけないとエラーになりますかね。
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-whole-word-masking")
distil_model = AutoModel.from_pretrained("bandainamco-mirai/distilbert-base-japanese")  

基本的に前回記事で紹介した日本語BERT-baseと同じように使えますが、内部のネットワーク構造の違いから、ファインチューニングは少々変更する必要がありそうです。

とりあえず、DistilBERTの中身の構造を確認してみましょう。

print(distil_model)

長いので閉じておきます。

DistilBERTモデルの構造
DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(32000, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (1): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (2): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (3): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (4): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (5): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
    )
  )
)

BERT-baseとの違いとして、transformerブロックがBERT-baseは12個でしたが、DistilBERTは6個だけになってます。また、中身の層の名前の付け方もBERT-baseと少々異なることが確認できます。

よってファインチューニングをする際は以下のように書けばよいです。分類モデルの宣言例も合わせて記載します。

今回は特に影響ないですが、合わせてDistilBERTのリファレンスも確認しましょう。モデルのreturnも若干異なります。
- https://huggingface.co/transformers/model_doc/distilbert.html

(前回記事同様にlivedoorニュースコーパスのタイトル分類を想定しています。)

モデル宣言

import torch
from torch import nn
import torch.nn.functional as F
from transformers import *

class DistilBertClassifier(nn.Module):
  def __init__(self):
    super(DistilBertClassifier, self).__init__()
    # BERT-baseと違うところはここだけ。
    self.distil_bert = AutoModel.from_pretrained("bandainamco-mirai/distilbert-base-japanese")
    # DistilBERTの隠れ層の次元数は768, livedoorニュースのカテゴリ数が9
    self.linear = nn.Linear(768, 9)
    # 重み初期化処理
    nn.init.normal_(self.linear.weight, std=0.02)
    nn.init.normal_(self.linear.bias, 0)

  def forward(self, input_ids):
    vec, _ = self.distil_bert(input_ids)
    # 先頭トークンclsのベクトルだけ取得
    vec = vec[:,0,:]
    vec = vec.view(-1, 768)
    # 全結合層でクラス分類用に次元を変換
    out = self.linear(vec)
    return F.log_softmax(out)

# 分類モデルのインスタンス
distil_classifier = DistilBertClassifier()

ファインチューニング

# まずは全部OFF
for param in distil_classifier.parameters():
    param.requires_grad = False

# DistilBERTの最後の層だけ更新ON
# BERT-baseは .encoder.layer[-1]でしたが、
# DistilBERTの場合は、上で構造を確認したように以下のように .transfomer.layer[-1]となります。
for param in distil_classifier.distil_bert.transformer.layer[-1].parameters():
    param.requires_grad = True

# クラス分類のところもON
for param in distil_classifier.linear.parameters():
    param.requires_grad = True

import torch.optim as optim

# 事前学習済の箇所は学習率小さめ、最後の全結合層は大きめにする。
# こちらも忘れずにDistilBERT用に変更
optimizer = optim.Adam([
    {'params': distil_classifier.distil_bert.transformer.layer[-1].parameters(), 'lr': 5e-5},
    {'params': distil_classifier.linear.parameters(), 'lr': 1e-4}
])

BERT-baseとDistilBERTの比較

前回と同様にlivedoorニュースコーパスのタイトル分類のタスクを扱います。

BERT-base

  • 以下のソースコードはほとんど前回と同様です。

モデル定義&ファインチューニング

class BertClassifier(nn.Module):
  def __init__(self):
    super(BertClassifier, self).__init__()
    self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
    # BERTの隠れ層の次元数は768, livedoorニュースのカテゴリ数が9
    self.linear = nn.Linear(768, 9)
    # 重み初期化処理
    nn.init.normal_(self.linear.weight, std=0.02)
    nn.init.normal_(self.linear.bias, 0)

  def forward(self, input_ids):
    # last_hiddenのみ受け取る
    vec, _ = self.bert(input_ids)
    # 先頭トークンclsのベクトルだけ取得
    vec = vec[:,0,:]
    vec = vec.view(-1, 768)
    # 全結合層でクラス分類用に次元を変換
    out = self.linear(vec)
    return F.log_softmax(out)

# 分類モデルのインスタンス宣言
bert_classifier = BertClassifier()

# ファインチューニングの設定
# まずは全部OFF
for param in bert_classifier.parameters():
    param.requires_grad = False

# BERTの最後の層だけ更新ON
for param in bert_classifier.bert.encoder.layer[-1].parameters():
    param.requires_grad = True

# クラス分類のところもON
for param in bert_classifier.linear.parameters():
    param.requires_grad = True

import torch.optim as optim

# 事前学習済の箇所は学習率小さめ、最後の全結合層は大きめにする。
optimizer = optim.Adam([
    {'params': bert_classifier.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
    {'params': bert_classifier.linear.parameters(), 'lr': 1e-4}
])

# 損失関数の設定
loss_function = nn.NLLLoss()

学習&推論

# 学習時間測ります。
import time

start = time.time()
# GPUの設定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# ネットワークをGPUへ送る
bert_classifier.to(device)
losses = []

# エポック数は10で
for epoch in range(10):
  all_loss = 0
  for idx, batch in enumerate(train_iter):
    batch_loss = 0
    bert_classifier.zero_grad()
    input_ids = batch.Text[0].to(device)
    label_ids = batch.Label.to(device)
    out = bert_classifier(input_ids)
    batch_loss = loss_function(out, label_ids)
    batch_loss.backward()
    optimizer.step()
    all_loss += batch_loss.item()
  print("epoch", epoch, "\t" , "loss", all_loss)

end = time.time()

print ("time : ", end - start)
#epoch 0     loss 251.19750046730042
#epoch 1     loss 110.7038831859827
#epoch 2     loss 82.88570280373096
#epoch 3     loss 67.0771074667573
#epoch 4     loss 56.24497305601835
#epoch 5     loss 42.61423560976982
#epoch 6     loss 35.98485875874758
#epoch 7     loss 25.728398952633142
#epoch 8     loss 20.40780107676983
#epoch 9     loss 16.567239843308926
#time :  101.97362518310547

# 推論
answer = []
prediction = []
with torch.no_grad():
    for batch in test_iter:

        text_tensor = batch.Text[0].to(device)
        label_tensor = batch.Label.to(device)

        score = bert_classifier(text_tensor)
        _, pred = torch.max(score, 1)

        prediction += list(pred.cpu().numpy())
        answer += list(label_tensor.cpu().numpy())
print(classification_report(prediction, answer, target_names=categories))
#                precision    recall  f1-score   support

# kaden-channel       0.94      0.92      0.93       172
#dokujo-tsushin       0.75      0.86      0.80       156
#        peachy       0.81      0.68      0.74       211
#   movie-enter       0.78      0.81      0.80       171
#          smax       0.98      0.91      0.94       176
#livedoor-homme       0.68      0.83      0.75        83
#  it-life-hack       0.79      0.94      0.86       150
#    topic-news       0.81      0.76      0.78       172
#  sports-watch       0.89      0.82      0.85       185

#      accuracy                           0.83      1476
#     macro avg       0.83      0.84      0.83      1476
#  weighted avg       0.84      0.83      0.83      1476

10エポックの学習時間は約102秒、精度は0.83(Fスコア)となりました。

DistilBERT

学習&推論

  • 上記で定義したモデルとファインチューニングの設定を元に以下のように学習&推論を実施
  • とはいってもBERT-baseと何も変わりませんが念の為...
import time

# GPUの設定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# ネットワークをGPUへ送る
distil_classifier.to(device)
losses = []

start = time.time()
# エポック数は10で
for epoch in range(10):
  all_loss = 0
  for idx, batch in enumerate(train_iter):
    batch_loss = 0
    distil_classifier.zero_grad()
    input_ids = batch.Text[0].to(device)
    label_ids = batch.Label.to(device)
    out = distil_classifier(input_ids)
    batch_loss = loss_function(out, label_ids)
    batch_loss.backward()
    optimizer.step()
    all_loss += batch_loss.item()
  print("epoch", epoch, "\t" , "loss", all_loss)

end = time.time()
print ("time : ", end - start)
#/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:26: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
#epoch 0     loss 450.1027842760086
#epoch 1     loss 317.39041769504547
#epoch 2     loss 211.34138756990433
#epoch 3     loss 144.4813650548458
#epoch 4     loss 106.24609130620956
#epoch 5     loss 83.87273170053959
#epoch 6     loss 68.9661111086607
#epoch 7     loss 59.31868125498295
#epoch 8     loss 49.874382212758064
#epoch 9     loss 41.56027300283313
#time :  60.22182369232178


from sklearn.metrics import classification_report

answer = []
prediction = []
with torch.no_grad():
    for batch in test_iter:

        text_tensor = batch.Text[0].to(device)
        label_tensor = batch.Label.to(device)

        score = distil_classifier(text_tensor)
        _, pred = torch.max(score, 1)

        prediction += list(pred.cpu().numpy())
        answer += list(label_tensor.cpu().numpy())
print(classification_report(prediction, answer, target_names=categories))

#/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:26: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
#                precision    recall  f1-score   support

# kaden-channel       0.93      0.96      0.95       163
#dokujo-tsushin       0.88      0.88      0.88       178
#        peachy       0.86      0.75      0.80       202
#   movie-enter       0.86      0.84      0.85       183
#          smax       0.96      0.95      0.95       165
#livedoor-homme       0.67      0.71      0.69        96
#  it-life-hack       0.91      0.91      0.91       178
#    topic-news       0.80      0.86      0.83       148
#  sports-watch       0.88      0.91      0.89       163

#      accuracy                           0.87      1476
#     macro avg       0.86      0.86      0.86      1476
#  weighted avg       0.87      0.87      0.87      1476
  • 10エポックの学習時間は約60秒、精度は0.87(Fスコア)となりました。
  • 学習時間が速くなるのはいいんだけど、精度まで上がってしまった
  • 本来であればBERT-baseより精度がやや落ちる想定でしたが、上がる場合もあるようで。
  • なんとなく、いつも実験として試すlivedoorニュースコーパスのタイトル分類というタスクがあまり良くないのかも...

BERT-baseの精度を向上させてみる

ここからはDistilBERTとの比較ではなくて、日本語BERTの文章分類をする際の精度を向上させるテクニックを1つ紹介します。

(本来であれば、まずはタスクに応じた前処理等をしっかりと検討すべきですが、タスクにあまり依存しない精度向上テクニックと思われるので、この場で紹介させていただきます。)

テクニックとは言ってもBERTの論文の5.3 Feature-based Approach with BERTでも紹介されていますし、以前kaggleで行われたNLPコンペJigsaw Unintended Bias in Toxicity Classificationの1stの手法でもあるようです。

テクニックの内容は詳しくは以下の記事をご参照ください。

要はBERT-baseの12層あるEncoder層のうち、最終層のCLSトークンのベクトルのみ使うより、最終4層のCLSトークンのベクトルを結合して使ったほうが文章のベクトルをよりよく作れるってことらしいです。(なぜかは知りません...)

アイディアはとても単純なので、今回のlivedoorニュースコーパスのタイトル分類タスクでもやってみようと思います。

実装

モデル宣言

class BertClassifierRevised(nn.Module):
  def __init__(self):
    super(BertClassifierRevised, self).__init__()
    self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
    # BERTの隠れ層の次元数は768だが、最終4層分のベクトルを結合したものを扱うので、768×4次元としている。
    self.linear = nn.Linear(768*4, 9)
    # 重み初期化処理
    nn.init.normal_(self.linear.weight, std=0.02)
    nn.init.normal_(self.linear.bias, 0)

  # clsトークンのベクトルを取得する用の関数を用意
  def _get_cls_vec(self, vec):
    return vec[:,0,:].view(-1, 768)

  def forward(self, input_ids):
    # 第1戻り値のlast_hidden_stateでは最終層のみしか取得できないので、
    # output_hidden_states=Trueを宣言し、全ての隠れ層ベクトルを取得できるようにし、
    # 第3戻り値(全部の隠れ層の状態)を取得する。
    _, _,  hidden_states = self.bert(input_ids, output_hidden_states=True)

    # 最終4層の隠れ層からそれぞれclsトークンのベクトルを取得する
    vec1 = self._get_cls_vec(hidden_states[-1])
    vec2 = self._get_cls_vec(hidden_states[-2])
    vec3 = self._get_cls_vec(hidden_states[-3])
    vec4 = self._get_cls_vec(hidden_states[-4])

    # 4つのclsトークンを結合して1つのベクトルにする。
    vec = torch.cat([vec1, vec2, vec3, vec4], dim=1)

    # 全結合層でクラス分類用に次元を変換
    out = self.linear(vec)
    return F.log_softmax(out)

# インスタンス宣言
bert_classifier_revised = BertClassifierRevised()

ファインチューニング

# まずは全部OFF
for param in bert_classifier_revised.parameters():
    param.requires_grad = False

# BERTの最終4層分をON
for param in bert_classifier_revised.bert.encoder.layer[-1].parameters():
    param.requires_grad = True

for param in bert_classifier_revised.bert.encoder.layer[-2].parameters():
    param.requires_grad = True

for param in bert_classifier_revised.bert.encoder.layer[-3].parameters():
    param.requires_grad = True

for param in bert_classifier_revised.bert.encoder.layer[-4].parameters():
    param.requires_grad = True

# クラス分類のところもON
for param in bert_classifier_revised.linear.parameters():
    param.requires_grad = True

import torch.optim as optim

# 事前学習済の箇所は学習率小さめ、最後の全結合層は大きめにする。
optimizer = optim.Adam([
    {'params': bert_classifier_revised.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
    {'params': bert_classifier_revised.bert.encoder.layer[-2].parameters(), 'lr': 5e-5},
    {'params': bert_classifier_revised.bert.encoder.layer[-3].parameters(), 'lr': 5e-5},
    {'params': bert_classifier_revised.bert.encoder.layer[-4].parameters(), 'lr': 5e-5},
    {'params': bert_classifier_revised.linear.parameters(), 'lr': 1e-4}
])

# 損失関数の設定
loss_function = nn.NLLLoss()

学習&推論


import time

start = time.time()
# GPUの設定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# ネットワークをGPUへ送る
bert_classifier_revised.to(device)
losses = []

# エポック数は5で
for epoch in range(10):
  all_loss = 0
  for idx, batch in enumerate(train_iter):
    batch_loss = 0
    bert_classifier_revised.zero_grad()
    input_ids = batch.Text[0].to(device)
    label_ids = batch.Label.to(device)
    out = bert_classifier_revised(input_ids)
    batch_loss = loss_function(out, label_ids)
    batch_loss.backward()
    optimizer.step()
    all_loss += batch_loss.item()
  print("epoch", epoch, "\t" , "loss", all_loss)

end = time.time()

print ("time : ", end - start)
#epoch 0     loss 196.0047192275524
#epoch 1     loss 75.8067753687501
#epoch 2     loss 42.30751228891313
#epoch 3     loss 16.470114511903375
#epoch 4     loss 7.427484432584606
#epoch 5     loss 2.9392087209271267
#epoch 6     loss 1.5984382012393326
#epoch 7     loss 1.7370687873335555
#epoch 8     loss 0.9278695838729618
#epoch 9     loss 1.499190401067608
#time :  149.01919651031494

# 推論
answer = []
prediction = []
with torch.no_grad():
    for batch in test_iter:

        text_tensor = batch.Text[0].to(device)
        label_tensor = batch.Label.to(device)

        score = bert_classifier_revised(text_tensor)
        _, pred = torch.max(score, 1)

        prediction += list(pred.cpu().numpy())
        answer += list(label_tensor.cpu().numpy())
print(classification_report(prediction, answer, target_names=categories))
#                precision    recall  f1-score   support

# kaden-channel       0.80      0.99      0.89       137
#dokujo-tsushin       0.89      0.86      0.88       183
#        peachy       0.78      0.82      0.80       168
#   movie-enter       0.87      0.88      0.87       176
#          smax       0.95      0.93      0.94       168
#livedoor-homme       0.72      0.83      0.77        88
#  it-life-hack       0.95      0.79      0.86       215
#    topic-news       0.83      0.84      0.83       159
#  sports-watch       0.92      0.86      0.89       182

#      accuracy                           0.86      1476
#     macro avg       0.86      0.87      0.86      1476
#  weighted avg       0.87      0.86      0.86      1476
  • BERT-baseと比べてlossの減りがはやい
  • 学習時間は約150秒とBERT-baseに比べてやや時間はかかりました
  • 精度はBERT-baseの0.83から比べて0.86と向上しています。素晴らしい。

おわりに

  • BERT-baseとDistilBERTの比較を行ってみました。結果としては速度も精度もDistilBERTのほうが良いという結果になってしまいましたが、DistilBERTの使い方がちょっとわかった気がします。
  • 後半でBERTの精度向上案として、最終4層のclsトークンを考慮する案を紹介しました。BERT-base に比べて確実に精度向上に寄与しているようです。これからはBERTで分類モデルを実装する際はとりあえず最終4層使ってみようかな。

おわり

27
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
27
18