機械学習を行う際に、精度を上げるためにモデルを複雑にしていきます。
例えば、Support Vecter Machine → Deep Learing のように。
しかし、複雑なモデルを用いると機械学習が出した答えはどのような特徴に基づいて判断しているのかを、人間が理解することや説明することが難しくなってしまいます。
そして、このようなブラックボックス状態では、利用しにくい場合があり、機械学習に取り組む場合の根本的な問題としてありました。
一般的に学習器がどのように学習しているのかを調べる方法としては次のようなものが挙げられます。
・生成したモデルのパラメータを見る
・より解釈しやすいモデル(ランダムフォレストとか)で学習する
・任意の入力(特徴量)を与えて、出力結果の変化を見ていく
今回は、これら一般的な方法よりもより複雑なモデルにおいても、学習状態をより細かく分析することができるアルゴリズムであるLIMEについてまとめました。
LIME
LIME(Local Interoretable Model-agnostic Explainmations)は、KDD2016で採択された論文において発表されたアルゴリズムです。
簡単に説明すると、複雑なモデルをいくつかの解釈性が高い複数の簡単なモデルに置き換えていくというものです
赤の+データと青の●データを判別するDNNのモデルがあった場合、その判別空間は図のような複雑なものとなっているとします。
この判別空間の一部部分を判別する点線のような線形回帰モデルを生成することで、部分出来にDNNモデルの判別根拠の説明することに繋があります。
あるデータを分類した結果それぞれの特徴がその程度分類に貢献しているのかを調べることで分類器の予測結果を説明します。また、分類器の予測結果を用いているので、任意の分類器に適用できます。
説明用の分類器gが、データxの付近で複雑なモデルfの結果を近似するように学習するとする。
次のような目的関数に従い、学習を行うことになる。
\DeclareMathOperator*{\argmin}{arg\,min}
\xi(x) = \argmin_{g \in G} L(f, g, \pi_x) + \Omega(g)
・$G$:解釈可能なモデルの集合
・$g$:説明用の分類器
・$f$:説明したい複雑な分類器
・$π_x$:データxとの距離
・$L(*)$:損失関数 データxの周辺でfとgの結果がどの程度異なっているか
・$Ω(*)$:gの複雑さ
より複雑性が低いモデルでデータx周りでfとgの予測結果の差が最小になるようにする。
Ω(g)は、gが用いる特徴量の最大数Kを用いる。
gにlassoモデルを利用することで、K個の特徴量の選択を行うことができる。
損失関数Lは次のように定義する。
L(f, g, \pi_x) = \sum_{z,z' \in Z } \pi_x (z) (f(z) - g(z'))^2
・$Z$:$x$周りのデータ集合
・$z$:サンプリングにより生成された2値{0,1}のスパースな集合
データx周りで説明したい複雑な分類器と説明したい説明用の分類器の重み付きの差の合計である。
LIMEを動かしてみる
LIMEは現在、pythonとRでパッケージが提供されています。
LIMEのGithubのページにサンプルがあり、それを参考にしました。
用いたデータは、みんな大好きアヤメのデータ。
説明したい分類器は、サンプルだとランダムフォレストですがXGBoostにしてみました。
わかってはいますが、irisデータをプロットして確認。
library(ggplot2)
# plot iris data
gp = ggplot(iris, aes(x=Sepal.Width, y=Sepal.Length, colour=Species))
gp = gp + geom_point(size=3, alpha=0.7)
print(gp)
setosaの判別は簡単そうだが、varginicaとvargicolorの判別は難しそう。
そして、limeを実行したクソコードです。
library(caret)
library(lime)
# Split up the data set
sample_point <- sample(c(1:150),6)
iris_test <- iris[sample_point, 1:4]
iris_train <- iris[-sample_point,]
iris_lab <- iris[sample_point,5]
# Create XGBoost model on iris data
model <- train(
Species ~ .,
data = iris_train,
method = "xgbTree",
preProcess = c('center', 'scale'),
trControl = trainControl(method = "cv")
)
pred_iris <- predict(model,iris_test)
# Create an explainer object
explainer <- lime(iris_train, model)
# Explain new observation
explanation <- explain(iris_test, explainer, n_labels = 1, n_features = 2)
plot_features(explanation)
正解だったは93,63,100,26,52は、その根拠として花びらの長さが学習したある判断基準を満していることが重要であったということを示しています。
不正解だったcase134は、花びらの幅が基準と満していなかったことが影響していること示されています。
しかし、元のデータをどれほど説明しているのか(Explanation Fit)が低いので、あまり適切な説明モデルではないようです。
画像でも同様の説明力を持つモデルの生成を行うことができ、具体的に与えられた画像の判断根拠となる部分を示してくれます。
Rのパッケージには結果サンプルが付属しています。
利用しているモデルはコードからするとVGG16?
explanation <- .load_image_example()
plot_image_explanation(explanation)
左は、モデルが35%でイチゴだと判断した場合の判断根拠となる部分を黄緑の領域で囲っています。
確かに、イチゴが写っている部分を指摘しています。
真ん中は、モデルが16%でろうそくかタッパーかワックスの光沢だと判断したことについてです。
明らかに間違った答えですが、トマトとりんごの側面を判断根拠としたとしており、まぁ間違っても仕方なかもと感じます。
テキスト分類のサンプルもありましたので、動かしてみました。
30種類の論文から、著者が作成した文章かどうかを判断します。
モデルはXGBoostを利用します。
library(lime)
library(xgboost) # the classifier
library(text2vec) # used to build the BoW matrix
# load data
data(train_sentences)
data(test_sentences)
# Data are stored in a 2 columns data.frame, one for the sentences, one for the
# labels.
print(str(train_sentences))
get_matrix <- function(text) {
it <- itoken(text, progressbar = FALSE)
create_dtm(it, vectorizer = hash_vectorizer())
}
# BoW matrix generation
dtm_train = get_matrix(train_sentences$text)
dtm_test = get_matrix(test_sentences$text)
# Create boosting model for binary classification (-> logistic loss)
# Other parameters are quite standard
param <- list(max_depth = 7,
eta = 0.1,
objective = "binary:logistic",
eval_metric = "error",
nthread = 1)
xgb_model <- xgb.train(
param,
xgb.DMatrix(dtm_train, label = train_sentences$class.text == "OWNX"),
nrounds = 50
)
# We use a (standard) threshold of 0.5
predictions <- predict(xgb_model, dtm_test) > 0.5
test_labels <- test_sentences$class.text == "OWNX"
# Accuracy
print(mean(predictions == test_labels))
sentence_to_explain <- head(test_sentences[test_labels,]$text, 5)
explainer <- lime(sentence_to_explain, model = xgb_model,
preprocess = get_matrix)
explanation <- explain(sentence_to_explain, explainer, n_labels = 1,
n_features = 2)
# Another more graphical view of the same information (2 first sentences only)
plot_features(explanation)
plot_text_explanations(explanation)
文章自体に示す方法もあります。
plot_text_explanations(explanation)
簡単にまとめ
・複雑なモデルを複数のより簡単なモデルで表現するアルゴリズム
・様々なモデル、様々なデータに対して高い説明性を示す
・個人的には、業務等で色々と活用できそうです
参考
・https://arxiv.org/abs/1602.04938
・https://qiita.com/fufufukakaka/items/d0081cd38251d22ffebf
・https://github.com/thomasp85/lime
・https://cran.r-project.org/web/packages/lime/vignettes/Understanding_lime.html