Help us understand the problem. What is going on with this article?

分類器毎のdecision_function、predict_probaの保持状況を調べてみた

はじめに

分類問題でROC曲線を書いたり、ROC-AUCスコアを計算する場合に、しきい値が必要となり、scikit-learnでは分類器の decision_function()やpredict_proba()メソッドが利用される。一般にscikit-learnの分類器はこれらのメソッドのどちらかを保持しているという記載を本でみたため、実際に調べてみた。

環境

  • scikit-learn 0.21.2

調査した分類器

調べたものは思いついた以下のもの。XGBClassifier、LGBMClassifier、CatBoostClassifier はscikit-learnにもともと用意されている分類器ではないが、scikit-learnのインターフェースが用意されているため、合わせて調査した。

  • LogisticRegression
  • SGDClassifier SVC
  • RandomForestClassifier
  • ExtraTreesClassifier
  • AdaBoostClassifier
  • XGBClassifier
  • LGBMClassifier
  • CatBoostClassifier
  • MLPClassifier

コード

ClassificaitonScoreFunctionCheck.py
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, AdaBoostClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
from sklearn.neural_network import MLPClassifier


def main():

    estimators = [LogisticRegression, SGDClassifier,SVC,RandomForestClassifier,
                  ExtraTreesClassifier, AdaBoostClassifier, XGBClassifier,
                  LGBMClassifier,CatBoostClassifier,MLPClassifier]

    for estimator in estimators:
        print("{0},{1},{2},{3}".format(
            estimator,
            hasattr(estimator, "predict"),
            hasattr(estimator, "predict_proba"),
            hasattr(estimator, "decision_function")
        ))

if __name__ == "__main__":
    main()

結果

分類器 predict_proba() decision_function()
LogisticRegression True True
SGDClassifier True True
SVC True True
RandomForestClassifier True False
ExtraTreesClassifier True False
AdaBoostClassifier True True
XGBClassifier True False
LGBMClassifier True False
CatBoostClassifier True False
MLPClassifier True False

考察

predict_proba()メソッドは全ての分類器が保持していたが、decision_functionについては、保持しているものが4つ、保持していないものが5つという結果になった。
この結果からROC-AUCスコアを計算する場合のしきい値にどちらを使うかについて、以下の方針が考えられる。

  • decision_function()がないものについては、predict_proba()メソッドを使う
  • decision_function()とpredict_proba()の両方があるものについては、各分類器で推奨されている、またはよく使われているメソッドを使う。
Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away