概要
BERTについての簡単な解説と、実装方法をまとめてみました。
解説については、簡単な概要と入出力のみ説明します。もっと良く知りたい方は、元の論文や他の記事を参考にしてください。
論文:https://arxiv.org/abs/1810.04805
BERTの解説
概要
BERT(Bidirectional Encoder Representations from Transformers)とは何か。
簡単に述べると、自然言語処理の分野で使われるもので、入力として与えられた文を機械学習しやすい形に変換してくれるものです(こういったものをEncoderと言います)。
機械学習を行うためには、どのようなタスクであろうとモデルのへの入力は数値ベクトルでなければなりません。この役割をBERTが担ってくれるわけです。
文を数値ベクトルに変換してくれるものにはDoc2Vecなど他のモデルもありますが、BERTは今のところ最も優れたEncoderだと言われています。
また、Doc2Vecは同じ入力であれば同じベクトルを出力するのに対し、BERTは使うデータやタスクに応じてファインチューニングというパラメーター調整を行うことで自分だけのEncoderに変えることができます。
入出力
この図は、論文に載ってる図です。
BERTモデル自体としては[CLS], Tok1, ..., TokNが入力となりますが、これらのトークンを作成するために、元の文をトークン化する処理が必要となります。
また、[SEP]というトークンを挿入することで、2つの文を入力とすることもできます。
出力は、入力トークンと同じ数のトークンとなります。[CLS]に対応する出力Cは入力した文全体の情報を持っていると考えられており、文の分類タスクなどで使うことができます。
実装
BERTを使う方法を紹介します。PyTorch自体の使い方などはここでは紹介しません。
作業環境
項目 | バージョン等 |
---|---|
OS | mac Mojave |
python | 3.8.2 |
pyenv | 1.2.26 |
transformers | 4.12.2 |
必要ライブラリのインストール
pip install transformers[ja]
これで今回必要なライブラリをインストールできます。
ソースコード
#必要ライブラリのインポート
from transformers import BertJapaneseTokenizer, BertModel
#入力する文
INPUT = "我輩は猫である。"
#事前学習済みモデルの選択
pretrained = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = BertJapaneseTokenizer.from_pretrained(pretrained)
bert = BertModel.from_pretrained(pretrained)
#入力文のトークン化
input_token = tokenizer.encode_plus(
INPUT,
add_special_tokens = True,
truncation = True,
padding = "max_length",
return_tensors = "pt"
)
ids = input_token["input_ids"]
mask = input_token["attention_mask"]
#トークンをBERTに通した結果
output = bert(ids, mask)
print(output)
出力↓
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.3607, 0.1658, 0.0195, ..., -0.1120, -0.3533, 0.0227],
[ 0.8294, 0.4226, 0.4477, ..., -0.6203, -0.4340, -0.4595],
[ 0.2881, 0.4584, -0.1200, ..., -0.4955, 0.2524, -0.5762],
...,
[-0.0026, 0.1373, 0.2016, ..., -0.3130, -0.1559, -0.2415],
[-0.0509, 0.4031, 0.0790, ..., -0.2865, -0.1257, -0.2302],
[-0.0594, 0.1091, -0.0197, ..., -0.2351, -0.1172, -0.2040]]],
grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 8.3770e-01, 9.5697e-02, -2.9548e-01, -3.1350e-01, 4.1436e-01,
(略)
こんな感じで出力は辞書型で出てきます。
last_hidden_stateは上で述べた出力トークン全ての情報を、pooler_outputは出力トークンのうちCの情報を持っています。
また、次のように書くことで辞書型ではなく、タプルで返すこともできます。
bert = BertModel.from_pretrained(pretrained)
↓
bert = BertModel.from_pretrained(pretrained, return_dict=False)
output = bert(ids, mask)
↓
last_hidden_state, pooler_output = bert(ids, mask)