17
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

GridSearchCVの評価指標にユーザ定義関数を使用する方法

Last updated at Posted at 2019-02-26

はじめに

関連の記事が少なく、公式の説明も自分には分かり辛かったので書いてみました。
公式: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring

手順概要

  1. ユーザ定義関数を定義する。
  2. sklearn.metrics.scorer.make_scorerにユーザ定義関数を渡してスコアラーを生成する(評価関数として使えるようにする)。
  3. 2.のmake_scorerをGridSearchCVのパラメータ「scoring」に設定する。

(ユーザ定義関数の内容に関して、今回は私のコードをそのまま貼りましたが、当然個々に寄ると思うので本記事のユーザ定義関数の内容を熟読してもあまり参考にはならないと思います。上記2点を満たしていれば問題ないです。)

ユーザ定義関数に関する注意点

下記2点を満たすように定義・実装すること。

  1. y_test(正解データ)とy_pred(推論結果データ)を引数として渡すことができる。
  2. 戻り値としてLoss(もしくは評価値)を返す。

make_scorerの設定に関する注意点

make_scorerのパラメータ「greater_is_better」の設定

  • ユーザ定義関数の戻り値を評価値にする場合: True
  • ユーザ定義関数の戻り値をLossにする場合: False

実装

「手順概要」の1.

ユーザ定義関数
def calc_score(y_train: np.array, y_pred: np.array):

  threshold = 1
  ignore = 10
  
  # 比較結果格納用配列の作成
  flg = np.zeros((y_train.shape[0],1))

  y_train = y_train.reshape([y_train.shape[0], 1])
  y_pred = y_pred.reshape([y_train.shape[0], 1])
  
  # 正解値と予測値の差分データを作成
  y_diff = y_train - y_pred

  for cnt_row in range(y_diff.shape[0]-DIFF):
      # 差分データの絶対値が閾値を超えていなければTrue
      if abs(y_diff[cnt_row]) < threshold:
        flg[cnt_row] = True
      else:
        flg[cnt_row] = False

  # 正解率を計算
  cnt_true = 0
  for cnt_row in range(flg.shape[0]):
    if flg[cnt_row] == 1:
      cnt_true += 1
  score = cnt_true / flg.shape[0]

  return score
main()
...

# 例
params = {'hidden_layer_sizes': [(100,), (100, 10), (100, 100, 10), (100, 100, 100, 10), (100, 100, 100, 100, 10),  (100, 100, 100, 100, 100, 10)],
             'max_iter': [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000], 
             'learning_rate_init': [0.001, 0.0001, 0.00001], 
             'early_stopping': [True, False], 
             'tol': [0.0001, 0.00001], 
             'batch_size': [100, 200, 300], 
             'verbose': [True], 
             'random_state': [42]}

# 例
mlpr = MLPRegressor()

from sklearn.metrics.scorer import make_scorer
# 「手順概要」の2.
my_scorer = make_scorer(calc_score, greater_is_better=True)
# 「手順概要」の3.
gs = GridSearchCV(mlpr, param_grid=params, cv=10, scoring=my_scorer)

...
17
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?