LoginSignup
2
3

More than 3 years have passed since last update.

分類器毎のdecision_function、predict_probaの保持状況を調べてみた

Last updated at Posted at 2019-08-24

はじめに

分類問題でROC曲線を書いたり、ROC-AUCスコアを計算する場合に、しきい値が必要となり、scikit-learnでは分類器の decision_function()やpredict_proba()メソッドが利用される。一般にscikit-learnの分類器はこれらのメソッドのどちらかを保持しているという記載を本でみたため、実際に調べてみた。

環境

  • scikit-learn 0.21.2

調査した分類器

調べたものは思いついた以下のもの。XGBClassifier、LGBMClassifier、CatBoostClassifier はscikit-learnにもともと用意されている分類器ではないが、scikit-learnのインターフェースが用意されているため、合わせて調査した。

  • LogisticRegression
  • SGDClassifier SVC
  • RandomForestClassifier
  • ExtraTreesClassifier
  • AdaBoostClassifier
  • XGBClassifier
  • LGBMClassifier
  • CatBoostClassifier
  • MLPClassifier

コード

ClassificaitonScoreFunctionCheck.py
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, AdaBoostClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
from sklearn.neural_network import MLPClassifier


def main():

    estimators = [LogisticRegression, SGDClassifier,SVC,RandomForestClassifier,
                  ExtraTreesClassifier, AdaBoostClassifier, XGBClassifier,
                  LGBMClassifier,CatBoostClassifier,MLPClassifier]

    for estimator in estimators:
        print("{0},{1},{2},{3}".format(
            estimator,
            hasattr(estimator, "predict"),
            hasattr(estimator, "predict_proba"),
            hasattr(estimator, "decision_function")
        ))

if __name__ == "__main__":
    main()

結果

分類器 predict_proba() decision_function()
LogisticRegression True True
SGDClassifier True True
SVC True True
RandomForestClassifier True False
ExtraTreesClassifier True False
AdaBoostClassifier True True
XGBClassifier True False
LGBMClassifier True False
CatBoostClassifier True False
MLPClassifier True False

考察

predict_proba()メソッドは全ての分類器が保持していたが、decision_functionについては、保持しているものが4つ、保持していないものが5つという結果になった。
この結果からROC-AUCスコアを計算する場合のしきい値にどちらを使うかについて、以下の方針が考えられる。

  • decision_function()がないものについては、predict_proba()メソッドを使う
  • decision_function()とpredict_proba()の両方があるものについては、各分類器で推奨されている、またはよく使われているメソッドを使う。
2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3