Help us understand the problem. What is going on with this article?

Keras BERTでファインチューニングしてみる

Keras BERTでファインチューニングしてみる

TL;DR

SentencePiece + 日本語WikipediaのBERTモデルをKeras BERTで利用するにおいて、Keras BERTを利用して日本語データセットの分類問題を扱って見ましたが、今回はファインチューニングを行ってみました。

BERTのモデルやベンチマーク用のデータなどはSentencePiece + 日本語WikipediaのBERTモデルをKeras BERTで利用すると同様です。

Keras BERTでファインチューニングする際のポイント

Keras BERTのGitHubにデモとして公開されているkeras_bert_classification_tpu.ipynbを参考にしました。

ポイントは以下のとおりです。私が試した範囲では、以下の両方を適切に設定しないと、Lossが収束しませんでした。

  • bert_config.jsonmax_position_embeddingmax_seq_lengthをデータセットの最大トークン数にする
  • optimizerAdamWarmupにする

bert_config.json

対象データセットであるKNBCを事前に調査し、BERTモデルで使用しているSetenencePieceでtokenizeした後の最大トークン数が101であることを調べておきます。

これにBERT用の区切り文字である[CLS][SEP]を足した103が最大トークン数になります。

{
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 103,
  "max_seq_length": 103,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 32000
}

AdamWarmup

スケジューリングされた学習をするためのAdamのバリエーションだと思いますが、詳細は調べていません。
基本的にkeras_bert_classification_tpu.ipynbのままです。

以下の様にcalc_train_stepsで減衰率などを事前に計算する必要があります。

    decay_steps, warmup_steps = calc_train_steps(
        input_shape[0],
        batch_size=BATCH_SIZE,
        epochs=EPOCH,
    )
    # [省略]
    model.compile(loss='categorical_crossentropy',
                  optimizer=AdamWarmup(decay_steps=decay_steps, warmup_steps=warmup_steps, lr=LR),
                  #optimizer='nadam',
                  metrics=['mae', 'mse', 'acc'])

ソースコード

BERTのロード

import sys
sys.path.append('modules')
from keras_bert import load_trained_model_from_checkpoint

config_path = 'bert-wiki-ja/bert_config.json'
# `model.ckpt-1400000` のように拡張子を付けないのがポイントです。
checkpoint_path = 'bert-wiki-ja/model.ckpt-1400000'
SEQ_LEN = 103 # 先にデータセットの最大のトークン数を調べています。
BATCH_SIZE = 16
BERT_DIM = 768
LR = 1e-4
EPOCH = 10

# ファインチューニング用にtraining,trainableをTrueに設定し、最大トークン数をseq_lenに設定します。
bert = load_trained_model_from_checkpoint(config_path, checkpoint_path, training=True,  trainable=True, seq_len=SEQ_LEN)
bert.summary()

データロード用関数

import pandas as pd
import sentencepiece as spm
from keras import utils
from keras.preprocessing.sequence import pad_sequences
import logging
import numpy as np

maxlen = SEQ_LEN

sp = spm.SentencePieceProcessor()
sp.Load('bert-wiki-ja/wiki-ja.model')

def _get_indice(feature):
    indices = np.zeros((maxlen), dtype = np.int32)

    tokens = []
    tokens.append('[CLS]')
    tokens.extend(sp.encode_as_pieces(feature))
    tokens.append('[SEP]')

    for t, token in enumerate(tokens):
        if t >= maxlen:
            break
        try:
            indices[t] = sp.piece_to_id(token)
        except:
            logging.warn(f'{token} is unknown.')
            indices[t] = sp.piece_to_id('<unk>')

    return indices

def _load_labeldata(train_dir, test_dir):
    train_features_df = pd.read_csv(f'{train_dir}/features.csv')
    train_labels_df = pd.read_csv(f'{train_dir}/labels.csv')
    test_features_df = pd.read_csv(f'{test_dir}/features.csv')
    test_labels_df = pd.read_csv(f'{test_dir}/labels.csv')
    label2index = {k: i for i, k in enumerate(train_labels_df['label'].unique())}
    index2label = {i: k for i, k in enumerate(train_labels_df['label'].unique())}
    class_count = len(label2index)
    train_labels = utils.np_utils.to_categorical([label2index[label] for label in train_labels_df['label']], num_classes=class_count)
    test_label_indices = [label2index[label] for label in test_labels_df['label']]
    test_labels = utils.np_utils.to_categorical(test_label_indices, num_classes=class_count)

    train_features = []
    test_features = []

    for feature in train_features_df['feature']:
        train_features.append(_get_indice(feature))
    train_segments = np.zeros((len(train_features), maxlen), dtype = np.float32)
    for feature in test_features_df['feature']:
        test_features.append(_get_indice(feature))
    test_segments = np.zeros((len(test_features), maxlen), dtype = np.float32)

    print(f'Trainデータ数: {len(train_features_df)}, Testデータ数: {len(test_features_df)}, ラベル数: {class_count}')

    return {
        'class_count': class_count,
        'label2index': label2index,
        'index2label': index2label,
        'train_labels': train_labels,
        'test_labels': test_labels,
        'test_label_indices': test_label_indices,
        'train_features': np.array(train_features),
        'train_segments': np.array(train_segments),
        'test_features': np.array(test_features),
        'test_segments': np.array(test_segments),
        'input_len': maxlen
    }

モデル準備関数

from keras.layers import Dense, Dropout, LSTM, Bidirectional, Flatten, GlobalMaxPooling1D
from keras_bert.layers import MaskedGlobalMaxPool1D
from keras import Input, Model
from keras_bert import AdamWarmup, calc_train_steps

def _create_model(input_shape, class_count):
    decay_steps, warmup_steps = calc_train_steps(
        input_shape[0],
        batch_size=BATCH_SIZE,
        epochs=EPOCH,
    )

    bert_last = bert.get_layer(name='NSP-Dense').output
    x1 = bert_last
    output_tensor = Dense(class_count, activation='softmax')(x1)
    # Trainableの場合は、Input Masked Layerが3番目の入力なりますが、
    # FineTuning時には必要無いので1, 2番目の入力だけ使用します。
    # Trainableでなければkeras-bertのModel.inputそのままで問題ありません。
    model = Model([bert.input[0], bert.input[1]], output_tensor)
    model.compile(loss='categorical_crossentropy',
                  optimizer=AdamWarmup(decay_steps=decay_steps, warmup_steps=warmup_steps, lr=LR),
                  #optimizer='nadam',
                  metrics=['mae', 'mse', 'acc'])

    return model

データのロードとモデルの準備

from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard

trains_dir = '../word-or-character/data/trains'
tests_dir = '../word-or-character/data/tests'

data = _load_labeldata(trains_dir, tests_dir)
model_filename = 'models/knbc-check-bert_v3.model'
model = _create_model(data['train_features'].shape, data['class_count'])

model.summary()

学習の実行

history = model.fit([data['train_features'], data['train_segments']],
          data['train_labels'],
          epochs = EPOCH,
          batch_size = BATCH_SIZE,
          validation_data=([data['test_features'], data['test_segments']], data['test_labels']),
          shuffle=False,
          verbose = 1,
          callbacks = [
              #EarlyStopping(patience=5, monitor='val_acc', mode='max'),
              ModelCheckpoint(monitor='val_acc', mode='max', filepath=model_filename, save_best_only=True)
          ])
val_loss val_mean_absolute_error val_mean_squared_error val_acc loss mean_absolute_error mean_squared_error acc
0 0.561248 0.131881 0.071544 0.792363 0.782363 0.202023 0.102747 0.679055
1 0.739699 0.115426 0.085423 0.782816 0.379809 0.096035 0.049494 0.856384
2 1.062872 0.125169 0.098740 0.761337 0.192548 0.046621 0.024535 0.934165
3 1.176542 0.119584 0.099672 0.763723 0.086767 0.022215 0.011102 0.971861
4 0.921495 0.113555 0.088224 0.778043 0.061029 0.013554 0.007306 0.980887
5 1.043859 0.104909 0.087325 0.794749 0.023458 0.006163 0.002836 0.991505
6 1.082306 0.102066 0.085363 0.804296 0.022063 0.005459 0.002826 0.992567
7 1.048261 0.098721 0.083447 0.809069 0.015968 0.003016 0.001616 0.995753
8 1.027535 0.096534 0.081463 0.816229 0.007195 0.001710 0.000927 0.996814
9 1.040992 0.096410 0.081862 0.811456 0.003626 0.001289 0.000520 0.997876

クラシフィケーションレポート

from sklearn.metrics import classification_report, confusion_matrix
from keras.models import load_model
from keras_bert import get_custom_objects

model = load_model(model_filename, custom_objects=get_custom_objects())

predicted_test_labels = model.predict([data['test_features'], data['test_segments']]).argmax(axis=1)
numeric_test_labels = np.array(data['test_labels']).argmax(axis=1)

report_filename = 'models/knbc-check-bert_v3.txt'

with open(report_filename, 'w', encoding='utf-8') as f:
    print(classification_report(numeric_test_labels, predicted_test_labels, target_names = ['グルメ', '携帯電話', '京都', 'スポーツ']), file=f)
    print(classification_report(numeric_test_labels, predicted_test_labels, target_names = ['グルメ', '携帯電話', '京都', 'スポーツ']))
              precision    recall  f1-score   support

         グルメ       0.80      0.80      0.80       137
        携帯電話       0.85      0.77      0.81       145
          京都       0.74      0.85      0.79        47
        スポーツ       0.84      0.90      0.87        90

   micro avg       0.82      0.82      0.82       419
   macro avg       0.81      0.83      0.82       419
weighted avg       0.82      0.82      0.82       419

まとめ

F値で82と綺麗に過去最高性能でした。
FineTuning前後で1%の差ですが、元々が高い精度ですので、有意な差がでることは素晴らしいと思います。

  • Wikipediaja with BERT/Fine Tuned(Weighted Avg F1): 0.82
  • Wikipediaja with BERT(Weighted Avg F1): 0.81 1
  • Wikipediaja(Weighted Avg F1): 0.77 23
  • Wikipediaja+現代日本語書き言葉均衡コーパス(Weighted Avg F1): 0.79 23

BERTの構造

今回のモデルの構造です。正確にはBERT+分類用のDenseの構造になっています。

from keras.utils import plot_model

plot_model(model, to_file='train-bert.png', show_shapes=True)

参考文献

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away