LoginSignup
2
3

More than 1 year has passed since last update.

PyCaretのドキュメント解読: Classification compare_models()関数

Last updated at Posted at 2023-02-03

PyCaret使いこなしのために公式ドキュメントの解読を決意して前回はClassificationのsetup()関数のドキュメントを解読。続いてはClassificationのcompare_models()関数のドキュメント。

compare_models()のオプション概要
pycaret.classification.compare_models(
  include: Optional[List[Union[str, Any]]] = None,
  exclude: Optional[List[str]] = None,
  fold: Optional[Union[int, Any]] = None,
  round: int = 4,
  cross_validation: bool = True,
  sort: str = 'Accuracy',
  n_select: int = 1,
  budget_time: Optional[float] = None,
  turbo: bool = True,
  errors: str = 'ignore',
  fit_kwargs: Optional[dict] = None,
  groups: Optional[Union[str, Any]] = None,
  experiment_custom_tags: Optional[Dict[str, Any]] = None,
  probability_threshold: Optional[float] = None,
  engine: Optional[Dict[str, str]] = None,
  verbose: bool = True,
  parallel: Optional[ParallelBackend] = None
)  Union[Any, List[Any]]

関数の説明

この関数は交差検証を用いてモデルライブラリ中の全ての予測器を学習、評価します。この関数の出力は交差検証スコアの平均を含むスコアグリッドです。交差検証中に評価されたメトリクスはget_metrics()関数を用いてアクセス可能です。add_metrics()関数及びremove_metrics()関数を用いてカスタムメトリクスを追加及び削除することが可能です。

注意

  • turboパラメータをFalseに設定した場合、10,000行を超えるデータセットを用いた学習では膨大な時間を要する可能性が有ります。
  • predict_proba非対応の予測器のAUCは0.0000と表示されます。
  • cross_validation = Falseの場合、いかなるモデルもMLFlowに記録されません。

Returns

n_selectパラメータに従い、学習済みモデルまたは学習済みモデルのリストを返します。

実行例

from pycaret.datasets import get_data
from pycaret.classification import *

juice = get_data('juice')
exp_name = setup(data = juice,  target = 'Purchase')
best_model = compare_models()

オプションパラメータ

include

  • 形式: list[str], list[scikit-learn互換オブジェクト]
  • デフォルト値: None
  • 説明: 選択したモデルをトレーニングおよび評価するには、モデルIDまたはscikit-learn互換オブジェクトを含むリストをincludeパラメータで渡すことができます。 モデルライブラリで利用可能なすべてのモデルのリストを表示するには、models()関数を使用します。

exclude

  • 形式: list[str]
  • デフォルト値: None
  • 説明: 特定のモデルを学習と評価から除外するには、excludeパラメーターにモデルIDを含むリストを渡します。 モデルライブラリで利用可能なすべてのモデルのリストを表示するには、models()関数を使用します。

fold

  • 形式: int , scikit-learn互換のCV生成器
  • デフォルト値: None
  • 説明: 交差検証を制御するパラメータです。もしNoneを指定した場合、setup()関数のfold_strategyパラメータで指定したCV生成器を使用します。もし整数値が渡された場合、setup()関数におけるCV生成器のn_splitsパラメータと解釈されます。

round

  • 形式: int
  • デフォルト値: 4
  • 説明: スコアグリッド内のメトリックが丸められる小数点以下の桁数を指定します。

cross_validation

  • 形式: bool
  • デフォルト値: True
  • 説明: Falseを設定した場合、メトリクスはホールドアウトセット(setup()関数のtest_dataパラメータで指定したデータセット)で評価されます。cross_validation = Falseの場合、foldパラメータは無視されます

sort

  • 形式: str
  • デフォルト値: Accuracy
  • 説明: スコアグリッドの並べ替え順を設定します。add_metric()関数で追加したカスタムメトリクスに対しても適用します。

n_select

  • 形式: int
  • デフォルト値: 1
  • 説明: (性能が)top_nのモデルを返します。例えば、上位3モデルを返したい場合、n_select = 3と設定します。

budget_time

  • 形式: int, float
  • デフォルト値: None
  • 説明: もし、設定値がNoneでは無い場合、budget_time分が経過した後に関数の実行を終了し、その時点までの結果を返します。

turbo

  • 形式: bool
  • デフォルト値: True
  • 説明: Trueを設定した場合、学習時間が長い推定器は除外されます。 除外されるアルゴリズムを確認するには、models()関数を使用します。

errors

  • 形式: str
  • デフォルト値: ignore
  • 説明: ignoreを設定した場合、例外またはcontinueが発生したモデルの逸れ以降の実行をスキップします。もしraiseを設定した場合、例外が発生した場合は、関数をエラー終了します。

fit_kwargs

  • 形式: dict
  • デフォルト値: {}(空のdict)
  • 説明: モデルのfit()関数に渡す(キーワード)引数の辞書を設定します。

groups

  • 形式: str, (n_samples,)の行列形式
  • デフォルト値: None
  • 説明: 交差検証にGroupKFoldを使用する場合のオプションのグループラベルを設定します。 (n_samples,)の行列形式で渡します。n_samplesは学習データの行数です。 文字列が渡されると、グループラベルの値と等しいデータセット内の列名として解釈されます。

experiment_custom_tags

  • 形式: dict
  • デフォルト値: None
  • 説明: tag_nameのディクショナリ: 文字列 -> 値: (文字列ですでない場合は文字列化されます) mlflow.set_tagsに渡され、実験用の新しいカスタムタグが追加されます。

probability_threshold

  • 形式: float
  • デフォルト値: None
  • 説明: 予測確率をクラスラベルに変換するための閾値を設定します。 このパラメーターで明示的に定義されていない限り、全ての分類子の閾値のデフォルトは0.5です。 2値分類にのみ適用されます。

engine

  • 形式: Optional[Dict[str, str]]
  • デフォルト値: None
  • 説明: model_idのdict形式でモデルに使用する実行エンジン: engine - 例: 線形回帰の場合 (engine={"lr": "sklearnex"}と指定します。ユーザーはsklearnまたはsklearnexを切り換えて利用出来ます)

verbose

  • 形式: bool
  • デフォルト値: True
  • 説明: verbose = Flaseの場合、スコアグリッドは表示されません。

parallel

  • 形式: pycaret.internal.parallel.parallel_backend.ParallelBackend
  • デフォルト値: None
  • 説明: BarallelBackendのインスタンスとして実行します。例えば、SparkSessionセッションが有る場合、FugueBackend(session)を使用することで、Sparkを使用してこの関数を実行できます。 詳細については、FugueBackendを参照してください。

Reference

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3