はじめに
自然言語処理の様々なタスクでSOTAを更新しているBERTですが、Google本家がGithubで公開しているものはTensorflowをベースに実装されています。
PyTorch使いの人はPyTorch版を使いたいところですが、PyTorch版は作っていないのでHuggingFaceが作ったやつを使ってね、ただし我々は開発に関与していないので詳しいことは彼らに訊いてね!とQAに書かれています。
HuggingFace製のBERTですが、2019年12月までは日本語のpre-trained modelsがありませんでした。
そのため、英語では気軽に試せたのですが、日本語ではpre-trained modelsを自分で用意する必要がありました。
しかし、2019年12月についに日本語のpre-trained modelsが追加されました。
https://huggingface.co/transformers/pretrained_models.html
- bert-base-japanese
- bert-base-japanese-whole-word-masking
- bert-base-japanese-char
- bert-base-japanese-char-whole-word-masking
東北大学の乾研究室が作成したもので、4つのモデルが使えます。
特別な事情がなければ2番目の bert-base-japanese-whole-word-masking
を使うのがよいでしょう。
通常版とWhole Word Masking版では、Whole Word Masking版の方がfine tuningしたタスクの精度が少し高い傾向にあるようです1。
これにより、PyTorch版BERTを日本語でも手軽に試すことができるようになりました。
BERTとは?
BERTの仕組みは既に様々なブログや書籍で紹介されているので、詳細な説明は割愛します。
簡単に説明すると、
- 大量の教師なしコーパスからpre-trained modelsを作成
- Masked Language ModelとNext Sentence Predicitionの2種類の言語タスクを解くことで事前学習する
- pre-trained modelsをfine tuningしてタスクを解く
という処理の流れになります。
Pre-trained modelsの作成には大量のコンピュータ資源と時間を要しますが、pre-trained modelsを利用することで少量の教師データからでもタスクを高精度に解くことができるというのがBERTの一番のポイントです。
日本語Pre-trained models
まずは、事前学習した日本語pre-trained modelsの精度を確認します。
今回はMasked Language Modelの精度を確認します。
Masked Language Modelを簡単に説明すると、文の中のある単語をマスクしておき、そのマスクされた単語を予測するというものです。
BertJapaneseTokenizerとBertForMaskedLMを使い、次のように書くことができます。
「テレビでサッカーの試合を見る。」という文の「サッカー」をマスクして、その単語を予測するというものです。
import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM
# Load pre-trained tokenizer
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
# Tokenize input
text = 'テレビでサッカーの試合を見る。'
tokenized_text = tokenizer.tokenize(text)
# ['テレビ', 'で', 'サッカー', 'の', '試合', 'を', '見る', '。']
# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 2
tokenized_text[masked_index] = '[MASK]'
# ['テレビ', 'で', '[MASK]', 'の', '試合', 'を', '見る', '。']
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# [571, 12, 4, 5, 608, 11, 2867, 8]
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
# tensor([[ 571, 12, 4, 5, 608, 11, 2867, 8]])
# Load pre-trained model
model = BertForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
model.eval()
# Predict
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0][0, masked_index].topk(5) # 予測結果の上位5件を抽出
# Show results
for i, index_t in enumerate(predictions.indices):
index = index_t.item()
token = tokenizer.convert_ids_to_tokens([index])[0]
print(i, token)
上記のプログラムの実行結果は次のようになります。
「サッカー」が3位に登場しており、他の単語も日本語として正しそうな結果になっています。
日本ではあまり馴染みのない「クリケット」や大リーグのチーム名が出てくるのは、Wikipediaのデータから事前学習したためだと考えられます。
0 クリケット
1 タイガース
2 サッカー
3 メッツ
4 カブス
以上より、pre-trained modelsが正しく事前学習されていることが確認できました。
次は、このpre-trained modelsをもとにfine tuningしてタスクを解きます。
Fine tuning with BERT
日本語のオリジナルデータで動くようにソースコードを修正
HuggingFaceのGitHubには、fine tuningしてタスクを解く例が幾つか載っています。
しかし、これらは英語のデータセットを対象にしたもので、日本語のデータセットを対象にしたものはありません2。
そこで、既存のソースコードを修正して、日本語のオリジナルデータでも動くようにします。
自然言語処理で基本的なタスクであるテキスト分類を想定し、GLUEのテキスト分類に使われているソースコードを対象とします。
そして、
の2つのプログラムを修正します。
注意
なお、これはgit clone
などでダウンロードしたファイルではなくて、インストール先のディレクトリのファイルを変更する必要があります。
例えば、venvを使っている場合、インストール先のディレクトリは[venvディレクトリ]/lib/python3.7/site-packages/transformers
のようになります。
1. transformers/data/processors/glue.py
学習データ(train.tsv)と検証データ(dev.tsv)を読み込む部分です。
次のようにglue_tasks_num_labels
、glue_processors
、glue_output_modes
にoriginal
というタスクを追加した上で、OriginalProcessor
というクラスを追加します。
glue_tasks_num_labels = {
"cola": 2,
"mnli": 3,
"mrpc": 2,
"sst-2": 2,
"sts-b": 1,
"qqp": 2,
"qnli": 2,
"rte": 2,
"wnli": 2,
"original": 2, # 追加
}
glue_processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mnli-mm": MnliMismatchedProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
"sts-b": StsbProcessor,
"qqp": QqpProcessor,
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
"original": OriginalProcessor, # 追加
}
glue_output_modes = {
"cola": "classification",
"mnli": "classification",
"mnli-mm": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
"original": "classification", # 追加
}
class OriginalProcessor(DataProcessor):
"""Processor for the original data set."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence"].numpy().decode("utf-8"),
None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# TSVファイルにヘッダー行がある場合はコメントアウトを外す
# if i == 0:
# continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
学習データと検証データは、
- テキスト
- ラベル
の2列から成るTSVファイルを想定しています。
面白かった 0
楽しかった 0
退屈だった 1
悲しかった 1
満喫した 0
辛かった 1
上記のプログラムは2値分類を想定していますが、多値分類のときはラベルの数と値を適宜修正して下さい。
2. transformers/data/metrics/__init__.py
検証データを使って精度を算出する部分です。
次のように条件式で task_name == "original"
の場合を追加するだけです。
def glue_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
# 追加
elif task_name == "original":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
Fine tuningして分類問題を解く
日本語のオリジナルデータでも動くようになったので、あとはfine tuningして分類問題を解くだけです。
これは次のコマンドを実行するのみです。学習データと検証データのファイルはdata/original/
配下に入れておきます。
$ python examples/run_glue.py \
--data_dir=data/original/ \
--model_type=bert \
--model_name_or_path=cl-tohoku/bert-base-japanese-whole-word-masking \
--task_name=original \
--do_train \
--do_eval \
--output_dir=output/original
上記のコマンドを実行して問題なく終了すれば、次のようなログが出力されます。
accの値が1.0
となっており、検証データの2件が正しく分類できていることが分かります。
01/18/2020 17:08:39 - INFO - __main__ - Saving features into cached file data/original/cached_dev_bert-base-japanese-whole-word-masking_128_original
01/18/2020 17:08:39 - INFO - __main__ - ***** Running evaluation *****
01/18/2020 17:08:39 - INFO - __main__ - Num examples = 2
01/18/2020 17:08:39 - INFO - __main__ - Batch size = 8
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2.59it/s]
01/18/2020 17:08:40 - INFO - __main__ - ***** Eval results *****
01/18/2020 17:08:40 - INFO - __main__ - acc = 1.0
そして、output/original/
配下にモデルファイルが作成されていることが確認できます。
$ find output/original
output/original
output/original/added_tokens.json
output/original/tokenizer_config.json
output/original/special_tokens_map.json
output/original/config.json
output/original/training_args.bin
output/original/vocab.txt
output/original/pytorch_model.bin
output/original/eval_results.txt
おわりに
PyTorch版のBERTを使って日本語のテキスト分類をする方法を紹介しました。
他のソースコードも修正すれば、テキスト分類だけでなくテキスト生成や質問応答などのタスクも行うことができます。
これまでPyTorchを使ってBERTを日本語で動かすのはハードルが高かったですが、日本語のpre-trained modelsが公開されたことでそのハードルが非常に低くなったように思います。
是非、皆さんもPyTorch版のBERTを日本語のタスクで試して下さい。
参考記事
https://techlife.cookpad.com/entry/2018/12/04/093000
http://kento1109.hatenablog.com/entry/2019/08/23/092944