0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

cross_val_predictを予測精度評価に使ってはいけない(場合がある)

Last updated at Posted at 2022-08-29

結論

 クロスバリデーションにおいて、各分割のサンプル数が異なる場合
 cross_val_predictによる予測結果を基にr2やMAEなどを計算することには注意しましょう。

sklearn

scikit-learnのcross-validationに関するドキュメントには、以下の様なコメントやWarningがあります

The function cross_val_predict has a similar interface to cross_val_score, but returns, for each element in the input, the prediction that was obtained for that element when it was in the test set. Only cross-validation strategies that assign all elements to a test set exactly once can be used (otherwise, an exception is raised).

まず、 Only cross-validation strategies that assign all elements to a test set exactly once can be used (otherwise, an exception is raised). について

全てのデータが1回のみテストセットとして呼び出されるCV以外ではエラーになります。

例えば、RepeatedKFold, TimeSeriesSplitなどが該当します。

以下のコードを実行するとValueErrorが発生します。

from sklearn.datasets import load_diabetes
from sklearn.model_selection import cross_val_predict, cross_val_score
from sklearn.model_selection import RepeatedKFold
from sklearn.ensemble import RandomForestRegressor

X,y = load_diabetes(return_X_y=True)
model = RandomForestRegressor()

y_pred_cv = cross_val_predict(model,X,y , cv=RepeatedKFold(n_splits=5,n_repeats=3))

各データに対して複数回予測値を求めるRepeatedKFoldの様なCV手法には対応しない。ということです。
一方、cross_val_predictの代わりにcross_validate、cross_val_scoreではn_splits*n_repeatsの15回分のスコアが求まります。

Warning Note on inappropriate usage of cross_val_predict
The result of cross_val_predict may be different from those obtained using cross_val_score as the elements are grouped in different ways. The function cross_val_score takes an average over cross-validation folds, whereas cross_val_predict simply returns the labels (or probabilities) from several distinct models undistinguished. Thus, cross_val_predict is not an appropriate measure of generalization error.

さらに重要な話として、そもそもcross_val_predictで求めたy_pred_cvをスコア評価に使うのは適切ではないとWarningに記載されています。

これについては、以下のQ&Aが分かりやすいです。
Tutorial More

分かりやすくするために、平均絶対誤差や平均二乗誤差ではなく最大絶対誤差(max-abs-errors)で評価する場合を考えています。

3-Foldで評価した場合、それぞれの最大絶対誤差(a, b, c)が求まります。
cross_val_scoreではmean(a, b, c)を求めます。この値を構築した機械学習の最大絶対誤差とすることは妥当です。

一方で、cross_val_predictで求めたy_pred_cv からmax-abs-errorsを求めるとmax(a, b, c)であり、これをモデルの最大絶対誤差とするのは合理的ではありません。

ただし、KFoldやLOOにおいては各バッチセットのサイズが(ほぼ)等しいためcross_val_predictの結果から平均絶対誤差や平均二乗誤差などの予測精度を求めても結果的には問題ありません。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?