10
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

UTH-BERT を Tensorflow2.X / Keras BERT から利用して文書分類を行う

Posted at

はじめに

東京大学の医療AI開発学講座が日本語の医療テキストで事前学習したBERTモデルであるUTH-BERTを2020年3月に公開してくださいました。

今回は、このモデルをKeras BERTを用いてファインチューニングして文書分類を行う手順を示したいと思います。

本記事の内容はGoogle Colab上で実行することを想定しています。特に前準備は必要なくそのまま実行できるようにしてありますが、リソースの割り当てによってはメモリやストレージが不足する可能性がありますのでご了承ください。

本記事の Google Colabノートブック(GitHub)

私自身、TensorflowやBERTについてはまだまだ勉強中ですので、間違いや改善できる点などがありましたらご指摘いただけると幸いです!

参考にしたサイト

本記事の作成に当たっては、主に以下のサイトを参考にさせていただきました。ありがとうございます。

制限事項

UTH-BERTは日本語医療テキストに適用することを目的としたモデルですが、簡単に入手できて一般公開できる日本語医療テキストが見当たりませんので、今回は医療テキストではなくブログで構成されたデータセットであるKNBコーパス(KNBC)を対象にしています。そのため、今回の結果はUTH-BERTの本来の性能を発揮できていないことをご了承ください。

KNBCは京都大学情報学研究科--NTTコミュニケーション科学基礎研究所 共同研究ユニットが公開している解析済みブログコーパスで、4テーマ(京都観光、携帯電話、スポーツ、グルメ)の249記事、4,186文を含んでいます。

文書分類タスクの概要

上述の通り、KNBCは4テーマの249記事、4,186文を含んでいますが、今回の文書分類タスクは一つ一つの文がどのテーマに属しているのかを予測するものとなっています。

Google Colab 上での実装

pythonライブラリのインストールとインポート

!pip install keras_bert
!pip install mecab-python3
!pip install jaconv
!pip install neologdn

import os, sys
import tensorflow as tf
import pandas as pd
import numpy as np
import keras_bert
  • Tensorflowのバージョンは2.3.0
  • Keras BERTのバージョンは0.86.0

で動作確認しています。

Keras BERT はデフォルトではスタンドアロンのKerasの使用を想定していますが、今回はtensorflow.kerasを使用するので環境変数を設定します。

# Keras BERT で tf.keras を使用するための環境変数を設定
os.environ['TF_KERAS'] = '1'

その他の必要なライブラリ等の準備

UTH-BERTではテキストの分割処理にMeCabを用いています。MeCabの辞書としてNEologd万病辞書が使われていますので、それらも一緒にダウンロード・インストールします。

テキストの前処理と分割処理についてはUTH-BERTのサイトSource code for pre-processing text and tokenizationに方法と必要なプログラムがあります。

なお、公式ではNEologdのインストール時に -a オプション(全ての追加辞書をインストール)が付いていませんが、付けた方が正しい分割結果になるようです。

なお、UTH-BERTとNEologdのデータサイズが大きいため、10分近くかかる場合があります。

# UTH-BERT
!wget https://ai-health.m.u-tokyo.ac.jp/labweb/dl/uth_bert/UTH_BERT_BASE_MC_BPE_V25000_10M.zip
!unzip UTH_BERT_BASE_MC_BPE_V25000_10M.zip
!rm UTH_BERT_BASE_MC_BPE_V25000_10M.zip
!git clone https://github.com/jinseikenai/uth-bert.git

# MeCab & NEologd
!apt install mecab libmecab-dev mecab-ipadic-utf8 file
!git clone --depth 1 https://github.com/neologd/mecab-ipadic-neologd.git
!mecab-ipadic-neologd/bin/install-mecab-ipadic-neologd -a -y # 公式では -a オプションはついていないが多分必要
os.environ['MECABRC'] = "/etc/mecabrc" # 環境変数でmecabrcの場所を指定

# 万病辞書
!wget http://sociocom.jp/~data/2018-manbyo/data/MANBYO_201907_Dic-utf8.dic

# KNBコーパス
!wget http://nlp.ist.i.kyoto-u.ac.jp/kuntt/KNBC_v1.0_090925_utf8.tar.bz2
!tar -jxvf KNBC_v1.0_090925_utf8.tar.bz2
!rm KNBC_v1.0_090925_utf8.tar.bz2

ディレクトリ・ファイルへのパスを設定

# データセット
knbc_dir_path = 'KNBC_v1.0_090925_utf8'
gourmet_tsv_file_path = os.path.join(knbc_dir_path, 'corpus2/Gourmet.tsv')
keitai_tsv_file_path = os.path.join(knbc_dir_path, 'corpus2/Keitai.tsv')
kyoto_tsv_file_path = os.path.join(knbc_dir_path, 'corpus2/Kyoto.tsv')
sports_tsv_file_path = os.path.join(knbc_dir_path, 'corpus2/Sports.tsv')

# 訓練済みモデル
pretrained_model_dir_path = 'UTH_BERT_BASE_MC_BPE_V25000_10M'
pretrained_bert_config_file_path = os.path.join(pretrained_model_dir_path, 'bert_config.json')
pretrained_model_checkpoint_path = os.path.join(pretrained_model_dir_path, 'model.ckpt-10000000') # 拡張子不要
pretrained_vocab_file_path = os.path.join(pretrained_model_dir_path, 'vocab.txt')

# 今回学習するモデル
!mkdir train_model
train_model_dir_path = 'train_model'
train_bert_config_file_path = os.path.join(train_model_dir_path, 'train_bert_config.json')
train_model_checkpoint_path = os.path.join(train_model_dir_path, 'train_model.ckpt')

# NEologd辞書ディレクトリへのパス
import subprocess
cmd = 'echo `mecab-config --dicdir`"/mecab-ipadic-neologd"'
neologd_dic_dir_path = subprocess.check_output(cmd, shell=True).decode('utf-8').strip()

# 万病辞書へのパス
manbyo_dic_path = 'MANBYO_201907_Dic-utf8.dic'

KNBコーパスから今回用いるデータを抽出

# 各カテゴリのtsvファイルを読み込んでラベル列を追加
df_gourmet = pd.read_table(gourmet_tsv_file_path, header=None)
df_gourmet['label'] = 'グルメ'

df_keitai = pd.read_table(keitai_tsv_file_path, header=None)
df_keitai['label'] = '携帯電話'

df_kyoto = pd.read_table(kyoto_tsv_file_path, header=None)
df_kyoto['label'] = '京都観光'

df_sports = pd.read_table(sports_tsv_file_path, header=None)
df_sports['label'] = 'スポーツ'

# 結合して必要な列だけ抽出
df_dataset = pd.concat([df_gourmet, df_keitai, df_kyoto, df_sports])[[1, 'label']]
df_dataset.columns = ['text', 'label'] # ラベル名変更
df_dataset

出力

text label
0 [グルメ]烏丸六角のおかき屋さん グルメ
1 六角堂の前にある、蕪村庵というお店に行ってきた。 グルメ
2 おかきやせんべいの店なのだが、これがオイシイ。 グルメ
3 のれんをくぐると小さな庭があり、その先に町屋風の店内がある。 グルメ
... ... ...
519 男性諸君、このこと忘れないでやぁ(>◆<) スポーツ
520 まぁ。。。 スポーツ
521 女はいろいろ強いし、怖いけどね笑 スポーツ

学習用データとテスト用データに分割

今回は4,186文のうち500をテスト用データとし、残りを学習用データとしました。

from sklearn.model_selection import train_test_split

df_train, df_test = train_test_split(df_dataset, test_size=500)

前処理用プログラムの動作確認

まず、Tensorflow 2.X では tokenization_mod.py の tf.gfile.GFile を tf.io.gfile.GFile に変更しないとエラーになるので書き換えます。

# sed コマンドで tf.gfile.GFile を tf.io.gfile.GFile に置換
!sed -i-e 's/tf\.gfile\.GFile/tf\.io\.gfile\.GFile/g' ./uth-bert/tokenization_mod.py

https://github.com/jinseikenai/uth-bertexample_main.py の内容が実行できることを確認します。

パスを今回の環境に合わせて書き換えています。

sys.path.append('uth-bert')
from preprocess_text import preprocess as my_preprocess
from tokenization_mod import MecabTokenizer, FullTokenizerForMecab

if __name__ == '__main__':

    # special token for a Person's name (Do not change)
    name_token = "@@N"

    # path to the mecab-ipadic-neologd
    #mecab_ipadic_neologd = '/usr/lib/mecab/dic/mecab-ipadic-neologd' # 変更
    mecab_ipadic_neologd = neologd_dic_dir_path

    # path to the J-Medic (We used MANBYO_201907_Dic-utf8.dic)
    #mecab_J_medic = './MANBYO_201907_Dic-utf8.dic' # 変更
    mecab_J_medic = manbyo_dic_path

    # path to the uth-bert vocabulary
    #vocab_file = "./bert_vocab_mc_v1_25000.txt" # 変更
    vocab_file = pretrained_vocab_file_path

    # MecabTokenizer
    sub_tokenizer = MecabTokenizer(mecab_ipadic_neologd=mecab_ipadic_neologd,
                                   mecab_J_medic=mecab_J_medic,
                                   name_token=name_token)

    # FullTokenizerForMecab
    tokenizer = FullTokenizerForMecab(sub_tokenizer=sub_tokenizer,
                                      vocab_file=vocab_file,
                                      do_lower_case=False)

    # pre-process and tokenize example
    original_text = "2002 年夏より重い物の持ち上げが困難になり,階段の昇りが遅くなるなど四肢の筋力低下が緩徐に進行した.2005 年 2 月頃より鼻声となりろれつが回りにくくなった.また,食事中にむせるようになり,同年 12 月に当院に精査入院した。"
    print ('元のテキスト\n', original_text, end='\n\n')

    pre_processed_text = my_preprocess(original_text)
    print ('前処理後テキスト\n', pre_processed_text, end='\n\n')

    output_tokens = tokenizer.tokenize(pre_processed_text)
    print ('トークン化後のテキスト\n', output_tokens)

出力

・Original text
 2002 年夏より重い物の持ち上げが困難になり,階段の昇りが遅くなるなど四肢の筋力低下が緩徐に進行した.2005 年 2 月頃より鼻声となりろれつが回りにくくなった.また,食事中にむせるようになり,同年 12 月に当院に精査入院した。

・After pre-processing
 2002年夏より重い物の持ち上げが困難になり、階段の昇りが遅くなるなど四肢の筋力低下が緩徐に進行した.2005年2月頃より鼻声となりろれつが回りにくくなった.また、食事中にむせるようになり、同年12月に当院に精査入院した。

・After tokenization
 ['2002年', '夏', 'より', '重い', '物', 'の', '持ち上げ', 'が', '困難', 'に', 'なり', '、', '階段', 'の', '[UNK]', 'が', '遅く', 'なる', 'など', '四肢', 'の', '筋力低下', 'が', '緩徐', 'に', '進行', 'し', 'た', '.', '2005年', '2', '月頃', 'より', '鼻', '##声', 'と', 'なり', 'ろ', '##れ', '##つ', 'が', '回り', '##にく', '##く', 'なっ', 'た', '.', 'また', '、', '食事', '中', 'に', 'むせる', 'よう', 'に', 'なり', '、', '同年', '12月', 'に', '当', '院', 'に', '精査', '入院', 'し', 'た', '。']

公式のExampleと同じ結果になりました。

前処理を実行

Pre-processingとTokenizationを学習用データとテスト用データに適用

Pre-processing

def preprocess_text(s):
    result = []
    for text in s:
        result.append(my_preprocess(text))

    return result

train_text_preprocessed = preprocess_text(df_train['text'])
test_text_preprocessed = preprocess_text(df_test['text'])

# 先頭3つのデータを表示
for i in range(3):
    print(train_text_preprocessed[i])

出力

それじゃあまた。
これもバスフリークの人や電車マニアのひとには申し訳ないが、好きじゃない。
また「受信中」と表示されている・・・・

Tokenization

  • Tokenization時に[CLS] と [SEP]を付加します。
  • 処理内容はtokenization_mod.py を参照ください。
def tokenize_text(s):
    result = []
    for text in s:
        result.append(['[CLS]'] + tokenizer.tokenize(text) + ['[SEP]'])

    return result

train_text_tokenized = tokenize_text(df_train['text'])
test_text_tokenized = tokenize_text(df_test['text'])

# 先頭3つのデータを表示
for i in range(3):
    print(train_text_tokenized[i])

実行結果

['[CLS]', 'それ', 'じゃ', 'あ', '##また', '。', '[SEP]']
['[CLS]', 'これ', 'も', 'バス', 'フリー', '##ク', 'の', '人', 'や', '電車', 'マ', '##ニア', 'の', 'ひと', 'に', 'は', '申し訳ない', 'が', '、', '好き', 'じゃ', 'ない', '。', '[SEP]']
['[CLS]', 'また', '「', '受', '##信', '中', '」', 'と', '表示', 'さ', 'れ', 'て', 'いる', '・', '・', '・', '・', '[SEP]']

入力データの最大長を算出

学習用モデルの定義を行う際に入力データの最大長が必要となるので、学習データ・テストデータの最大長を調べ、大きい方を採用します。

# ファインチューニングする場合は学習データ・テストデータの最大長を用いる
maxlen = 0

for tokens in train_text_tokenized:
    maxlen = max(maxlen, len(tokens))

for tokens in test_text_tokenized:
    maxlen = max(maxlen, len(tokens))

maxlen

出力

140

テキストをトークン列から単語ID列に変換

def tokens_to_ids(tokenized_text):
    result = []
    for tokens in tokenized_text:
        result.append(tokenizer.convert_tokens_to_ids(tokens))

    return result

train_text_ids = tokens_to_ids(train_text_tokenized)
test_text_ids = tokens_to_ids(test_text_tokenized)

# 先頭3つのデータを表示
for i in range(3):
    print(train_text_ids[i])

出力

[2, 540, 949, 1243, 13978, 6, 3]
[2, 539, 33, 9568, 943, 2451, 10, 693, 110, 6974, 4104, 16749, 10, 14870, 14, 15, 8100, 18, 5, 2661, 949, 42, 6, 3]
[2, 177, 254, 21444, 9822, 87, 253, 27, 10578, 62, 61, 13, 39, 41, 41, 41, 41, 3]

単語ID列をリストからnumpy arrayに変換

def list_to_numpy_array(input_list, maxlen):
    result = np.zeros((len(input_list), maxlen), dtype=np.int32)

    for i in range(len(input_list)):
        for j in range(len(input_list[i])):
            result[i][j] = input_list[i][j]

    return result

X_train = list_to_numpy_array(train_text_ids, maxlen)
X_test = list_to_numpy_array(test_text_ids, maxlen)

# 先頭3つのデータを表示
X_train[:3]

出力

array([[    2,   540,   949,  1243, 13978,     6,     3,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0],
       [    2,   539,    33,  9568,   943,  2451,    10,   693,   110,
         6974,  4104, 16749,    10, 14870,    14,    15,  8100,    18,
            5,  2661,   949,    42,     6,     3,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0],
       [    2,   177,   254, 21444,  9822,    87,   253,    27, 10578,
           62,    61,    13,    39,    41,    41,    41,    41,     3,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0]], dtype=int32)

ラベルデータを入力形式に変換

kerasのクラス分類のラベルデータはOne-hotエンコーディング

# ラベル -> インデックス の対応
label2index = {k: i for i, k in enumerate(df_train['label'].unique())}

# インデックス -> ラベル の対応
index2label = {i: k for i, k in enumerate(df_train['label'].unique())}

# ラベルの分類クラス数
class_count = len(label2index)
print('class count = ', class_count)

# One-hot encoding
y_train = tf.keras.utils.to_categorical([label2index[label] for label in df_train['label']], num_classes=class_count)
y_test = tf.keras.utils.to_categorical([label2index[label] for label in df_test['label']], num_classes=class_count)

# 先頭3つのデータを表示
print(y_train[:3])

出力

class count =  4
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]]

コンフィグファイルの準備

訓練済みモデルのBERTコンフィグファイルをロード

import json
from pprint import pprint as pp

json_open = open(pretrained_bert_config_file_path, 'r')
pretrained_bert_config =json.load(json_open)

pp(pretrained_bert_config)

出力

{'attention_probs_dropout_prob': 0.1,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'hidden_size': 768,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'max_position_embeddings': 512,
 'num_attention_heads': 12,
 'num_hidden_layers': 12,
 'type_vocab_size': 2,
 'vocab_size': 25000}

学習用BERTコンフィグファイルを作成

max_position_embedding と max_seq_length を今回使用するデータのトークンの最大長に設定します。

その他のパラメータは訓練済みモデルと同じです。

train_bert_config = pretrained_bert_config
train_bert_config['max_position_embeddings'] = maxlen
train_bert_config['max_seq_length'] = maxlen

pp(train_bert_config)

# jsonファイルとして保存
with open(train_bert_config_file_path, 'w') as f:
    json.dump(train_bert_config, f, indent=4)

出力

{'attention_probs_dropout_prob': 0.1,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'hidden_size': 768,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'max_position_embeddings': 140,
 'max_seq_length': 140,
 'num_attention_heads': 12,
 'num_hidden_layers': 12,
 'type_vocab_size': 2,
 'vocab_size': 25000}

学習

学習パラメータの設定

BERTはかなり大きいモデルなので、バッチサイズを大きくするとメモリ不足になりやすいです。今回は小さめに設定しています。

BATCH_SIZE = 16
EPOCHS = 20
LR = 1e-4

訓練済みBERTモデルのロード

from keras_bert import load_trained_model_from_checkpoint
bert = load_trained_model_from_checkpoint(train_bert_config_file_path, pretrained_model_checkpoint_path, training=True, trainable=True, seq_len=maxlen)
bert.summary()

(出力は省略)

学習用モデルの定義

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, LSTM, Bidirectional
from keras_bert import AdamWarmup, calc_train_steps

def _create_model(bert, maxlen, class_count):
    bert_last = bert.get_layer(name='NSP-Dense').output #  NSP-Denseを指定する理由は要確認
    output_tensor = Dense(class_count, activation='softmax')(bert_last)

    model = Model([bert.input[0], bert.input[1]], output_tensor)

    decay_steps, warmup_steps = calc_train_steps(
        maxlen,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
    )

    # optimizer='nadam' では収束しないのでAdamWarmupを用いる
    model.compile(
        loss='categorical_crossentropy',
        optimizer=AdamWarmup(decay_steps=decay_steps, warmup_steps=warmup_steps, lr=LR),
        metrics=['acc', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
    )

    return model

model = _create_model(bert, maxlen, class_count)
model.summary()

(出力は省略)

学習を実行

Epoch数を 20 に設定していますが、Early Stoppingにより 6 Epochで終了しています。

学習に用いるデータ量が少ないせいか、1 Epoch で収束してしまいました。Learning Rateを調整するなどしてみましたが、デフォルト値よりいい結果は得られませんでした。後述するように学習そのものはそれなりに上手くいっているようですが、もっといい結果が得られる調整方法がありましたらコメントをいただけると幸いです。

実行時間は割り当てられたGPUによりますが、Tesla T4 で 17分程度でした。

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

history = model.fit(
    [X_train, np.zeros_like(X_train)],
    y_train,
    epochs = EPOCHS,
    batch_size = BATCH_SIZE,
    validation_split=0.1,
    shuffle=True,
    verbose = 1,
    callbacks = [
        EarlyStopping(patience=5, monitor='val_acc', mode='max'),
        ModelCheckpoint(monitor='val_acc', mode='max', filepath=train_model_checkpoint_path, save_best_only=True)
    ]
)

出力

Epoch 1/20
208/208 [==============================] - ETA: 0s - loss: 0.7907 - acc: 0.6856 - precision: 0.8078 - recall: 0.5653WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: train_model/train_model.ckpt/assets
208/208 [==============================] - 295s 1s/step - loss: 0.7907 - acc: 0.6856 - precision: 0.8078 - recall: 0.5653 - val_loss: 0.6089 - val_acc: 0.7751 - val_precision: 0.8690 - val_recall: 0.6829
Epoch 2/20
208/208 [==============================] - 252s 1s/step - loss: 0.4586 - acc: 0.8459 - precision: 0.9280 - recall: 0.7498 - val_loss: 0.6089 - val_acc: 0.7751 - val_precision: 0.8690 - val_recall: 0.6829
Epoch 3/20
208/208 [==============================] - 251s 1s/step - loss: 0.4568 - acc: 0.8475 - precision: 0.9290 - recall: 0.7498 - val_loss: 0.6089 - val_acc: 0.7751 - val_precision: 0.8690 - val_recall: 0.6829
Epoch 4/20
208/208 [==============================] - 252s 1s/step - loss: 0.4596 - acc: 0.8478 - precision: 0.9298 - recall: 0.7465 - val_loss: 0.6089 - val_acc: 0.7751 - val_precision: 0.8690 - val_recall: 0.6829
Epoch 5/20
208/208 [==============================] - 252s 1s/step - loss: 0.4599 - acc: 0.8441 - precision: 0.9277 - recall: 0.7462 - val_loss: 0.6089 - val_acc: 0.7751 - val_precision: 0.8690 - val_recall: 0.6829
Epoch 6/20
208/208 [==============================] - 252s 1s/step - loss: 0.4613 - acc: 0.8472 - precision: 0.9282 - recall: 0.7446 - val_loss: 0.6089 - val_acc: 0.7751 - val_precision: 0.8690 - val_recall: 0.6829

学習曲線を表示

前述の通り、1 Epochで収束してしまっています。

import matplotlib.pyplot as plt

# Loss
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.legend()

loss.png

# Accuracy
plt.plot(history.history['acc'], label='acc')
plt.plot(history.history['val_acc'], label='val_acc')
plt.legend()

acc.png

テストデータを用いて予測

from tensorflow.keras.models import load_model
from keras_bert import get_custom_objects

# 学習中に保存した最良モデルをロード (tensorflow2.1ではエラーになるので注意)
model = load_model(train_model_checkpoint_path, custom_objects=get_custom_objects())

# 予測を実行
y_test_pred_proba = model.predict([X_test, np.zeros_like(X_test)])

# 先頭3つのデータを表示
y_test_pred_proba[:3]

出力

array([[0.01833974, 0.9391952 , 0.01062312, 0.03184191],
       [0.08920565, 0.43004167, 0.30348903, 0.17726359],
       [0.00904089, 0.9884781 , 0.00123674, 0.00124414]], dtype=float32)

結果レポート

from sklearn.metrics import classification_report, confusion_matrix

y_test_true_labels = y_test.argmax(axis=1) # Probability -> index
y_test_pred_labels = y_test_pred_proba.argmax(axis=1) # One-hot -> index

target_names = [index2label[i] for i in range(class_count)]
rep = classification_report(y_test_true_labels, y_test_pred_labels, target_names=target_names, output_dict=True)

pd.DataFrame(rep)

出力

携帯電話 グルメ 京都観光 スポーツ accuracy macro avg weighted avg
precision 0.797386 0.896552 0.736041 0.587302 0.764 0.75432 0.773659
recall 0.7625 0.728972 0.814607 0.672727 0.764 0.744701 0.764
f1-score 0.779553 0.804124 0.773333 0.627119 0.764 0.746032 0.765829
support 160 107 178 55 0.764 500 500

F1-scoreは高いクラス(グルメ)で0.80, 低いクラス(スポーツ)で0.63となりました。

参考サイトの結果と比べると低い値になりましたが、UTH-BERTが医療テキストで訓練されたものであることを考えると、妥当な結果であると思われます。

おわりに

UTH-BERTを用いて簡単な文書分類を行う方法を紹介させていただきました。今回は公開できるデータの都合で医療テキストを対象としませんでしたが、医療テキストに対しては一般のテキストで学習したBERTモデルよりも高い性能が期待できるのではないかと思います。

貴重なモデルを公開してくださった東京大学の医療AI開発学講座の皆様に厚く御礼を申し上げます。

公式サイト

10
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?