55
55

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

SentencePiece + 日本語WikipediaのBERTモデルをKeras BERTで利用する

Last updated at Posted at 2019-03-11

SentencePiece + 日本語WikipediaのBERTモデルをKeras BERTで利用する

TL;DR

Googleが公開しているBERTの学習済みモデルは、日本語Wikipediaもデータセットに含まれていますが、Tokenizeの方法が分かち書きを前提としているため、そのまま利用しても日本語の分類問題ではあまり高い精度を得ることができません。

このため、SentencePieceでTokenizeしたデータセットで学習し直す必要があります。

BERTのトレーニングは結構な時間やマシンリソースが必要ですが、ありがたいことにSentencePiece+日本語Wikipediaで学習済みのモデルを配布してくれている方がいらっしゃるので、今回は以下を利用します。

BERTには分類問題用のスクリプトが付属していますが、今回はKeras BERTからBERTを利用します。

例の如くKNBC(後述のベンチマーク用データを参照)を利用して自然言語分類問題を対象としています。

ベンチマーク用データ

京都大学情報学研究科--NTTコミュニケーション科学基礎研究所 共同研究ユニットが提供するブログの記事に関するデータセットを利用しました。 このデータセットでは、ブログの記事に対して以下の4つの分類がされています。

  • グルメ
  • 携帯電話
  • 京都
  • スポーツ

Keras BERTで独自の学習済みモデルを使用するための準備

Keras BERTではGoogleが配布している学習済みモデルを利用する、もしくは自分でBERTモデルをトレーニングするのであれば特に特別な準備は必要ありませんが、今回は独自(Keras BERTに付属していない)の学習済みモデルを利用するためモデルのダウンロードやBERT用の設定ファイルを準備します。

SentencePiece + 日本語WikipediaのBERTモデルをダウンロードする

BERT with SentencePiece を日本語 Wikipedia で学習してモデルを公開しましたより、SentencePieceおよびBERT用の学習済みモデルをダウンロードします。
それぞれ以下のファイルをダウンロードします。

SentencePiece用のファイル

  • wiki-ja.vocab
  • wiki-ja.model

BERT用のファイル

  • model.ckpt-1400000.data-00000-of-00001
  • model.ckpt-1400000.index
  • model.ckpt-1400000.meta

bert_config.jsonを作成する

BERT用の各種パラメータはbert_config.jsonに記載します。Keras BERTではbert_config.jsonをConfigとして指定します。
記載する内容は、以下の通りです(bert-japaneseconfig.iniよりBERT関連を抜粋しています)。

{
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 32000
}

ソースコード

BERTのロード

import sys
sys.path.append('modules')
from keras_bert import load_trained_model_from_checkpoint

config_path = 'bert-wiki-ja/bert_config.json'
# `model.ckpt-1400000` のように拡張子を付けないのがポイントです。
checkpoint_path = 'bert-wiki-ja/model.ckpt-1400000'

bert = load_trained_model_from_checkpoint(config_path, checkpoint_path)
bert.summary()

データロード用関数

BERTで特徴量を抽出しています。

import pandas as pd
import sentencepiece as spm
from keras import utils
import logging
import numpy as np

maxlen = 512
bert_dim = 768

sp = spm.SentencePieceProcessor()
sp.Load('bert-wiki-ja/wiki-ja.model')

def _get_vector(feature):
    common_seg_input = np.zeros((1, maxlen), dtype = np.float32)
    indices = np.zeros((1, maxlen), dtype = np.float32)

    tokens = []
    tokens.append('[CLS]')
    tokens.extend(sp.encode_as_pieces(feature))
    tokens.append('[SEP]')

    for t, token in enumerate(tokens):
        try:
            indices[0, t] = sp.piece_to_id(token)
        except:
            logging.warn(f'{token} is unknown.')
            indices[0, t] = sp.piece_to_id('<unk>')
    vector =  bert.predict([indices, common_seg_input])[0]

    return vector

def _load_labeldata(train_dir, test_dir):
    train_features_df = pd.read_csv(f'{train_dir}/features.csv')
    train_labels_df = pd.read_csv(f'{train_dir}/labels.csv')
    test_features_df = pd.read_csv(f'{test_dir}/features.csv')
    test_labels_df = pd.read_csv(f'{test_dir}/labels.csv')
    label2index = {k: i for i, k in enumerate(train_labels_df['label'].unique())}
    index2label = {i: k for i, k in enumerate(train_labels_df['label'].unique())}
    class_count = len(label2index)
    train_labels = utils.np_utils.to_categorical([label2index[label] for label in train_labels_df['label']], num_classes=class_count)
    test_label_indices = [label2index[label] for label in test_labels_df['label']]
    test_labels = utils.np_utils.to_categorical(test_label_indices, num_classes=class_count)

    train_features = []
    test_features = []

    for feature in train_features_df['feature']:
        train_features.append(_get_vector(feature))
    for feature in test_features_df['feature']:
        test_features.append(_get_vector(feature))

    print(f'Trainデータ数: {len(train_features_df)}, Testデータ数: {len(test_features_df)}, ラベル数: {class_count}')

    return {
        'class_count': class_count,
        'label2index': label2index,
        'index2label': index2label,
        'train_labels': train_labels,
        'test_labels': test_labels,
        'test_label_indices': test_label_indices,
        'train_features': np.array(train_features),
        'test_features': np.array(test_features),
        'input_len': maxlen
    }

モデル準備関数

BERTで得た特徴量を入力としたBi-LSTMによるモデルです。

from keras.layers import Dense, Dropout, LSTM, Bidirectional
from keras import Input, Model

def _create_model(train_features):
    class_count = 4

    input_tensor = Input(train_features[0].shape)
    x1 = Bidirectional(LSTM(356))(input_tensor)
    output_tensor = Dense(class_count, activation='softmax')(x1)

    model = Model(input_tensor, output_tensor)
    model.compile(loss='categorical_crossentropy', optimizer='nadam', metrics=['mae', 'mse', 'acc'])

    return model

データのロードとモデルの準備

前述のようにデータロード時にBERTにより特徴量を抽出しています。
このため、実行には少々(GPUが無い場合はかなり?)時間がかかります。

trains_dir = '../word-or-character/data/trains'
tests_dir = '../word-or-character/data/tests'

data = _load_labeldata(trains_dir, tests_dir)
Trainデータ数: 3767, Testデータ数: 419, ラベル数: 4
data['train_features'][:5]
array([[[ 0.35466704,  0.8344387 , -0.44097826, ...,  0.2786668 ,
         -0.11497068,  0.29313257],
        [ 0.15172665,  0.02186944,  0.5719023 , ..., -0.34355962,
         -0.76192784, -0.37710115],
        [ 0.58773553, -0.5266533 ,  1.0297949 , ...,  0.2826115 ,
          0.01596622, -0.11598644],
        ...,
        [-0.42831585,  0.16217883,  0.24022198, ...,  0.35656708,
         -0.4026342 ,  0.90145355],
        [-0.4352011 ,  0.3489631 ,  0.3045481 , ...,  0.21556729,
         -0.18176533,  0.609959  ],
        [-0.36450043,  0.02214984,  0.2564181 , ...,  0.35734403,
         -0.01305788,  0.8493353 ]],

       [[ 0.31687105,  1.2244819 ,  0.36265972, ...,  0.2956537 ,
         -0.7329245 ,  0.3783138 ],
        [ 0.1442154 ,  0.169483  ,  0.4410671 , ..., -0.49457642,
         -0.855622  , -0.09040152],
        [ 0.6914083 ,  0.611323  ,  0.14708832, ..., -0.12922424,
         -0.19439547,  0.0302214 ],
        ...,
        [-0.38953468,  0.83356047,  0.22081861, ...,  0.21940875,
         -0.71583635,  0.9292939 ],
        [-0.37721068,  1.1795326 ,  0.29908165, ...,  0.18421009,
         -0.48757967,  0.8818958 ],
        [-0.20600995,  1.0697701 ,  0.21132138, ...,  0.60967827,
         -0.58680713,  0.75882226]],

       [[-0.10706001,  0.3738867 , -0.44477016, ...,  0.20106766,
          0.43196645, -0.40388373],
        [ 0.10844216, -0.16230595,  0.81597984, ..., -0.4139093 ,
         -0.94628114, -0.29458356],
        [ 1.0240474 , -0.6265489 ,  0.11975094, ..., -0.04260744,
          0.40971422, -0.1690548 ],
        ...,
        [-0.44445795, -0.18297565, -0.81684655, ..., -0.01337491,
          0.53387684,  0.8393724 ],
        [-0.5090573 , -0.3397976 , -0.86311024, ..., -0.15371192,
          0.33833283,  0.7278649 ],
        [-0.31896925, -0.41911578, -0.7302225 , ..., -0.19937491,
          0.30634403,  0.7484876 ]],

       [[-0.08156536,  1.3763438 , -0.4511324 , ..., -0.2718555 ,
          0.01139554,  0.3233358 ],
        [ 0.29251447,  0.016281  ,  0.3287114 , ..., -0.04428148,
         -0.527779  ,  0.12028977],
        [ 0.61508095, -0.10685639, -0.33700344, ...,  0.58882135,
          0.5536139 ,  0.06634241],
        ...,
        [-0.1810619 , -0.22033806, -0.37033078, ...,  0.03473432,
         -0.45422173,  1.4944264 ],
        [-0.40457633, -0.02369374, -0.30867767, ..., -0.07707401,
         -0.39675236,  1.509865  ],
        [-0.49272478, -0.3240378 ,  0.0129818 , ..., -0.05501541,
         -0.13940048,  1.6560441 ]],

       [[ 0.33917487, -0.45505336, -0.12048204, ..., -0.15798426,
         -0.31510183,  0.08813866],
        [ 0.01879137,  0.3558463 ,  0.74166125, ..., -1.1869664 ,
         -0.48231342, -0.57868123],
        [-0.04152466, -0.14042495,  0.47751656, ..., -0.8260575 ,
         -0.12649226,  0.49701825],
        ...,
        [ 0.16141744,  0.22437172, -0.06188362, ..., -0.3822586 ,
         -0.02379168,  0.22798373],
        [ 0.35819334,  0.11869927, -0.2132361 , ..., -0.29555017,
          0.06280891,  0.37545964],
        [ 0.20598263,  0.09404632, -0.13750747, ..., -0.29785928,
          0.10465414,  0.35825673]]], dtype=float32)

学習の実行

from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard

model_filename = 'models/knbc-check-bert_v2.model'

model = _create_model(data['train_features'])
model.summary()
history = model.fit(data['train_features'],
          data['train_labels'],
          epochs = 100,
          batch_size = 128,
          validation_data=(data['test_features'], data['test_labels']),
          shuffle=False,
          verbose = 1,
          callbacks = [
              EarlyStopping(patience=5, monitor='val_acc', mode='max'),
              ModelCheckpoint(monitor='val_acc', mode='max', filepath=model_filename, save_best_only=True)
          ])
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 512, 768)          0
_________________________________________________________________
bidirectional_1 (Bidirection (None, 712)               3204000
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 2852
=================================================================
Total params: 3,206,852
Trainable params: 3,206,852
Non-trainable params: 0
_________________________________________________________________
Train on 3767 samples, validate on 419 samples
Epoch 1/100
3767/3767 [==============================] - 55s 15ms/step - loss: 1.1582 - mean_absolute_error: 0.2398 - mean_squared_error: 0.1260 - acc: 0.6294 - val_loss: 0.9369 - val_mean_absolute_error: 0.2007 - val_mean_squared_error: 0.1196 - val_acc: 0.6730
Epoch 2/100
3767/3767 [==============================] - 55s 15ms/step - loss: 0.5901 - mean_absolute_error: 0.1593 - mean_squared_error: 0.0782 - acc: 0.7648 - val_loss: 0.6725 - val_mean_absolute_error: 0.1509 - val_mean_squared_error: 0.0847 - val_acc: 0.7375
Epoch 3/100
3767/3767 [==============================] - 53s 14ms/step - loss: 0.5058 - mean_absolute_error: 0.1361 - mean_squared_error: 0.0674 - acc: 0.8044 - val_loss: 0.7014 - val_mean_absolute_error: 0.1513 - val_mean_squared_error: 0.0862 - val_acc: 0.7399
Epoch 4/100
3767/3767 [==============================] - 54s 14ms/step - loss: 0.3879 - mean_absolute_error: 0.1101 - mean_squared_error: 0.0522 - acc: 0.8535 - val_loss: 0.6635 - val_mean_absolute_error: 0.1347 - val_mean_squared_error: 0.0784 - val_acc: 0.7709
Epoch 5/100
3767/3767 [==============================] - 53s 14ms/step - loss: 0.2774 - mean_absolute_error: 0.0811 - mean_squared_error: 0.0368 - acc: 0.8967 - val_loss: 0.7371 - val_mean_absolute_error: 0.1338 - val_mean_squared_error: 0.0844 - val_acc: 0.7613
Epoch 6/100
3767/3767 [==============================] - 56s 15ms/step - loss: 0.2810 - mean_absolute_error: 0.0779 - mean_squared_error: 0.0364 - acc: 0.9055 - val_loss: 0.7168 - val_mean_absolute_error: 0.1222 - val_mean_squared_error: 0.0801 - val_acc: 0.7852
Epoch 7/100
3767/3767 [==============================] - 57s 15ms/step - loss: 0.1383 - mean_absolute_error: 0.0403 - mean_squared_error: 0.0164 - acc: 0.9562 - val_loss: 0.7089 - val_mean_absolute_error: 0.1261 - val_mean_squared_error: 0.0784 - val_acc: 0.7828
Epoch 8/100
3767/3767 [==============================] - 57s 15ms/step - loss: 0.1013 - mean_absolute_error: 0.0324 - mean_squared_error: 0.0123 - acc: 0.9700 - val_loss: 0.8280 - val_mean_absolute_error: 0.1213 - val_mean_squared_error: 0.0830 - val_acc: 0.7876
Epoch 9/100
3767/3767 [==============================] - 55s 15ms/step - loss: 0.0370 - mean_absolute_error: 0.0136 - mean_squared_error: 0.0038 - acc: 0.9915 - val_loss: 0.9463 - val_mean_absolute_error: 0.1124 - val_mean_squared_error: 0.0836 - val_acc: 0.7971
Epoch 10/100
3767/3767 [==============================] - 54s 14ms/step - loss: 0.0218 - mean_absolute_error: 0.0074 - mean_squared_error: 0.0023 - acc: 0.9944 - val_loss: 0.9268 - val_mean_absolute_error: 0.1112 - val_mean_squared_error: 0.0832 - val_acc: 0.7947
Epoch 11/100
3767/3767 [==============================] - 52s 14ms/step - loss: 0.0124 - mean_absolute_error: 0.0042 - mean_squared_error: 0.0013 - acc: 0.9965 - val_loss: 1.0381 - val_mean_absolute_error: 0.1047 - val_mean_squared_error: 0.0829 - val_acc: 0.8019
Epoch 12/100
3767/3767 [==============================] - 53s 14ms/step - loss: 0.0083 - mean_absolute_error: 0.0026 - mean_squared_error: 0.0010 - acc: 0.9971 - val_loss: 1.1083 - val_mean_absolute_error: 0.1052 - val_mean_squared_error: 0.0858 - val_acc: 0.8091
Epoch 13/100
3767/3767 [==============================] - 54s 14ms/step - loss: 0.0067 - mean_absolute_error: 0.0021 - mean_squared_error: 8.8009e-04 - acc: 0.9976 - val_loss: 1.1346 - val_mean_absolute_error: 0.1051 - val_mean_squared_error: 0.0861 - val_acc: 0.8019
Epoch 14/100
3767/3767 [==============================] - 51s 14ms/step - loss: 0.0057 - mean_absolute_error: 0.0018 - mean_squared_error: 8.3039e-04 - acc: 0.9973 - val_loss: 1.1624 - val_mean_absolute_error: 0.1028 - val_mean_squared_error: 0.0853 - val_acc: 0.7971
Epoch 15/100
3767/3767 [==============================] - 52s 14ms/step - loss: 0.0055 - mean_absolute_error: 0.0016 - mean_squared_error: 8.2199e-04 - acc: 0.9973 - val_loss: 1.1879 - val_mean_absolute_error: 0.1033 - val_mean_squared_error: 0.0862 - val_acc: 0.7947
Epoch 16/100
3767/3767 [==============================] - 52s 14ms/step - loss: 0.0049 - mean_absolute_error: 0.0014 - mean_squared_error: 7.9869e-04 - acc: 0.9973 - val_loss: 1.2090 - val_mean_absolute_error: 0.1033 - val_mean_squared_error: 0.0864 - val_acc: 0.7971
Epoch 17/100
3767/3767 [==============================] - 51s 14ms/step - loss: 0.0045 - mean_absolute_error: 0.0013 - mean_squared_error: 7.6631e-04 - acc: 0.9973 - val_loss: 1.2209 - val_mean_absolute_error: 0.1030 - val_mean_squared_error: 0.0864 - val_acc: 0.7947
df = pd.DataFrame(history.history)
display(df)
val_loss val_mean_absolute_error val_mean_squared_error val_acc loss mean_absolute_error mean_squared_error acc
0 0.936901 0.200690 0.119645 0.673031 1.158211 0.239806 0.125970 0.629413
1 0.672482 0.150946 0.084675 0.737470 0.590081 0.159289 0.078209 0.764800
2 0.701436 0.151316 0.086166 0.739857 0.505808 0.136076 0.067353 0.804354
3 0.663509 0.134704 0.078389 0.770883 0.387925 0.110072 0.052156 0.853464
4 0.737064 0.133766 0.084423 0.761337 0.277368 0.081109 0.036827 0.896735
5 0.716836 0.122190 0.080062 0.785203 0.281022 0.077929 0.036403 0.905495
6 0.708884 0.126144 0.078431 0.782816 0.138349 0.040343 0.016417 0.956199
7 0.828004 0.121335 0.082972 0.787590 0.101332 0.032350 0.012329 0.970003
8 0.946291 0.112425 0.083580 0.797136 0.037030 0.013608 0.003801 0.991505
9 0.926845 0.111206 0.083175 0.794749 0.021777 0.007377 0.002344 0.994425
10 1.038088 0.104662 0.082922 0.801909 0.012358 0.004177 0.001337 0.996549
11 1.108292 0.105200 0.085754 0.809069 0.008341 0.002552 0.001044 0.997080
12 1.134614 0.105058 0.086104 0.801909 0.006718 0.002110 0.000880 0.997611
13 1.162371 0.102757 0.085296 0.797136 0.005706 0.001753 0.000830 0.997345
14 1.187874 0.103300 0.086247 0.794749 0.005452 0.001592 0.000822 0.997345
15 1.209019 0.103252 0.086414 0.797136 0.004915 0.001417 0.000799 0.997345
16 1.220853 0.102980 0.086429 0.794749 0.004455 0.001282 0.000766 0.997345

クラシフィケーションレポート

from sklearn.metrics import classification_report, confusion_matrix
from keras.models import load_model
from keras_bert import get_custom_objects

model = load_model(model_filename)

predicted_test_labels = model.predict(data['test_features']).argmax(axis=1)
numeric_test_labels = np.array(data['test_labels']).argmax(axis=1)
print(classification_report(numeric_test_labels, predicted_test_labels, target_names = ['グルメ', '携帯電話', '京都', 'スポーツ']))
              precision    recall  f1-score   support

         グルメ       0.76      0.86      0.81       137
        携帯電話       0.85      0.82      0.84       145
          京都       0.73      0.68      0.70        47
        スポーツ       0.89      0.78      0.83        90

   micro avg       0.81      0.81      0.81       419
   macro avg       0.80      0.79      0.79       419
weighted avg       0.81      0.81      0.81       419

まとめ

F値で81と高い精度が得られました。これは同じコーパスを利用したWord2Vecの結果よりも高い数値であり、コーパスを追加したものを上回っています。

  • Wikipediaja with BERT(Weighted Avg F1): 0.81
  • Wikipediaja(Weighted Avg F1): 0.77 12
  • Wikipediaja+現代日本語書き言葉均衡コーパス(Weighted Avg F1): 0.79 12

ToDo

Keras BERT経由でFine Tuningを行うと結果が収束しなかったため、今回はFineTuningを行っていません。
さらに高い精度を得られる可能性があるため、原因を特定してFine Tuningした結果も確認したいと考えています。

2019/05/27 FineTunignしてみました。

BERTの構造

BERTはトレーニング時と推論時で入出力部分が少し異なります。以下は推論時のネットワークを可視化したものです。

from keras.utils import plot_model

plot_model(bert, to_file='train-bert.png', show_shapes=True)

参考文献

  1. SentencePiece+word2vecでコーパスによる差を確認してみるより。 2

  2. 独自にトレーニングしたWord2Vecを使用しています。 2

55
55
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
55
55

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?