概要
先日、huggingfeceのtransformersで日本語学習済BERTが公式に使えるようになりました。
https://github.com/huggingface/transformers
おはようござえます、日本の友達
— Hugging Face (@huggingface) December 13, 2019
Hello, Friends from Japan 🇯🇵!
Thanks to @NlpTohoku, we now have a state-of-the-art Japanese language model in Transformers, `bert-base-japanese`.
Can you guess what the model outputs in the masked LM task below? pic.twitter.com/XIBUu7wrex
これまで、(transformersに限らず)公開されている日本語学習済BERTを利用するためには色々やることが多くて面倒でしたが、transformersを使えばかなり簡単に利用できるようになりました。
本記事では、transformersとPyTorch, torchtextを用いて日本語の文章を分類するclassifierを作成、ファインチューニングして予測するまでを行います。
間違っているところやより良いところがあったらぜひ教えて下さい。
また、本記事の実装は つくりながら学ぶ!PyTorchによる発展ディープラーニング
をとても参照しています。とても良い本です。
https://www.amazon.co.jp/dp/B07VPDVNKW
環境
- PC
- OS Ubuntu18.04 LTS
- RAM 32GB
- GPU nVIDIA GeForce RTX-2070
- Python
- python3.7.4
- torch==1.2.0
- torchtext==0.4.0
- transformers==2.2.2
補足 transformers==2.2.2について
2019/12/14時点では、 pip install transformers
で入るバージョンだとbert-japaneseは使えませんでした。
gitをcloneしてきてinstallすることで使えました。
いずれpipにも入ると思います。たぶん。
(2020/01/08 追記)
pypiに登録されている最新が2.3.0になっていることをコメントにてご教示いただきました (https://qiita.com/nekoumei/items/7b911c61324f16c43e7e#comment-c5fa1d89c91a4110b050 )
pip install transformers
で日本語BERT使えます!
実装されたモデルについて
東北大学 乾・鈴木研究室の作成・公開されたBERTモデルだそうです。
https://github.com/cl-tohoku/bert-japanese
- 日本語Wikipediaを用いて学習
- tokenizerはMeCab + WordPiece(character tokenizationもある)
- max sequence lengthは512
詳しくは上記githubなりtransformersを参照してください。
触ってみる
本項では、transformersを利用するにあたって重要と思われる部分をかいつまんで説明します。
とにかく使いたいんだが?って方は最後に自作classifierを載せているのでそっちを見てください。
tokenizerについて
tokenizerが用意されています。自分でMeCabを用意する必要はありません。良い。
MeCab+WordPiece or character tokenization, 通常 or whole word masking の2*2で4種類ありますが、ここではMeCab+WordPiece, whole word maskingを使います。
from transformers import BertJapaneseTokenizer
tokenizer = BertJapaneseTokenizer.from_pretrained('bert-base-japanese-whole-word-masking')
tokenizer.tokenize('お腹が痛いので遅れます。')
# ['お', '##腹', 'が', '痛', '##い', 'ので', '遅れ', 'ます', '。']
また、上記 BertJapaneseTokenizer.from_pretrained('bert-base-japanese-whole-word-masking')
ではキャッシュとしてダウンロードされます。
ちゃんと保存したい場合、 tokenizer.save_pretrained('path/to/dir/')
をしてあげることで、指定したdirにvocab.txt, special_tokens_map.json, added_tokens.json が保存されます。
保存したファイルを使用する場合は BertJapaneseTokenizer.from_pretrained('path/to/dir')
でOKです。
BertForSequenceClassificationについて
モデルダウンロード
今回は文書分類をしたいので BertForSequenceClassification
を使います。これは普通のBERTモデルの最後にclassifierユニットが接続されています。
from transformers import BertForSequenceClassification
net = BertForSequenceClassification.from_pretrained('bert-base-japanese-whole-word-masking', num_labels=9)
print(net.classifier)
# Linear(in_features=768, out_features=9, bias=True)
from_pretrained
時にnum_labelsを指定してあげることで、任意のクラス数の分類器にできます。(デフォルトは2クラス分類器)便利ですね。
1クラス分類器にすると回帰タスクに対応できそうです。試してないです。
tokenizerと同様、キャッシュダウンロードになるので、保存したい場合は下記のようにしてください。
net.save_pretrained('path/to/dir') # save
net = BertForSequenceClassification.from_pretrained('path/to/dir') # load
モデルのreturnについて
面白いのは、modelにinputs, labelsを入れるとreturnが (loss, logit)
のtupleになっていることです。
# dataloaderから1つバッチを取り出して2クラス分類をしてみる。
# train_dl はDataLoader
batch = next(iter(train_dl))
inputs = batch.Text[0] # 文章
labels = batch.Label # ラベル
loss, logit = net(input_ids=inputs, labels=labels)
print(loss)
# tensor(0.7030, grad_fn=<NllLossBackward>)
print(logit)
# tensor([[-0.0195, 0.0448],
# [ 0.0459, 0.0977],
# [ 0.1200, 0.1295],
# [ 0.1090, 0.0590]], grad_fn=<AddmmBackward>)
logitはsoftmax噛ますことで使いやすくなりますね。
自作classifierをつくった
https://github.com/nekoumei/DocumentClassificationUsingBERT-Japanese
本題ですが、コードは長いのでgithubにあげておきました。
exampleのNotebookも載せています。皆大好きlivedoorニュースコーパスの9クラス分類です。
3行で日本語BERTをfine tuning, predictできるようになりました。
model = clf.DocumentClassifier(num_labels=9, num_epochs=100)
model.fit(train_df, val_df, early_stopping_rounds=10)
y_proba = model.predict(val_df)
transformersとは関係ないんですが、torchtextは現在、ファイルからの読込しか対応していません。
そこで、下記stackoverflowで回答されていたpandas DataFrameからtorchtextで読み込めるクラスを用いています。
https://stackoverflow.com/questions/52602071/dataframe-as-datasource-in-torchtext
KFold CVがやりやすくなりますね。
torchtextでもopen issueのようですが、まだ公式ではできないようです。
https://github.com/pytorch/text/issues/488
終わりに
もっとこうしたほうがいいよ!とか、種々のマサカリがあると思うんで、Twitterとかで指摘もらえると助かります。
https://twitter.com/nekoumei