0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【sklearn】全てのモデルで性能を比較したい

Posted at

はじめに

最近はもっぱら深層学習を使うことが多いですが,古典的な機械学習モデル1を使うこともしばしば.かく言う私も,最近,こうした機械学習モデルを用いた分析をする必要があり,久々にsklearnに戻ってきました.大体の場合,Random ForestやGradient Boostingを試してみて終わるのですが,今回は網羅的に探索したいという要望がありました.sklearnは細かいところに手が届くライブラリというイメージだったので自分で実装する前に調査を進めたところ,そうした要望を満たしてくれる機能がやはり用意されていました.

本記事は,タイトルの通りsklearnで用意されているモデルを網羅的に探索してくれる,そんな機能の紹介です.

前提

今回は実際のコードを公開する訳にもいかないので,よく用いられるIris Dataset2を用いて実装を紹介します.ただ,本データセットやsklearnのdatasetの使い方については本題では無いので,深掘りはしません.

また,一般的なモデルのパラメータの調整方法や学習方法については理解できているものとします.

全体コード

特に説明は要らないのでコードを見て理解したいという方のために,早速,コードを示したいと思います.詳細な説明は追って行います.

sample.py
import polars as pl

from sklearn import datasets
from sklearn.utils import all_estimators
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import japanize_matplotlib
import seaborn as sns


data = datasets.load_iris()

inputs = data['data']
outputs = data['target']


estimator_classes = all_estimators(type_filter='classifier')

IGNORE_MODELS = [
    'ClassifierChain',
    'FixedThresholdClassifier',
    'MultiOutputClassifier',
    'OneVsOneClassifier',
    'OneVsRestClassifier',
    'OutputCodeClassifier',
    'StackingClassifier',
    'TunedThresholdClassifierCV',
    'VotingClassifier',
    'SelfTrainingClassifier',
]

estimators = {}
for model_name, model_class in estimator_classes:
    if model_name in IGNORE_MODELS:
        continue
    
    clf = model_class()
    
    if 'random_state' in clf.get_params():
        clf.set_params(random_state=42)
        
    if 'n_jobs' in clf.get_params():
        clf.set_params(n_jobs=-1)
    
    estimators[model_name] = clf


scores = {}
for model_name, model in estimators.items():
    model.fit(inputs, outputs)
    pred = model.predict(inputs)

    scores[model_name] = accuracy_score(outputs, pred)


df_results = pl.DataFrame({
    'model': list(scores.keys()),
    'accuracy': list(scores.values()),
})
df_results = df_results.sort(
    by='accuracy', 
    descending=True,
)
print(df_results)


plt.figure(figsize=(16, 9))
plt.rcParams['font.size'] = 18

sns.barplot(
    data=df_results, 
    x='accuracy', 
    y='model'
)

plt.tight_layout()
plt.savefig('results.png')

上記のコードを実行して得られるグラフは以下の通りです.

results.png

詳解

ここから少し詳しく見ていきます.モデルの評価方法やグラフの作成方法などの本流とは異なる部分については省略します.

ライブラリのインポート

利用するライブラリをインポートしていきます.今回のメインとなるメソッドがall_estimators3です.

import libraries
import polars as pl

from sklearn import datasets
from sklearn.utils import all_estimators
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import japanize_matplotlib
import seaborn as sns

モデルの準備

次に,モデルを準備していきます.
all_estimatorsの引数type_filterにタスクの種類を指定します.今回はIris Datasetなのでclassifierです.

type_filter 説明
classifier 分類タスク.今回の問題など.
regressor 回帰タスク.株価予測など.
cluster クラスタリング.異常検知など.
transformer データ加工タスク.特徴量抽出など.

なお,all_estimatorsはモデルクラスを渡すため,インスタンス化することが必要であるということに注意してください.

prepare classifiers
estimator_classes = all_estimators(type_filter='classifier')

# 今回のタスクに合致しないモデルを除外
IGNORE_MODELS = [
    'ClassifierChain',
    'FixedThresholdClassifier',
    'MultiOutputClassifier',
    'OneVsOneClassifier',
    'OneVsRestClassifier',
    'OutputCodeClassifier',
    'StackingClassifier',
    'TunedThresholdClassifierCV',
    'VotingClassifier',
    'SelfTrainingClassifier',
]

estimators = {}
for model_name, model_class in estimator_classes:
    if model_name in IGNORE_MODELS:
        continue
    
    clf = model_class()

    # 再現性が必要であれば導入
    if 'random_state' in clf.get_params():
        clf.set_params(random_state=42)

    # 並列処理が必要であれば導入
    if 'n_jobs' in clf.get_params():
        clf.set_params(n_jobs=-1)
    
    estimators[model_name] = clf

最後に

全てのモデルを網羅的に探索することは少ないかなと思いますが,そうした需要に対しての用意があるということを知ることができたので良かったかなと思います.
以降のどなたかの参考になれば嬉しいです

  1. 日本語訳が合っているか分かりませんが,論文中ではTraditional Machine Learningとか呼ばれているやつです

  2. https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html#sklearn.datasets.load_iris

  3. https://scikit-learn.org/stable/modules/generated/sklearn.utils.discovery.all_estimators.html

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?