60
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

huggingface/transformers (ver 4.5.0)で日本語BERTを動かすサンプルソースコード

Last updated at Posted at 2021-04-08

はじめに

huggingfaceのtransformersを使って、久しぶりに日本語BERTを動かそうと思ったら、昔書いたソースコードでは、あれよあれよとエラーが出るようになってしまっていました。transformersのバージョンを以前のもで指定すれば動くのですが、それってtransformersのバージョンアップについていけてないよねって思うので、現状(2021年4月8日時点)の最新版transformers(ver 4.5.0)で日本語BERTを動かせるソースコードをメモしておきます。

実装

Google Colabで動かします。データはlivedoorニュースコーパスを使って、ニュースコーパスの本文をカテゴリに分類するタスクを考えます。データは各自で用意お願いします。

ライブラリ準備

@polm23 様より有益なコメントをいただきました。日本語BERTを使うときのライブラリの準備は以下を実行するだけでOKです。

!pip install transformers[ja]

現在のtransformersmecabではなくfugashiを使っています。(fugashiMeCabのラッパー)
上を実行するだけで、fugashiも辞書も全部入ります。とても便利になりました。

torch(PyTorch)、torchtexttransformersfugashiipadicのバージョンは以下の通り

!pip list | grep torch
!pip list | grep transformers
!pip list | grep fugashi
!pip list | grep ipadic
#torch                         1.8.1+cu101   
#torchsummary                  1.5.1         
#torchtext                     0.9.1         
#torchvision                   0.9.1+cu101   
#transformers                  4.5.1         
#fugashi                       1.1.0         
#ipadic                        1.0.0           

データ準備


import numpy as np
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext

# データ読み込み
with open('livedoor_data.pickle', 'rb') as r:
    livedoor_data = pickle.load(r)

# データ確認
display(livedoor_data.head())

# 正解ラベル(カテゴリー)をデータセットから取得
categories = list(set(livedoor_data['category']))
print(categories)

# カテゴリーのID辞書を作成
id2cat = dict(zip(list(range(len(categories))), categories))
cat2id = dict(zip(categories, list(range(len(categories)))))
print(id2cat)
print(cat2id)

# DataFrameにカテゴリーID列を追加
livedoor_data['category_id'] = livedoor_data['category'].map(cat2id)

# 念の為シャッフル
livedoor_data = livedoor_data.sample(frac=1).reset_index(drop=True)

# データセットを本文とカテゴリーID列だけにする
livedoor_data = livedoor_data[['body', 'category_id']]
display(livedoor_data.head())

データを学習データとテストデータに分ける


train_df, test_df = train_test_split(livedoor_data, train_size=0.8)
print("学習データサイズ", train_df.shape[0])
print("テストデータサイズ", test_df.shape[0])
#学習データサイズ 5900
#テストデータサイズ 1476

# tsvファイルとして保存する
train_df.to_csv('./train.tsv', sep='\t', index=False, header=None)
test_df.to_csv('./test.tsv', sep='\t', index=False, header=None)

torchtextでDataLoader作成

昔は日本語BERTを使うとき、

from transformers.modeling_bert import BertModel
from transformers.tokenization_bert_japanese import BertJapaneseTokenizer

ってしてたと思いますが、この宣言の仕方だと最新のtransformersでは使えないです。
日本語だけ特別扱い的な呼び出ししてるなーとは思ってましたが、英語のBERTとかと呼び出し方が統一されたって感じですかね。以下のようにシンプルにimportすればOKです。

あと、torchtexttorchtext.datatorchtext.legacy.dataという風に変更されています。(このtorchtextの使い方はもう古いってことかな?)


from transformers import BertModel
# from transformers import BertTokenizer
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

# BERTはサブワードを含めて最大512単語まで扱える
MAX_LENGTH = 512
def bert_tokenizer(text):
    return tokenizer.encode(text, max_length=MAX_LENGTH, truncation=True, return_tensors='pt')[0]

TEXT = torchtext.legacy.data.Field(sequential=True, tokenize=bert_tokenizer, use_vocab=False, lower=False,
                            include_lengths=True, batch_first=True, fix_length=MAX_LENGTH, pad_token=0)
LABEL = torchtext.legacy.data.Field(sequential=False, use_vocab=False)

train_data, test_data = torchtext.legacy.data.TabularDataset.splits(
    path='./', train='train.tsv', test='test.tsv', format='tsv', fields=[('Text', TEXT), ('Label', LABEL)])

# BERTではミニバッチサイズは16か32が推奨される
BATCH_SIZE = 16
train_iter, test_iter = torchtext.legacy.data.Iterator.splits((train_data, test_data), batch_sizes=(BATCH_SIZE, BATCH_SIZE), repeat=False, sort=False)

BertTokenizerとAutoTokenizerの違い

BertTokenizerとAutoTokenizerの違いについては、本記事のコメント欄をご参照ください。
@tomohideshibata 様から有益なコメントをいただいております!
以下ではBertTokenizerAutoTokenizerの挙動の違いを確認しています。

上のソースでtokenizerのimportをBertTokenizerではなくAutoTokenizerにしていますが、現状では(なぜか)BertTokenizerを使うと以下のようにN-Gramのような分割がされてしまうようです。とりあえずはAutoTokenizerを使っておけば問題なさそうです。

参考

from transformers import BertTokenizer
from transformers import AutoTokenizer

a_tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
b_tokenizer = BertTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

text = "人工知能は人間の仕事を奪った。"

ids = a_tokenizer.encode(text, return_tensors='pt')[0]
wakati = a_tokenizer.convert_ids_to_tokens(ids)
print(ids)
print(wakati)

ids = b_tokenizer.encode(text, return_tensors='pt')[0]
wakati = b_tokenizer.convert_ids_to_tokens(ids)
print(ids)
print(wakati)
# tensor([    2,  4969, 12588,     9,  1410,     5,  2198,    11, 11224,    10,
#             8,     3])
# ['[CLS]', '人工', '知能', 'は', '人間', 'の', '仕事', 'を', '奪っ', 'た', '。', '[SEP]']
# tensor([    2,    53,   461,   357,  1329,     9,    53,   284,     5,   757,
#           146,    11,  1847,  4046, 28447,     8,     3])
# ['[CLS]', '人', '工', '知', '能', 'は', '人', '間', 'の', '仕', '事', 'を', '奪', 'っ', '##た', '。', '[SEP]']

モデル定義

昔はAttention weightの取得や全BertLayerの隠れ層を取得するときは順伝播時にoutput_attentions=True, output_hidden_states=Trueを宣言してたかと思いますが、今は学習済みモデルをロードするときに宣言するようになったようです。

さらに、順伝播のoutputの形式も変わってます。現在は辞書形式で返ってくるので、以下のように必要な値はkeyを指定して取得することになります。

(一応、精度が出やすい最終4層の隠れ層を結合する方法でモデル組んでます。)


class BertClassifier(nn.Module):
    def __init__(self):
        super(BertClassifier, self).__init__()
        
        # 日本語学習済モデルをロードする
        # output_attentions=Trueで順伝播のときにattention weightを受け取れるようにする
        # output_hidden_state=Trueで12層のBertLayerの隠れ層を取得する
        self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking',
                                              output_attentions=True,
                                              output_hidden_states=True)
        
        # BERTの隠れ層の次元数は768だが、最終4層分のベクトルを結合したものを扱うので、768×4次元としている。
        self.linear = nn.Linear(768*4, 9)
        
        # 重み初期化処理
        nn.init.normal_(self.linear.weight, std=0.02)
        nn.init.normal_(self.linear.bias, 0)

    # clsトークンのベクトルを取得する用の関数を用意
    def _get_cls_vec(self, vec):
        return vec[:,0,:].view(-1, 768)

    def forward(self, input_ids):
        
        # 順伝播の出力結果は辞書形式なので、必要な値のkeyを指定して取得する
        output = self.bert(input_ids)
        attentions = output['attentions']
        hidden_states = output['hidden_states']
        
        # 最終4層の隠れ層からそれぞれclsトークンのベクトルを取得する
        vec1 = self._get_cls_vec(hidden_states[-1])
        vec2 = self._get_cls_vec(hidden_states[-2])
        vec3 = self._get_cls_vec(hidden_states[-3])
        vec4 = self._get_cls_vec(hidden_states[-4])
        
        # 4つのclsトークンを結合して1つのベクトルにする。
        vec = torch.cat([vec1, vec2, vec3, vec4], dim=1)
        
        # 全結合層でクラス分類用に次元を変換
        out = self.linear(vec)
        
        return F.log_softmax(out, dim=1), attentions

classifier = BertClassifier()

以降は特に昔のソースコードのままでも動きました。

ファインチューニングの設定


# まずは全部OFF
for param in classifier.parameters():
    param.requires_grad = False

# BERTの最終4層分をON
for param in classifier.bert.encoder.layer[-1].parameters():
    param.requires_grad = True

for param in classifier.bert.encoder.layer[-2].parameters():
    param.requires_grad = True

for param in classifier.bert.encoder.layer[-3].parameters():
    param.requires_grad = True

for param in classifier.bert.encoder.layer[-4].parameters():
    param.requires_grad = True

# クラス分類のところもON
for param in classifier.linear.parameters():
    param.requires_grad = True

# 事前学習済の箇所は学習率小さめ、最後の全結合層は大きめにする。
optimizer = optim.Adam([
    {'params': classifier.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
    {'params': classifier.bert.encoder.layer[-2].parameters(), 'lr': 5e-5},
    {'params': classifier.bert.encoder.layer[-3].parameters(), 'lr': 5e-5},
    {'params': classifier.bert.encoder.layer[-4].parameters(), 'lr': 5e-5},
    {'params': classifier.linear.parameters(), 'lr': 1e-4}
])

学習


# 損失関数の設定
loss_function = nn.NLLLoss()

# GPUの設定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# ネットワークをGPUへ送る
classifier.to(device)
losses = []

# エポック数は5で
for epoch in range(5):
    
    all_loss = 0
    
    for idx, batch in enumerate(train_iter):
        
        classifier.zero_grad()
        
        input_ids = batch.Text[0].to(device)
        label_ids = batch.Label.to(device)
        
        out, _ = classifier(input_ids)
        
        batch_loss = loss_function(out, label_ids)
        batch_loss.backward()
        
        optimizer.step()
        
        all_loss += batch_loss.item()

    print("epoch", epoch, "\t" , "loss", all_loss)

# epoch 0 	 loss 173.34355276590213
# epoch 1 	 loss 41.68860982079059
# epoch 2 	 loss 9.25834927111282
# epoch 3 	 loss 4.206327178646461
# epoch 4 	 loss 2.3201748140127165

推論


answer = []
prediction = []

with torch.no_grad():
    for batch in test_iter:

        text_tensor = batch.Text[0].to(device)
        label_tensor = batch.Label.to(device)

        score, _ = classifier(text_tensor)
        _, pred = torch.max(score, 1)

        prediction += list(pred.cpu().numpy())
        answer += list(label_tensor.cpu().numpy())

print(classification_report(prediction, answer, target_names=categories))

#                 precision    recall  f1-score   support

#   it-life-hack       0.95      0.94      0.95       184
#         peachy       0.93      0.91      0.92       164
# livedoor-homme       0.79      0.89      0.84        90
#    movie-enter       0.99      0.94      0.96       196
#     topic-news       0.95      0.98      0.97       150
# dokujo-tsushin       0.92      0.94      0.93       168
#   sports-watch       0.99      0.97      0.98       188
#           smax       0.96      0.96      0.96       161
#  kaden-channel       0.95      0.94      0.95       175

#       accuracy                           0.94      1476
#      macro avg       0.94      0.94      0.94      1476
#   weighted avg       0.95      0.94      0.94      1476
#                 precision    recall  f1-score   support


# ちなみに以下の精度はAutoTokenizerではなく、BertTokenizerからTokenizerをインポートした場合です。
# N-Gramのような分割だと、全体の精度はやや下がることが伺えます。
#   sports-watch       0.98      0.97      0.97       173
#    movie-enter       0.96      0.92      0.94       175
# dokujo-tsushin       0.88      0.93      0.90       161
#     topic-news       0.94      0.97      0.96       158
#   it-life-hack       0.92      0.89      0.90       185
#           smax       0.97      0.98      0.97       177
#  kaden-channel       0.91      0.88      0.90       181
#         peachy       0.88      0.84      0.86       166
# livedoor-homme       0.68      0.75      0.71       100

#       accuracy                           0.91      1476
#      macro avg       0.90      0.90      0.90      1476
#   weighted avg       0.91      0.91      0.91      1476

Attentionの可視化

def highlight(word, attn):
    html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
    return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(index, batch, preds, attention_weight):
    sentence = batch.Text[0][index]
    label =batch.Label[index].item()
    pred = preds[index].item()

    label_str = id2cat[label]
    pred_str = id2cat[pred]

    html = "正解カテゴリ: {}<br>予測カテゴリ: {}<br>".format(label_str, pred_str)

    # 文章の長さ分のzero tensorを宣言
    seq_len = attention_weight.size()[2]
    all_attens = torch.zeros(seq_len).to(device)

    for i in range(12):
        all_attens += attention_weight[index, i, 0, :]

    for word, attn in zip(sentence, all_attens):
        if tokenizer.convert_ids_to_tokens([word.tolist()])[0] == "[SEP]":
            break
        html += highlight(tokenizer.convert_ids_to_tokens([word.numpy().tolist()])[0], attn) + " "
    html += "<br><br>"
    return html

batch = next(iter(test_iter))
score, attentions = classifier(batch.Text[0].to(device))
_, pred = torch.max(score, 1)

from IPython.display import display, HTML
for i in range(BATCH_SIZE):
    html_output = mk_html(i, batch, pred, attentions[-1])
    display(HTML(html_output))

おわりに

上記の書き方はいつまで使えるかわかりません。
ライブラリのバージョンアップについていけるように精進します。

おわり

60
39
11

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
60
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?