3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

BERTについての簡単な解説と使い方

Posted at

概要

BERTについての簡単な解説と、実装方法をまとめてみました。
解説については、簡単な概要と入出力のみ説明します。もっと良く知りたい方は、元の論文や他の記事を参考にしてください。
論文:https://arxiv.org/abs/1810.04805

BERTの解説

概要

BERT(Bidirectional Encoder Representations from Transformers)とは何か。
簡単に述べると、自然言語処理の分野で使われるもので、入力として与えられた文を機械学習しやすい形に変換してくれるものです(こういったものをEncoderと言います)。

機械学習を行うためには、どのようなタスクであろうとモデルのへの入力は数値ベクトルでなければなりません。この役割をBERTが担ってくれるわけです。

文を数値ベクトルに変換してくれるものにはDoc2Vecなど他のモデルもありますが、BERTは今のところ最も優れたEncoderだと言われています。
また、Doc2Vecは同じ入力であれば同じベクトルを出力するのに対し、BERTは使うデータやタスクに応じてファインチューニングというパラメーター調整を行うことで自分だけのEncoderに変えることができます。

入出力

スクリーンショット 2021-11-16 14.04.47.png
この図は、論文に載ってる図です。
BERTモデル自体としては[CLS], Tok1, ..., TokNが入力となりますが、これらのトークンを作成するために、元の文をトークン化する処理が必要となります。
また、[SEP]というトークンを挿入することで、2つの文を入力とすることもできます。

出力は、入力トークンと同じ数のトークンとなります。[CLS]に対応する出力Cは入力した文全体の情報を持っていると考えられており、文の分類タスクなどで使うことができます。

実装

BERTを使う方法を紹介します。PyTorch自体の使い方などはここでは紹介しません。

作業環境

項目 バージョン等
OS mac Mojave
python 3.8.2
pyenv 1.2.26
transformers 4.12.2

必要ライブラリのインストール

pip install transformers[ja]

これで今回必要なライブラリをインストールできます。

ソースコード

#必要ライブラリのインポート
from transformers import BertJapaneseTokenizer, BertModel

#入力する文
INPUT = "我輩は猫である。"

#事前学習済みモデルの選択
pretrained = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = BertJapaneseTokenizer.from_pretrained(pretrained)
bert = BertModel.from_pretrained(pretrained)

#入力文のトークン化
input_token = tokenizer.encode_plus(
    INPUT,
    add_special_tokens = True,
    truncation = True,
    padding = "max_length",
    return_tensors = "pt"
    )
ids = input_token["input_ids"]
mask = input_token["attention_mask"]

#トークンをBERTに通した結果
output = bert(ids, mask)

print(output)

出力↓

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.3607,  0.1658,  0.0195,  ..., -0.1120, -0.3533,  0.0227],
         [ 0.8294,  0.4226,  0.4477,  ..., -0.6203, -0.4340, -0.4595],
         [ 0.2881,  0.4584, -0.1200,  ..., -0.4955,  0.2524, -0.5762],
         ...,
         [-0.0026,  0.1373,  0.2016,  ..., -0.3130, -0.1559, -0.2415],
         [-0.0509,  0.4031,  0.0790,  ..., -0.2865, -0.1257, -0.2302],
         [-0.0594,  0.1091, -0.0197,  ..., -0.2351, -0.1172, -0.2040]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 8.3770e-01,  9.5697e-02, -2.9548e-01, -3.1350e-01,  4.1436e-01,
(略)

こんな感じで出力は辞書型で出てきます。
last_hidden_stateは上で述べた出力トークン全ての情報を、pooler_outputは出力トークンのうちCの情報を持っています。

また、次のように書くことで辞書型ではなく、タプルで返すこともできます。

bert = BertModel.from_pretrained(pretrained)
↓
bert = BertModel.from_pretrained(pretrained, return_dict=False)


output = bert(ids, mask)
↓
last_hidden_state, pooler_output = bert(ids, mask)
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?