1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LIMEを実装しながら理解してみる

Last updated at Posted at 2023-12-20

はじめまして!大学3年生fkiと申します!
普段は大学で機械学習を学んでいます。
この記事では学習した内容の一部を実装を交えながらまとめてみました。
ぜひ最後まで読んでいただけると嬉しいです!

この記事はCA Tech Lounge Advent Calendar 2023の21日目の記事になります。

はじめに

近年、機械学習の社会実装が進み、様々なサービスへの導入がなされています。

しかしながら、機械学習モデルの予測過程はブラックボックスであることが多く、機械学習モデルが普及するに伴ってその透明性や説明可能性の重要性が叫ばれています。

そうした中で、XAIと呼ばれるブラックボックスな機械学習モデルから予測結果以外の追加情報を抽出することで、モデルの説明を可能とする技術が注目されています。

この記事では、XAIの代表的な手法の一つであるLIMEについて、実装しながら学習した内容をまとめています。LIMEの元論文と公式の実装については下を参照ください。


論文

公式の実装

LIMEとは?

LIMEとは、ブラックボックスなモデルを解釈が可能なモデルで部分的に近似し、近似したモデルについて解釈することで、モデルに依存せず局所的な説明を行う技術です。


下の例では、Bostonの住宅販売価格を予測する回帰モデルの一つのインスタンスについてLIMEを適用しています。

このインスタンスにおける予測結果にlstat, rmのような特徴量が寄与していることが読み取れます。

Screen Shot 2023-12-19 at 14.25.12.png


このように、LIMEを用いることで個々のインスタンスごとに重要な特徴量を示すことができます。

つくってみる

ここでは論文と公式の実装を参考に、テーブルデータを対象とした簡単なLIMEの実装をしてみます。

説明するモデルをつくる

はじめにLIMEで説明するブラックボックスなモデルを作成します。

ここでは、サンプルデータをランダムフォレストを用いて二値分類するモデルを作成します。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

# データを生成
X, y = make_moons(noise=0.05, random_state=0, n_samples=200)

# モデルのトレーニング
clf = RandomForestClassifier()
clf.fit(X, y)

# 決定境界の描画
plt.figure(figsize=(5, 4))
X_mesh, y_mesh = np.meshgrid(np.linspace(-1.5, 2.5), np.linspace(-1, 1.5))
Xy_mesh = np.column_stack([X_mesh.ravel(), y_mesh.ravel()])
mesh_pred = clf.predict(Xy_mesh).reshape(X_mesh.shape)
plt.contourf(X_mesh, y_mesh, mesh_pred, alpha=0.5, cmap='coolwarm')
plt.scatter(X[:, 0], X[:, 1], c=y, s=4, cmap='coolwarm')
plt.show()

image.png

データのサンプリングと重みづけ

LIMEでは、データをサンプリングしたのち、説明したいインスタンスの周辺のデータに重みづけを行い、重みづけしたデータをもとに解釈可能なモデルを学習します。

インスタンスの指定

今回は103番目のインスタンスを指定します。

青い星で表示したデータポイントが、指定したインスタンスです。

# インスタンスの指定
instance_index = 103
selected_instance = X[instance_index]

image.png

データのサンプリング

LIMEでは、指定したインスタンスをもとにデータを生成します。

データの形式に応じて生成方法は異なりますが、ここでは数値データの場合について実装します。

数値データにおいては、データを標準化したのち、正規分布に基づいてランダムに特徴量の値を決定します。

生成した値に元のデータのサイズに再度スケールし、指定したインスタンスの値と足し合わせることで、局所的なサンプルを得ます。

生成されたデータを黒い点で示します。

# データの生成用関数の定義
def generate_samples(data_row, num_samples, scaler):
    generated_data = np.random.normal(
        0, 1, num_samples * data_row.shape[0]
    ).reshape(num_samples, data_row.shape[0])
    generated_data = generated_data * scaler.scale_ + data_row
    return generated_data

# データを標準化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 選択したインスタンスのデータ
scaled_selected_instance = X_scaled[instance_index]

# データの生成
num_samples = 100  # 生成するサンプル数
generated_data = generate_samples(
    data_row=scaled_selected_instance,
    num_samples=num_samples,
    scaler=scaler,
)

image.png

重みの計算

ここでは、サンプリングしたデータを重みづけします。

公式の実装ではカーネル平滑化を用いて周辺のデータを重みづけを行っています。

また、その際のカーネル幅はデータの行数の平方根の0.75倍に設定されています。

from sklearn.metrics.pairwise import pairwise_distances

# カーネル関数の定義
def gaussian_kernel(distances, kernel_width):
    return np.exp(-(distances ** 2) / (2 * kernel_width ** 2))

# 各データポイントと選択したインスタンスとの距離を計算
distances = pairwise_distances(
    generated_data,
    selected_instance.reshape(1, -1),
    metric='euclidean'
).ravel()

# カーネル幅を設定
kernel_width = np.sqrt(X.shape[1]) * 0.75

# 重みの計算
weights = gaussian_kernel(distances, kernel_width)

image.png

解釈可能なモデルの学習

最後に、サンプリングしたデータと重みをもとに解釈可能なモデルの学習を行います。

サンプリングしたデータについて元のモデルの予測結果を取得し、それを目的変数として線形回帰を学習します。

今回は2値分類なので、目的変数としてクラス1への所属確率を用いています。

from sklearn.linear_model import LinearRegression

# サンプリングしたデータについてクラス1への所属確率を得る
generated_data_pred = clf.predict_proba(generated_data)[:, 1]

# 線形回帰
local_model = LinearRegression()
local_model.fit(perturbed_data, perturbed_y, sample_weight=weights)

作成した線形回帰モデルについて回帰係数をとることで局所的な説明を得ることができます。

import seaborn as sns

sns.barplot(
    x=local_model.coef_,
    y=['feature_0', 'feature_1'],
    orient='h'
)

image.png

LIMEの課題

ここまで紹介してきたLIMEですが、以下のような課題が挙げられています。

  • 近傍の正しい定義
    • 表形式データにおける近傍の適切な定義がない
    • 閾値によって大きく説明結果が変動するため、説明が意味をなすかどうかを個別に確認する必要がある
  • サンプリングの問題
    • 特徴量間の相関を無視して正規分布からサンプリングされるため、実際には発生しがたいデータポイントが使用される可能性がある
  • 説明の不安定さ
    • 非常に近いデータポイントでも説明が大きく異なる可能性がある

これらの課題を認識した上で適切に利用することが重要であると考えます。

おわりに

拙い文章で読みづらい部分も多々あったかと思いますが、ここまで読んでいただきありがとうございます🙇‍♂️🙇‍♂️

本記事では、XAIの一手法であるLIMEについて、実装を交えつつ、学習した内容をまとめました。

内容に誤りなどあれば、コメント等で教えていただけると嬉しいです🙇‍♂️

参考文献

  1. https://dl.acm.org/doi/abs/10.1145/2939672.2939778
  2. https://github.com/marcotcr/lime
  3. https://hacarus.github.io/interpretable-ml-book-ja/lime.html
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?