4
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

LIMEを使ったBERTの判断根拠の可視化

Last updated at Posted at 2022-10-02

ニューラルネットワークの推論結果の説明にLIMEを使ってみたので、実装内容を紹介します。

LIMEはニューラルネットワークや勾配ブースティング木のような複雑なモデルによる1つ1つの推論結果を、(回帰モデルのような)解釈性の高いモデルを使って説明しようとする手法のようです。

詳しくは論文を読まれるとよいかと思います。(私も読み解いているところです)
https://www.kdd.org/kdd2016/papers/files/rfp0573-ribeiroA.pdf

本記事では、以下の手順でLIMEによる判断根拠の可視化をおこないます。

  • データセットの用意(ライブドアのニュースコーパス)
  • BERTによる文章分類のモデルを用意
  • モデルによる推論を行う
  • 推論結果の説明をLIMEで実行する

コードの全量はGitHubにもあります。
https://github.com/nukano0522/pytorch/blob/master/livedoor_news_cls/bert_eval_attention%2Blime.ipynb

実行環境

Google Colab (GPU)
Python 3.7.13
torch==1.12.1+cu113
transformers==4.21.2
lime==0.2.0.1

ライブラリインストール

!pip install transformers==4.5.0 fugashi==1.1.0 ipadic==1.0.0 lime==0.2.0.1

データセットとモデルの用意

今回はLIMEに主眼を置くため、データ前処理とモデル学習は省略します。
以下リポジトリに加工済みのデータセットと、ファインチューニングによる学習済みモデルを用意したのでクローンします。
https://github.com/nukano0522/livedoor_data_and_model

モデルのサイズが大きく、そのままではダウンロードできないためLFSを使ってpullします。

# リポジトリをクローン
!git clone https://github.com/nukano0522/livedoor_data_and_model.git

# モデルのサイズが大きい(100MB以上)ため、LFSを使ってダウンロード
%cd /content/livedoor_data_and_model/model
!git lfs install
!git lfs pull

# モデルのサイズを確認(450MBほどあります)
!ls -lh

ライブラリのインポートとデータ取得

必要なライブラリをインポートします。

import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, BertJapaneseTokenizer, BertModel
from torch import cuda
import sklearn.metrics as skm
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from transformers import logging
from tqdm import tqdm

# エンコーディング時の文章の最大長
max_len = 512
# BERTの学習済みモデル
MODEL_NAME = "cl-tohoku/bert-base-japanese-whole-word-masking"

先ほどクローンしてきたデータを読み込みます。

df = pd.read_csv("/content/livedoor_data_and_model/data/livedoor_text.csv")
print(df.shape)
df.head()

image.png

データセットを作成する

後続の処理でデータを取り出しやすい形にしておきたいため、データセットを作成します

class CreateDataset(Dataset):
  def __init__(self, X, y, tokenizer, max_len):
    self.X = X
    self.y = y
    self.tokenizer = tokenizer
    self.max_len = max_len

  def __len__(self):
    return len(self.y)

  def encode(self, tokenizer, text):
      inputs = tokenizer.encode_plus(
          text,
          add_special_tokens=True,
          max_length=self.max_len,
          padding = 'max_length',
          truncation = True
      )
      return inputs

  def __getitem__(self, index):
    text = self.X[index]
    label = self.y[index]
    ids = []
    mask = []
    inputs = self.encode(tokenizer=self.tokenizer, text=text)
    ids.append(torch.LongTensor(inputs['input_ids']))
    mask.append(torch.LongTensor(inputs['attention_mask']))

    return {
      'ids': ids,
      'mask': mask,
      'label': label,
      'text':text
    }

データを訓練、検証、テスト用に分割します。
※実際にLIMEで使うのはテスト用のデータです。訓練、検証は成り行きで作ってしまっていますが、後続で使いません

X = df["text"].values
y = df["category"].values

# データを訓練、検証、テストに分割
X_train_eval, X_test, y_train_eval, y_test = train_test_split(X, y, train_size=0.8)
X_train, X_eval, y_train, y_eval = train_test_split(X_train_eval, y_train_eval, train_size=0.75)

# トークナイザの定義
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)

# データセットの作成
dataset_train = CreateDataset(X_train, y_train, tokenizer, max_len=max_len)
dataset_eval = CreateDataset(X_eval, y_eval, tokenizer, max_len=max_len)
dataset_test = CreateDataset(X_test, y_test, tokenizer, max_len=max_len)

print(dataset_train.__len__())
print(dataset_eval.__len__())
print(dataset_test.__len__())

BERTモデルを定義し、事前に学習したモデルをロード

# BERTの学習済みモデル
bert_model = BertModel.from_pretrained(MODEL_NAME, output_attentions=True, output_hidden_states=True)

class MyBertClassification(nn.Module):
    """BERTによるライブドアニュースの分類モデル
    """

    def __init__(self):
        super(MyBertClassification, self).__init__()

        # BERTモジュール
        self.bert = bert_model  # 日本語学習済みのBERTモデル

        # 全結合層
        self.cls = nn.Linear(in_features=768, out_features=9)

        # 重み初期化
        nn.init.normal_(self.cls.weight, std=0.02)
        nn.init.normal_(self.cls.bias, 0)


    def forward(self, input_ids, attention_show_flg:bool=False):
        """順伝搬
        Args:
            input_ids: [batch_size, max_len] の文章のID列
            attention_show_flg: Trueの場合、Attentionを出力する
        """

        # BERT順伝搬
        result = self.bert(input_ids)

        # BERT出力の最終層
        sequence_output = result[0]
        # 先頭単語の全768要素
        vec_0 = sequence_output[:, 0, :]
        vec_0 = vec_0.view(-1, 768)  # sizeを[batch_size, hidden_size]に変換
        output = self.cls(vec_0)  # 全結合層

        if attention_show_flg:
            # Attentionの最終層を返す
            return output, result.attentions[-1]
        else:
            return output

クローンしたモデルをロードして、GPUに送ります。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# モデルのロード
model = MyBertClassification()
model.load_state_dict(torch.load("/content/livedoor_data_and_model/model/model_512.pth"))
model.to(device)
print("complete")

LIMEによる判断根拠の可視化

前提となる作業が完了しました。
ここからLIMEをつかった判断根拠の可視化を行います。

マルチラベル分類にLIMEを適用するチュートリアルのコードを参考にしています。
https://marcotcr.github.io/lime/tutorials/Lime%20-%20multiclass.html

limeをインポートします。

from typing import List
import torch.nn.functional as F
import lime
from lime.lime_text import LimeTextExplainer

# ニュース分類
class_names = [
    'dokujo-tsushin',
    'it-life-hack',
    'smax',
    'sports-watch',
    'kaden-channel',
    'movie-enter',
    'topic-news',
    'livedoor-homme',
    'peachy'
]

LIMEを扱ううえで大切なポイントとして、
LIMEのexplainerに渡す推論結果は、各クラスの確率である必要があるということです。(今回の場合は、9クラスのそれぞれに分類される確率)

通常の文章分類ですと最終的なラベルを推論結果として扱うことが基本だと思いますが、
ここでは確率を出力するためのpredictorを定義します。

また、predictorに渡すtextの形式は単体の文字列だと後続処理がうまくいかず、リスト形式で渡す必要があるみたいです。

def predictor(texts:List, max_length:int = max_len) -> np:
    """LIMEに渡す用に、推論結果(確率)を計算する
    Args:
        texts: 文章のリスト
        max_length: エンコーディング時の文章の最大長
    Returns:
        推論結果(確率)のリスト
    """
    # 文章をID化する
    encoding = tokenizer.batch_encode_plus(
                texts, 
                add_special_tokens=True,
                padding="max_length", 
                max_length=max_length,
                truncation = True)
    
    input_ids = torch.tensor(encoding['input_ids']).to(device)
    # print(input_ids.size())

    # 学習済みモデルによる推論
    with torch.no_grad():
        output = model(input_ids, attention_show_flg=False)
    
    # 推論結果をSoftmax関数を通して確率表現にする
    probas = F.softmax(output, dim=1).cpu().detach().numpy()

    return probas

実際に、テストデータのある文章に対して、predictorを通した結果を見てみます。

# 評価対象の文章
idx = 300

# テストデータセットからテキストと正解ラベルを取得
text = dataset_test[idx]["text"]
label = dataset_test[idx]["label"]

# LIME入力用
texts = []
texts.append(text)

output = predictor(texts)
print(output)

print文の結果。各クラスの分類確率が出力されます。

[[6.5388078e-05 1.6261819e-06 2.4259509e-05 2.1345629e-06 4.3495089e-05
  8.9905825e-06 1.2923315e-06 2.1587400e-06 9.9985063e-01]]

予測結果と正解のラベルを確認します。

print("予測", class_names[np.argmax(output)])
print("正解", class_names[label])
予測 livedoor-homme
正解 livedoor-homme

LIMEのexplainerにクラス名、テキスト、predictorなどを渡して、判断根拠を可視化します。
top_labelsのパラメータで、分類確率が高いクラスの順に根拠を可視化してくれます。(今回はTOP2のクラスを指定)

explainer = LimeTextExplainer(class_names=class_names)

# 予測確率が高いTOP-K
exp = explainer.explain_instance(text, predictor, num_features=10, num_samples=70, top_labels=2)

exp.show_in_notebook(text=text)

先ほど見たように、livedoor-hommeが予測結果として出ており、その根拠となるテキスト情報が出力されています。
image.png
背景色が濃くなっているテキストほど寄与度が高いテキストといえるようです。
今回の文章でいえば、野球にかかわるフレーズ中心にピックアップされているように見えます。
image.png

コードは省略しますが、同様の文章に対してAttentionの可視化も行いました。その結果が下記です。
image.png

まとめと考察

いろいろな文章で試した結果、Attentionは色がつく箇所が多くどこが影響しているのかイマイチわかりづらいなと感じました。
一方、LIMEは根拠となる部分が一定絞られており、また各フレーズの共通項も見出しやすいため、推論結果を解釈したり第3者に説明することを考えると扱いやすいのでは、と思いました。

ぜひ、様々な文章で試して結果を見ていただけるとよいかと思います。

なにか間違いなどご指摘ありましたら、よろしくお願いします。

参考資料

4
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?